Browse Source

Add optimized path for intermediate values aggregator (#131390)

Similar to #127849, this change adds an optimized path for leveraging 
ordinal blocks of intermediate input pages in the Values aggregator.
Below are the micro-benchmark results.

Before:
```
// 1 raw input page + 1000 intermediate input pages
Benchmark                      (dataType)  (groups)  Mode  Cnt       Score   Error  Units
ValuesAggregatorBenchmark.run    BytesRef         1  avgt    2       0.382          ms/op
ValuesAggregatorBenchmark.run    BytesRef      1000  avgt    2     112.293          ms/op
ValuesAggregatorBenchmark.run    BytesRef   1000000  avgt    2  113182.908          ms/op
```

```
After:
// 1 raw input page + 1000 intermediate input pages
Benchmark                      (dataType)  (groups)  Mode  Cnt      Score   Error  Units
ValuesAggregatorBenchmark.run    BytesRef         1  avgt    2      0.378          ms/op
ValuesAggregatorBenchmark.run    BytesRef      1000  avgt    2     34.410          ms/op
ValuesAggregatorBenchmark.run    BytesRef   1000000  avgt    2  64654.830          ms/op
```
1K groups:  112 ms -> 34.4ms
1M groups:     113s -> 64s

More to come with #130510

Relates #127849
Nhat Nguyen 2 months ago
parent
commit
256437902b

+ 20 - 9
benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java

@@ -113,16 +113,16 @@ public class ValuesAggregatorBenchmark {
     @Param({ BYTES_REF, INT, LONG })
     public String dataType;
 
-    private static Operator operator(DriverContext driverContext, int groups, String dataType) {
+    private static Operator operator(DriverContext driverContext, int groups, String dataType, AggregatorMode mode) {
         if (groups == 1) {
             return new AggregationOperator(
-                List.of(supplier(dataType).aggregatorFactory(AggregatorMode.SINGLE, List.of(0)).apply(driverContext)),
+                List.of(supplier(dataType).aggregatorFactory(mode, List.of(0)).apply(driverContext)),
                 driverContext
             );
         }
         List<BlockHash.GroupSpec> groupSpec = List.of(new BlockHash.GroupSpec(0, ElementType.LONG));
         return new HashAggregationOperator(
-            List.of(supplier(dataType).groupingAggregatorFactory(AggregatorMode.SINGLE, List.of(1))),
+            List.of(supplier(dataType).groupingAggregatorFactory(mode, List.of(1))),
             () -> BlockHash.build(groupSpec, driverContext.blockFactory(), 16 * 1024, false),
             driverContext
         ) {
@@ -177,6 +177,9 @@ public class ValuesAggregatorBenchmark {
 
                 // Check them
                 BytesRefBlock values = page.getBlock(1);
+                if (values.asOrdinals() == null) {
+                    throw new AssertionError(" expected ordinals; but got " + values);
+                }
                 for (int p = 0; p < groups; p++) {
                     checkExpectedBytesRef(prefix, values, p, expected.get(p));
                 }
@@ -341,13 +344,21 @@ public class ValuesAggregatorBenchmark {
 
     private static void run(int groups, String dataType, int opCount) {
         DriverContext driverContext = driverContext();
-        try (Operator operator = operator(driverContext, groups, dataType)) {
-            Page page = page(groups, dataType);
-            for (int i = 0; i < opCount; i++) {
-                operator.addInput(page.shallowCopy());
+        try (Operator finalAggregator = operator(driverContext, groups, dataType, AggregatorMode.FINAL)) {
+            try (Operator initialAggregator = operator(driverContext, groups, dataType, AggregatorMode.INITIAL)) {
+                Page rawPage = page(groups, dataType);
+                for (int i = 0; i < opCount; i++) {
+                    initialAggregator.addInput(rawPage.shallowCopy());
+                }
+                initialAggregator.finish();
+                Page intermediatePage = initialAggregator.getOutput();
+                for (int i = 0; i < opCount; i++) {
+                    finalAggregator.addInput(intermediatePage.shallowCopy());
+                }
             }
-            operator.finish();
-            checkExpected(groups, dataType, operator.getOutput());
+            finalAggregator.finish();
+            Page outputPage = finalAggregator.getOutput();
+            checkExpected(groups, dataType, outputPage);
         }
     }
 

+ 5 - 0
docs/changelog/131390.yaml

@@ -0,0 +1,5 @@
+pr: 131390
+summary: Add optimized path for intermediate values aggregator
+area: ES|QL
+type: enhancement
+issues: []

+ 84 - 62
x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java

@@ -58,6 +58,7 @@ import static org.elasticsearch.compute.gen.Types.GROUPING_AGGREGATOR_FUNCTION_A
 import static org.elasticsearch.compute.gen.Types.INTERMEDIATE_STATE_DESC;
 import static org.elasticsearch.compute.gen.Types.INT_ARRAY_BLOCK;
 import static org.elasticsearch.compute.gen.Types.INT_BIG_ARRAY_BLOCK;
+import static org.elasticsearch.compute.gen.Types.INT_BLOCK;
 import static org.elasticsearch.compute.gen.Types.INT_VECTOR;
 import static org.elasticsearch.compute.gen.Types.LIST_AGG_FUNC_DESC;
 import static org.elasticsearch.compute.gen.Types.LIST_INTEGER;
@@ -609,77 +610,98 @@ public class GroupingAggregatorImplementer {
                         .collect(joining(" && "))
             );
         }
-        if (intermediateState.stream().map(AggregatorImplementer.IntermediateStateDesc::elementType).anyMatch(n -> n.equals("BYTES_REF"))) {
-            builder.addStatement("$T scratch = new $T()", BYTES_REF, BYTES_REF);
-        }
-        builder.beginControlFlow("for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++)");
-        {
-            if (groupsIsBlock) {
-                builder.beginControlFlow("if (groups.isNull(groupPosition))");
-                builder.addStatement("continue");
-                builder.endControlFlow();
-                builder.addStatement("int groupStart = groups.getFirstValueIndex(groupPosition)");
-                builder.addStatement("int groupEnd = groupStart + groups.getValueCount(groupPosition)");
-                builder.beginControlFlow("for (int g = groupStart; g < groupEnd; g++)");
-                builder.addStatement("int groupId = groups.getInt(g)");
-            } else {
-                builder.addStatement("int groupId = groups.getInt(groupPosition)");
+        var bulkCombineIntermediateMethod = optionalStaticMethod(
+            declarationType,
+            requireVoidType(),
+            requireName("combineIntermediate"),
+            requireArgs(
+                Stream.concat(
+                    // aggState, positionOffset, groupIds
+                    Stream.of(aggState.declaredType(), TypeName.INT, groupsIsBlock ? INT_BLOCK : INT_VECTOR),
+                    intermediateState.stream().map(AggregatorImplementer.IntermediateStateDesc::combineArgType)
+                ).map(Methods::requireType).toArray(Methods.TypeMatcher[]::new)
+            )
+        );
+        if (bulkCombineIntermediateMethod != null) {
+            var states = intermediateState.stream()
+                .map(AggregatorImplementer.IntermediateStateDesc::name)
+                .collect(Collectors.joining(", "));
+            builder.addStatement("$T.combineIntermediate(state, positionOffset, groups, " + states + ")", declarationType);
+        } else {
+            if (intermediateState.stream()
+                .map(AggregatorImplementer.IntermediateStateDesc::elementType)
+                .anyMatch(n -> n.equals("BYTES_REF"))) {
+                builder.addStatement("$T scratch = new $T()", BYTES_REF, BYTES_REF);
             }
-
-            if (aggState.declaredType().isPrimitive()) {
-                if (warnExceptions.isEmpty()) {
-                    assert intermediateState.size() == 2;
-                    assert intermediateState.get(1).name().equals("seen");
-                    builder.beginControlFlow("if (seen.getBoolean(groupPosition + positionOffset))");
+            builder.beginControlFlow("for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++)");
+            {
+                if (groupsIsBlock) {
+                    builder.beginControlFlow("if (groups.isNull(groupPosition))");
+                    builder.addStatement("continue");
+                    builder.endControlFlow();
+                    builder.addStatement("int groupStart = groups.getFirstValueIndex(groupPosition)");
+                    builder.addStatement("int groupEnd = groupStart + groups.getValueCount(groupPosition)");
+                    builder.beginControlFlow("for (int g = groupStart; g < groupEnd; g++)");
+                    builder.addStatement("int groupId = groups.getInt(g)");
                 } else {
-                    assert intermediateState.size() == 3;
-                    assert intermediateState.get(1).name().equals("seen");
-                    assert intermediateState.get(2).name().equals("failed");
-                    builder.beginControlFlow("if (failed.getBoolean(groupPosition + positionOffset))");
-                    {
-                        builder.addStatement("state.setFailed(groupId)");
-                    }
-                    builder.nextControlFlow("else if (seen.getBoolean(groupPosition + positionOffset))");
+                    builder.addStatement("int groupId = groups.getInt(groupPosition)");
                 }
 
-                warningsBlock(builder, () -> {
-                    var name = intermediateState.get(0).name();
-                    var vectorAccessor = vectorAccessorName(intermediateState.get(0).elementType());
-                    builder.addStatement(
-                        "state.set(groupId, $T.combine(state.getOrDefault(groupId), $L.$L(groupPosition + positionOffset)))",
+                if (aggState.declaredType().isPrimitive()) {
+                    if (warnExceptions.isEmpty()) {
+                        assert intermediateState.size() == 2;
+                        assert intermediateState.get(1).name().equals("seen");
+                        builder.beginControlFlow("if (seen.getBoolean(groupPosition + positionOffset))");
+                    } else {
+                        assert intermediateState.size() == 3;
+                        assert intermediateState.get(1).name().equals("seen");
+                        assert intermediateState.get(2).name().equals("failed");
+                        builder.beginControlFlow("if (failed.getBoolean(groupPosition + positionOffset))");
+                        {
+                            builder.addStatement("state.setFailed(groupId)");
+                        }
+                        builder.nextControlFlow("else if (seen.getBoolean(groupPosition + positionOffset))");
+                    }
+
+                    warningsBlock(builder, () -> {
+                        var name = intermediateState.get(0).name();
+                        var vectorAccessor = vectorAccessorName(intermediateState.get(0).elementType());
+                        builder.addStatement(
+                            "state.set(groupId, $T.combine(state.getOrDefault(groupId), $L.$L(groupPosition + positionOffset)))",
+                            declarationType,
+                            name,
+                            vectorAccessor
+                        );
+                    });
+                    builder.endControlFlow();
+                } else {
+                    var stateHasBlock = intermediateState.stream().anyMatch(AggregatorImplementer.IntermediateStateDesc::block);
+                    requireStaticMethod(
                         declarationType,
-                        name,
-                        vectorAccessor
+                        requireVoidType(),
+                        requireName("combineIntermediate"),
+                        requireArgs(
+                            Stream.of(
+                                Stream.of(aggState.declaredType(), TypeName.INT), // aggState and groupId
+                                intermediateState.stream().map(AggregatorImplementer.IntermediateStateDesc::combineArgType),
+                                Stream.of(TypeName.INT).filter(p -> stateHasBlock) // position
+                            ).flatMap(Function.identity()).map(Methods::requireType).toArray(Methods.TypeMatcher[]::new)
+                        )
                     );
-                });
-                builder.endControlFlow();
-            } else {
-                var stateHasBlock = intermediateState.stream().anyMatch(AggregatorImplementer.IntermediateStateDesc::block);
-                requireStaticMethod(
-                    declarationType,
-                    requireVoidType(),
-                    requireName("combineIntermediate"),
-                    requireArgs(
-                        Stream.of(
-                            Stream.of(aggState.declaredType(), TypeName.INT), // aggState and groupId
-                            intermediateState.stream().map(AggregatorImplementer.IntermediateStateDesc::combineArgType),
-                            Stream.of(TypeName.INT).filter(p -> stateHasBlock) // position
-                        ).flatMap(Function.identity()).map(Methods::requireType).toArray(Methods.TypeMatcher[]::new)
-                    )
-                );
 
-                builder.addStatement(
-                    "$T.combineIntermediate(state, groupId, "
-                        + intermediateState.stream().map(desc -> desc.access("groupPosition + positionOffset")).collect(joining(", "))
-                        + (stateHasBlock ? ", groupPosition + positionOffset" : "")
-                        + ")",
-                    declarationType
-                );
-            }
-            if (groupsIsBlock) {
+                    builder.addStatement(
+                        "$T.combineIntermediate(state, groupId, "
+                            + intermediateState.stream().map(desc -> desc.access("groupPosition + positionOffset")).collect(joining(", "))
+                            + (stateHasBlock ? ", groupPosition + positionOffset" : "")
+                            + ")",
+                        declarationType
+                    );
+                }
+                if (groupsIsBlock) {
+                    builder.endControlFlow();
+                }
                 builder.endControlFlow();
             }
-            builder.endControlFlow();
         }
         return builder.build();
     }

+ 7 - 8
x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java

@@ -80,13 +80,12 @@ class ValuesBytesRefAggregator {
         state.addValue(groupId, v);
     }
 
-    public static void combineIntermediate(GroupingState state, int groupId, BytesRefBlock values, int valuesPosition) {
-        BytesRef scratch = new BytesRef();
-        int start = values.getFirstValueIndex(valuesPosition);
-        int end = start + values.getValueCount(valuesPosition);
-        for (int i = start; i < end; i++) {
-            state.addValue(groupId, values.getBytesRef(i, scratch));
-        }
+    public static void combineIntermediate(GroupingState state, int positionOffset, IntVector groups, BytesRefBlock values) {
+        ValuesBytesRefAggregators.combineIntermediateInputValues(state, positionOffset, groups, values);
+    }
+
+    public static void combineIntermediate(GroupingState state, int positionOffset, IntBlock groups, BytesRefBlock values) {
+        ValuesBytesRefAggregators.combineIntermediateInputValues(state, positionOffset, groups, values);
     }
 
     public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) {
@@ -199,7 +198,7 @@ class ValuesBytesRefAggregator {
             }
 
             try (var sorted = buildSorted(selected)) {
-                if (OrdinalBytesRefBlock.isDense(selected.getPositionCount(), Math.toIntExact(values.size()))) {
+                if (OrdinalBytesRefBlock.isDense(values.size(), bytes.size())) {
                     return buildOrdinalOutputBlock(blockFactory, selected, sorted.counts, sorted.ids);
                 } else {
                     return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids);

+ 3 - 29
x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBytesRefGroupingAggregatorFunction.java

@@ -152,18 +152,7 @@ public final class ValuesBytesRefGroupingAggregatorFunction implements GroupingA
       return;
     }
     BytesRefBlock values = (BytesRefBlock) valuesUncast;
-    BytesRef scratch = new BytesRef();
-    for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
-      if (groups.isNull(groupPosition)) {
-        continue;
-      }
-      int groupStart = groups.getFirstValueIndex(groupPosition);
-      int groupEnd = groupStart + groups.getValueCount(groupPosition);
-      for (int g = groupStart; g < groupEnd; g++) {
-        int groupId = groups.getInt(g);
-        ValuesBytesRefAggregator.combineIntermediate(state, groupId, values, groupPosition + positionOffset);
-      }
-    }
+    ValuesBytesRefAggregator.combineIntermediate(state, positionOffset, groups, values);
   }
 
   private void addRawInput(int positionOffset, IntBigArrayBlock groups, BytesRefBlock values) {
@@ -209,18 +198,7 @@ public final class ValuesBytesRefGroupingAggregatorFunction implements GroupingA
       return;
     }
     BytesRefBlock values = (BytesRefBlock) valuesUncast;
-    BytesRef scratch = new BytesRef();
-    for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
-      if (groups.isNull(groupPosition)) {
-        continue;
-      }
-      int groupStart = groups.getFirstValueIndex(groupPosition);
-      int groupEnd = groupStart + groups.getValueCount(groupPosition);
-      for (int g = groupStart; g < groupEnd; g++) {
-        int groupId = groups.getInt(g);
-        ValuesBytesRefAggregator.combineIntermediate(state, groupId, values, groupPosition + positionOffset);
-      }
-    }
+    ValuesBytesRefAggregator.combineIntermediate(state, positionOffset, groups, values);
   }
 
   private void addRawInput(int positionOffset, IntVector groups, BytesRefBlock values) {
@@ -255,11 +233,7 @@ public final class ValuesBytesRefGroupingAggregatorFunction implements GroupingA
       return;
     }
     BytesRefBlock values = (BytesRefBlock) valuesUncast;
-    BytesRef scratch = new BytesRef();
-    for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
-      int groupId = groups.getInt(groupPosition);
-      ValuesBytesRefAggregator.combineIntermediate(state, groupId, values, groupPosition + positionOffset);
-    }
+    ValuesBytesRefAggregator.combineIntermediate(state, positionOffset, groups, values);
   }
 
   @Override

+ 114 - 33
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregators.java

@@ -28,15 +28,7 @@ final class ValuesBytesRefAggregators {
         if (valuesOrdinal == null) {
             return delegate;
         }
-        BytesRefVector dict = valuesOrdinal.getDictionaryVector();
-        final IntVector hashIds;
-        BytesRef spare = new BytesRef();
-        try (var hashIdsBuilder = values.blockFactory().newIntVectorFixedBuilder(dict.getPositionCount())) {
-            for (int p = 0; p < dict.getPositionCount(); p++) {
-                hashIdsBuilder.appendInt(Math.toIntExact(BlockHash.hashOrdToGroup(state.bytes.add(dict.getBytesRef(p, spare)))));
-            }
-            hashIds = hashIdsBuilder.build();
-        }
+        final IntVector hashIds = hashDict(state, valuesOrdinal.getDictionaryVector());
         IntBlock ordinalIds = valuesOrdinal.getOrdinalsBlock();
         return new GroupingAggregatorFunction.AddInput() {
             @Override
@@ -85,17 +77,7 @@ final class ValuesBytesRefAggregators {
 
             @Override
             public void add(int positionOffset, IntVector groupIds) {
-                for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) {
-                    int groupId = groupIds.getInt(groupPosition);
-                    if (ordinalIds.isNull(groupPosition + positionOffset)) {
-                        continue;
-                    }
-                    int valuesStart = ordinalIds.getFirstValueIndex(groupPosition + positionOffset);
-                    int valuesEnd = valuesStart + ordinalIds.getValueCount(groupPosition + positionOffset);
-                    for (int v = valuesStart; v < valuesEnd; v++) {
-                        state.addValueOrdinal(groupId, hashIds.getInt(ordinalIds.getInt(v)));
-                    }
-                }
+                addOrdinalInputBlock(state, positionOffset, groupIds, ordinalIds, hashIds);
             }
 
             @Override
@@ -114,15 +96,7 @@ final class ValuesBytesRefAggregators {
         if (valuesOrdinal == null) {
             return delegate;
         }
-        BytesRefVector dict = valuesOrdinal.getDictionaryVector();
-        final IntVector hashIds;
-        BytesRef spare = new BytesRef();
-        try (var hashIdsBuilder = values.blockFactory().newIntVectorFixedBuilder(dict.getPositionCount())) {
-            for (int p = 0; p < dict.getPositionCount(); p++) {
-                hashIdsBuilder.appendInt(Math.toIntExact(BlockHash.hashOrdToGroup(state.bytes.add(dict.getBytesRef(p, spare)))));
-            }
-            hashIds = hashIdsBuilder.build();
-        }
+        final IntVector hashIds = hashDict(state, valuesOrdinal.getDictionaryVector());
         var ordinalIds = valuesOrdinal.getOrdinalsVector();
         return new GroupingAggregatorFunction.AddInput() {
             @Override
@@ -157,10 +131,7 @@ final class ValuesBytesRefAggregators {
 
             @Override
             public void add(int positionOffset, IntVector groupIds) {
-                for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) {
-                    int groupId = groupIds.getInt(groupPosition);
-                    state.addValueOrdinal(groupId, hashIds.getInt(ordinalIds.getInt(groupPosition + positionOffset)));
-                }
+                addOrdinalInputVector(state, positionOffset, groupIds, ordinalIds, hashIds);
             }
 
             @Override
@@ -169,4 +140,114 @@ final class ValuesBytesRefAggregators {
             }
         };
     }
+
+    static IntVector hashDict(ValuesBytesRefAggregator.GroupingState state, BytesRefVector dict) {
+        BytesRef scratch = new BytesRef();
+        try (var hashIdsBuilder = dict.blockFactory().newIntVectorFixedBuilder(dict.getPositionCount())) {
+            for (int p = 0; p < dict.getPositionCount(); p++) {
+                final long hashId = BlockHash.hashOrdToGroup(state.bytes.add(dict.getBytesRef(p, scratch)));
+                hashIdsBuilder.appendInt(Math.toIntExact(hashId));
+            }
+            return hashIdsBuilder.build();
+        }
+    }
+
+    static void addOrdinalInputBlock(
+        ValuesBytesRefAggregator.GroupingState state,
+        int positionOffset,
+        IntVector groupIds,
+        IntBlock ordinalIds,
+        IntVector hashIds
+    ) {
+        for (int p = 0; p < groupIds.getPositionCount(); p++) {
+            final int valuePosition = p + positionOffset;
+            final int groupId = groupIds.getInt(valuePosition);
+            final int start = ordinalIds.getFirstValueIndex(valuePosition);
+            final int end = start + ordinalIds.getValueCount(valuePosition);
+            for (int i = start; i < end; i++) {
+                int ord = ordinalIds.getInt(i);
+                state.addValueOrdinal(groupId, hashIds.getInt(ord));
+            }
+        }
+    }
+
+    static void addOrdinalInputVector(
+        ValuesBytesRefAggregator.GroupingState state,
+        int positionOffset,
+        IntVector groupIds,
+        IntVector ordinalIds,
+        IntVector hashIds
+    ) {
+        for (int p = 0; p < groupIds.getPositionCount(); p++) {
+            int groupId = groupIds.getInt(p);
+            int ord = ordinalIds.getInt(p + positionOffset);
+            state.addValueOrdinal(groupId, hashIds.getInt(ord));
+        }
+    }
+
+    static void combineIntermediateInputValues(
+        ValuesBytesRefAggregator.GroupingState state,
+        int positionOffset,
+        IntVector groupIds,
+        BytesRefBlock values
+    ) {
+        BytesRefVector dict = null;
+        IntBlock ordinals = null;
+        {
+            final OrdinalBytesRefBlock asOrdinals = values.asOrdinals();
+            if (asOrdinals != null) {
+                dict = asOrdinals.getDictionaryVector();
+                ordinals = asOrdinals.getOrdinalsBlock();
+            }
+        }
+        if (dict != null && dict.getPositionCount() < groupIds.getPositionCount()) {
+            try (var hashIds = hashDict(state, dict)) {
+                IntVector ordinalsVector = ordinals.asVector();
+                if (ordinalsVector != null) {
+                    addOrdinalInputVector(state, positionOffset, groupIds, ordinalsVector, hashIds);
+                } else {
+                    addOrdinalInputBlock(state, positionOffset, groupIds, ordinals, hashIds);
+                }
+            }
+        } else {
+            final BytesRef scratch = new BytesRef();
+            for (int p = 0; p < groupIds.getPositionCount(); p++) {
+                final int valuePosition = p + positionOffset;
+                final int groupId = groupIds.getInt(valuePosition);
+                final int start = values.getFirstValueIndex(valuePosition);
+                final int end = start + values.getValueCount(valuePosition);
+                for (int i = start; i < end; i++) {
+                    state.addValue(groupId, values.getBytesRef(i, scratch));
+                }
+            }
+        }
+    }
+
+    static void combineIntermediateInputValues(
+        ValuesBytesRefAggregator.GroupingState state,
+        int positionOffset,
+        IntBlock groupIds,
+        BytesRefBlock values
+    ) {
+        final BytesRef scratch = new BytesRef();
+        for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) {
+            if (groupIds.isNull(groupPosition)) {
+                continue;
+            }
+            int groupStart = groupIds.getFirstValueIndex(groupPosition);
+            int groupEnd = groupStart + groupIds.getValueCount(groupPosition);
+            for (int g = groupStart; g < groupEnd; g++) {
+                if (values.isNull(groupPosition + positionOffset)) {
+                    continue;
+                }
+                int groupId = groupIds.getInt(g);
+                int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
+                int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset);
+                for (int v = valuesStart; v < valuesEnd; v++) {
+                    var bytes = values.getBytesRef(v, scratch);
+                    state.addValue(groupId, bytes);
+                }
+            }
+        }
+    }
 }

+ 12 - 8
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st

@@ -113,20 +113,24 @@ $endif$
         state.addValue(groupId, v);
     }
 
-    public static void combineIntermediate(GroupingState state, int groupId, $Type$Block values, int valuesPosition) {
 $if(BytesRef)$
-        BytesRef scratch = new BytesRef();
-$endif$
+    public static void combineIntermediate(GroupingState state, int positionOffset, IntVector groups, $Type$Block values) {
+        ValuesBytesRefAggregators.combineIntermediateInputValues(state, positionOffset, groups, values);
+    }
+
+    public static void combineIntermediate(GroupingState state, int positionOffset, IntBlock groups, $Type$Block values) {
+        ValuesBytesRefAggregators.combineIntermediateInputValues(state, positionOffset, groups, values);
+    }
+
+$else$
+    public static void combineIntermediate(GroupingState state, int groupId, $Type$Block values, int valuesPosition) {
         int start = values.getFirstValueIndex(valuesPosition);
         int end = start + values.getValueCount(valuesPosition);
         for (int i = start; i < end; i++) {
-$if(BytesRef)$
-            state.addValue(groupId, values.getBytesRef(i, scratch));
-$else$
             state.addValue(groupId, values.get$Type$(i));
-$endif$
         }
     }
+$endif$
 
     public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) {
         return state.toBlock(driverContext.blockFactory(), selected);
@@ -304,7 +308,7 @@ $endif$
 
             try (var sorted = buildSorted(selected)) {
 $if(BytesRef)$
-                if (OrdinalBytesRefBlock.isDense(selected.getPositionCount(), Math.toIntExact(values.size()))) {
+                if (OrdinalBytesRefBlock.isDense(values.size(), bytes.size())) {
                     return buildOrdinalOutputBlock(blockFactory, selected, sorted.counts, sorted.ids);
                 } else {
                     return buildOutputBlock(blockFactory, selected, sorted.counts, sorted.ids);