Browse Source

ES|QL: Fix BytesRef2BlockHash (#130705)

Luigi Dell'Aquila 3 months ago
parent
commit
2868e41079

+ 5 - 0
docs/changelog/130705.yaml

@@ -0,0 +1,5 @@
+pr: 130705
+summary: Fix `BytesRef2BlockHash`
+area: ES|QL
+type: bug
+issues: []

+ 3 - 1
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BytesRef2BlockHash.java

@@ -145,7 +145,9 @@ final class BytesRef2BlockHash extends BlockHash {
         try {
             try (BytesRefBlock.Builder b1 = blockFactory.newBytesRefBlockBuilder(positions)) {
                 for (int i = 0; i < positions; i++) {
-                    int k1 = (int) (finalHash.get(i) & 0xffffL);
+                    int k1 = (int) (finalHash.get(i) & 0xffffffffL);
+                    // k1 is always positive, it's how hash values are generated, see BytesRefBlockHash.
+                    // For now, we only manage at most 2^31 hash entries
                     if (k1 == 0) {
                         b1.appendNull();
                     } else {

+ 192 - 1
x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/BlockHashTests.java

@@ -11,6 +11,7 @@ import com.carrotsearch.randomizedtesting.annotations.Name;
 import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
 
 import org.apache.lucene.util.BytesRef;
+import org.elasticsearch.common.lucene.BytesRefs;
 import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
 import org.elasticsearch.compute.data.Block;
@@ -35,8 +36,10 @@ import org.elasticsearch.core.Releasables;
 import org.elasticsearch.xpack.esql.core.util.Holder;
 
 import java.util.ArrayList;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Locale;
+import java.util.Set;
 import java.util.function.Consumer;
 import java.util.stream.IntStream;
 import java.util.stream.LongStream;
@@ -1232,6 +1235,194 @@ public class BlockHashTests extends BlockHashTestCase {
         }, blockFactory.newLongArrayVector(values, values.length).asBlock(), blockFactory.newConstantNullBlock(values.length));
     }
 
+    public void test2BytesRefsHighCardinalityKey() {
+        final Page page;
+        int positions1 = 10;
+        int positions2 = 100_000;
+        if (randomBoolean()) {
+            positions1 = 100_000;
+            positions2 = 10;
+        }
+        final int totalPositions = positions1 * positions2;
+        try (
+            BytesRefBlock.Builder builder1 = blockFactory.newBytesRefBlockBuilder(totalPositions);
+            BytesRefBlock.Builder builder2 = blockFactory.newBytesRefBlockBuilder(totalPositions);
+        ) {
+            for (int i = 0; i < positions1; i++) {
+                for (int p = 0; p < positions2; p++) {
+                    builder1.appendBytesRef(new BytesRef("abcdef" + i));
+                    builder2.appendBytesRef(new BytesRef("abcdef" + p));
+                }
+            }
+            page = new Page(builder1.build(), builder2.build());
+        }
+        record Output(int offset, IntBlock block, IntVector vector) implements Releasable {
+            @Override
+            public void close() {
+                Releasables.close(block, vector);
+            }
+        }
+        List<Output> output = new ArrayList<>();
+
+        try (BlockHash hash1 = new BytesRef2BlockHash(blockFactory, 0, 1, totalPositions);) {
+            hash1.add(page, new GroupingAggregatorFunction.AddInput() {
+                @Override
+                public void add(int positionOffset, IntArrayBlock groupIds) {
+                    groupIds.incRef();
+                    output.add(new Output(positionOffset, groupIds, null));
+                }
+
+                @Override
+                public void add(int positionOffset, IntBigArrayBlock groupIds) {
+                    groupIds.incRef();
+                    output.add(new Output(positionOffset, groupIds, null));
+                }
+
+                @Override
+                public void add(int positionOffset, IntVector groupIds) {
+                    groupIds.incRef();
+                    output.add(new Output(positionOffset, null, groupIds));
+                }
+
+                @Override
+                public void close() {
+                    fail("hashes should not close AddInput");
+                }
+            });
+
+            Block[] keys = hash1.getKeys();
+            try {
+                Set<String> distinctKeys = new HashSet<>();
+                BytesRefBlock block0 = (BytesRefBlock) keys[0];
+                BytesRefBlock block1 = (BytesRefBlock) keys[1];
+                BytesRef scratch = new BytesRef();
+                StringBuilder builder = new StringBuilder();
+                for (int i = 0; i < totalPositions; i++) {
+                    builder.setLength(0);
+                    builder.append(BytesRefs.toString(block0.getBytesRef(i, scratch)));
+                    builder.append("#");
+                    builder.append(BytesRefs.toString(block1.getBytesRef(i, scratch)));
+                    distinctKeys.add(builder.toString());
+                }
+                assertThat(distinctKeys.size(), equalTo(totalPositions));
+            } finally {
+                Releasables.close(keys);
+            }
+        } finally {
+            Releasables.close(output);
+            page.releaseBlocks();
+        }
+    }
+
+    public void test2BytesRefs() {
+        final Page page;
+        final int positions = randomIntBetween(1, 1000);
+        final boolean generateVector = randomBoolean();
+        try (
+            BytesRefBlock.Builder builder1 = blockFactory.newBytesRefBlockBuilder(positions);
+            BytesRefBlock.Builder builder2 = blockFactory.newBytesRefBlockBuilder(positions);
+        ) {
+            List<BytesRefBlock.Builder> builders = List.of(builder1, builder2);
+            for (int p = 0; p < positions; p++) {
+                for (BytesRefBlock.Builder builder : builders) {
+                    int valueCount = generateVector ? 1 : between(0, 3);
+                    switch (valueCount) {
+                        case 0 -> builder.appendNull();
+                        case 1 -> builder.appendBytesRef(new BytesRef(Integer.toString(between(1, 100))));
+                        default -> {
+                            builder.beginPositionEntry();
+                            for (int v = 0; v < valueCount; v++) {
+                                builder.appendBytesRef(new BytesRef(Integer.toString(between(1, 100))));
+                            }
+                            builder.endPositionEntry();
+                        }
+                    }
+                }
+            }
+            page = new Page(builder1.build(), builder2.build());
+        }
+        final int emitBatchSize = between(positions, 10 * 1024);
+        var groupSpecs = List.of(new BlockHash.GroupSpec(0, ElementType.BYTES_REF), new BlockHash.GroupSpec(1, ElementType.BYTES_REF));
+        record Output(int offset, IntBlock block, IntVector vector) implements Releasable {
+            @Override
+            public void close() {
+                Releasables.close(block, vector);
+            }
+        }
+        List<Output> output1 = new ArrayList<>();
+        List<Output> output2 = new ArrayList<>();
+        try (
+            BlockHash hash1 = new BytesRef2BlockHash(blockFactory, 0, 1, emitBatchSize);
+            BlockHash hash2 = new PackedValuesBlockHash(groupSpecs, blockFactory, emitBatchSize)
+        ) {
+            hash1.add(page, new GroupingAggregatorFunction.AddInput() {
+                @Override
+                public void add(int positionOffset, IntArrayBlock groupIds) {
+                    groupIds.incRef();
+                    output1.add(new Output(positionOffset, groupIds, null));
+                }
+
+                @Override
+                public void add(int positionOffset, IntBigArrayBlock groupIds) {
+                    groupIds.incRef();
+                    output1.add(new Output(positionOffset, groupIds, null));
+                }
+
+                @Override
+                public void add(int positionOffset, IntVector groupIds) {
+                    groupIds.incRef();
+                    output1.add(new Output(positionOffset, null, groupIds));
+                }
+
+                @Override
+                public void close() {
+                    fail("hashes should not close AddInput");
+                }
+            });
+            hash2.add(page, new GroupingAggregatorFunction.AddInput() {
+                @Override
+                public void add(int positionOffset, IntArrayBlock groupIds) {
+                    groupIds.incRef();
+                    output2.add(new Output(positionOffset, groupIds, null));
+                }
+
+                @Override
+                public void add(int positionOffset, IntBigArrayBlock groupIds) {
+                    groupIds.incRef();
+                    output2.add(new Output(positionOffset, groupIds, null));
+                }
+
+                @Override
+                public void add(int positionOffset, IntVector groupIds) {
+                    groupIds.incRef();
+                    output2.add(new Output(positionOffset, null, groupIds));
+                }
+
+                @Override
+                public void close() {
+                    fail("hashes should not close AddInput");
+                }
+            });
+            assertThat(output1.size(), equalTo(output2.size()));
+            for (int i = 0; i < output1.size(); i++) {
+                Output o1 = output1.get(i);
+                Output o2 = output2.get(i);
+                assertThat(o1.offset, equalTo(o2.offset));
+                if (o1.vector != null) {
+                    assertNull(o1.block);
+                    assertThat(o1.vector, equalTo(o2.vector != null ? o2.vector : o2.block.asVector()));
+                } else {
+                    assertNull(o2.vector);
+                    assertThat(o1.block, equalTo(o2.block));
+                }
+            }
+        } finally {
+            Releasables.close(output1);
+            Releasables.close(output2);
+            page.releaseBlocks();
+        }
+    }
+
     public void test3BytesRefs() {
         final Page page;
         final int positions = randomIntBetween(1, 1000);
@@ -1326,7 +1517,7 @@ public class BlockHashTests extends BlockHashTestCase {
                     fail("hashes should not close AddInput");
                 }
             });
-            assertThat(output1.size(), equalTo(output1.size()));
+            assertThat(output1.size(), equalTo(output2.size()));
             for (int i = 0; i < output1.size(); i++) {
                 Output o1 = output1.get(i);
                 Output o2 = output2.get(i);