1
0
Эх сурвалжийг харах

ESQL: Speed up grouping by bytes (#114021) (#114652)

This speeds up grouping by bytes valued fields (keyword, text, ip, and
wildcard) when the input is an ordinal block:
```
    bytes_refs 22.213 ± 0.322 -> 19.848 ± 0.205 ns/op (*maybe* real, maybe noise. still good)
       ordinal didn't exist   ->  2.988 ± 0.011 ns/op
```
I see this as 20ns -> 3ns, an 85% speed up. We never hard the ordinals
branch before so I'm expecting the same performance there - about 20ns
per op.

This also speeds up grouping by a pair of byte valued fields:
```
two_bytes_refs 83.112 ± 42.348  -> 46.521 ± 0.386 ns/op
  two_ordinals 83.531 ± 23.473  ->  8.617 ± 0.105 ns/op
```
The speed up is much better when the fields are ordinals because hashing
bytes is comparatively slow.

I believe the ordinals case is quite common. I've run into it in quite a
few profiles.
Nik Everett 1 жил өмнө
parent
commit
1212dee8b4
13 өөрчлөгдсөн 632 нэмэгдсэн , 66 устгасан
  1. 46 3
      benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/AggregatorBenchmark.java
  2. 5 0
      docs/changelog/114021.yaml
  3. 32 8
      x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/blockhash/BytesRefBlockHash.java
  4. 2 0
      x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/blockhash/DoubleBlockHash.java
  5. 2 0
      x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/blockhash/IntBlockHash.java
  6. 2 0
      x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/blockhash/LongBlockHash.java
  7. 43 8
      x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java
  8. 3 2
      x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BooleanBlockHash.java
  9. 196 0
      x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BytesRef2BlockHash.java
  10. 1 1
      x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BytesRef3BlockHash.java
  11. 34 8
      x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/X-BlockHash.java.st
  12. 135 36
      x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/BlockHashRandomizedTests.java
  13. 131 0
      x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/BlockHashTests.java

+ 46 - 3
benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/AggregatorBenchmark.java

@@ -30,10 +30,13 @@ import org.elasticsearch.compute.data.BlockFactory;
 import org.elasticsearch.compute.data.BooleanBlock;
 import org.elasticsearch.compute.data.BooleanVector;
 import org.elasticsearch.compute.data.BytesRefBlock;
+import org.elasticsearch.compute.data.BytesRefVector;
 import org.elasticsearch.compute.data.DoubleBlock;
 import org.elasticsearch.compute.data.ElementType;
 import org.elasticsearch.compute.data.IntBlock;
+import org.elasticsearch.compute.data.IntVector;
 import org.elasticsearch.compute.data.LongBlock;
+import org.elasticsearch.compute.data.OrdinalBytesRefVector;
 import org.elasticsearch.compute.data.Page;
 import org.elasticsearch.compute.operator.AggregationOperator;
 import org.elasticsearch.compute.operator.DriverContext;
@@ -78,7 +81,10 @@ public class AggregatorBenchmark {
     private static final String DOUBLES = "doubles";
     private static final String BOOLEANS = "booleans";
     private static final String BYTES_REFS = "bytes_refs";
+    private static final String ORDINALS = "ordinals";
     private static final String TWO_LONGS = "two_" + LONGS;
+    private static final String TWO_BYTES_REFS = "two_" + BYTES_REFS;
+    private static final String TWO_ORDINALS = "two_" + ORDINALS;
     private static final String LONGS_AND_BYTES_REFS = LONGS + "_and_" + BYTES_REFS;
     private static final String TWO_LONGS_AND_BYTES_REFS = "two_" + LONGS + "_and_" + BYTES_REFS;
 
@@ -119,7 +125,21 @@ public class AggregatorBenchmark {
         }
     }
 
-    @Param({ NONE, LONGS, INTS, DOUBLES, BOOLEANS, BYTES_REFS, TWO_LONGS, LONGS_AND_BYTES_REFS, TWO_LONGS_AND_BYTES_REFS })
+    @Param(
+        {
+            NONE,
+            LONGS,
+            INTS,
+            DOUBLES,
+            BOOLEANS,
+            BYTES_REFS,
+            ORDINALS,
+            TWO_LONGS,
+            TWO_BYTES_REFS,
+            TWO_ORDINALS,
+            LONGS_AND_BYTES_REFS,
+            TWO_LONGS_AND_BYTES_REFS }
+    )
     public String grouping;
 
     @Param({ COUNT, COUNT_DISTINCT, MIN, MAX, SUM })
@@ -144,8 +164,12 @@ public class AggregatorBenchmark {
             case INTS -> List.of(new BlockHash.GroupSpec(0, ElementType.INT));
             case DOUBLES -> List.of(new BlockHash.GroupSpec(0, ElementType.DOUBLE));
             case BOOLEANS -> List.of(new BlockHash.GroupSpec(0, ElementType.BOOLEAN));
-            case BYTES_REFS -> List.of(new BlockHash.GroupSpec(0, ElementType.BYTES_REF));
+            case BYTES_REFS, ORDINALS -> List.of(new BlockHash.GroupSpec(0, ElementType.BYTES_REF));
             case TWO_LONGS -> List.of(new BlockHash.GroupSpec(0, ElementType.LONG), new BlockHash.GroupSpec(1, ElementType.LONG));
+            case TWO_BYTES_REFS, TWO_ORDINALS -> List.of(
+                new BlockHash.GroupSpec(0, ElementType.BYTES_REF),
+                new BlockHash.GroupSpec(1, ElementType.BYTES_REF)
+            );
             case LONGS_AND_BYTES_REFS -> List.of(
                 new BlockHash.GroupSpec(0, ElementType.LONG),
                 new BlockHash.GroupSpec(1, ElementType.BYTES_REF)
@@ -218,6 +242,10 @@ public class AggregatorBenchmark {
                 checkGroupingBlock(prefix, LONGS, page.getBlock(0));
                 checkGroupingBlock(prefix, LONGS, page.getBlock(1));
             }
+            case TWO_BYTES_REFS, TWO_ORDINALS -> {
+                checkGroupingBlock(prefix, BYTES_REFS, page.getBlock(0));
+                checkGroupingBlock(prefix, BYTES_REFS, page.getBlock(1));
+            }
             case LONGS_AND_BYTES_REFS -> {
                 checkGroupingBlock(prefix, LONGS, page.getBlock(0));
                 checkGroupingBlock(prefix, BYTES_REFS, page.getBlock(1));
@@ -379,7 +407,7 @@ public class AggregatorBenchmark {
                     throw new AssertionError(prefix + "bad group expected [true] but was [" + groups.getBoolean(1) + "]");
                 }
             }
-            case BYTES_REFS -> {
+            case BYTES_REFS, ORDINALS -> {
                 BytesRefBlock groups = (BytesRefBlock) block;
                 for (int g = 0; g < GROUPS; g++) {
                     if (false == groups.getBytesRef(g, new BytesRef()).equals(bytesGroup(g))) {
@@ -508,6 +536,8 @@ public class AggregatorBenchmark {
     private static List<Block> groupingBlocks(String grouping, String blockType) {
         return switch (grouping) {
             case TWO_LONGS -> List.of(groupingBlock(LONGS, blockType), groupingBlock(LONGS, blockType));
+            case TWO_BYTES_REFS -> List.of(groupingBlock(BYTES_REFS, blockType), groupingBlock(BYTES_REFS, blockType));
+            case TWO_ORDINALS -> List.of(groupingBlock(ORDINALS, blockType), groupingBlock(ORDINALS, blockType));
             case LONGS_AND_BYTES_REFS -> List.of(groupingBlock(LONGS, blockType), groupingBlock(BYTES_REFS, blockType));
             case TWO_LONGS_AND_BYTES_REFS -> List.of(
                 groupingBlock(LONGS, blockType),
@@ -570,6 +600,19 @@ public class AggregatorBenchmark {
                 }
                 yield builder.build();
             }
+            case ORDINALS -> {
+                IntVector.Builder ordinals = blockFactory.newIntVectorBuilder(BLOCK_LENGTH * valuesPerGroup);
+                for (int i = 0; i < BLOCK_LENGTH; i++) {
+                    for (int v = 0; v < valuesPerGroup; v++) {
+                        ordinals.appendInt(i % GROUPS);
+                    }
+                }
+                BytesRefVector.Builder bytes = blockFactory.newBytesRefVectorBuilder(BLOCK_LENGTH * valuesPerGroup);
+                for (int i = 0; i < GROUPS; i++) {
+                    bytes.appendBytesRef(bytesGroup(i));
+                }
+                yield new OrdinalBytesRefVector(ordinals.build(), bytes.build()).asBlock();
+            }
             default -> throw new UnsupportedOperationException("unsupported grouping [" + grouping + "]");
         };
     }

+ 5 - 0
docs/changelog/114021.yaml

@@ -0,0 +1,5 @@
+pr: 114021
+summary: "ESQL: Speed up grouping by bytes"
+area: ES|QL
+type: enhancement
+issues: []

+ 32 - 8
x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/blockhash/BytesRefBlockHash.java

@@ -23,15 +23,18 @@ import org.elasticsearch.compute.data.BytesRefVector;
 import org.elasticsearch.compute.data.IntBlock;
 import org.elasticsearch.compute.data.IntVector;
 import org.elasticsearch.compute.data.OrdinalBytesRefBlock;
+import org.elasticsearch.compute.data.OrdinalBytesRefVector;
 import org.elasticsearch.compute.data.Page;
 import org.elasticsearch.compute.operator.mvdedupe.MultivalueDedupe;
 import org.elasticsearch.compute.operator.mvdedupe.MultivalueDedupeBytesRef;
+import org.elasticsearch.compute.operator.mvdedupe.MultivalueDedupeInt;
 import org.elasticsearch.core.ReleasableIterator;
 
 import java.io.IOException;
 
 /**
  * Maps a {@link BytesRefBlock} column to group ids.
+ * This class is generated. Do not edit it.
  */
 final class BytesRefBlockHash extends BlockHash {
     private final int channel;
@@ -54,6 +57,7 @@ final class BytesRefBlockHash extends BlockHash {
 
     @Override
     public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
+        // TODO track raw counts and which implementation we pick for the profiler - #114008
         var block = page.getBlock(channel);
         if (block.areAllValuesNull()) {
             seenNull = true;
@@ -76,6 +80,10 @@ final class BytesRefBlockHash extends BlockHash {
     }
 
     IntVector add(BytesRefVector vector) {
+        var ordinals = vector.asOrdinals();
+        if (ordinals != null) {
+            return addOrdinalsVector(ordinals);
+        }
         BytesRef scratch = new BytesRef();
         int positions = vector.getPositionCount();
         try (var builder = blockFactory.newIntVectorFixedBuilder(positions)) {
@@ -113,15 +121,29 @@ final class BytesRefBlockHash extends BlockHash {
         return ReleasableIterator.single(lookup(vector));
     }
 
-    private IntBlock addOrdinalsBlock(OrdinalBytesRefBlock inputBlock) {
-        var inputOrds = inputBlock.getOrdinalsBlock();
+    private IntVector addOrdinalsVector(OrdinalBytesRefVector inputBlock) {
+        IntVector inputOrds = inputBlock.getOrdinalsVector();
         try (
-            var builder = blockFactory.newIntBlockBuilder(inputOrds.getPositionCount());
+            var builder = blockFactory.newIntVectorBuilder(inputOrds.getPositionCount());
             var hashOrds = add(inputBlock.getDictionaryVector())
         ) {
-            for (int i = 0; i < inputOrds.getPositionCount(); i++) {
-                int valueCount = inputOrds.getValueCount(i);
-                int firstIndex = inputOrds.getFirstValueIndex(i);
+            for (int p = 0; p < inputOrds.getPositionCount(); p++) {
+                int ord = hashOrds.getInt(inputOrds.getInt(p));
+                builder.appendInt(ord);
+            }
+            return builder.build();
+        }
+    }
+
+    private IntBlock addOrdinalsBlock(OrdinalBytesRefBlock inputBlock) {
+        try (
+            IntBlock inputOrds = new MultivalueDedupeInt(inputBlock.getOrdinalsBlock()).dedupeToBlockAdaptive(blockFactory);
+            IntBlock.Builder builder = blockFactory.newIntBlockBuilder(inputOrds.getPositionCount());
+            IntVector hashOrds = add(inputBlock.getDictionaryVector())
+        ) {
+            for (int p = 0; p < inputOrds.getPositionCount(); p++) {
+                int valueCount = inputOrds.getValueCount(p);
+                int firstIndex = inputOrds.getFirstValueIndex(p);
                 switch (valueCount) {
                     case 0 -> {
                         builder.appendInt(0);
@@ -132,9 +154,11 @@ final class BytesRefBlockHash extends BlockHash {
                         builder.appendInt(ord);
                     }
                     default -> {
+                        int start = firstIndex;
+                        int end = firstIndex + valueCount;
                         builder.beginPositionEntry();
-                        for (int v = 0; v < valueCount; v++) {
-                            int ord = hashOrds.getInt(inputOrds.getInt(firstIndex + i));
+                        for (int i = start; i < end; i++) {
+                            int ord = hashOrds.getInt(inputOrds.getInt(i));
                             builder.appendInt(ord);
                         }
                         builder.endPositionEntry();

+ 2 - 0
x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/blockhash/DoubleBlockHash.java

@@ -28,6 +28,7 @@ import java.util.BitSet;
 
 /**
  * Maps a {@link DoubleBlock} column to group ids.
+ * This class is generated. Do not edit it.
  */
 final class DoubleBlockHash extends BlockHash {
     private final int channel;
@@ -50,6 +51,7 @@ final class DoubleBlockHash extends BlockHash {
 
     @Override
     public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
+        // TODO track raw counts and which implementation we pick for the profiler - #114008
         var block = page.getBlock(channel);
         if (block.areAllValuesNull()) {
             seenNull = true;

+ 2 - 0
x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/blockhash/IntBlockHash.java

@@ -26,6 +26,7 @@ import java.util.BitSet;
 
 /**
  * Maps a {@link IntBlock} column to group ids.
+ * This class is generated. Do not edit it.
  */
 final class IntBlockHash extends BlockHash {
     private final int channel;
@@ -48,6 +49,7 @@ final class IntBlockHash extends BlockHash {
 
     @Override
     public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
+        // TODO track raw counts and which implementation we pick for the profiler - #114008
         var block = page.getBlock(channel);
         if (block.areAllValuesNull()) {
             seenNull = true;

+ 2 - 0
x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/blockhash/LongBlockHash.java

@@ -28,6 +28,7 @@ import java.util.BitSet;
 
 /**
  * Maps a {@link LongBlock} column to group ids.
+ * This class is generated. Do not edit it.
  */
 final class LongBlockHash extends BlockHash {
     private final int channel;
@@ -50,6 +51,7 @@ final class LongBlockHash extends BlockHash {
 
     @Override
     public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
+        // TODO track raw counts and which implementation we pick for the profiler - #114008
         var block = page.getBlock(channel);
         if (block.areAllValuesNull()) {
             seenNull = true;

+ 43 - 8
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java

@@ -11,6 +11,7 @@ import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.common.util.BigArrays;
 import org.elasticsearch.common.util.BitArray;
 import org.elasticsearch.common.util.BytesRefHash;
+import org.elasticsearch.common.util.Int3Hash;
 import org.elasticsearch.common.util.LongHash;
 import org.elasticsearch.common.util.LongLongHash;
 import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
@@ -28,14 +29,37 @@ import java.util.Iterator;
 import java.util.List;
 
 /**
- * A specialized hash table implementation maps values of a {@link Block} to ids (in longs).
- * This class delegates to {@link LongHash} or {@link BytesRefHash}.
- *
- * @see LongHash
- * @see BytesRefHash
+ * Specialized hash table implementations that map rows to a <strong>set</strong>
+ * of bucket IDs to which they belong to implement {@code GROUP BY} expressions.
+ * <p>
+ *     A row is always in at least one bucket so the results are never {@code null}.
+ *     {@code null} valued key columns will map to some integer bucket id.
+ *     If none of key columns are multivalued then the output is always an
+ *     {@link IntVector}. If any of the key are multivalued then a row is
+ *     in a bucket for each value. If more than one key is multivalued then
+ *     the row is in the combinatorial explosion of all value combinations.
+ *     Luckily for the number of values rows can only be in each bucket once.
+ *     Unluckily, it's the responsibility of {@link BlockHash} to remove those
+ *     duplicates.
+ * </p>
+ * <p>
+ *     These classes typically delegate to some combination of {@link BytesRefHash},
+ *     {@link LongHash}, {@link LongLongHash}, {@link Int3Hash}. They don't
+ *     <strong>technically</strong> have to be hash tables, so long as they
+ *     implement the deduplication semantics above and vend integer ids.
+ * </p>
+ * <p>
+ *     The integer ids are assigned to offsets into arrays of aggregation states
+ *     so its permissible to have gaps in the ints. But large gaps are a bad
+ *     idea because they'll waste space in the aggregations that use these
+ *     positions. For example, {@link BooleanBlockHash} assigns {@code 0} to
+ *     {@code null}, {@code 1} to {@code false}, and {@code 1} to {@code true}
+ *     and that's <strong>fine</strong> and simple and good because it'll never
+ *     leave a big gap, even if we never see {@code null}.
+ * </p>
  */
 public abstract sealed class BlockHash implements Releasable, SeenGroupIds //
-    permits BooleanBlockHash, BytesRefBlockHash, DoubleBlockHash, IntBlockHash, LongBlockHash, BytesRef3BlockHash, //
+    permits BooleanBlockHash, BytesRefBlockHash, DoubleBlockHash, IntBlockHash, LongBlockHash, BytesRef2BlockHash, BytesRef3BlockHash, //
     NullBlockHash, PackedValuesBlockHash, BytesRefLongBlockHash, LongLongBlockHash, TimeSeriesBlockHash {
 
     protected final BlockFactory blockFactory;
@@ -98,8 +122,19 @@ public abstract sealed class BlockHash implements Releasable, SeenGroupIds //
         if (groups.size() == 1) {
             return newForElementType(groups.get(0).channel(), groups.get(0).elementType(), blockFactory);
         }
-        if (groups.size() == 3 && groups.stream().allMatch(g -> g.elementType == ElementType.BYTES_REF)) {
-            return new BytesRef3BlockHash(blockFactory, groups.get(0).channel, groups.get(1).channel, groups.get(2).channel, emitBatchSize);
+        if (groups.stream().allMatch(g -> g.elementType == ElementType.BYTES_REF)) {
+            switch (groups.size()) {
+                case 2:
+                    return new BytesRef2BlockHash(blockFactory, groups.get(0).channel, groups.get(1).channel, emitBatchSize);
+                case 3:
+                    return new BytesRef3BlockHash(
+                        blockFactory,
+                        groups.get(0).channel,
+                        groups.get(1).channel,
+                        groups.get(2).channel,
+                        emitBatchSize
+                    );
+            }
         }
         if (allowBrokenOptimizations && groups.size() == 2) {
             var g1 = groups.get(0);

+ 3 - 2
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BooleanBlockHash.java

@@ -25,8 +25,9 @@ import static org.elasticsearch.compute.operator.mvdedupe.MultivalueDedupeBoolea
 import static org.elasticsearch.compute.operator.mvdedupe.MultivalueDedupeBoolean.TRUE_ORD;
 
 /**
- * Maps a {@link BooleanBlock} column to group ids. Assigns group
- * {@code 0} to {@code false} and group {@code 1} to {@code true}.
+ * Maps a {@link BooleanBlock} column to group ids. Assigns
+ * {@code 0} to {@code null}, {@code 1} to {@code false}, and
+ * {@code 2} to {@code true}.
  */
 final class BooleanBlockHash extends BlockHash {
     private final int channel;

+ 196 - 0
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BytesRef2BlockHash.java

@@ -0,0 +1,196 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.compute.aggregation.blockhash;
+
+import org.apache.lucene.util.BytesRef;
+import org.elasticsearch.common.unit.ByteSizeValue;
+import org.elasticsearch.common.util.BigArrays;
+import org.elasticsearch.common.util.BitArray;
+import org.elasticsearch.common.util.LongHash;
+import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
+import org.elasticsearch.compute.data.Block;
+import org.elasticsearch.compute.data.BlockFactory;
+import org.elasticsearch.compute.data.BytesRefBlock;
+import org.elasticsearch.compute.data.BytesRefVector;
+import org.elasticsearch.compute.data.IntBlock;
+import org.elasticsearch.compute.data.IntVector;
+import org.elasticsearch.compute.data.Page;
+import org.elasticsearch.core.ReleasableIterator;
+import org.elasticsearch.core.Releasables;
+
+import java.util.Locale;
+
+/**
+ * Maps two {@link BytesRefBlock}s to group ids.
+ */
+final class BytesRef2BlockHash extends BlockHash {
+    private final int emitBatchSize;
+    private final int channel1;
+    private final int channel2;
+    private final BytesRefBlockHash hash1;
+    private final BytesRefBlockHash hash2;
+    private final LongHash finalHash;
+
+    BytesRef2BlockHash(BlockFactory blockFactory, int channel1, int channel2, int emitBatchSize) {
+        super(blockFactory);
+        this.emitBatchSize = emitBatchSize;
+        this.channel1 = channel1;
+        this.channel2 = channel2;
+        boolean success = false;
+        try {
+            this.hash1 = new BytesRefBlockHash(channel1, blockFactory);
+            this.hash2 = new BytesRefBlockHash(channel2, blockFactory);
+            this.finalHash = new LongHash(1, blockFactory.bigArrays());
+            success = true;
+        } finally {
+            if (success == false) {
+                close();
+            }
+        }
+    }
+
+    @Override
+    public void close() {
+        Releasables.close(hash1, hash2, finalHash);
+    }
+
+    @Override
+    public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
+        BytesRefBlock b1 = page.getBlock(channel1);
+        BytesRefBlock b2 = page.getBlock(channel2);
+        BytesRefVector v1 = b1.asVector();
+        BytesRefVector v2 = b2.asVector();
+        if (v1 != null && v2 != null) {
+            addVectors(v1, v2, addInput);
+        } else {
+            try (IntBlock k1 = hash1.add(b1); IntBlock k2 = hash2.add(b2)) {
+                try (AddWork work = new AddWork(k1, k2, addInput)) {
+                    work.add();
+                }
+            }
+        }
+    }
+
+    private void addVectors(BytesRefVector v1, BytesRefVector v2, GroupingAggregatorFunction.AddInput addInput) {
+        final int positionCount = v1.getPositionCount();
+        try (IntVector.FixedBuilder ordsBuilder = blockFactory.newIntVectorFixedBuilder(positionCount)) {
+            try (IntVector k1 = hash1.add(v1); IntVector k2 = hash2.add(v2)) {
+                for (int p = 0; p < positionCount; p++) {
+                    long ord = ord(k1.getInt(p), k2.getInt(p));
+                    ordsBuilder.appendInt(p, Math.toIntExact(ord));
+                }
+            }
+            try (IntVector ords = ordsBuilder.build()) {
+                addInput.add(0, ords);
+            }
+        }
+    }
+
+    private class AddWork extends AddPage {
+        final IntBlock b1;
+        final IntBlock b2;
+
+        AddWork(IntBlock b1, IntBlock b2, GroupingAggregatorFunction.AddInput addInput) {
+            super(blockFactory, emitBatchSize, addInput);
+            this.b1 = b1;
+            this.b2 = b2;
+        }
+
+        void add() {
+            int positionCount = b1.getPositionCount();
+            for (int i = 0; i < positionCount; i++) {
+                int v1 = b1.getValueCount(i);
+                int v2 = b2.getValueCount(i);
+                int first1 = b1.getFirstValueIndex(i);
+                int first2 = b2.getFirstValueIndex(i);
+                if (v1 == 1 && v2 == 1) {
+                    long ord = ord(b1.getInt(first1), b2.getInt(first2));
+                    appendOrdSv(i, Math.toIntExact(ord));
+                    continue;
+                }
+                for (int i1 = 0; i1 < v1; i1++) {
+                    int k1 = b1.getInt(first1 + i1);
+                    for (int i2 = 0; i2 < v2; i2++) {
+                        int k2 = b2.getInt(first2 + i2);
+                        long ord = ord(k1, k2);
+                        appendOrdInMv(i, Math.toIntExact(ord));
+                    }
+                }
+                finishMv();
+            }
+            flushRemaining();
+        }
+    }
+
+    private long ord(int k1, int k2) {
+        return hashOrdToGroup(finalHash.add((long) k2 << 32 | k1));
+    }
+
+    @Override
+    public ReleasableIterator<IntBlock> lookup(Page page, ByteSizeValue targetBlockSize) {
+        throw new UnsupportedOperationException("TODO");
+    }
+
+    @Override
+    public Block[] getKeys() {
+        // TODO Build Ordinals blocks #114010
+        final int positions = (int) finalHash.size();
+        final BytesRef scratch = new BytesRef();
+        final BytesRefBlock[] outputBlocks = new BytesRefBlock[2];
+        try {
+            try (BytesRefBlock.Builder b1 = blockFactory.newBytesRefBlockBuilder(positions)) {
+                for (int i = 0; i < positions; i++) {
+                    int k1 = (int) (finalHash.get(i) & 0xffffL);
+                    if (k1 == 0) {
+                        b1.appendNull();
+                    } else {
+                        b1.appendBytesRef(hash1.hash.get(k1 - 1, scratch));
+                    }
+                }
+                outputBlocks[0] = b1.build();
+            }
+            try (BytesRefBlock.Builder b2 = blockFactory.newBytesRefBlockBuilder(positions)) {
+                for (int i = 0; i < positions; i++) {
+                    int k2 = (int) (finalHash.get(i) >>> 32);
+                    if (k2 == 0) {
+                        b2.appendNull();
+                    } else {
+                        b2.appendBytesRef(hash2.hash.get(k2 - 1, scratch));
+                    }
+                }
+                outputBlocks[1] = b2.build();
+            }
+            return outputBlocks;
+        } finally {
+            if (outputBlocks[outputBlocks.length - 1] == null) {
+                Releasables.close(outputBlocks);
+            }
+        }
+    }
+
+    @Override
+    public BitArray seenGroupIds(BigArrays bigArrays) {
+        return new Range(0, Math.toIntExact(finalHash.size())).seenGroupIds(bigArrays);
+    }
+
+    @Override
+    public IntVector nonEmpty() {
+        return IntVector.range(0, Math.toIntExact(finalHash.size()), blockFactory);
+    }
+
+    @Override
+    public String toString() {
+        return String.format(
+            Locale.ROOT,
+            "BytesRef2BlockHash{keys=[channel1=%d, channel2=%d], entries=%d}",
+            channel1,
+            channel2,
+            finalHash.size()
+        );
+    }
+}

+ 1 - 1
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BytesRef3BlockHash.java

@@ -85,7 +85,6 @@ final class BytesRef3BlockHash extends BlockHash {
     private void addVectors(BytesRefVector v1, BytesRefVector v2, BytesRefVector v3, GroupingAggregatorFunction.AddInput addInput) {
         final int positionCount = v1.getPositionCount();
         try (IntVector.FixedBuilder ordsBuilder = blockFactory.newIntVectorFixedBuilder(positionCount)) {
-            // TODO: enable ordinal vectors in BytesRefBlockHash
             try (IntVector k1 = hash1.add(v1); IntVector k2 = hash2.add(v2); IntVector k3 = hash3.add(v3)) {
                 for (int p = 0; p < positionCount; p++) {
                     long ord = hashOrdToGroup(finalHash.add(k1.getInt(p), k2.getInt(p), k3.getInt(p)));
@@ -148,6 +147,7 @@ final class BytesRef3BlockHash extends BlockHash {
 
     @Override
     public Block[] getKeys() {
+        // TODO Build Ordinals blocks #114010
         final int positions = (int) finalHash.size();
         final BytesRef scratch = new BytesRef();
         final BytesRefBlock[] outputBlocks = new BytesRefBlock[3];

+ 34 - 8
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/X-BlockHash.java.st

@@ -28,6 +28,7 @@ import org.elasticsearch.compute.data.BytesRefVector;
 import org.elasticsearch.compute.data.IntBlock;
 import org.elasticsearch.compute.data.IntVector;
 import org.elasticsearch.compute.data.OrdinalBytesRefBlock;
+import org.elasticsearch.compute.data.OrdinalBytesRefVector;
 $elseif(double)$
 import org.elasticsearch.compute.data.Block;
 import org.elasticsearch.compute.data.BlockFactory;
@@ -51,6 +52,9 @@ $endif$
 import org.elasticsearch.compute.data.Page;
 import org.elasticsearch.compute.operator.mvdedupe.MultivalueDedupe;
 import org.elasticsearch.compute.operator.mvdedupe.MultivalueDedupe$Type$;
+$if(BytesRef)$
+import org.elasticsearch.compute.operator.mvdedupe.MultivalueDedupeInt;
+$endif$
 import org.elasticsearch.core.ReleasableIterator;
 
 $if(BytesRef)$
@@ -62,6 +66,7 @@ import java.util.BitSet;
 $endif$
 /**
  * Maps a {@link $Type$Block} column to group ids.
+ * This class is generated. Do not edit it.
  */
 final class $Type$BlockHash extends BlockHash {
     private final int channel;
@@ -84,6 +89,7 @@ final class $Type$BlockHash extends BlockHash {
 
     @Override
     public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
+        // TODO track raw counts and which implementation we pick for the profiler - #114008
         var block = page.getBlock(channel);
         if (block.areAllValuesNull()) {
             seenNull = true;
@@ -107,6 +113,10 @@ final class $Type$BlockHash extends BlockHash {
 
     IntVector add($Type$Vector vector) {
 $if(BytesRef)$
+        var ordinals = vector.asOrdinals();
+        if (ordinals != null) {
+            return addOrdinalsVector(ordinals);
+        }
         BytesRef scratch = new BytesRef();
 $endif$
         int positions = vector.getPositionCount();
@@ -154,15 +164,29 @@ $endif$
     }
 
 $if(BytesRef)$
-    private IntBlock addOrdinalsBlock(OrdinalBytesRefBlock inputBlock) {
-        var inputOrds = inputBlock.getOrdinalsBlock();
+    private IntVector addOrdinalsVector(OrdinalBytesRefVector inputBlock) {
+        IntVector inputOrds = inputBlock.getOrdinalsVector();
         try (
-            var builder = blockFactory.newIntBlockBuilder(inputOrds.getPositionCount());
+            var builder = blockFactory.newIntVectorBuilder(inputOrds.getPositionCount());
             var hashOrds = add(inputBlock.getDictionaryVector())
         ) {
-            for (int i = 0; i < inputOrds.getPositionCount(); i++) {
-                int valueCount = inputOrds.getValueCount(i);
-                int firstIndex = inputOrds.getFirstValueIndex(i);
+            for (int p = 0; p < inputOrds.getPositionCount(); p++) {
+                int ord = hashOrds.getInt(inputOrds.getInt(p));
+                builder.appendInt(ord);
+            }
+            return builder.build();
+        }
+    }
+
+    private IntBlock addOrdinalsBlock(OrdinalBytesRefBlock inputBlock) {
+        try (
+            IntBlock inputOrds = new MultivalueDedupeInt(inputBlock.getOrdinalsBlock()).dedupeToBlockAdaptive(blockFactory);
+            IntBlock.Builder builder = blockFactory.newIntBlockBuilder(inputOrds.getPositionCount());
+            IntVector hashOrds = add(inputBlock.getDictionaryVector())
+        ) {
+            for (int p = 0; p < inputOrds.getPositionCount(); p++) {
+                int valueCount = inputOrds.getValueCount(p);
+                int firstIndex = inputOrds.getFirstValueIndex(p);
                 switch (valueCount) {
                     case 0 -> {
                         builder.appendInt(0);
@@ -173,9 +197,11 @@ $if(BytesRef)$
                         builder.appendInt(ord);
                     }
                     default -> {
+                        int start = firstIndex;
+                        int end = firstIndex + valueCount;
                         builder.beginPositionEntry();
-                        for (int v = 0; v < valueCount; v++) {
-                            int ord = hashOrds.getInt(inputOrds.getInt(firstIndex + i));
+                        for (int i = start; i < end; i++) {
+                            int ord = hashOrds.getInt(inputOrds.getInt(i));
                             builder.appendInt(ord);
                         }
                         builder.endPositionEntry();

+ 135 - 36
x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/BlockHashRandomizedTests.java

@@ -21,10 +21,13 @@ import org.elasticsearch.compute.data.BasicBlockTests;
 import org.elasticsearch.compute.data.Block;
 import org.elasticsearch.compute.data.BlockFactory;
 import org.elasticsearch.compute.data.BlockTestUtils;
+import org.elasticsearch.compute.data.BytesRefVector;
 import org.elasticsearch.compute.data.ElementType;
 import org.elasticsearch.compute.data.IntBlock;
 import org.elasticsearch.compute.data.MockBlockFactory;
+import org.elasticsearch.compute.data.OrdinalBytesRefBlock;
 import org.elasticsearch.compute.data.Page;
+import org.elasticsearch.compute.data.TestBlockFactory;
 import org.elasticsearch.compute.operator.mvdedupe.MultivalueDedupeTests;
 import org.elasticsearch.core.ReleasableIterator;
 import org.elasticsearch.core.Releasables;
@@ -38,11 +41,13 @@ import java.util.ArrayList;
 import java.util.Collections;
 import java.util.Comparator;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.NavigableSet;
 import java.util.Set;
 import java.util.TreeSet;
+import java.util.stream.Stream;
 
 import static org.elasticsearch.test.ListMatcher.matchesList;
 import static org.elasticsearch.test.MapMatcher.assertMap;
@@ -58,26 +63,40 @@ import static org.mockito.Mockito.when;
 public class BlockHashRandomizedTests extends ESTestCase {
     @ParametersFactory
     public static List<Object[]> params() {
-        List<Object[]> params = new ArrayList<>();
+        List<List<? extends Type>> allowedTypesChoices = List.of(
+            /*
+             * Run with only `LONG` elements because we have some
+             * optimizations that hit if you only have those.
+             */
+            List.of(new Basic(ElementType.LONG)),
+            /*
+             * Run with only `BYTES_REF` elements because we have some
+             * optimizations that hit if you only have those.
+             */
+            List.of(new Basic(ElementType.BYTES_REF)),
+            /*
+             * Run with only `BYTES_REF` elements in an OrdinalBytesRefBlock
+             * because we have a few optimizations that use it.
+             */
+            List.of(new Ordinals(10)),
+            /*
+             * Run with only `LONG` and `BYTES_REF` elements because
+             * we have some optimizations that hit if you only have
+             * those.
+             */
+            List.of(new Basic(ElementType.LONG), new Basic(ElementType.BYTES_REF)),
+            /*
+             * Any random source.
+             */
+            Stream.concat(Stream.of(new Ordinals(10)), MultivalueDedupeTests.supportedTypes().stream().map(Basic::new)).toList()
+        );
 
+        List<Object[]> params = new ArrayList<>();
         for (boolean forcePackedHash : new boolean[] { false, true }) {
             for (int groups : new int[] { 1, 2, 3, 4, 5, 10 }) {
                 for (int maxValuesPerPosition : new int[] { 1, 3 }) {
                     for (int dups : new int[] { 0, 2 }) {
-                        for (List<ElementType> allowedTypes : List.of(
-                            /*
-                             * Run with only `LONG` elements because we have some
-                             * optimizations that hit if you only have those.
-                             */
-                            List.of(ElementType.LONG),
-                            /*
-                             * Run with only `LONG` and `BYTES_REF` elements because
-                             * we have some optimizations that hit if you only have
-                             * those.
-                             */
-                            List.of(ElementType.LONG, ElementType.BYTES_REF),
-                            MultivalueDedupeTests.supportedTypes()
-                        )) {
+                        for (List<? extends Type> allowedTypes : allowedTypesChoices) {
                             params.add(new Object[] { forcePackedHash, groups, maxValuesPerPosition, dups, allowedTypes });
                         }
                     }
@@ -87,18 +106,33 @@ public class BlockHashRandomizedTests extends ESTestCase {
         return params;
     }
 
+    /**
+     * The type of {@link Block} being tested.
+     */
+    interface Type {
+        /**
+         * The type of the {@link ElementType elements} in the {@link Block}.
+         */
+        ElementType elementType();
+
+        /**
+         * Build a random {@link Block}.
+         */
+        BasicBlockTests.RandomBlock randomBlock(int positionCount, int maxValuesPerPosition, int dups);
+    }
+
     private final boolean forcePackedHash;
     private final int groups;
     private final int maxValuesPerPosition;
     private final int dups;
-    private final List<ElementType> allowedTypes;
+    private final List<? extends Type> allowedTypes;
 
     public BlockHashRandomizedTests(
         @Name("forcePackedHash") boolean forcePackedHash,
         @Name("groups") int groups,
         @Name("maxValuesPerPosition") int maxValuesPerPosition,
         @Name("dups") int dups,
-        @Name("allowedTypes") List<ElementType> allowedTypes
+        @Name("allowedTypes") List<Type> allowedTypes
     ) {
         this.forcePackedHash = forcePackedHash;
         this.groups = groups;
@@ -127,21 +161,22 @@ public class BlockHashRandomizedTests extends ESTestCase {
     }
 
     private void test(MockBlockFactory blockFactory) {
-        List<ElementType> types = randomList(groups, groups, () -> randomFrom(allowedTypes));
+        List<Type> types = randomList(groups, groups, () -> randomFrom(allowedTypes));
+        List<ElementType> elementTypes = types.stream().map(Type::elementType).toList();
         BasicBlockTests.RandomBlock[] randomBlocks = new BasicBlockTests.RandomBlock[types.size()];
         Block[] blocks = new Block[types.size()];
-        int pageCount = between(1, 10);
+        int pageCount = between(1, groups < 10 ? 10 : 5);
         int positionCount = 100;
         int emitBatchSize = 100;
-        try (BlockHash blockHash = newBlockHash(blockFactory, emitBatchSize, types)) {
+        try (BlockHash blockHash = newBlockHash(blockFactory, emitBatchSize, elementTypes)) {
             /*
              * Only the long/long, long/bytes_ref, and bytes_ref/long implementations don't collect nulls.
              */
             Oracle oracle = new Oracle(
                 forcePackedHash
-                    || false == (types.equals(List.of(ElementType.LONG, ElementType.LONG))
-                        || types.equals(List.of(ElementType.LONG, ElementType.BYTES_REF))
-                        || types.equals(List.of(ElementType.BYTES_REF, ElementType.LONG)))
+                    || false == (elementTypes.equals(List.of(ElementType.LONG, ElementType.LONG))
+                        || elementTypes.equals(List.of(ElementType.LONG, ElementType.BYTES_REF))
+                        || elementTypes.equals(List.of(ElementType.BYTES_REF, ElementType.LONG)))
             );
             /*
              * Expected ordinals for checking lookup. Skipped if we have more than 5 groups because
@@ -151,15 +186,7 @@ public class BlockHashRandomizedTests extends ESTestCase {
 
             for (int p = 0; p < pageCount; p++) {
                 for (int g = 0; g < blocks.length; g++) {
-                    randomBlocks[g] = BasicBlockTests.randomBlock(
-                        types.get(g),
-                        positionCount,
-                        types.get(g) == ElementType.NULL ? true : randomBoolean(),
-                        1,
-                        maxValuesPerPosition,
-                        0,
-                        dups
-                    );
+                    randomBlocks[g] = types.get(g).randomBlock(positionCount, maxValuesPerPosition, dups);
                     blocks[g] = randomBlocks[g].block();
                 }
                 oracle.add(randomBlocks);
@@ -209,6 +236,7 @@ public class BlockHashRandomizedTests extends ESTestCase {
 
                 if (blockHash instanceof LongLongBlockHash == false
                     && blockHash instanceof BytesRefLongBlockHash == false
+                    && blockHash instanceof BytesRef2BlockHash == false
                     && blockHash instanceof BytesRef3BlockHash == false) {
                     assertLookup(blockFactory, expectedOrds, types, blockHash, oracle);
                 }
@@ -235,14 +263,14 @@ public class BlockHashRandomizedTests extends ESTestCase {
     private void assertLookup(
         BlockFactory blockFactory,
         Map<List<Object>, Set<Integer>> expectedOrds,
-        List<ElementType> types,
+        List<Type> types,
         BlockHash blockHash,
         Oracle oracle
     ) {
         Block.Builder[] builders = new Block.Builder[types.size()];
         try {
             for (int b = 0; b < builders.length; b++) {
-                builders[b] = types.get(b).newBlockBuilder(LOOKUP_POSITIONS, blockFactory);
+                builders[b] = types.get(b).elementType().newBlockBuilder(LOOKUP_POSITIONS, blockFactory);
             }
             for (int p = 0; p < LOOKUP_POSITIONS; p++) {
                 /*
@@ -408,8 +436,8 @@ public class BlockHashRandomizedTests extends ESTestCase {
         return breakerService;
     }
 
-    private static List<Object> randomKey(List<ElementType> types) {
-        return types.stream().map(BlockHashRandomizedTests::randomKeyElement).toList();
+    private static List<Object> randomKey(List<Type> types) {
+        return types.stream().map(t -> randomKeyElement(t.elementType())).toList();
     }
 
     public static Object randomKeyElement(ElementType type) {
@@ -423,4 +451,75 @@ public class BlockHashRandomizedTests extends ESTestCase {
             default -> throw new IllegalArgumentException("unsupported element type [" + type + "]");
         };
     }
+
+    private record Basic(ElementType elementType) implements Type {
+        @Override
+        public BasicBlockTests.RandomBlock randomBlock(int positionCount, int maxValuesPerPosition, int dups) {
+            return BasicBlockTests.randomBlock(
+                elementType,
+                positionCount,
+                elementType == ElementType.NULL | randomBoolean(),
+                1,
+                maxValuesPerPosition,
+                0,
+                dups
+            );
+        }
+    }
+
+    private record Ordinals(int dictionarySize) implements Type {
+        @Override
+        public ElementType elementType() {
+            return ElementType.BYTES_REF;
+        }
+
+        @Override
+        public BasicBlockTests.RandomBlock randomBlock(int positionCount, int maxValuesPerPosition, int dups) {
+            List<Map.Entry<String, Integer>> dictionary = new ArrayList<>();
+            List<List<Object>> values = new ArrayList<>(positionCount);
+            try (
+                IntBlock.Builder ordinals = TestBlockFactory.getNonBreakingInstance()
+                    .newIntBlockBuilder(positionCount * maxValuesPerPosition);
+                BytesRefVector.Builder bytes = TestBlockFactory.getNonBreakingInstance().newBytesRefVectorBuilder(maxValuesPerPosition);
+            ) {
+                for (String value : dictionary(maxValuesPerPosition)) {
+                    bytes.appendBytesRef(new BytesRef(value));
+                    dictionary.add(Map.entry(value, dictionary.size()));
+                }
+                for (int p = 0; p < positionCount; p++) {
+                    int valueCount = between(1, maxValuesPerPosition);
+                    int dupCount = between(0, dups);
+
+                    List<Integer> ordsAtPosition = new ArrayList<>();
+                    List<Object> valuesAtPosition = new ArrayList<>();
+                    values.add(valuesAtPosition);
+                    if (valueCount != 1 || dupCount != 0) {
+                        ordinals.beginPositionEntry();
+                    }
+                    for (int v = 0; v < valueCount; v++) {
+                        Map.Entry<String, Integer> value = randomFrom(dictionary);
+                        valuesAtPosition.add(new BytesRef(value.getKey()));
+                        ordinals.appendInt(value.getValue());
+                        ordsAtPosition.add(value.getValue());
+                    }
+                    for (int v = 0; v < dupCount; v++) {
+                        ordinals.appendInt(randomFrom(ordsAtPosition));
+                    }
+                    if (valueCount != 1 || dupCount != 0) {
+                        ordinals.endPositionEntry();
+                    }
+                }
+                return new BasicBlockTests.RandomBlock(values, new OrdinalBytesRefBlock(ordinals.build(), bytes.build()));
+            }
+        }
+
+        private Set<String> dictionary(int maxValuesPerPosition) {
+            int count = Math.max(dictionarySize, maxValuesPerPosition);
+            Set<String> values = new HashSet<>();
+            while (values.size() < count) {
+                values.add(randomAlphaOfLength(5));
+            }
+            return values;
+        }
+    }
 }

+ 131 - 0
x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/BlockHashTests.java

@@ -20,12 +20,15 @@ import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
 import org.elasticsearch.compute.data.Block;
 import org.elasticsearch.compute.data.BooleanBlock;
 import org.elasticsearch.compute.data.BytesRefBlock;
+import org.elasticsearch.compute.data.BytesRefVector;
 import org.elasticsearch.compute.data.DoubleBlock;
 import org.elasticsearch.compute.data.ElementType;
 import org.elasticsearch.compute.data.IntBlock;
 import org.elasticsearch.compute.data.IntVector;
 import org.elasticsearch.compute.data.LongBlock;
 import org.elasticsearch.compute.data.MockBlockFactory;
+import org.elasticsearch.compute.data.OrdinalBytesRefBlock;
+import org.elasticsearch.compute.data.OrdinalBytesRefVector;
 import org.elasticsearch.compute.data.Page;
 import org.elasticsearch.compute.data.TestBlockFactory;
 import org.elasticsearch.core.Releasable;
@@ -460,6 +463,133 @@ public class BlockHashTests extends ESTestCase {
         }
     }
 
+    public void testBasicOrdinals() {
+        try (
+            IntVector.Builder ords = blockFactory.newIntVectorFixedBuilder(8);
+            BytesRefVector.Builder bytes = blockFactory.newBytesRefVectorBuilder(8)
+        ) {
+            ords.appendInt(1);
+            ords.appendInt(0);
+            ords.appendInt(3);
+            ords.appendInt(1);
+            ords.appendInt(3);
+            ords.appendInt(0);
+            ords.appendInt(2);
+            ords.appendInt(3);
+            bytes.appendBytesRef(new BytesRef("item-1"));
+            bytes.appendBytesRef(new BytesRef("item-2"));
+            bytes.appendBytesRef(new BytesRef("item-3"));
+            bytes.appendBytesRef(new BytesRef("item-4"));
+
+            hash(ordsAndKeys -> {
+                if (forcePackedHash) {
+                    assertThat(ordsAndKeys.description, startsWith("PackedValuesBlockHash{groups=[0:BYTES_REF], entries=4, size="));
+                    assertThat(ordsAndKeys.description, endsWith("b}"));
+                    assertOrds(ordsAndKeys.ords, 0, 1, 2, 0, 2, 1, 3, 2);
+                    assertThat(ordsAndKeys.nonEmpty, equalTo(intRange(0, 4)));
+                    assertKeys(ordsAndKeys.keys, "item-2", "item-1", "item-4", "item-3");
+                } else {
+                    assertThat(ordsAndKeys.description, startsWith("BytesRefBlockHash{channel=0, entries=4, size="));
+                    assertThat(ordsAndKeys.description, endsWith("b, seenNull=false}"));
+                    assertOrds(ordsAndKeys.ords, 2, 1, 4, 2, 4, 1, 3, 4);
+                    assertThat(ordsAndKeys.nonEmpty, equalTo(intRange(1, 5)));
+                    assertKeys(ordsAndKeys.keys, "item-1", "item-2", "item-3", "item-4");
+                }
+            }, new OrdinalBytesRefVector(ords.build(), bytes.build()).asBlock());
+        }
+    }
+
+    public void testOrdinalsWithNulls() {
+        try (
+            IntBlock.Builder ords = blockFactory.newIntBlockBuilder(4);
+            BytesRefVector.Builder bytes = blockFactory.newBytesRefVectorBuilder(2)
+        ) {
+            ords.appendInt(0);
+            ords.appendNull();
+            ords.appendInt(1);
+            ords.appendNull();
+            bytes.appendBytesRef(new BytesRef("cat"));
+            bytes.appendBytesRef(new BytesRef("dog"));
+
+            hash(ordsAndKeys -> {
+                if (forcePackedHash) {
+                    assertThat(ordsAndKeys.description, startsWith("PackedValuesBlockHash{groups=[0:BYTES_REF], entries=3, size="));
+                    assertThat(ordsAndKeys.description, endsWith("b}"));
+                    assertOrds(ordsAndKeys.ords, 0, 1, 2, 1);
+                    assertKeys(ordsAndKeys.keys, "cat", null, "dog");
+                } else {
+                    assertThat(ordsAndKeys.description, startsWith("BytesRefBlockHash{channel=0, entries=2, size="));
+                    assertThat(ordsAndKeys.description, endsWith("b, seenNull=true}"));
+                    assertOrds(ordsAndKeys.ords, 1, 0, 2, 0);
+                    assertKeys(ordsAndKeys.keys, null, "cat", "dog");
+                }
+                assertThat(ordsAndKeys.nonEmpty, equalTo(intRange(0, 3)));
+            }, new OrdinalBytesRefBlock(ords.build(), bytes.build()));
+        }
+    }
+
+    public void testOrdinalsWithMultiValuedFields() {
+        try (
+            IntBlock.Builder ords = blockFactory.newIntBlockBuilder(4);
+            BytesRefVector.Builder bytes = blockFactory.newBytesRefVectorBuilder(2)
+        ) {
+            ords.appendInt(0);
+            ords.beginPositionEntry();
+            ords.appendInt(0);
+            ords.appendInt(1);
+            ords.endPositionEntry();
+            ords.beginPositionEntry();
+            ords.appendInt(1);
+            ords.appendInt(2);
+            ords.endPositionEntry();
+            ords.beginPositionEntry();
+            ords.appendInt(2);
+            ords.appendInt(1);
+            ords.endPositionEntry();
+            ords.appendNull();
+            ords.beginPositionEntry();
+            ords.appendInt(2);
+            ords.appendInt(2);
+            ords.appendInt(1);
+            ords.endPositionEntry();
+
+            bytes.appendBytesRef(new BytesRef("foo"));
+            bytes.appendBytesRef(new BytesRef("bar"));
+            bytes.appendBytesRef(new BytesRef("bort"));
+
+            hash(ordsAndKeys -> {
+                if (forcePackedHash) {
+                    assertThat(ordsAndKeys.description, startsWith("PackedValuesBlockHash{groups=[0:BYTES_REF], entries=4, size="));
+                    assertThat(ordsAndKeys.description, endsWith("b}"));
+                    assertOrds(
+                        ordsAndKeys.ords,
+                        new int[] { 0 },
+                        new int[] { 0, 1 },
+                        new int[] { 1, 2 },
+                        new int[] { 2, 1 },
+                        new int[] { 3 },
+                        new int[] { 2, 1 }
+                    );
+                    assertKeys(ordsAndKeys.keys, "foo", "bar", "bort", null);
+                } else {
+                    assertThat(ordsAndKeys.description, startsWith("BytesRefBlockHash{channel=0, entries=3, size="));
+                    assertThat(ordsAndKeys.description, endsWith("b, seenNull=true}"));
+                    assertOrds(
+                        ordsAndKeys.ords,
+                        new int[] { 1 },
+                        new int[] { 1, 2 },
+                        new int[] { 2, 3 },
+                        new int[] { 3, 2 },
+                        new int[] { 0 },
+                        new int[] { 3, 2 }
+                    );
+                    assertKeys(ordsAndKeys.keys, null, "foo", "bar", "bort");
+                }
+                assertThat(ordsAndKeys.nonEmpty, equalTo(intRange(0, 4)));
+            }, new OrdinalBytesRefBlock(ords.build(), bytes.build()));
+        }
+    }
+
     public void testBooleanHashFalseFirst() {
         boolean[] values = new boolean[] { false, true, true, true, true };
         hash(ordsAndKeys -> {
@@ -1315,6 +1445,7 @@ public class BlockHashTests extends ESTestCase {
         });
         if (blockHash instanceof LongLongBlockHash == false
             && blockHash instanceof BytesRefLongBlockHash == false
+            && blockHash instanceof BytesRef2BlockHash == false
             && blockHash instanceof BytesRef3BlockHash == false) {
             Block[] keys = blockHash.getKeys();
             try (ReleasableIterator<IntBlock> lookup = blockHash.lookup(new Page(keys), ByteSizeValue.ofKb(between(1, 100)))) {