Browse Source

Add fast path for single value in VALUES aggregator (#130510)

This change introduces a fast path for the VALUES aggregator in the 
single-value case. For the first value seen in each group, we add it the
new big array without touching the hash. For subsequent values, if they
are the same as the current value, we skip them; if they differ, we
trigger the slow path and add them to the hash. This optimization speeds
up VALUES when the number of groups is large and most groups have only
one value.

Before:
```
Benchmark                      (dataType)  (groups)  Mode  Cnt      Score       Error  Units
ValuesAggregatorBenchmark.run    BytesRef         1  avgt    3    177.756 ±     2.111  ms/op
ValuesAggregatorBenchmark.run    BytesRef      1000  avgt    3    126.174 ±     0.431  ms/op
ValuesAggregatorBenchmark.run    BytesRef   1000000  avgt    3  66920.144 ± 53588.490  ms/op
```

After:
```
Benchmark                      (dataType)  (groups)  Mode  Cnt      Score      Error  Units
ValuesAggregatorBenchmark.run    BytesRef         1  avgt    3    180.269 ±    4.019  ms/op
ValuesAggregatorBenchmark.run    BytesRef      1000  avgt    3    107.051 ±    3.149  ms/op
ValuesAggregatorBenchmark.run    BytesRef   1000000  avgt    3  26277.863 ± 7214.319  ms/op
```
Nhat Nguyen 2 months ago
parent
commit
71957caaa1

+ 5 - 0
docs/changelog/130510.yaml

@@ -0,0 +1,5 @@
+pr: 130510
+summary: Add fast path for single value in VALUES aggregator
+area: ES|QL
+type: enhancement
+issues: []

+ 199 - 166
x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java

@@ -7,11 +7,15 @@
 
 package org.elasticsearch.compute.aggregation;
 
+// begin generated imports
 import org.apache.lucene.util.BytesRef;
 import org.apache.lucene.util.RamUsageEstimator;
 import org.elasticsearch.common.util.BigArrays;
+import org.elasticsearch.common.util.BitArray;
 import org.elasticsearch.common.util.BytesRefHash;
+import org.elasticsearch.common.util.LongHash;
 import org.elasticsearch.common.util.LongLongHash;
+import org.elasticsearch.common.util.IntArray;
 import org.elasticsearch.compute.aggregation.blockhash.BlockHash;
 import org.elasticsearch.compute.ann.Aggregator;
 import org.elasticsearch.compute.ann.GroupingAggregator;
@@ -19,13 +23,16 @@ import org.elasticsearch.compute.ann.IntermediateState;
 import org.elasticsearch.compute.data.Block;
 import org.elasticsearch.compute.data.BlockFactory;
 import org.elasticsearch.compute.data.BytesRefBlock;
+import org.elasticsearch.compute.data.BytesRefBlock;
 import org.elasticsearch.compute.data.BytesRefVector;
 import org.elasticsearch.compute.data.IntBlock;
 import org.elasticsearch.compute.data.IntVector;
+import org.elasticsearch.compute.data.LongBlock;
 import org.elasticsearch.compute.data.OrdinalBytesRefBlock;
 import org.elasticsearch.compute.operator.DriverContext;
 import org.elasticsearch.core.Releasable;
 import org.elasticsearch.core.Releasables;
+// end generated imports
 
 /**
  * Aggregates field values for BytesRef.
@@ -129,47 +136,146 @@ class ValuesBytesRefAggregator {
     }
 
     /**
-     * Values are collected in a hash. Iterating over them in order (row by row) to build the output,
-     * or merging with other state, can be expensive. To optimize this, we build a sorted structure once,
-     * and then use it to iterate over the values in order.
-     *
-     * @param ids positions of the {@link GroupingState#values} to read.
+     * Values after the first in each group are collected in a hash, keyed by the pair of groupId and value.
+     * When emitting the output, we need to iterate the hash one group at a time to build the output block,
+     * which would require O(N^2). To avoid this, we compute the counts for each group and remap the hash id
+     * to an array, allowing us to build the output in O(N) instead.
      */
-    private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable {
+    private static class NextValues implements Releasable {
+        private final BlockFactory blockFactory;
+        private final LongHash hashes;
+        private int[] selectedCounts = null;
+        private int[] ids = null;
+        private long extraMemoryUsed = 0;
+
+        private NextValues(BlockFactory blockFactory) {
+            this.blockFactory = blockFactory;
+            this.hashes = new LongHash(1, blockFactory.bigArrays());
+        }
+
+        void addValue(int groupId, int v) {
+            /*
+             * Encode the groupId and value into a single long -
+             * the top 32 bits for the group, the bottom 32 for the value.
+             */
+            hashes.add((((long) groupId) << Integer.SIZE) | (v & 0xFFFFFFFFL));
+        }
+
+        int getValue(int index) {
+            long both = hashes.get(ids[index]);
+            return (int) (both & 0xFFFFFFFFL);
+        }
+
+        private void reserveBytesForIntArray(long numElements) {
+            long adjust = RamUsageEstimator.alignObjectSize(RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + numElements * Integer.BYTES);
+            blockFactory.adjustBreaker(adjust);
+            extraMemoryUsed += adjust;
+        }
+
+        private void prepareForEmitting(IntVector selected) {
+            if (hashes.size() == 0) {
+                return;
+            }
+            /*
+             * Get a count of all groups less than the maximum selected group. Count
+             * *downwards* so that we can flip the sign on all of the actually selected
+             * groups. Negative values in this array are always unselected groups.
+             */
+            int selectedCountsLen = selected.max() + 1;
+            reserveBytesForIntArray(selectedCountsLen);
+            this.selectedCounts = new int[selectedCountsLen];
+            for (int id = 0; id < hashes.size(); id++) {
+                long both = hashes.get(id);
+                int group = (int) (both >>> Float.SIZE);
+                if (group < selectedCounts.length) {
+                    selectedCounts[group]--;
+                }
+            }
+
+            /*
+             * Total the selected groups and turn the counts into the start index into a sort-of
+             * off-by-one running count. It's really the number of values that have been inserted
+             * into the results before starting on this group. Unselected groups will still
+             * have negative counts.
+             *
+             * For example, if
+             * | Group | Value Count | Selected |
+             * |-------|-------------|----------|
+             * |     0 | 3           | <-       |
+             * |     1 | 1           | <-       |
+             * |     2 | 2           |          |
+             * |     3 | 1           | <-       |
+             * |     4 | 4           | <-       |
+             *
+             * Then the total is 9 and the counts array will contain 0, 3, -2, 4, 5
+             */
+            int total = 0;
+            for (int s = 0; s < selected.getPositionCount(); s++) {
+                int group = selected.getInt(s);
+                int count = -selectedCounts[group];
+                selectedCounts[group] = total;
+                total += count;
+            }
+
+            /*
+             * Build a list of ids to insert in order *and* convert the running
+             * count in selectedCounts[group] into the end index (exclusive) in
+             * ids for each group.
+             * Here we use the negative counts to signal that a group hasn't been
+             * selected and the id containing values for that group is ignored.
+             *
+             * For example, if
+             * | Group | Value Count | Selected |
+             * |-------|-------------|----------|
+             * |     0 | 3           | <-       |
+             * |     1 | 1           | <-       |
+             * |     2 | 2           |          |
+             * |     3 | 1           | <-       |
+             * |     4 | 4           | <-       |
+             *
+             * Then the total is 9 and the counts array will start with 0, 3, -2, 4, 5.
+             * The counts will end with 3, 4, -2, 5, 9.
+             */
+            reserveBytesForIntArray(total);
+
+            this.ids = new int[total];
+            for (int id = 0; id < hashes.size(); id++) {
+                long both = hashes.get(id);
+                int group = (int) (both >>> Float.SIZE);
+                ids[selectedCounts[group]++] = id;
+            }
+        }
+
         @Override
         public void close() {
-            releasable.close();
+            Releasables.closeExpectNoException(hashes, () -> blockFactory.adjustBreaker(-extraMemoryUsed));
         }
     }
 
     /**
      * State for a grouped {@code VALUES} aggregation. This implementation
-     * emphasizes collect-time performance over the performance of rendering
-     * results. That's good, but it's a pretty intensive emphasis, requiring
-     * an {@code O(n^2)} operation for collection to support a {@code O(1)}
-     * collector operation. But at least it's fairly simple.
+     * emphasizes collect-time performance over result rendering performance.
+     * The first value in each group is collected in the {@code firstValues}
+     * array, and subsequent values for each group are collected in {@code nextValues}.
      */
     public static class GroupingState implements GroupingAggregatorState {
-        private int maxGroupId = -1;
         private final BlockFactory blockFactory;
-        private final LongLongHash values;
         BytesRefHash bytes;
+        private IntArray firstValues;
+        private final NextValues nextValues;
 
         private GroupingState(DriverContext driverContext) {
             this.blockFactory = driverContext.blockFactory();
-            LongLongHash _values = null;
-            BytesRefHash _bytes = null;
+            boolean success = false;
             try {
-                _values = new LongLongHash(1, driverContext.bigArrays());
-                _bytes = new BytesRefHash(1, driverContext.bigArrays());
-
-                values = _values;
-                bytes = _bytes;
-
-                _values = null;
-                _bytes = null;
+                this.bytes = new BytesRefHash(1, driverContext.bigArrays());
+                this.firstValues = driverContext.bigArrays().newIntArray(1, true);
+                this.nextValues = new NextValues(driverContext.blockFactory());
+                success = true;
             } finally {
-                Releasables.closeExpectNoException(_values, _bytes);
+                if (success == false) {
+                    this.close();
+                }
             }
         }
 
@@ -178,14 +284,28 @@ class ValuesBytesRefAggregator {
             blocks[offset] = toBlock(driverContext.blockFactory(), selected);
         }
 
-        void addValueOrdinal(int groupId, long valueOrdinal) {
-            values.add(groupId, valueOrdinal);
-            maxGroupId = Math.max(maxGroupId, groupId);
+        void addValueOrdinal(int groupId, int valueOrdinal) {
+            if (groupId < firstValues.size()) {
+                int current = firstValues.get(groupId) - 1;
+                if (current < 0) {
+                    firstValues.set(groupId, valueOrdinal + 1);
+                } else if (current != valueOrdinal) {
+                    nextValues.addValue(groupId, valueOrdinal);
+                }
+            } else {
+                firstValues = blockFactory.bigArrays().grow(firstValues, groupId + 1);
+                firstValues.set(groupId, valueOrdinal + 1);
+            }
         }
 
         void addValue(int groupId, BytesRef v) {
-            values.add(groupId, BlockHash.hashOrdToGroup(bytes.add(v)));
-            maxGroupId = Math.max(maxGroupId, groupId);
+            int valueOrdinal = Math.toIntExact(BlockHash.hashOrdToGroup(bytes.add(v)));
+            addValueOrdinal(groupId, valueOrdinal);
+        }
+
+        @Override
+        public void enableGroupIdTracking(SeenGroupIds seen) {
+            // we figure out seen values from firstValues since ordinals are non-negative
         }
 
         /**
@@ -193,159 +313,81 @@ class ValuesBytesRefAggregator {
          * groups. This is the implementation of the final and intermediate results of the agg.
          */
         Block toBlock(BlockFactory blockFactory, IntVector selected) {
-            if (values.size() == 0) {
-                return blockFactory.newConstantNullBlock(selected.getPositionCount());
-            }
-
-            try (var sorted = buildSorted(selected)) {
-                if (OrdinalBytesRefBlock.isDense(values.size(), bytes.size())) {
-                    return buildOrdinalOutputBlock(blockFactory, selected, sorted.counts, sorted.ids);
-                } else {
-                    return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids);
-                }
+            nextValues.prepareForEmitting(selected);
+            if (OrdinalBytesRefBlock.isDense(firstValues.size() + nextValues.hashes.size(), bytes.size())) {
+                return buildOrdinalOutputBlock(blockFactory, selected);
+            } else {
+                return buildOutputBlock(blockFactory, selected);
             }
         }
 
-        private Sorted buildSorted(IntVector selected) {
-            long selectedCountsSize = 0;
-            long idsSize = 0;
-            Sorted sorted = null;
-            try {
-                /*
-                 * Get a count of all groups less than the maximum selected group. Count
-                 * *downwards* so that we can flip the sign on all of the actually selected
-                 * groups. Negative values in this array are always unselected groups.
-                 */
-                int selectedCountsLen = selected.max() + 1;
-                long adjust = RamUsageEstimator.alignObjectSize(
-                    RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + selectedCountsLen * Integer.BYTES
-                );
-                blockFactory.adjustBreaker(adjust);
-                selectedCountsSize = adjust;
-                int[] selectedCounts = new int[selectedCountsLen];
-                for (int id = 0; id < values.size(); id++) {
-                    int group = (int) values.getKey1(id);
-                    if (group < selectedCounts.length) {
-                        selectedCounts[group]--;
-                    }
-                }
-
-                /*
-                 * Total the selected groups and turn the counts into the start index into a sort-of
-                 * off-by-one running count. It's really the number of values that have been inserted
-                 * into the results before starting on this group. Unselected groups will still
-                 * have negative counts.
-                 *
-                 * For example, if
-                 * | Group | Value Count | Selected |
-                 * |-------|-------------|----------|
-                 * |     0 | 3           | <-       |
-                 * |     1 | 1           | <-       |
-                 * |     2 | 2           |          |
-                 * |     3 | 1           | <-       |
-                 * |     4 | 4           | <-       |
-                 *
-                 * Then the total is 9 and the counts array will contain 0, 3, -2, 4, 5
-                 */
-                int total = 0;
-                for (int s = 0; s < selected.getPositionCount(); s++) {
-                    int group = selected.getInt(s);
-                    int count = -selectedCounts[group];
-                    selectedCounts[group] = total;
-                    total += count;
-                }
-
-                /*
-                 * Build a list of ids to insert in order *and* convert the running
-                 * count in selectedCounts[group] into the end index (exclusive) in
-                 * ids for each group.
-                 * Here we use the negative counts to signal that a group hasn't been
-                 * selected and the id containing values for that group is ignored.
-                 *
-                 * For example, if
-                 * | Group | Value Count | Selected |
-                 * |-------|-------------|----------|
-                 * |     0 | 3           | <-       |
-                 * |     1 | 1           | <-       |
-                 * |     2 | 2           |          |
-                 * |     3 | 1           | <-       |
-                 * |     4 | 4           | <-       |
-                 *
-                 * Then the total is 9 and the counts array will start with 0, 3, -2, 4, 5.
-                 * The counts will end with 3, 4, -2, 5, 9.
-                 */
-                adjust = RamUsageEstimator.alignObjectSize(RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + total * Integer.BYTES);
-                blockFactory.adjustBreaker(adjust);
-                idsSize = adjust;
-                int[] ids = new int[total];
-                for (int id = 0; id < values.size(); id++) {
-                    int group = (int) values.getKey1(id);
-                    if (group < selectedCounts.length && selectedCounts[group] >= 0) {
-                        ids[selectedCounts[group]++] = id;
-                    }
-                }
-                final long totalMemoryUsed = selectedCountsSize + idsSize;
-                sorted = new Sorted(() -> blockFactory.adjustBreaker(-totalMemoryUsed), selectedCounts, ids);
-                return sorted;
-            } finally {
-                if (sorted == null) {
-                    blockFactory.adjustBreaker(-selectedCountsSize - idsSize);
-                }
-            }
-        }
-
-        Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) {
+        Block buildOutputBlock(BlockFactory blockFactory, IntVector selected) {
             /*
              * Insert the ids in order.
              */
             BytesRef scratch = new BytesRef();
+            final int[] nextValueCounts = nextValues.selectedCounts;
             try (BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(selected.getPositionCount())) {
-                int start = 0;
+                int nextValuesStart = 0;
                 for (int s = 0; s < selected.getPositionCount(); s++) {
                     int group = selected.getInt(s);
-                    int end = selectedCounts[group];
-                    int count = end - start;
-                    switch (count) {
-                        case 0 -> builder.appendNull();
-                        case 1 -> builder.appendBytesRef(getValue(ids[start], scratch));
-                        default -> {
-                            builder.beginPositionEntry();
-                            for (int i = start; i < end; i++) {
-                                builder.appendBytesRef(getValue(ids[i], scratch));
-                            }
-                            builder.endPositionEntry();
+                    int firstValue = group >= firstValues.size() ? -1 : firstValues.get(group) - 1;
+                    if (firstValue < 0) {
+                        builder.appendNull();
+                        continue;
+                    }
+                    final int nextValuesEnd = nextValueCounts != null ? nextValueCounts[group] : nextValuesStart;
+                    if (nextValuesEnd == nextValuesStart) {
+                        builder.appendBytesRef(bytes.get(firstValue, scratch));
+                    } else {
+                        builder.beginPositionEntry();
+                        builder.appendBytesRef(bytes.get(firstValue, scratch));
+                        // append values from the nextValues
+                        for (int i = nextValuesStart; i < nextValuesEnd; i++) {
+                            var nextValue = nextValues.getValue(i);
+                            builder.appendBytesRef(bytes.get(nextValue, scratch));
                         }
+                        builder.endPositionEntry();
+                        nextValuesStart = nextValuesEnd;
                     }
-                    start = end;
                 }
                 return builder.build();
             }
         }
 
-        Block buildOrdinalOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) {
+        Block buildOrdinalOutputBlock(BlockFactory blockFactory, IntVector selected) {
             BytesRefVector dict = null;
             IntBlock ordinals = null;
             BytesRefBlock result = null;
             var dictArray = bytes.takeBytesRefsOwnership();
             bytes = null; // transfer ownership to dictArray
-            try (var builder = blockFactory.newIntBlockBuilder(selected.getPositionCount())) {
-                int start = 0;
+            int estimateSize = Math.toIntExact(firstValues.size() + nextValues.hashes.size());
+            final int[] nextValueCounts = nextValues.selectedCounts;
+            try (var builder = blockFactory.newIntBlockBuilder(estimateSize)) {
+                int nextValuesStart = 0;
                 for (int s = 0; s < selected.getPositionCount(); s++) {
                     int group = selected.getInt(s);
-                    int end = selectedCounts[group];
-                    int count = end - start;
-                    switch (count) {
-                        case 0 -> builder.appendNull();
-                        case 1 -> builder.appendInt(Math.toIntExact(values.getKey2(ids[start])));
-                        default -> {
-                            builder.beginPositionEntry();
-                            for (int i = start; i < end; i++) {
-                                builder.appendInt(Math.toIntExact(values.getKey2(ids[i])));
-                            }
-                            builder.endPositionEntry();
+                    if (firstValues.size() < group) {
+                        builder.appendNull();
+                        continue;
+                    }
+                    int firstValue = firstValues.get(group) - 1;
+                    if (firstValue < 0) {
+                        builder.appendNull();
+                        continue;
+                    }
+                    final int nextValuesEnd = nextValueCounts != null ? nextValueCounts[group] : nextValuesStart;
+                    if (nextValuesEnd == nextValuesStart) {
+                        builder.appendInt(firstValue);
+                    } else {
+                        builder.beginPositionEntry();
+                        builder.appendInt(firstValue);
+                        for (int i = nextValuesStart; i < nextValuesEnd; i++) {
+                            builder.appendInt(nextValues.getValue(i));
                         }
+                        builder.endPositionEntry();
                     }
-                    start = end;
+                    nextValuesStart = nextValuesEnd;
                 }
                 ordinals = builder.build();
                 dict = blockFactory.newBytesRefArrayVector(dictArray, Math.toIntExact(dictArray.size()));
@@ -359,18 +401,9 @@ class ValuesBytesRefAggregator {
             }
         }
 
-        BytesRef getValue(int valueId, BytesRef scratch) {
-            return bytes.get(values.getKey2(valueId), scratch);
-        }
-
-        @Override
-        public void enableGroupIdTracking(SeenGroupIds seen) {
-            // we figure out seen values from nulls on the values block
-        }
-
         @Override
         public void close() {
-            Releasables.closeExpectNoException(values, bytes);
+            Releasables.closeExpectNoException(bytes, firstValues, nextValues);
         }
     }
 }

+ 192 - 135
x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java

@@ -7,20 +7,32 @@
 
 package org.elasticsearch.compute.aggregation;
 
+// begin generated imports
+import org.apache.lucene.util.BytesRef;
 import org.apache.lucene.util.RamUsageEstimator;
 import org.elasticsearch.common.util.BigArrays;
+import org.elasticsearch.common.util.BitArray;
+import org.elasticsearch.common.util.BytesRefHash;
 import org.elasticsearch.common.util.LongHash;
 import org.elasticsearch.common.util.LongLongHash;
+import org.elasticsearch.common.util.DoubleArray;
+import org.elasticsearch.compute.aggregation.blockhash.BlockHash;
 import org.elasticsearch.compute.ann.Aggregator;
 import org.elasticsearch.compute.ann.GroupingAggregator;
 import org.elasticsearch.compute.ann.IntermediateState;
 import org.elasticsearch.compute.data.Block;
 import org.elasticsearch.compute.data.BlockFactory;
 import org.elasticsearch.compute.data.DoubleBlock;
+import org.elasticsearch.compute.data.BytesRefBlock;
+import org.elasticsearch.compute.data.BytesRefVector;
+import org.elasticsearch.compute.data.IntBlock;
 import org.elasticsearch.compute.data.IntVector;
+import org.elasticsearch.compute.data.LongBlock;
+import org.elasticsearch.compute.data.OrdinalBytesRefBlock;
 import org.elasticsearch.compute.operator.DriverContext;
 import org.elasticsearch.core.Releasable;
 import org.elasticsearch.core.Releasables;
+// end generated imports
 
 /**
  * Aggregates field values for double.
@@ -106,34 +118,140 @@ class ValuesDoubleAggregator {
     }
 
     /**
-     * Values are collected in a hash. Iterating over them in order (row by row) to build the output,
-     * or merging with other state, can be expensive. To optimize this, we build a sorted structure once,
-     * and then use it to iterate over the values in order.
-     *
-     * @param ids positions of the {@link GroupingState#values} to read.
+     * Values after the first in each group are collected in a hash, keyed by the pair of groupId and value.
+     * When emitting the output, we need to iterate the hash one group at a time to build the output block,
+     * which would require O(N^2). To avoid this, we compute the counts for each group and remap the hash id
+     * to an array, allowing us to build the output in O(N) instead.
      */
-    private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable {
+    private static class NextValues implements Releasable {
+        private final BlockFactory blockFactory;
+        private final LongLongHash hashes;
+        private int[] selectedCounts = null;
+        private int[] ids = null;
+        private long extraMemoryUsed = 0;
+
+        private NextValues(BlockFactory blockFactory) {
+            this.blockFactory = blockFactory;
+            this.hashes = new LongLongHash(1, blockFactory.bigArrays());
+        }
+
+        void addValue(int groupId, double v) {
+            hashes.add(groupId, Double.doubleToLongBits(v));
+        }
+
+        double getValue(int index) {
+            return Double.longBitsToDouble(hashes.getKey2(ids[index]));
+        }
+
+        private void reserveBytesForIntArray(long numElements) {
+            long adjust = RamUsageEstimator.alignObjectSize(RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + numElements * Integer.BYTES);
+            blockFactory.adjustBreaker(adjust);
+            extraMemoryUsed += adjust;
+        }
+
+        private void prepareForEmitting(IntVector selected) {
+            if (hashes.size() == 0) {
+                return;
+            }
+            /*
+             * Get a count of all groups less than the maximum selected group. Count
+             * *downwards* so that we can flip the sign on all of the actually selected
+             * groups. Negative values in this array are always unselected groups.
+             */
+            int selectedCountsLen = selected.max() + 1;
+            reserveBytesForIntArray(selectedCountsLen);
+            this.selectedCounts = new int[selectedCountsLen];
+            for (int id = 0; id < hashes.size(); id++) {
+                int group = (int) hashes.getKey1(id);
+                if (group < selectedCounts.length) {
+                    selectedCounts[group]--;
+                }
+            }
+
+            /*
+             * Total the selected groups and turn the counts into the start index into a sort-of
+             * off-by-one running count. It's really the number of values that have been inserted
+             * into the results before starting on this group. Unselected groups will still
+             * have negative counts.
+             *
+             * For example, if
+             * | Group | Value Count | Selected |
+             * |-------|-------------|----------|
+             * |     0 | 3           | <-       |
+             * |     1 | 1           | <-       |
+             * |     2 | 2           |          |
+             * |     3 | 1           | <-       |
+             * |     4 | 4           | <-       |
+             *
+             * Then the total is 9 and the counts array will contain 0, 3, -2, 4, 5
+             */
+            int total = 0;
+            for (int s = 0; s < selected.getPositionCount(); s++) {
+                int group = selected.getInt(s);
+                int count = -selectedCounts[group];
+                selectedCounts[group] = total;
+                total += count;
+            }
+
+            /*
+             * Build a list of ids to insert in order *and* convert the running
+             * count in selectedCounts[group] into the end index (exclusive) in
+             * ids for each group.
+             * Here we use the negative counts to signal that a group hasn't been
+             * selected and the id containing values for that group is ignored.
+             *
+             * For example, if
+             * | Group | Value Count | Selected |
+             * |-------|-------------|----------|
+             * |     0 | 3           | <-       |
+             * |     1 | 1           | <-       |
+             * |     2 | 2           |          |
+             * |     3 | 1           | <-       |
+             * |     4 | 4           | <-       |
+             *
+             * Then the total is 9 and the counts array will start with 0, 3, -2, 4, 5.
+             * The counts will end with 3, 4, -2, 5, 9.
+             */
+            reserveBytesForIntArray(total);
+
+            this.ids = new int[total];
+            for (int id = 0; id < hashes.size(); id++) {
+                int group = (int) hashes.getKey1(id);
+                ids[selectedCounts[group]++] = id;
+            }
+        }
+
         @Override
         public void close() {
-            releasable.close();
+            Releasables.closeExpectNoException(hashes, () -> blockFactory.adjustBreaker(-extraMemoryUsed));
         }
     }
 
     /**
      * State for a grouped {@code VALUES} aggregation. This implementation
-     * emphasizes collect-time performance over the performance of rendering
-     * results. That's good, but it's a pretty intensive emphasis, requiring
-     * an {@code O(n^2)} operation for collection to support a {@code O(1)}
-     * collector operation. But at least it's fairly simple.
+     * emphasizes collect-time performance over result rendering performance.
+     * The first value in each group is collected in the {@code firstValues}
+     * array, and subsequent values for each group are collected in {@code nextValues}.
      */
     public static class GroupingState implements GroupingAggregatorState {
-        private int maxGroupId = -1;
         private final BlockFactory blockFactory;
-        private final LongLongHash values;
+        DoubleArray firstValues;
+        private BitArray seen;
+        private int maxGroupId = -1;
+        private final NextValues nextValues;
 
         private GroupingState(DriverContext driverContext) {
             this.blockFactory = driverContext.blockFactory();
-            values = new LongLongHash(1, driverContext.bigArrays());
+            boolean success = false;
+            try {
+                this.firstValues = driverContext.bigArrays().newDoubleArray(1, false);
+                this.nextValues = new NextValues(driverContext.blockFactory());
+                success = true;
+            } finally {
+                if (success == false) {
+                    this.close();
+                }
+            }
         }
 
         @Override
@@ -142,151 +260,90 @@ class ValuesDoubleAggregator {
         }
 
         void addValue(int groupId, double v) {
-            values.add(groupId, Double.doubleToLongBits(v));
-            maxGroupId = Math.max(maxGroupId, groupId);
+            if (groupId > maxGroupId) {
+                firstValues = blockFactory.bigArrays().grow(firstValues, groupId + 1);
+                firstValues.set(groupId, v);
+                // We start in untracked mode, assuming every group has a value as an optimization to avoid allocating
+                // and updating the seen bitset. However, once some groups don't have values, we initialize the seen bitset,
+                // fill the groups that have values, and begin tracking incoming groups.
+                if (seen == null && groupId > maxGroupId + 1) {
+                    seen = new BitArray(groupId + 1, blockFactory.bigArrays());
+                    seen.fill(0, maxGroupId + 1, true);
+                }
+                trackGroupId(groupId);
+                maxGroupId = groupId;
+            } else if (hasValue(groupId) == false) {
+                firstValues.set(groupId, v);
+                trackGroupId(groupId);
+            } else if (firstValues.get(groupId) != v) {
+                nextValues.addValue(groupId, v);
+            }
         }
 
-        /**
-         * Builds a {@link Block} with the unique values collected for the {@code #selected}
-         * groups. This is the implementation of the final and intermediate results of the agg.
-         */
-        Block toBlock(BlockFactory blockFactory, IntVector selected) {
-            if (values.size() == 0) {
-                return blockFactory.newConstantNullBlock(selected.getPositionCount());
-            }
+        @Override
+        public void enableGroupIdTracking(SeenGroupIds seen) {
+            // we track the seen values manually
+        }
 
-            try (var sorted = buildSorted(selected)) {
-                return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids);
+        private void trackGroupId(int groupId) {
+            if (seen != null) {
+                seen.set(groupId);
             }
         }
 
-        private Sorted buildSorted(IntVector selected) {
-            long selectedCountsSize = 0;
-            long idsSize = 0;
-            Sorted sorted = null;
-            try {
-                /*
-                 * Get a count of all groups less than the maximum selected group. Count
-                 * *downwards* so that we can flip the sign on all of the actually selected
-                 * groups. Negative values in this array are always unselected groups.
-                 */
-                int selectedCountsLen = selected.max() + 1;
-                long adjust = RamUsageEstimator.alignObjectSize(
-                    RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + selectedCountsLen * Integer.BYTES
-                );
-                blockFactory.adjustBreaker(adjust);
-                selectedCountsSize = adjust;
-                int[] selectedCounts = new int[selectedCountsLen];
-                for (int id = 0; id < values.size(); id++) {
-                    int group = (int) values.getKey1(id);
-                    if (group < selectedCounts.length) {
-                        selectedCounts[group]--;
-                    }
-                }
-
-                /*
-                 * Total the selected groups and turn the counts into the start index into a sort-of
-                 * off-by-one running count. It's really the number of values that have been inserted
-                 * into the results before starting on this group. Unselected groups will still
-                 * have negative counts.
-                 *
-                 * For example, if
-                 * | Group | Value Count | Selected |
-                 * |-------|-------------|----------|
-                 * |     0 | 3           | <-       |
-                 * |     1 | 1           | <-       |
-                 * |     2 | 2           |          |
-                 * |     3 | 1           | <-       |
-                 * |     4 | 4           | <-       |
-                 *
-                 * Then the total is 9 and the counts array will contain 0, 3, -2, 4, 5
-                 */
-                int total = 0;
-                for (int s = 0; s < selected.getPositionCount(); s++) {
-                    int group = selected.getInt(s);
-                    int count = -selectedCounts[group];
-                    selectedCounts[group] = total;
-                    total += count;
-                }
+        /**
+         * Returns true if the group has a value in firstValues; having a value in nextValues is optional.
+         * Returns false if the group does not have values in either firstValues or nextValues.
+         */
+        private boolean hasValue(int groupId) {
+            return seen == null || seen.get(groupId);
+        }
 
-                /*
-                 * Build a list of ids to insert in order *and* convert the running
-                 * count in selectedCounts[group] into the end index (exclusive) in
-                 * ids for each group.
-                 * Here we use the negative counts to signal that a group hasn't been
-                 * selected and the id containing values for that group is ignored.
-                 *
-                 * For example, if
-                 * | Group | Value Count | Selected |
-                 * |-------|-------------|----------|
-                 * |     0 | 3           | <-       |
-                 * |     1 | 1           | <-       |
-                 * |     2 | 2           |          |
-                 * |     3 | 1           | <-       |
-                 * |     4 | 4           | <-       |
-                 *
-                 * Then the total is 9 and the counts array will start with 0, 3, -2, 4, 5.
-                 * The counts will end with 3, 4, -2, 5, 9.
-                 */
-                adjust = RamUsageEstimator.alignObjectSize(RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + total * Integer.BYTES);
-                blockFactory.adjustBreaker(adjust);
-                idsSize = adjust;
-                int[] ids = new int[total];
-                for (int id = 0; id < values.size(); id++) {
-                    int group = (int) values.getKey1(id);
-                    if (group < selectedCounts.length && selectedCounts[group] >= 0) {
-                        ids[selectedCounts[group]++] = id;
-                    }
-                }
-                final long totalMemoryUsed = selectedCountsSize + idsSize;
-                sorted = new Sorted(() -> blockFactory.adjustBreaker(-totalMemoryUsed), selectedCounts, ids);
-                return sorted;
-            } finally {
-                if (sorted == null) {
-                    blockFactory.adjustBreaker(-selectedCountsSize - idsSize);
-                }
-            }
+        /**
+         * Builds a {@link Block} with the unique values collected for the {@code #selected}
+         * groups. This is the implementation of the final and intermediate results of the agg.
+         */
+        Block toBlock(BlockFactory blockFactory, IntVector selected) {
+            nextValues.prepareForEmitting(selected);
+            return buildOutputBlock(blockFactory, selected);
         }
 
-        Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) {
+        Block buildOutputBlock(BlockFactory blockFactory, IntVector selected) {
             /*
              * Insert the ids in order.
              */
+            final int[] nextValueCounts = nextValues.selectedCounts;
             try (DoubleBlock.Builder builder = blockFactory.newDoubleBlockBuilder(selected.getPositionCount())) {
-                int start = 0;
+                int nextValuesStart = 0;
                 for (int s = 0; s < selected.getPositionCount(); s++) {
                     int group = selected.getInt(s);
-                    int end = selectedCounts[group];
-                    int count = end - start;
-                    switch (count) {
-                        case 0 -> builder.appendNull();
-                        case 1 -> builder.appendDouble(getValue(ids[start]));
-                        default -> {
-                            builder.beginPositionEntry();
-                            for (int i = start; i < end; i++) {
-                                builder.appendDouble(getValue(ids[i]));
-                            }
-                            builder.endPositionEntry();
+                    if (group > maxGroupId || hasValue(group) == false) {
+                        builder.appendNull();
+                        continue;
+                    }
+                    double firstValue = firstValues.get(group);
+                    final int nextValuesEnd = nextValueCounts != null ? nextValueCounts[group] : nextValuesStart;
+                    if (nextValuesEnd == nextValuesStart) {
+                        builder.appendDouble(firstValue);
+                    } else {
+                        builder.beginPositionEntry();
+                        builder.appendDouble(firstValue);
+                        // append values from the nextValues
+                        for (int i = nextValuesStart; i < nextValuesEnd; i++) {
+                            var nextValue = nextValues.getValue(i);
+                            builder.appendDouble(nextValue);
                         }
+                        builder.endPositionEntry();
+                        nextValuesStart = nextValuesEnd;
                     }
-                    start = end;
                 }
                 return builder.build();
             }
         }
 
-        double getValue(int valueId) {
-            return Double.longBitsToDouble(values.getKey2(valueId));
-        }
-
-        @Override
-        public void enableGroupIdTracking(SeenGroupIds seen) {
-            // we figure out seen values from nulls on the values block
-        }
-
         @Override
         public void close() {
-            Releasables.closeExpectNoException(values);
+            Releasables.closeExpectNoException(seen, firstValues, nextValues);
         }
     }
 }

+ 200 - 142
x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java

@@ -7,19 +7,32 @@
 
 package org.elasticsearch.compute.aggregation;
 
+// begin generated imports
+import org.apache.lucene.util.BytesRef;
 import org.apache.lucene.util.RamUsageEstimator;
 import org.elasticsearch.common.util.BigArrays;
+import org.elasticsearch.common.util.BitArray;
+import org.elasticsearch.common.util.BytesRefHash;
 import org.elasticsearch.common.util.LongHash;
+import org.elasticsearch.common.util.LongLongHash;
+import org.elasticsearch.common.util.FloatArray;
+import org.elasticsearch.compute.aggregation.blockhash.BlockHash;
 import org.elasticsearch.compute.ann.Aggregator;
 import org.elasticsearch.compute.ann.GroupingAggregator;
 import org.elasticsearch.compute.ann.IntermediateState;
 import org.elasticsearch.compute.data.Block;
 import org.elasticsearch.compute.data.BlockFactory;
 import org.elasticsearch.compute.data.FloatBlock;
+import org.elasticsearch.compute.data.BytesRefBlock;
+import org.elasticsearch.compute.data.BytesRefVector;
+import org.elasticsearch.compute.data.IntBlock;
 import org.elasticsearch.compute.data.IntVector;
+import org.elasticsearch.compute.data.LongBlock;
+import org.elasticsearch.compute.data.OrdinalBytesRefBlock;
 import org.elasticsearch.compute.operator.DriverContext;
 import org.elasticsearch.core.Releasable;
 import org.elasticsearch.core.Releasables;
+// end generated imports
 
 /**
  * Aggregates field values for float.
@@ -105,34 +118,147 @@ class ValuesFloatAggregator {
     }
 
     /**
-     * Values are collected in a hash. Iterating over them in order (row by row) to build the output,
-     * or merging with other state, can be expensive. To optimize this, we build a sorted structure once,
-     * and then use it to iterate over the values in order.
-     *
-     * @param ids positions of the {@link GroupingState#values} to read.
+     * Values after the first in each group are collected in a hash, keyed by the pair of groupId and value.
+     * When emitting the output, we need to iterate the hash one group at a time to build the output block,
+     * which would require O(N^2). To avoid this, we compute the counts for each group and remap the hash id
+     * to an array, allowing us to build the output in O(N) instead.
      */
-    private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable {
+    private static class NextValues implements Releasable {
+        private final BlockFactory blockFactory;
+        private final LongHash hashes;
+        private int[] selectedCounts = null;
+        private int[] ids = null;
+        private long extraMemoryUsed = 0;
+
+        private NextValues(BlockFactory blockFactory) {
+            this.blockFactory = blockFactory;
+            this.hashes = new LongHash(1, blockFactory.bigArrays());
+        }
+
+        void addValue(int groupId, float v) {
+            /*
+             * Encode the groupId and value into a single long -
+             * the top 32 bits for the group, the bottom 32 for the value.
+             */
+            hashes.add((((long) groupId) << Float.SIZE) | (Float.floatToIntBits(v) & 0xFFFFFFFFL));
+        }
+
+        float getValue(int index) {
+            long both = hashes.get(ids[index]);
+            return Float.intBitsToFloat((int) (both & 0xFFFFFFFFL));
+        }
+
+        private void reserveBytesForIntArray(long numElements) {
+            long adjust = RamUsageEstimator.alignObjectSize(RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + numElements * Integer.BYTES);
+            blockFactory.adjustBreaker(adjust);
+            extraMemoryUsed += adjust;
+        }
+
+        private void prepareForEmitting(IntVector selected) {
+            if (hashes.size() == 0) {
+                return;
+            }
+            /*
+             * Get a count of all groups less than the maximum selected group. Count
+             * *downwards* so that we can flip the sign on all of the actually selected
+             * groups. Negative values in this array are always unselected groups.
+             */
+            int selectedCountsLen = selected.max() + 1;
+            reserveBytesForIntArray(selectedCountsLen);
+            this.selectedCounts = new int[selectedCountsLen];
+            for (int id = 0; id < hashes.size(); id++) {
+                long both = hashes.get(id);
+                int group = (int) (both >>> Float.SIZE);
+                if (group < selectedCounts.length) {
+                    selectedCounts[group]--;
+                }
+            }
+
+            /*
+             * Total the selected groups and turn the counts into the start index into a sort-of
+             * off-by-one running count. It's really the number of values that have been inserted
+             * into the results before starting on this group. Unselected groups will still
+             * have negative counts.
+             *
+             * For example, if
+             * | Group | Value Count | Selected |
+             * |-------|-------------|----------|
+             * |     0 | 3           | <-       |
+             * |     1 | 1           | <-       |
+             * |     2 | 2           |          |
+             * |     3 | 1           | <-       |
+             * |     4 | 4           | <-       |
+             *
+             * Then the total is 9 and the counts array will contain 0, 3, -2, 4, 5
+             */
+            int total = 0;
+            for (int s = 0; s < selected.getPositionCount(); s++) {
+                int group = selected.getInt(s);
+                int count = -selectedCounts[group];
+                selectedCounts[group] = total;
+                total += count;
+            }
+
+            /*
+             * Build a list of ids to insert in order *and* convert the running
+             * count in selectedCounts[group] into the end index (exclusive) in
+             * ids for each group.
+             * Here we use the negative counts to signal that a group hasn't been
+             * selected and the id containing values for that group is ignored.
+             *
+             * For example, if
+             * | Group | Value Count | Selected |
+             * |-------|-------------|----------|
+             * |     0 | 3           | <-       |
+             * |     1 | 1           | <-       |
+             * |     2 | 2           |          |
+             * |     3 | 1           | <-       |
+             * |     4 | 4           | <-       |
+             *
+             * Then the total is 9 and the counts array will start with 0, 3, -2, 4, 5.
+             * The counts will end with 3, 4, -2, 5, 9.
+             */
+            reserveBytesForIntArray(total);
+
+            this.ids = new int[total];
+            for (int id = 0; id < hashes.size(); id++) {
+                long both = hashes.get(id);
+                int group = (int) (both >>> Float.SIZE);
+                ids[selectedCounts[group]++] = id;
+            }
+        }
+
         @Override
         public void close() {
-            releasable.close();
+            Releasables.closeExpectNoException(hashes, () -> blockFactory.adjustBreaker(-extraMemoryUsed));
         }
     }
 
     /**
      * State for a grouped {@code VALUES} aggregation. This implementation
-     * emphasizes collect-time performance over the performance of rendering
-     * results. That's good, but it's a pretty intensive emphasis, requiring
-     * an {@code O(n^2)} operation for collection to support a {@code O(1)}
-     * collector operation. But at least it's fairly simple.
+     * emphasizes collect-time performance over result rendering performance.
+     * The first value in each group is collected in the {@code firstValues}
+     * array, and subsequent values for each group are collected in {@code nextValues}.
      */
     public static class GroupingState implements GroupingAggregatorState {
-        private int maxGroupId = -1;
         private final BlockFactory blockFactory;
-        private final LongHash values;
+        FloatArray firstValues;
+        private BitArray seen;
+        private int maxGroupId = -1;
+        private final NextValues nextValues;
 
         private GroupingState(DriverContext driverContext) {
             this.blockFactory = driverContext.blockFactory();
-            values = new LongHash(1, driverContext.bigArrays());
+            boolean success = false;
+            try {
+                this.firstValues = driverContext.bigArrays().newFloatArray(1, false);
+                this.nextValues = new NextValues(driverContext.blockFactory());
+                success = true;
+            } finally {
+                if (success == false) {
+                    this.close();
+                }
+            }
         }
 
         @Override
@@ -141,158 +267,90 @@ class ValuesFloatAggregator {
         }
 
         void addValue(int groupId, float v) {
-            /*
-             * Encode the groupId and value into a single long -
-             * the top 32 bits for the group, the bottom 32 for the value.
-             */
-            values.add((((long) groupId) << Float.SIZE) | (Float.floatToIntBits(v) & 0xFFFFFFFFL));
-            maxGroupId = Math.max(maxGroupId, groupId);
+            if (groupId > maxGroupId) {
+                firstValues = blockFactory.bigArrays().grow(firstValues, groupId + 1);
+                firstValues.set(groupId, v);
+                // We start in untracked mode, assuming every group has a value as an optimization to avoid allocating
+                // and updating the seen bitset. However, once some groups don't have values, we initialize the seen bitset,
+                // fill the groups that have values, and begin tracking incoming groups.
+                if (seen == null && groupId > maxGroupId + 1) {
+                    seen = new BitArray(groupId + 1, blockFactory.bigArrays());
+                    seen.fill(0, maxGroupId + 1, true);
+                }
+                trackGroupId(groupId);
+                maxGroupId = groupId;
+            } else if (hasValue(groupId) == false) {
+                firstValues.set(groupId, v);
+                trackGroupId(groupId);
+            } else if (firstValues.get(groupId) != v) {
+                nextValues.addValue(groupId, v);
+            }
         }
 
-        /**
-         * Builds a {@link Block} with the unique values collected for the {@code #selected}
-         * groups. This is the implementation of the final and intermediate results of the agg.
-         */
-        Block toBlock(BlockFactory blockFactory, IntVector selected) {
-            if (values.size() == 0) {
-                return blockFactory.newConstantNullBlock(selected.getPositionCount());
-            }
+        @Override
+        public void enableGroupIdTracking(SeenGroupIds seen) {
+            // we track the seen values manually
+        }
 
-            try (var sorted = buildSorted(selected)) {
-                return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids);
+        private void trackGroupId(int groupId) {
+            if (seen != null) {
+                seen.set(groupId);
             }
         }
 
-        private Sorted buildSorted(IntVector selected) {
-            long selectedCountsSize = 0;
-            long idsSize = 0;
-            Sorted sorted = null;
-            try {
-                /*
-                 * Get a count of all groups less than the maximum selected group. Count
-                 * *downwards* so that we can flip the sign on all of the actually selected
-                 * groups. Negative values in this array are always unselected groups.
-                 */
-                int selectedCountsLen = selected.max() + 1;
-                long adjust = RamUsageEstimator.alignObjectSize(
-                    RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + selectedCountsLen * Integer.BYTES
-                );
-                blockFactory.adjustBreaker(adjust);
-                selectedCountsSize = adjust;
-                int[] selectedCounts = new int[selectedCountsLen];
-                for (int id = 0; id < values.size(); id++) {
-                    long both = values.get(id);
-                    int group = (int) (both >>> Float.SIZE);
-                    if (group < selectedCounts.length) {
-                        selectedCounts[group]--;
-                    }
-                }
-
-                /*
-                 * Total the selected groups and turn the counts into the start index into a sort-of
-                 * off-by-one running count. It's really the number of values that have been inserted
-                 * into the results before starting on this group. Unselected groups will still
-                 * have negative counts.
-                 *
-                 * For example, if
-                 * | Group | Value Count | Selected |
-                 * |-------|-------------|----------|
-                 * |     0 | 3           | <-       |
-                 * |     1 | 1           | <-       |
-                 * |     2 | 2           |          |
-                 * |     3 | 1           | <-       |
-                 * |     4 | 4           | <-       |
-                 *
-                 * Then the total is 9 and the counts array will contain 0, 3, -2, 4, 5
-                 */
-                int total = 0;
-                for (int s = 0; s < selected.getPositionCount(); s++) {
-                    int group = selected.getInt(s);
-                    int count = -selectedCounts[group];
-                    selectedCounts[group] = total;
-                    total += count;
-                }
+        /**
+         * Returns true if the group has a value in firstValues; having a value in nextValues is optional.
+         * Returns false if the group does not have values in either firstValues or nextValues.
+         */
+        private boolean hasValue(int groupId) {
+            return seen == null || seen.get(groupId);
+        }
 
-                /*
-                 * Build a list of ids to insert in order *and* convert the running
-                 * count in selectedCounts[group] into the end index (exclusive) in
-                 * ids for each group.
-                 * Here we use the negative counts to signal that a group hasn't been
-                 * selected and the id containing values for that group is ignored.
-                 *
-                 * For example, if
-                 * | Group | Value Count | Selected |
-                 * |-------|-------------|----------|
-                 * |     0 | 3           | <-       |
-                 * |     1 | 1           | <-       |
-                 * |     2 | 2           |          |
-                 * |     3 | 1           | <-       |
-                 * |     4 | 4           | <-       |
-                 *
-                 * Then the total is 9 and the counts array will start with 0, 3, -2, 4, 5.
-                 * The counts will end with 3, 4, -2, 5, 9.
-                 */
-                adjust = RamUsageEstimator.alignObjectSize(RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + total * Integer.BYTES);
-                blockFactory.adjustBreaker(adjust);
-                idsSize = adjust;
-                int[] ids = new int[total];
-                for (int id = 0; id < values.size(); id++) {
-                    long both = values.get(id);
-                    int group = (int) (both >>> Float.SIZE);
-                    if (group < selectedCounts.length && selectedCounts[group] >= 0) {
-                        ids[selectedCounts[group]++] = id;
-                    }
-                }
-                final long totalMemoryUsed = selectedCountsSize + idsSize;
-                sorted = new Sorted(() -> blockFactory.adjustBreaker(-totalMemoryUsed), selectedCounts, ids);
-                return sorted;
-            } finally {
-                if (sorted == null) {
-                    blockFactory.adjustBreaker(-selectedCountsSize - idsSize);
-                }
-            }
+        /**
+         * Builds a {@link Block} with the unique values collected for the {@code #selected}
+         * groups. This is the implementation of the final and intermediate results of the agg.
+         */
+        Block toBlock(BlockFactory blockFactory, IntVector selected) {
+            nextValues.prepareForEmitting(selected);
+            return buildOutputBlock(blockFactory, selected);
         }
 
-        Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) {
+        Block buildOutputBlock(BlockFactory blockFactory, IntVector selected) {
             /*
              * Insert the ids in order.
              */
+            final int[] nextValueCounts = nextValues.selectedCounts;
             try (FloatBlock.Builder builder = blockFactory.newFloatBlockBuilder(selected.getPositionCount())) {
-                int start = 0;
+                int nextValuesStart = 0;
                 for (int s = 0; s < selected.getPositionCount(); s++) {
                     int group = selected.getInt(s);
-                    int end = selectedCounts[group];
-                    int count = end - start;
-                    switch (count) {
-                        case 0 -> builder.appendNull();
-                        case 1 -> builder.appendFloat(getValue(ids[start]));
-                        default -> {
-                            builder.beginPositionEntry();
-                            for (int i = start; i < end; i++) {
-                                builder.appendFloat(getValue(ids[i]));
-                            }
-                            builder.endPositionEntry();
+                    if (group > maxGroupId || hasValue(group) == false) {
+                        builder.appendNull();
+                        continue;
+                    }
+                    float firstValue = firstValues.get(group);
+                    final int nextValuesEnd = nextValueCounts != null ? nextValueCounts[group] : nextValuesStart;
+                    if (nextValuesEnd == nextValuesStart) {
+                        builder.appendFloat(firstValue);
+                    } else {
+                        builder.beginPositionEntry();
+                        builder.appendFloat(firstValue);
+                        // append values from the nextValues
+                        for (int i = nextValuesStart; i < nextValuesEnd; i++) {
+                            var nextValue = nextValues.getValue(i);
+                            builder.appendFloat(nextValue);
                         }
+                        builder.endPositionEntry();
+                        nextValuesStart = nextValuesEnd;
                     }
-                    start = end;
                 }
                 return builder.build();
             }
         }
 
-        float getValue(int valueId) {
-            long both = values.get(valueId);
-            return Float.intBitsToFloat((int) both);
-        }
-
-        @Override
-        public void enableGroupIdTracking(SeenGroupIds seen) {
-            // we figure out seen values from nulls on the values block
-        }
-
         @Override
         public void close() {
-            Releasables.closeExpectNoException(values);
+            Releasables.closeExpectNoException(seen, firstValues, nextValues);
         }
     }
 }

+ 200 - 142
x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java

@@ -7,19 +7,32 @@
 
 package org.elasticsearch.compute.aggregation;
 
+// begin generated imports
+import org.apache.lucene.util.BytesRef;
 import org.apache.lucene.util.RamUsageEstimator;
 import org.elasticsearch.common.util.BigArrays;
+import org.elasticsearch.common.util.BitArray;
+import org.elasticsearch.common.util.BytesRefHash;
 import org.elasticsearch.common.util.LongHash;
+import org.elasticsearch.common.util.LongLongHash;
+import org.elasticsearch.common.util.IntArray;
+import org.elasticsearch.compute.aggregation.blockhash.BlockHash;
 import org.elasticsearch.compute.ann.Aggregator;
 import org.elasticsearch.compute.ann.GroupingAggregator;
 import org.elasticsearch.compute.ann.IntermediateState;
 import org.elasticsearch.compute.data.Block;
 import org.elasticsearch.compute.data.BlockFactory;
 import org.elasticsearch.compute.data.IntBlock;
+import org.elasticsearch.compute.data.BytesRefBlock;
+import org.elasticsearch.compute.data.BytesRefVector;
+import org.elasticsearch.compute.data.IntBlock;
 import org.elasticsearch.compute.data.IntVector;
+import org.elasticsearch.compute.data.LongBlock;
+import org.elasticsearch.compute.data.OrdinalBytesRefBlock;
 import org.elasticsearch.compute.operator.DriverContext;
 import org.elasticsearch.core.Releasable;
 import org.elasticsearch.core.Releasables;
+// end generated imports
 
 /**
  * Aggregates field values for int.
@@ -105,34 +118,147 @@ class ValuesIntAggregator {
     }
 
     /**
-     * Values are collected in a hash. Iterating over them in order (row by row) to build the output,
-     * or merging with other state, can be expensive. To optimize this, we build a sorted structure once,
-     * and then use it to iterate over the values in order.
-     *
-     * @param ids positions of the {@link GroupingState#values} to read.
+     * Values after the first in each group are collected in a hash, keyed by the pair of groupId and value.
+     * When emitting the output, we need to iterate the hash one group at a time to build the output block,
+     * which would require O(N^2). To avoid this, we compute the counts for each group and remap the hash id
+     * to an array, allowing us to build the output in O(N) instead.
      */
-    private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable {
+    private static class NextValues implements Releasable {
+        private final BlockFactory blockFactory;
+        private final LongHash hashes;
+        private int[] selectedCounts = null;
+        private int[] ids = null;
+        private long extraMemoryUsed = 0;
+
+        private NextValues(BlockFactory blockFactory) {
+            this.blockFactory = blockFactory;
+            this.hashes = new LongHash(1, blockFactory.bigArrays());
+        }
+
+        void addValue(int groupId, int v) {
+            /*
+             * Encode the groupId and value into a single long -
+             * the top 32 bits for the group, the bottom 32 for the value.
+             */
+            hashes.add((((long) groupId) << Integer.SIZE) | (v & 0xFFFFFFFFL));
+        }
+
+        int getValue(int index) {
+            long both = hashes.get(ids[index]);
+            return (int) (both & 0xFFFFFFFFL);
+        }
+
+        private void reserveBytesForIntArray(long numElements) {
+            long adjust = RamUsageEstimator.alignObjectSize(RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + numElements * Integer.BYTES);
+            blockFactory.adjustBreaker(adjust);
+            extraMemoryUsed += adjust;
+        }
+
+        private void prepareForEmitting(IntVector selected) {
+            if (hashes.size() == 0) {
+                return;
+            }
+            /*
+             * Get a count of all groups less than the maximum selected group. Count
+             * *downwards* so that we can flip the sign on all of the actually selected
+             * groups. Negative values in this array are always unselected groups.
+             */
+            int selectedCountsLen = selected.max() + 1;
+            reserveBytesForIntArray(selectedCountsLen);
+            this.selectedCounts = new int[selectedCountsLen];
+            for (int id = 0; id < hashes.size(); id++) {
+                long both = hashes.get(id);
+                int group = (int) (both >>> Float.SIZE);
+                if (group < selectedCounts.length) {
+                    selectedCounts[group]--;
+                }
+            }
+
+            /*
+             * Total the selected groups and turn the counts into the start index into a sort-of
+             * off-by-one running count. It's really the number of values that have been inserted
+             * into the results before starting on this group. Unselected groups will still
+             * have negative counts.
+             *
+             * For example, if
+             * | Group | Value Count | Selected |
+             * |-------|-------------|----------|
+             * |     0 | 3           | <-       |
+             * |     1 | 1           | <-       |
+             * |     2 | 2           |          |
+             * |     3 | 1           | <-       |
+             * |     4 | 4           | <-       |
+             *
+             * Then the total is 9 and the counts array will contain 0, 3, -2, 4, 5
+             */
+            int total = 0;
+            for (int s = 0; s < selected.getPositionCount(); s++) {
+                int group = selected.getInt(s);
+                int count = -selectedCounts[group];
+                selectedCounts[group] = total;
+                total += count;
+            }
+
+            /*
+             * Build a list of ids to insert in order *and* convert the running
+             * count in selectedCounts[group] into the end index (exclusive) in
+             * ids for each group.
+             * Here we use the negative counts to signal that a group hasn't been
+             * selected and the id containing values for that group is ignored.
+             *
+             * For example, if
+             * | Group | Value Count | Selected |
+             * |-------|-------------|----------|
+             * |     0 | 3           | <-       |
+             * |     1 | 1           | <-       |
+             * |     2 | 2           |          |
+             * |     3 | 1           | <-       |
+             * |     4 | 4           | <-       |
+             *
+             * Then the total is 9 and the counts array will start with 0, 3, -2, 4, 5.
+             * The counts will end with 3, 4, -2, 5, 9.
+             */
+            reserveBytesForIntArray(total);
+
+            this.ids = new int[total];
+            for (int id = 0; id < hashes.size(); id++) {
+                long both = hashes.get(id);
+                int group = (int) (both >>> Float.SIZE);
+                ids[selectedCounts[group]++] = id;
+            }
+        }
+
         @Override
         public void close() {
-            releasable.close();
+            Releasables.closeExpectNoException(hashes, () -> blockFactory.adjustBreaker(-extraMemoryUsed));
         }
     }
 
     /**
      * State for a grouped {@code VALUES} aggregation. This implementation
-     * emphasizes collect-time performance over the performance of rendering
-     * results. That's good, but it's a pretty intensive emphasis, requiring
-     * an {@code O(n^2)} operation for collection to support a {@code O(1)}
-     * collector operation. But at least it's fairly simple.
+     * emphasizes collect-time performance over result rendering performance.
+     * The first value in each group is collected in the {@code firstValues}
+     * array, and subsequent values for each group are collected in {@code nextValues}.
      */
     public static class GroupingState implements GroupingAggregatorState {
-        private int maxGroupId = -1;
         private final BlockFactory blockFactory;
-        private final LongHash values;
+        IntArray firstValues;
+        private BitArray seen;
+        private int maxGroupId = -1;
+        private final NextValues nextValues;
 
         private GroupingState(DriverContext driverContext) {
             this.blockFactory = driverContext.blockFactory();
-            values = new LongHash(1, driverContext.bigArrays());
+            boolean success = false;
+            try {
+                this.firstValues = driverContext.bigArrays().newIntArray(1, false);
+                this.nextValues = new NextValues(driverContext.blockFactory());
+                success = true;
+            } finally {
+                if (success == false) {
+                    this.close();
+                }
+            }
         }
 
         @Override
@@ -141,158 +267,90 @@ class ValuesIntAggregator {
         }
 
         void addValue(int groupId, int v) {
-            /*
-             * Encode the groupId and value into a single long -
-             * the top 32 bits for the group, the bottom 32 for the value.
-             */
-            values.add((((long) groupId) << Integer.SIZE) | (v & 0xFFFFFFFFL));
-            maxGroupId = Math.max(maxGroupId, groupId);
+            if (groupId > maxGroupId) {
+                firstValues = blockFactory.bigArrays().grow(firstValues, groupId + 1);
+                firstValues.set(groupId, v);
+                // We start in untracked mode, assuming every group has a value as an optimization to avoid allocating
+                // and updating the seen bitset. However, once some groups don't have values, we initialize the seen bitset,
+                // fill the groups that have values, and begin tracking incoming groups.
+                if (seen == null && groupId > maxGroupId + 1) {
+                    seen = new BitArray(groupId + 1, blockFactory.bigArrays());
+                    seen.fill(0, maxGroupId + 1, true);
+                }
+                trackGroupId(groupId);
+                maxGroupId = groupId;
+            } else if (hasValue(groupId) == false) {
+                firstValues.set(groupId, v);
+                trackGroupId(groupId);
+            } else if (firstValues.get(groupId) != v) {
+                nextValues.addValue(groupId, v);
+            }
         }
 
-        /**
-         * Builds a {@link Block} with the unique values collected for the {@code #selected}
-         * groups. This is the implementation of the final and intermediate results of the agg.
-         */
-        Block toBlock(BlockFactory blockFactory, IntVector selected) {
-            if (values.size() == 0) {
-                return blockFactory.newConstantNullBlock(selected.getPositionCount());
-            }
+        @Override
+        public void enableGroupIdTracking(SeenGroupIds seen) {
+            // we track the seen values manually
+        }
 
-            try (var sorted = buildSorted(selected)) {
-                return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids);
+        private void trackGroupId(int groupId) {
+            if (seen != null) {
+                seen.set(groupId);
             }
         }
 
-        private Sorted buildSorted(IntVector selected) {
-            long selectedCountsSize = 0;
-            long idsSize = 0;
-            Sorted sorted = null;
-            try {
-                /*
-                 * Get a count of all groups less than the maximum selected group. Count
-                 * *downwards* so that we can flip the sign on all of the actually selected
-                 * groups. Negative values in this array are always unselected groups.
-                 */
-                int selectedCountsLen = selected.max() + 1;
-                long adjust = RamUsageEstimator.alignObjectSize(
-                    RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + selectedCountsLen * Integer.BYTES
-                );
-                blockFactory.adjustBreaker(adjust);
-                selectedCountsSize = adjust;
-                int[] selectedCounts = new int[selectedCountsLen];
-                for (int id = 0; id < values.size(); id++) {
-                    long both = values.get(id);
-                    int group = (int) (both >>> Float.SIZE);
-                    if (group < selectedCounts.length) {
-                        selectedCounts[group]--;
-                    }
-                }
-
-                /*
-                 * Total the selected groups and turn the counts into the start index into a sort-of
-                 * off-by-one running count. It's really the number of values that have been inserted
-                 * into the results before starting on this group. Unselected groups will still
-                 * have negative counts.
-                 *
-                 * For example, if
-                 * | Group | Value Count | Selected |
-                 * |-------|-------------|----------|
-                 * |     0 | 3           | <-       |
-                 * |     1 | 1           | <-       |
-                 * |     2 | 2           |          |
-                 * |     3 | 1           | <-       |
-                 * |     4 | 4           | <-       |
-                 *
-                 * Then the total is 9 and the counts array will contain 0, 3, -2, 4, 5
-                 */
-                int total = 0;
-                for (int s = 0; s < selected.getPositionCount(); s++) {
-                    int group = selected.getInt(s);
-                    int count = -selectedCounts[group];
-                    selectedCounts[group] = total;
-                    total += count;
-                }
+        /**
+         * Returns true if the group has a value in firstValues; having a value in nextValues is optional.
+         * Returns false if the group does not have values in either firstValues or nextValues.
+         */
+        private boolean hasValue(int groupId) {
+            return seen == null || seen.get(groupId);
+        }
 
-                /*
-                 * Build a list of ids to insert in order *and* convert the running
-                 * count in selectedCounts[group] into the end index (exclusive) in
-                 * ids for each group.
-                 * Here we use the negative counts to signal that a group hasn't been
-                 * selected and the id containing values for that group is ignored.
-                 *
-                 * For example, if
-                 * | Group | Value Count | Selected |
-                 * |-------|-------------|----------|
-                 * |     0 | 3           | <-       |
-                 * |     1 | 1           | <-       |
-                 * |     2 | 2           |          |
-                 * |     3 | 1           | <-       |
-                 * |     4 | 4           | <-       |
-                 *
-                 * Then the total is 9 and the counts array will start with 0, 3, -2, 4, 5.
-                 * The counts will end with 3, 4, -2, 5, 9.
-                 */
-                adjust = RamUsageEstimator.alignObjectSize(RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + total * Integer.BYTES);
-                blockFactory.adjustBreaker(adjust);
-                idsSize = adjust;
-                int[] ids = new int[total];
-                for (int id = 0; id < values.size(); id++) {
-                    long both = values.get(id);
-                    int group = (int) (both >>> Float.SIZE);
-                    if (group < selectedCounts.length && selectedCounts[group] >= 0) {
-                        ids[selectedCounts[group]++] = id;
-                    }
-                }
-                final long totalMemoryUsed = selectedCountsSize + idsSize;
-                sorted = new Sorted(() -> blockFactory.adjustBreaker(-totalMemoryUsed), selectedCounts, ids);
-                return sorted;
-            } finally {
-                if (sorted == null) {
-                    blockFactory.adjustBreaker(-selectedCountsSize - idsSize);
-                }
-            }
+        /**
+         * Builds a {@link Block} with the unique values collected for the {@code #selected}
+         * groups. This is the implementation of the final and intermediate results of the agg.
+         */
+        Block toBlock(BlockFactory blockFactory, IntVector selected) {
+            nextValues.prepareForEmitting(selected);
+            return buildOutputBlock(blockFactory, selected);
         }
 
-        Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) {
+        Block buildOutputBlock(BlockFactory blockFactory, IntVector selected) {
             /*
              * Insert the ids in order.
              */
+            final int[] nextValueCounts = nextValues.selectedCounts;
             try (IntBlock.Builder builder = blockFactory.newIntBlockBuilder(selected.getPositionCount())) {
-                int start = 0;
+                int nextValuesStart = 0;
                 for (int s = 0; s < selected.getPositionCount(); s++) {
                     int group = selected.getInt(s);
-                    int end = selectedCounts[group];
-                    int count = end - start;
-                    switch (count) {
-                        case 0 -> builder.appendNull();
-                        case 1 -> builder.appendInt(getValue(ids[start]));
-                        default -> {
-                            builder.beginPositionEntry();
-                            for (int i = start; i < end; i++) {
-                                builder.appendInt(getValue(ids[i]));
-                            }
-                            builder.endPositionEntry();
+                    if (group > maxGroupId || hasValue(group) == false) {
+                        builder.appendNull();
+                        continue;
+                    }
+                    int firstValue = firstValues.get(group);
+                    final int nextValuesEnd = nextValueCounts != null ? nextValueCounts[group] : nextValuesStart;
+                    if (nextValuesEnd == nextValuesStart) {
+                        builder.appendInt(firstValue);
+                    } else {
+                        builder.beginPositionEntry();
+                        builder.appendInt(firstValue);
+                        // append values from the nextValues
+                        for (int i = nextValuesStart; i < nextValuesEnd; i++) {
+                            var nextValue = nextValues.getValue(i);
+                            builder.appendInt(nextValue);
                         }
+                        builder.endPositionEntry();
+                        nextValuesStart = nextValuesEnd;
                     }
-                    start = end;
                 }
                 return builder.build();
             }
         }
 
-        int getValue(int valueId) {
-            long both = values.get(valueId);
-            return (int) both;
-        }
-
-        @Override
-        public void enableGroupIdTracking(SeenGroupIds seen) {
-            // we figure out seen values from nulls on the values block
-        }
-
         @Override
         public void close() {
-            Releasables.closeExpectNoException(values);
+            Releasables.closeExpectNoException(seen, firstValues, nextValues);
         }
     }
 }

+ 192 - 135
x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java

@@ -7,20 +7,32 @@
 
 package org.elasticsearch.compute.aggregation;
 
+// begin generated imports
+import org.apache.lucene.util.BytesRef;
 import org.apache.lucene.util.RamUsageEstimator;
 import org.elasticsearch.common.util.BigArrays;
+import org.elasticsearch.common.util.BitArray;
+import org.elasticsearch.common.util.BytesRefHash;
 import org.elasticsearch.common.util.LongHash;
 import org.elasticsearch.common.util.LongLongHash;
+import org.elasticsearch.common.util.LongArray;
+import org.elasticsearch.compute.aggregation.blockhash.BlockHash;
 import org.elasticsearch.compute.ann.Aggregator;
 import org.elasticsearch.compute.ann.GroupingAggregator;
 import org.elasticsearch.compute.ann.IntermediateState;
 import org.elasticsearch.compute.data.Block;
 import org.elasticsearch.compute.data.BlockFactory;
+import org.elasticsearch.compute.data.LongBlock;
+import org.elasticsearch.compute.data.BytesRefBlock;
+import org.elasticsearch.compute.data.BytesRefVector;
+import org.elasticsearch.compute.data.IntBlock;
 import org.elasticsearch.compute.data.IntVector;
 import org.elasticsearch.compute.data.LongBlock;
+import org.elasticsearch.compute.data.OrdinalBytesRefBlock;
 import org.elasticsearch.compute.operator.DriverContext;
 import org.elasticsearch.core.Releasable;
 import org.elasticsearch.core.Releasables;
+// end generated imports
 
 /**
  * Aggregates field values for long.
@@ -106,34 +118,140 @@ class ValuesLongAggregator {
     }
 
     /**
-     * Values are collected in a hash. Iterating over them in order (row by row) to build the output,
-     * or merging with other state, can be expensive. To optimize this, we build a sorted structure once,
-     * and then use it to iterate over the values in order.
-     *
-     * @param ids positions of the {@link GroupingState#values} to read.
+     * Values after the first in each group are collected in a hash, keyed by the pair of groupId and value.
+     * When emitting the output, we need to iterate the hash one group at a time to build the output block,
+     * which would require O(N^2). To avoid this, we compute the counts for each group and remap the hash id
+     * to an array, allowing us to build the output in O(N) instead.
      */
-    private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable {
+    private static class NextValues implements Releasable {
+        private final BlockFactory blockFactory;
+        private final LongLongHash hashes;
+        private int[] selectedCounts = null;
+        private int[] ids = null;
+        private long extraMemoryUsed = 0;
+
+        private NextValues(BlockFactory blockFactory) {
+            this.blockFactory = blockFactory;
+            this.hashes = new LongLongHash(1, blockFactory.bigArrays());
+        }
+
+        void addValue(int groupId, long v) {
+            hashes.add(groupId, v);
+        }
+
+        long getValue(int index) {
+            return hashes.getKey2(ids[index]);
+        }
+
+        private void reserveBytesForIntArray(long numElements) {
+            long adjust = RamUsageEstimator.alignObjectSize(RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + numElements * Integer.BYTES);
+            blockFactory.adjustBreaker(adjust);
+            extraMemoryUsed += adjust;
+        }
+
+        private void prepareForEmitting(IntVector selected) {
+            if (hashes.size() == 0) {
+                return;
+            }
+            /*
+             * Get a count of all groups less than the maximum selected group. Count
+             * *downwards* so that we can flip the sign on all of the actually selected
+             * groups. Negative values in this array are always unselected groups.
+             */
+            int selectedCountsLen = selected.max() + 1;
+            reserveBytesForIntArray(selectedCountsLen);
+            this.selectedCounts = new int[selectedCountsLen];
+            for (int id = 0; id < hashes.size(); id++) {
+                int group = (int) hashes.getKey1(id);
+                if (group < selectedCounts.length) {
+                    selectedCounts[group]--;
+                }
+            }
+
+            /*
+             * Total the selected groups and turn the counts into the start index into a sort-of
+             * off-by-one running count. It's really the number of values that have been inserted
+             * into the results before starting on this group. Unselected groups will still
+             * have negative counts.
+             *
+             * For example, if
+             * | Group | Value Count | Selected |
+             * |-------|-------------|----------|
+             * |     0 | 3           | <-       |
+             * |     1 | 1           | <-       |
+             * |     2 | 2           |          |
+             * |     3 | 1           | <-       |
+             * |     4 | 4           | <-       |
+             *
+             * Then the total is 9 and the counts array will contain 0, 3, -2, 4, 5
+             */
+            int total = 0;
+            for (int s = 0; s < selected.getPositionCount(); s++) {
+                int group = selected.getInt(s);
+                int count = -selectedCounts[group];
+                selectedCounts[group] = total;
+                total += count;
+            }
+
+            /*
+             * Build a list of ids to insert in order *and* convert the running
+             * count in selectedCounts[group] into the end index (exclusive) in
+             * ids for each group.
+             * Here we use the negative counts to signal that a group hasn't been
+             * selected and the id containing values for that group is ignored.
+             *
+             * For example, if
+             * | Group | Value Count | Selected |
+             * |-------|-------------|----------|
+             * |     0 | 3           | <-       |
+             * |     1 | 1           | <-       |
+             * |     2 | 2           |          |
+             * |     3 | 1           | <-       |
+             * |     4 | 4           | <-       |
+             *
+             * Then the total is 9 and the counts array will start with 0, 3, -2, 4, 5.
+             * The counts will end with 3, 4, -2, 5, 9.
+             */
+            reserveBytesForIntArray(total);
+
+            this.ids = new int[total];
+            for (int id = 0; id < hashes.size(); id++) {
+                int group = (int) hashes.getKey1(id);
+                ids[selectedCounts[group]++] = id;
+            }
+        }
+
         @Override
         public void close() {
-            releasable.close();
+            Releasables.closeExpectNoException(hashes, () -> blockFactory.adjustBreaker(-extraMemoryUsed));
         }
     }
 
     /**
      * State for a grouped {@code VALUES} aggregation. This implementation
-     * emphasizes collect-time performance over the performance of rendering
-     * results. That's good, but it's a pretty intensive emphasis, requiring
-     * an {@code O(n^2)} operation for collection to support a {@code O(1)}
-     * collector operation. But at least it's fairly simple.
+     * emphasizes collect-time performance over result rendering performance.
+     * The first value in each group is collected in the {@code firstValues}
+     * array, and subsequent values for each group are collected in {@code nextValues}.
      */
     public static class GroupingState implements GroupingAggregatorState {
-        private int maxGroupId = -1;
         private final BlockFactory blockFactory;
-        private final LongLongHash values;
+        LongArray firstValues;
+        private BitArray seen;
+        private int maxGroupId = -1;
+        private final NextValues nextValues;
 
         private GroupingState(DriverContext driverContext) {
             this.blockFactory = driverContext.blockFactory();
-            values = new LongLongHash(1, driverContext.bigArrays());
+            boolean success = false;
+            try {
+                this.firstValues = driverContext.bigArrays().newLongArray(1, false);
+                this.nextValues = new NextValues(driverContext.blockFactory());
+                success = true;
+            } finally {
+                if (success == false) {
+                    this.close();
+                }
+            }
         }
 
         @Override
@@ -142,151 +260,90 @@ class ValuesLongAggregator {
         }
 
         void addValue(int groupId, long v) {
-            values.add(groupId, v);
-            maxGroupId = Math.max(maxGroupId, groupId);
+            if (groupId > maxGroupId) {
+                firstValues = blockFactory.bigArrays().grow(firstValues, groupId + 1);
+                firstValues.set(groupId, v);
+                // We start in untracked mode, assuming every group has a value as an optimization to avoid allocating
+                // and updating the seen bitset. However, once some groups don't have values, we initialize the seen bitset,
+                // fill the groups that have values, and begin tracking incoming groups.
+                if (seen == null && groupId > maxGroupId + 1) {
+                    seen = new BitArray(groupId + 1, blockFactory.bigArrays());
+                    seen.fill(0, maxGroupId + 1, true);
+                }
+                trackGroupId(groupId);
+                maxGroupId = groupId;
+            } else if (hasValue(groupId) == false) {
+                firstValues.set(groupId, v);
+                trackGroupId(groupId);
+            } else if (firstValues.get(groupId) != v) {
+                nextValues.addValue(groupId, v);
+            }
         }
 
-        /**
-         * Builds a {@link Block} with the unique values collected for the {@code #selected}
-         * groups. This is the implementation of the final and intermediate results of the agg.
-         */
-        Block toBlock(BlockFactory blockFactory, IntVector selected) {
-            if (values.size() == 0) {
-                return blockFactory.newConstantNullBlock(selected.getPositionCount());
-            }
+        @Override
+        public void enableGroupIdTracking(SeenGroupIds seen) {
+            // we track the seen values manually
+        }
 
-            try (var sorted = buildSorted(selected)) {
-                return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids);
+        private void trackGroupId(int groupId) {
+            if (seen != null) {
+                seen.set(groupId);
             }
         }
 
-        private Sorted buildSorted(IntVector selected) {
-            long selectedCountsSize = 0;
-            long idsSize = 0;
-            Sorted sorted = null;
-            try {
-                /*
-                 * Get a count of all groups less than the maximum selected group. Count
-                 * *downwards* so that we can flip the sign on all of the actually selected
-                 * groups. Negative values in this array are always unselected groups.
-                 */
-                int selectedCountsLen = selected.max() + 1;
-                long adjust = RamUsageEstimator.alignObjectSize(
-                    RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + selectedCountsLen * Integer.BYTES
-                );
-                blockFactory.adjustBreaker(adjust);
-                selectedCountsSize = adjust;
-                int[] selectedCounts = new int[selectedCountsLen];
-                for (int id = 0; id < values.size(); id++) {
-                    int group = (int) values.getKey1(id);
-                    if (group < selectedCounts.length) {
-                        selectedCounts[group]--;
-                    }
-                }
-
-                /*
-                 * Total the selected groups and turn the counts into the start index into a sort-of
-                 * off-by-one running count. It's really the number of values that have been inserted
-                 * into the results before starting on this group. Unselected groups will still
-                 * have negative counts.
-                 *
-                 * For example, if
-                 * | Group | Value Count | Selected |
-                 * |-------|-------------|----------|
-                 * |     0 | 3           | <-       |
-                 * |     1 | 1           | <-       |
-                 * |     2 | 2           |          |
-                 * |     3 | 1           | <-       |
-                 * |     4 | 4           | <-       |
-                 *
-                 * Then the total is 9 and the counts array will contain 0, 3, -2, 4, 5
-                 */
-                int total = 0;
-                for (int s = 0; s < selected.getPositionCount(); s++) {
-                    int group = selected.getInt(s);
-                    int count = -selectedCounts[group];
-                    selectedCounts[group] = total;
-                    total += count;
-                }
+        /**
+         * Returns true if the group has a value in firstValues; having a value in nextValues is optional.
+         * Returns false if the group does not have values in either firstValues or nextValues.
+         */
+        private boolean hasValue(int groupId) {
+            return seen == null || seen.get(groupId);
+        }
 
-                /*
-                 * Build a list of ids to insert in order *and* convert the running
-                 * count in selectedCounts[group] into the end index (exclusive) in
-                 * ids for each group.
-                 * Here we use the negative counts to signal that a group hasn't been
-                 * selected and the id containing values for that group is ignored.
-                 *
-                 * For example, if
-                 * | Group | Value Count | Selected |
-                 * |-------|-------------|----------|
-                 * |     0 | 3           | <-       |
-                 * |     1 | 1           | <-       |
-                 * |     2 | 2           |          |
-                 * |     3 | 1           | <-       |
-                 * |     4 | 4           | <-       |
-                 *
-                 * Then the total is 9 and the counts array will start with 0, 3, -2, 4, 5.
-                 * The counts will end with 3, 4, -2, 5, 9.
-                 */
-                adjust = RamUsageEstimator.alignObjectSize(RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + total * Integer.BYTES);
-                blockFactory.adjustBreaker(adjust);
-                idsSize = adjust;
-                int[] ids = new int[total];
-                for (int id = 0; id < values.size(); id++) {
-                    int group = (int) values.getKey1(id);
-                    if (group < selectedCounts.length && selectedCounts[group] >= 0) {
-                        ids[selectedCounts[group]++] = id;
-                    }
-                }
-                final long totalMemoryUsed = selectedCountsSize + idsSize;
-                sorted = new Sorted(() -> blockFactory.adjustBreaker(-totalMemoryUsed), selectedCounts, ids);
-                return sorted;
-            } finally {
-                if (sorted == null) {
-                    blockFactory.adjustBreaker(-selectedCountsSize - idsSize);
-                }
-            }
+        /**
+         * Builds a {@link Block} with the unique values collected for the {@code #selected}
+         * groups. This is the implementation of the final and intermediate results of the agg.
+         */
+        Block toBlock(BlockFactory blockFactory, IntVector selected) {
+            nextValues.prepareForEmitting(selected);
+            return buildOutputBlock(blockFactory, selected);
         }
 
-        Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) {
+        Block buildOutputBlock(BlockFactory blockFactory, IntVector selected) {
             /*
              * Insert the ids in order.
              */
+            final int[] nextValueCounts = nextValues.selectedCounts;
             try (LongBlock.Builder builder = blockFactory.newLongBlockBuilder(selected.getPositionCount())) {
-                int start = 0;
+                int nextValuesStart = 0;
                 for (int s = 0; s < selected.getPositionCount(); s++) {
                     int group = selected.getInt(s);
-                    int end = selectedCounts[group];
-                    int count = end - start;
-                    switch (count) {
-                        case 0 -> builder.appendNull();
-                        case 1 -> builder.appendLong(getValue(ids[start]));
-                        default -> {
-                            builder.beginPositionEntry();
-                            for (int i = start; i < end; i++) {
-                                builder.appendLong(getValue(ids[i]));
-                            }
-                            builder.endPositionEntry();
+                    if (group > maxGroupId || hasValue(group) == false) {
+                        builder.appendNull();
+                        continue;
+                    }
+                    long firstValue = firstValues.get(group);
+                    final int nextValuesEnd = nextValueCounts != null ? nextValueCounts[group] : nextValuesStart;
+                    if (nextValuesEnd == nextValuesStart) {
+                        builder.appendLong(firstValue);
+                    } else {
+                        builder.beginPositionEntry();
+                        builder.appendLong(firstValue);
+                        // append values from the nextValues
+                        for (int i = nextValuesStart; i < nextValuesEnd; i++) {
+                            var nextValue = nextValues.getValue(i);
+                            builder.appendLong(nextValue);
                         }
+                        builder.endPositionEntry();
+                        nextValuesStart = nextValuesEnd;
                     }
-                    start = end;
                 }
                 return builder.build();
             }
         }
 
-        long getValue(int valueId) {
-            return values.getKey2(valueId);
-        }
-
-        @Override
-        public void enableGroupIdTracking(SeenGroupIds seen) {
-            // we figure out seen values from nulls on the values block
-        }
-
         @Override
         public void close() {
-            Releasables.closeExpectNoException(values);
+            Releasables.closeExpectNoException(seen, firstValues, nextValues);
         }
     }
 }

+ 291 - 241
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st

@@ -7,44 +7,32 @@
 
 package org.elasticsearch.compute.aggregation;
 
-$if(BytesRef)$
+// begin generated imports
 import org.apache.lucene.util.BytesRef;
-$endif$
 import org.apache.lucene.util.RamUsageEstimator;
 import org.elasticsearch.common.util.BigArrays;
-$if(BytesRef)$
+import org.elasticsearch.common.util.BitArray;
 import org.elasticsearch.common.util.BytesRefHash;
-$else$
 import org.elasticsearch.common.util.LongHash;
-$endif$
-$if(long||double||BytesRef)$
 import org.elasticsearch.common.util.LongLongHash;
-$endif$
-$if(BytesRef)$
+import org.elasticsearch.common.util.$if(BytesRef)$Int$else$$Type$$endif$Array;
 import org.elasticsearch.compute.aggregation.blockhash.BlockHash;
-$endif$
 import org.elasticsearch.compute.ann.Aggregator;
 import org.elasticsearch.compute.ann.GroupingAggregator;
 import org.elasticsearch.compute.ann.IntermediateState;
 import org.elasticsearch.compute.data.Block;
 import org.elasticsearch.compute.data.BlockFactory;
-$if(int||double||float)$
 import org.elasticsearch.compute.data.$Type$Block;
-$elseif(BytesRef)$
 import org.elasticsearch.compute.data.BytesRefBlock;
 import org.elasticsearch.compute.data.BytesRefVector;
 import org.elasticsearch.compute.data.IntBlock;
-$endif$
 import org.elasticsearch.compute.data.IntVector;
-$if(long)$
 import org.elasticsearch.compute.data.LongBlock;
-$endif$
-$if(BytesRef)$
 import org.elasticsearch.compute.data.OrdinalBytesRefBlock;
-$endif$
 import org.elasticsearch.compute.operator.DriverContext;
 import org.elasticsearch.core.Releasable;
 import org.elasticsearch.core.Releasables;
+// end generated imports
 
 /**
  * Aggregates field values for $type$.
@@ -204,62 +192,190 @@ $endif$
     }
 
     /**
-     * Values are collected in a hash. Iterating over them in order (row by row) to build the output,
-     * or merging with other state, can be expensive. To optimize this, we build a sorted structure once,
-     * and then use it to iterate over the values in order.
-     *
-     * @param ids positions of the {@link GroupingState#values} to read.
+     * Values after the first in each group are collected in a hash, keyed by the pair of groupId and value.
+     * When emitting the output, we need to iterate the hash one group at a time to build the output block,
+     * which would require O(N^2). To avoid this, we compute the counts for each group and remap the hash id
+     * to an array, allowing us to build the output in O(N) instead.
      */
-    private record Sorted(Releasable releasable, int[] counts, int[] ids) implements Releasable {
+    private static class NextValues implements Releasable {
+        private final BlockFactory blockFactory;
+$if(long||double)$
+        private final LongLongHash hashes;
+$else$
+        private final LongHash hashes;
+$endif$
+        private int[] selectedCounts = null;
+        private int[] ids = null;
+        private long extraMemoryUsed = 0;
+
+        private NextValues(BlockFactory blockFactory) {
+            this.blockFactory = blockFactory;
+            this.hashes = new Long$if(long||double)$Long$endif$Hash(1, blockFactory.bigArrays());
+        }
+
+        void addValue(int groupId, $if(BytesRef)$int$else$$type$$endif$ v) {
+$if(long)$
+            hashes.add(groupId, v);
+$elseif(double)$
+            hashes.add(groupId, Double.doubleToLongBits(v));
+$elseif(int||BytesRef)$
+            /*
+             * Encode the groupId and value into a single long -
+             * the top 32 bits for the group, the bottom 32 for the value.
+             */
+            hashes.add((((long) groupId) << Integer.SIZE) | (v & 0xFFFFFFFFL));
+$elseif(float)$
+            /*
+             * Encode the groupId and value into a single long -
+             * the top 32 bits for the group, the bottom 32 for the value.
+             */
+            hashes.add((((long) groupId) << Float.SIZE) | (Float.floatToIntBits(v) & 0xFFFFFFFFL));
+$endif$
+        }
+
+        $if(BytesRef)$int$else$$type$$endif$ getValue(int index) {
+$if(long)$
+            return hashes.getKey2(ids[index]);
+$elseif(double)$
+            return Double.longBitsToDouble(hashes.getKey2(ids[index]));
+$elseif(float)$
+            long both = hashes.get(ids[index]);
+            return Float.intBitsToFloat((int) (both & 0xFFFFFFFFL));
+$elseif(BytesRef||int)$
+            long both = hashes.get(ids[index]);
+            return (int) (both & 0xFFFFFFFFL);
+$endif$
+        }
+
+        private void reserveBytesForIntArray(long numElements) {
+            long adjust = RamUsageEstimator.alignObjectSize(RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + numElements * Integer.BYTES);
+            blockFactory.adjustBreaker(adjust);
+            extraMemoryUsed += adjust;
+        }
+
+        private void prepareForEmitting(IntVector selected) {
+            if (hashes.size() == 0) {
+                return;
+            }
+            /*
+             * Get a count of all groups less than the maximum selected group. Count
+             * *downwards* so that we can flip the sign on all of the actually selected
+             * groups. Negative values in this array are always unselected groups.
+             */
+            int selectedCountsLen = selected.max() + 1;
+            reserveBytesForIntArray(selectedCountsLen);
+            this.selectedCounts = new int[selectedCountsLen];
+            for (int id = 0; id < hashes.size(); id++) {
+$if(long||double)$
+                int group = (int) hashes.getKey1(id);
+$elseif(float||int||BytesRef)$
+                long both = hashes.get(id);
+                int group = (int) (both >>> Float.SIZE);
+$endif$
+                if (group < selectedCounts.length) {
+                    selectedCounts[group]--;
+                }
+            }
+
+            /*
+             * Total the selected groups and turn the counts into the start index into a sort-of
+             * off-by-one running count. It's really the number of values that have been inserted
+             * into the results before starting on this group. Unselected groups will still
+             * have negative counts.
+             *
+             * For example, if
+             * | Group | Value Count | Selected |
+             * |-------|-------------|----------|
+             * |     0 | 3           | <-       |
+             * |     1 | 1           | <-       |
+             * |     2 | 2           |          |
+             * |     3 | 1           | <-       |
+             * |     4 | 4           | <-       |
+             *
+             * Then the total is 9 and the counts array will contain 0, 3, -2, 4, 5
+             */
+            int total = 0;
+            for (int s = 0; s < selected.getPositionCount(); s++) {
+                int group = selected.getInt(s);
+                int count = -selectedCounts[group];
+                selectedCounts[group] = total;
+                total += count;
+            }
+
+            /*
+             * Build a list of ids to insert in order *and* convert the running
+             * count in selectedCounts[group] into the end index (exclusive) in
+             * ids for each group.
+             * Here we use the negative counts to signal that a group hasn't been
+             * selected and the id containing values for that group is ignored.
+             *
+             * For example, if
+             * | Group | Value Count | Selected |
+             * |-------|-------------|----------|
+             * |     0 | 3           | <-       |
+             * |     1 | 1           | <-       |
+             * |     2 | 2           |          |
+             * |     3 | 1           | <-       |
+             * |     4 | 4           | <-       |
+             *
+             * Then the total is 9 and the counts array will start with 0, 3, -2, 4, 5.
+             * The counts will end with 3, 4, -2, 5, 9.
+             */
+            reserveBytesForIntArray(total);
+
+            this.ids = new int[total];
+            for (int id = 0; id < hashes.size(); id++) {
+$if(long||double)$
+                int group = (int) hashes.getKey1(id);
+$elseif(float||int||BytesRef)$
+                long both = hashes.get(id);
+                int group = (int) (both >>> Float.SIZE);
+$endif$
+                ids[selectedCounts[group]++] = id;
+            }
+        }
+
         @Override
         public void close() {
-            releasable.close();
+            Releasables.closeExpectNoException(hashes, () -> blockFactory.adjustBreaker(-extraMemoryUsed));
         }
     }
 
     /**
      * State for a grouped {@code VALUES} aggregation. This implementation
-     * emphasizes collect-time performance over the performance of rendering
-     * results. That's good, but it's a pretty intensive emphasis, requiring
-     * an {@code O(n^2)} operation for collection to support a {@code O(1)}
-     * collector operation. But at least it's fairly simple.
+     * emphasizes collect-time performance over result rendering performance.
+     * The first value in each group is collected in the {@code firstValues}
+     * array, and subsequent values for each group are collected in {@code nextValues}.
      */
     public static class GroupingState implements GroupingAggregatorState {
-        private int maxGroupId = -1;
         private final BlockFactory blockFactory;
-$if(long||double)$
-        private final LongLongHash values;
-
-$elseif(BytesRef)$
-        private final LongLongHash values;
+$if(BytesRef)$
         BytesRefHash bytes;
-
-$elseif(int||float)$
-        private final LongHash values;
-
+        private IntArray firstValues;
+$else$
+        $Type$Array firstValues;
+        private BitArray seen;
+        private int maxGroupId = -1;
 $endif$
+        private final NextValues nextValues;
+
         private GroupingState(DriverContext driverContext) {
             this.blockFactory = driverContext.blockFactory();
-$if(long||double)$
-            values = new LongLongHash(1, driverContext.bigArrays());
-$elseif(BytesRef)$
-            LongLongHash _values = null;
-            BytesRefHash _bytes = null;
+            boolean success = false;
             try {
-                _values = new LongLongHash(1, driverContext.bigArrays());
-                _bytes = new BytesRefHash(1, driverContext.bigArrays());
-
-                values = _values;
-                bytes = _bytes;
-
-                _values = null;
-                _bytes = null;
+$if(BytesRef)$
+                this.bytes = new BytesRefHash(1, driverContext.bigArrays());
+                this.firstValues = driverContext.bigArrays().newIntArray(1, true);
+$else$
+                this.firstValues = driverContext.bigArrays().new$Type$Array(1, false);
+$endif$
+                this.nextValues = new NextValues(driverContext.blockFactory());
+                success = true;
             } finally {
-                Releasables.closeExpectNoException(_values, _bytes);
+                if (success == false) {
+                    this.close();
+                }
             }
-$elseif(int||float)$
-            values = new LongHash(1, driverContext.bigArrays());
-$endif$
         }
 
         @Override
@@ -268,210 +384,169 @@ $endif$
         }
 
 $if(BytesRef)$
-        void addValueOrdinal(int groupId, long valueOrdinal) {
-            values.add(groupId, valueOrdinal);
-            maxGroupId = Math.max(maxGroupId, groupId);
+        void addValueOrdinal(int groupId, int valueOrdinal) {
+            if (groupId < firstValues.size()) {
+                int current = firstValues.get(groupId) - 1;
+                if (current < 0) {
+                    firstValues.set(groupId, valueOrdinal + 1);
+                } else if (current != valueOrdinal) {
+                    nextValues.addValue(groupId, valueOrdinal);
+                }
+            } else {
+                firstValues = blockFactory.bigArrays().grow(firstValues, groupId + 1);
+                firstValues.set(groupId, valueOrdinal + 1);
+            }
         }
 
 $endif$
         void addValue(int groupId, $type$ v) {
-$if(long)$
-            values.add(groupId, v);
-$elseif(double)$
-            values.add(groupId, Double.doubleToLongBits(v));
-$elseif(BytesRef)$
-            values.add(groupId, BlockHash.hashOrdToGroup(bytes.add(v)));
-$elseif(int)$
-            /*
-             * Encode the groupId and value into a single long -
-             * the top 32 bits for the group, the bottom 32 for the value.
-             */
-            values.add((((long) groupId) << Integer.SIZE) | (v & 0xFFFFFFFFL));
-$elseif(float)$
-            /*
-             * Encode the groupId and value into a single long -
-             * the top 32 bits for the group, the bottom 32 for the value.
-             */
-            values.add((((long) groupId) << Float.SIZE) | (Float.floatToIntBits(v) & 0xFFFFFFFFL));
+$if(BytesRef)$
+            int valueOrdinal = Math.toIntExact(BlockHash.hashOrdToGroup(bytes.add(v)));
+            addValueOrdinal(groupId, valueOrdinal);
+$else$
+            if (groupId > maxGroupId) {
+                firstValues = blockFactory.bigArrays().grow(firstValues, groupId + 1);
+                firstValues.set(groupId, v);
+                // We start in untracked mode, assuming every group has a value as an optimization to avoid allocating
+                // and updating the seen bitset. However, once some groups don't have values, we initialize the seen bitset,
+                // fill the groups that have values, and begin tracking incoming groups.
+                if (seen == null && groupId > maxGroupId + 1) {
+                    seen = new BitArray(groupId + 1, blockFactory.bigArrays());
+                    seen.fill(0, maxGroupId + 1, true);
+                }
+                trackGroupId(groupId);
+                maxGroupId = groupId;
+            } else if (hasValue(groupId) == false) {
+                firstValues.set(groupId, v);
+                trackGroupId(groupId);
+            } else if (firstValues.get(groupId) != v) {
+                nextValues.addValue(groupId, v);
+            }
 $endif$
-            maxGroupId = Math.max(maxGroupId, groupId);
         }
 
+$if(BytesRef)$
+        @Override
+        public void enableGroupIdTracking(SeenGroupIds seen) {
+            // we figure out seen values from firstValues since ordinals are non-negative
+        }
+
+$else$
+        @Override
+        public void enableGroupIdTracking(SeenGroupIds seen) {
+            // we track the seen values manually
+        }
+
+        private void trackGroupId(int groupId) {
+            if (seen != null) {
+                seen.set(groupId);
+            }
+        }
+
+        /**
+         * Returns true if the group has a value in firstValues; having a value in nextValues is optional.
+         * Returns false if the group does not have values in either firstValues or nextValues.
+         */
+        private boolean hasValue(int groupId) {
+            return seen == null || seen.get(groupId);
+        }
+
+$endif$
         /**
          * Builds a {@link Block} with the unique values collected for the {@code #selected}
          * groups. This is the implementation of the final and intermediate results of the agg.
          */
         Block toBlock(BlockFactory blockFactory, IntVector selected) {
-            if (values.size() == 0) {
-                return blockFactory.newConstantNullBlock(selected.getPositionCount());
-            }
-
-            try (var sorted = buildSorted(selected)) {
+            nextValues.prepareForEmitting(selected);
 $if(BytesRef)$
-                if (OrdinalBytesRefBlock.isDense(values.size(), bytes.size())) {
-                    return buildOrdinalOutputBlock(blockFactory, selected, sorted.counts, sorted.ids);
-                } else {
-                    return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids);
-                }
+            if (OrdinalBytesRefBlock.isDense(firstValues.size() + nextValues.hashes.size(), bytes.size())) {
+                return buildOrdinalOutputBlock(blockFactory, selected);
+            } else {
+                return buildOutputBlock(blockFactory, selected);
+            }
 $else$
-                return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids);
+            return buildOutputBlock(blockFactory, selected);
 $endif$
-            }
         }
 
-        private Sorted buildSorted(IntVector selected) {
-            long selectedCountsSize = 0;
-            long idsSize = 0;
-            Sorted sorted = null;
-            try {
-                /*
-                 * Get a count of all groups less than the maximum selected group. Count
-                 * *downwards* so that we can flip the sign on all of the actually selected
-                 * groups. Negative values in this array are always unselected groups.
-                 */
-                int selectedCountsLen = selected.max() + 1;
-                long adjust = RamUsageEstimator.alignObjectSize(
-                    RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + selectedCountsLen * Integer.BYTES
-                );
-                blockFactory.adjustBreaker(adjust);
-                selectedCountsSize = adjust;
-                int[] selectedCounts = new int[selectedCountsLen];
-                for (int id = 0; id < values.size(); id++) {
-$if(long||BytesRef||double)$
-                    int group = (int) values.getKey1(id);
-$elseif(float||int)$
-                    long both = values.get(id);
-                    int group = (int) (both >>> Float.SIZE);
-$endif$
-                    if (group < selectedCounts.length) {
-                        selectedCounts[group]--;
-                    }
-                }
-
-                /*
-                 * Total the selected groups and turn the counts into the start index into a sort-of
-                 * off-by-one running count. It's really the number of values that have been inserted
-                 * into the results before starting on this group. Unselected groups will still
-                 * have negative counts.
-                 *
-                 * For example, if
-                 * | Group | Value Count | Selected |
-                 * |-------|-------------|----------|
-                 * |     0 | 3           | <-       |
-                 * |     1 | 1           | <-       |
-                 * |     2 | 2           |          |
-                 * |     3 | 1           | <-       |
-                 * |     4 | 4           | <-       |
-                 *
-                 * Then the total is 9 and the counts array will contain 0, 3, -2, 4, 5
-                 */
-                int total = 0;
-                for (int s = 0; s < selected.getPositionCount(); s++) {
-                    int group = selected.getInt(s);
-                    int count = -selectedCounts[group];
-                    selectedCounts[group] = total;
-                    total += count;
-                }
-
-                /*
-                 * Build a list of ids to insert in order *and* convert the running
-                 * count in selectedCounts[group] into the end index (exclusive) in
-                 * ids for each group.
-                 * Here we use the negative counts to signal that a group hasn't been
-                 * selected and the id containing values for that group is ignored.
-                 *
-                 * For example, if
-                 * | Group | Value Count | Selected |
-                 * |-------|-------------|----------|
-                 * |     0 | 3           | <-       |
-                 * |     1 | 1           | <-       |
-                 * |     2 | 2           |          |
-                 * |     3 | 1           | <-       |
-                 * |     4 | 4           | <-       |
-                 *
-                 * Then the total is 9 and the counts array will start with 0, 3, -2, 4, 5.
-                 * The counts will end with 3, 4, -2, 5, 9.
-                 */
-                adjust = RamUsageEstimator.alignObjectSize(RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + total * Integer.BYTES);
-                blockFactory.adjustBreaker(adjust);
-                idsSize = adjust;
-                int[] ids = new int[total];
-                for (int id = 0; id < values.size(); id++) {
-        $if(long||BytesRef||double)$
-                    int group = (int) values.getKey1(id);
-        $elseif(float||int)$
-                    long both = values.get(id);
-                    int group = (int) (both >>> Float.SIZE);
-        $endif$
-                    if (group < selectedCounts.length && selectedCounts[group] >= 0) {
-                        ids[selectedCounts[group]++] = id;
-                    }
-                }
-                final long totalMemoryUsed = selectedCountsSize + idsSize;
-                sorted = new Sorted(() -> blockFactory.adjustBreaker(-totalMemoryUsed), selectedCounts, ids);
-                return sorted;
-            } finally {
-                if (sorted == null) {
-                    blockFactory.adjustBreaker(-selectedCountsSize - idsSize);
-                }
-            }
-        }
-
-        Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) {
+        Block buildOutputBlock(BlockFactory blockFactory, IntVector selected) {
             /*
              * Insert the ids in order.
              */
 $if(BytesRef)$
             BytesRef scratch = new BytesRef();
 $endif$
+            final int[] nextValueCounts = nextValues.selectedCounts;
             try ($Type$Block.Builder builder = blockFactory.new$Type$BlockBuilder(selected.getPositionCount())) {
-                int start = 0;
+                int nextValuesStart = 0;
                 for (int s = 0; s < selected.getPositionCount(); s++) {
                     int group = selected.getInt(s);
-                    int end = selectedCounts[group];
-                    int count = end - start;
-                    switch (count) {
-                        case 0 -> builder.appendNull();
-                        case 1 -> builder.append$Type$(getValue(ids[start]$if(BytesRef)$, scratch$endif$));
-                        default -> {
-                            builder.beginPositionEntry();
-                            for (int i = start; i < end; i++) {
-                                builder.append$Type$(getValue(ids[i]$if(BytesRef)$, scratch$endif$));
-                            }
-                            builder.endPositionEntry();
+$if(BytesRef)$
+                    int firstValue = group >= firstValues.size() ? -1 : firstValues.get(group) - 1;
+                    if (firstValue < 0) {
+                        builder.appendNull();
+                        continue;
+                    }
+$else$
+                    if (group > maxGroupId || hasValue(group) == false) {
+                        builder.appendNull();
+                        continue;
+                    }
+                    $type$ firstValue = firstValues.get(group);
+$endif$
+                    final int nextValuesEnd = nextValueCounts != null ? nextValueCounts[group] : nextValuesStart;
+                    if (nextValuesEnd == nextValuesStart) {
+                        builder.append$Type$($if(BytesRef)$bytes.get(firstValue, scratch)$else$firstValue$endif$);
+                    } else {
+                        builder.beginPositionEntry();
+                        builder.append$Type$($if(BytesRef)$bytes.get(firstValue, scratch)$else$firstValue$endif$);
+                        // append values from the nextValues
+                        for (int i = nextValuesStart; i < nextValuesEnd; i++) {
+                            var nextValue = nextValues.getValue(i);
+                            builder.append$Type$($if(BytesRef)$bytes.get(nextValue, scratch)$else$nextValue$endif$);
                         }
+                        builder.endPositionEntry();
+                        nextValuesStart = nextValuesEnd;
                     }
-                    start = end;
                 }
                 return builder.build();
             }
         }
 
 $if(BytesRef)$
-        Block buildOrdinalOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) {
+        Block buildOrdinalOutputBlock(BlockFactory blockFactory, IntVector selected) {
             BytesRefVector dict = null;
             IntBlock ordinals = null;
             BytesRefBlock result = null;
             var dictArray = bytes.takeBytesRefsOwnership();
             bytes = null; // transfer ownership to dictArray
-            try (var builder = blockFactory.newIntBlockBuilder(selected.getPositionCount())) {
-                int start = 0;
+            int estimateSize = Math.toIntExact(firstValues.size() + nextValues.hashes.size());
+            final int[] nextValueCounts = nextValues.selectedCounts;
+            try (var builder = blockFactory.newIntBlockBuilder(estimateSize)) {
+                int nextValuesStart = 0;
                 for (int s = 0; s < selected.getPositionCount(); s++) {
                     int group = selected.getInt(s);
-                    int end = selectedCounts[group];
-                    int count = end - start;
-                    switch (count) {
-                        case 0 -> builder.appendNull();
-                        case 1 -> builder.appendInt(Math.toIntExact(values.getKey2(ids[start])));
-                        default -> {
-                            builder.beginPositionEntry();
-                            for (int i = start; i < end; i++) {
-                                builder.appendInt(Math.toIntExact(values.getKey2(ids[i])));
-                            }
-                            builder.endPositionEntry();
+                    if (firstValues.size() < group) {
+                        builder.appendNull();
+                        continue;
+                    }
+                    int firstValue = firstValues.get(group) - 1;
+                    if (firstValue < 0) {
+                        builder.appendNull();
+                        continue;
+                    }
+                    final int nextValuesEnd = nextValueCounts != null ? nextValueCounts[group] : nextValuesStart;
+                    if (nextValuesEnd == nextValuesStart) {
+                        builder.appendInt(firstValue);
+                    } else {
+                        builder.beginPositionEntry();
+                        builder.appendInt(firstValue);
+                        for (int i = nextValuesStart; i < nextValuesEnd; i++) {
+                            builder.appendInt(nextValues.getValue(i));
                         }
+                        builder.endPositionEntry();
                     }
-                    start = end;
+                    nextValuesStart = nextValuesEnd;
                 }
                 ordinals = builder.build();
                 dict = blockFactory.newBytesRefArrayVector(dictArray, Math.toIntExact(dictArray.size()));
@@ -486,34 +561,9 @@ $if(BytesRef)$
         }
 $endif$
 
-        $type$ getValue(int valueId$if(BytesRef)$, BytesRef scratch$endif$) {
-$if(BytesRef)$
-            return bytes.get(values.getKey2(valueId), scratch);
-$elseif(long)$
-            return values.getKey2(valueId);
-$elseif(double)$
-            return Double.longBitsToDouble(values.getKey2(valueId));
-$elseif(float)$
-            long both = values.get(valueId);
-            return Float.intBitsToFloat((int) both);
-$elseif(int)$
-            long both = values.get(valueId);
-            return (int) both;
-$endif$
-        }
-
-        @Override
-        public void enableGroupIdTracking(SeenGroupIds seen) {
-            // we figure out seen values from nulls on the values block
-        }
-
         @Override
         public void close() {
-$if(BytesRef)$
-            Releasables.closeExpectNoException(values, bytes);
-$else$
-            Releasables.closeExpectNoException(values);
-$endif$
+            Releasables.closeExpectNoException($if(BytesRef)$bytes$else$seen$endif$, firstValues, nextValues);
         }
     }
 }