Bläddra i källkod

Add BufferedMurmur3Hasher to reduce allocations when hashing Strings (#133226)

Felix Barnsteiner 1 månad sedan
förälder
incheckning
11ac7def18

+ 7 - 36
server/src/main/java/org/elasticsearch/cluster/routing/TsidBuilder.java

@@ -10,7 +10,7 @@
 package org.elasticsearch.cluster.routing;
 
 import org.apache.lucene.util.BytesRef;
-import org.elasticsearch.common.hash.Murmur3Hasher;
+import org.elasticsearch.common.hash.BufferedMurmur3Hasher;
 import org.elasticsearch.common.hash.MurmurHash3;
 import org.elasticsearch.common.util.ByteUtils;
 import org.elasticsearch.index.mapper.RoutingPathFields;
@@ -32,7 +32,7 @@ import java.util.List;
 public class TsidBuilder {
 
     private static final int MAX_TSID_VALUE_FIELDS = 16;
-    private final Murmur3Hasher murmur3Hasher = new Murmur3Hasher(0L);
+    private final BufferedMurmur3Hasher murmur3Hasher = new BufferedMurmur3Hasher(0L);
 
     private final List<Dimension> dimensions = new ArrayList<>();
 
@@ -166,7 +166,7 @@ public class TsidBuilder {
 
     private void addDimension(String path, MurmurHash3.Hash128 valueHash) {
         murmur3Hasher.reset();
-        addString(murmur3Hasher, path);
+        murmur3Hasher.addString(path);
         MurmurHash3.Hash128 pathHash = murmur3Hasher.digestHash();
         dimensions.add(new Dimension(path, pathHash, valueHash, dimensions.size()));
     }
@@ -198,7 +198,7 @@ public class TsidBuilder {
         Collections.sort(dimensions);
         murmur3Hasher.reset();
         for (Dimension dim : dimensions) {
-            addLongs(murmur3Hasher, dim.pathHash.h1, dim.pathHash.h2, dim.valueHash.h1, dim.valueHash.h2);
+            murmur3Hasher.addLongs(dim.pathHash.h1, dim.pathHash.h2, dim.valueHash.h1, dim.valueHash.h2);
         }
         return murmur3Hasher.digestHash();
     }
@@ -237,7 +237,7 @@ public class TsidBuilder {
         murmur3Hasher.reset();
         for (int i = 0; i < dimensions.size(); i++) {
             Dimension dim = dimensions.get(i);
-            addLong(murmur3Hasher, dim.pathHash.h1 ^ dim.pathHash.h2);
+            murmur3Hasher.addLong(dim.pathHash.h1 ^ dim.pathHash.h2);
         }
         ByteUtils.writeIntLE((int) murmur3Hasher.digestHash(hashBuffer).h1, hash, index);
         index += 4;
@@ -253,7 +253,7 @@ public class TsidBuilder {
             }
             MurmurHash3.Hash128 valueHash = dim.valueHash();
             murmur3Hasher.reset();
-            addLong(murmur3Hasher, valueHash.h1 ^ valueHash.h2);
+            murmur3Hasher.addLong(valueHash.h1 ^ valueHash.h2);
             hash[index++] = (byte) murmur3Hasher.digestHash(hashBuffer).h1;
             previousPath = path;
         }
@@ -261,7 +261,7 @@ public class TsidBuilder {
         murmur3Hasher.reset();
         for (int i = 0; i < dimensions.size(); i++) {
             Dimension dim = dimensions.get(i);
-            addLongs(murmur3Hasher, dim.pathHash.h1, dim.pathHash.h2, dim.valueHash.h1, dim.valueHash.h2);
+            murmur3Hasher.addLongs(dim.pathHash.h1, dim.pathHash.h2, dim.valueHash.h1, dim.valueHash.h2);
         }
         index = writeHash128(murmur3Hasher.digestHash(hashBuffer), hash, index);
         return new BytesRef(hash, 0, index);
@@ -314,33 +314,4 @@ public class TsidBuilder {
             return Integer.compare(insertionOrder, o.insertionOrder);
         }
     }
-
-    // these methods will be replaced with a more optimized version when https://github.com/elastic/elasticsearch/pull/133226 is merged
-
-    private static void addString(Murmur3Hasher murmur3Hasher, String path) {
-        BytesRef bytesRef = new BytesRef(path);
-        murmur3Hasher.update(bytesRef.bytes, bytesRef.offset, bytesRef.length);
-    }
-
-    private static void addLong(Murmur3Hasher murmur3Hasher, long value) {
-        byte[] bytes = new byte[8];
-        ByteUtils.writeLongLE(value, bytes, 0);
-        murmur3Hasher.update(bytes);
-    }
-
-    private static void addLongs(Murmur3Hasher murmur3Hasher, long v1, long v2) {
-        byte[] bytes = new byte[16];
-        ByteUtils.writeLongLE(v1, bytes, 0);
-        ByteUtils.writeLongLE(v2, bytes, 8);
-        murmur3Hasher.update(bytes);
-    }
-
-    private static void addLongs(Murmur3Hasher murmur3Hasher, long v1, long v2, long v3, long v4) {
-        byte[] bytes = new byte[32];
-        ByteUtils.writeLongLE(v1, bytes, 0);
-        ByteUtils.writeLongLE(v2, bytes, 8);
-        ByteUtils.writeLongLE(v3, bytes, 16);
-        ByteUtils.writeLongLE(v4, bytes, 24);
-        murmur3Hasher.update(bytes);
-    }
 }

+ 140 - 0
server/src/main/java/org/elasticsearch/common/hash/BufferedMurmur3Hasher.java

@@ -0,0 +1,140 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+
+package org.elasticsearch.common.hash;
+
+import org.apache.lucene.util.UnicodeUtil;
+import org.elasticsearch.common.util.ByteUtils;
+
+/**
+ * A buffered Murmur3 hasher that allows hashing strings and longs efficiently.
+ * It uses a byte array buffer to reduce allocations for converting strings and longs to bytes before passing them to the hasher.
+ * The buffer also allows for more efficient execution by minimizing the number of times the underlying hasher is updated,
+ * and by maximizing the amount of data processed in each update call.
+ */
+public class BufferedMurmur3Hasher extends Murmur3Hasher {
+
+    public static final int DEFAULT_BUFFER_SIZE = 32 * 4; // 32 characters, each character may take up to 4 bytes in UTF-8
+    /**
+     * The buffer used for holding the UTF-8 encoded strings before passing them to the hasher.
+     * Should be sized so that it can hold the longest UTF-8 encoded string that is expected to be hashed,
+     * to avoid re-sizing the buffer.
+     * But should also be small enough to not waste memory in case the keys are short.
+     */
+    private byte[] buffer;
+    private int pos;
+
+    public BufferedMurmur3Hasher(long seed) {
+        this(seed, DEFAULT_BUFFER_SIZE);
+    }
+
+    /**
+     * Constructs a BufferedMurmur3Hasher with a specified seed and buffer size.
+     *
+     * @param seed        the seed for the Murmur3 hash function
+     * @param bufferSize  the size of the buffer in bytes, must be at least 32
+     */
+    public BufferedMurmur3Hasher(long seed, int bufferSize) {
+        super(seed);
+        if (bufferSize < 32) {
+            throw new IllegalArgumentException("Buffer size must be at least 32 bytes");
+        }
+        this.buffer = new byte[bufferSize];
+    }
+
+    @Override
+    public MurmurHash3.Hash128 digestHash(MurmurHash3.Hash128 hash) {
+        flush();
+        return super.digestHash(hash);
+    }
+
+    @Override
+    public void reset() {
+        super.reset();
+        pos = 0;
+    }
+
+    /**
+     * Adds a string to the hasher.
+     * The string is converted to UTF-8 and written into the buffer.
+     * The buffer is resized if necessary to accommodate the UTF-8 encoded string.
+     *
+     * @param value the string value to add
+     */
+    public void addString(String value) {
+        int requiredBufferLength = UnicodeUtil.maxUTF8Length(value.length());
+        ensureCapacity(requiredBufferLength);
+        flushIfRemainingCapacityLowerThan(requiredBufferLength);
+        pos = UnicodeUtil.UTF16toUTF8(value, 0, value.length(), buffer, pos);
+    }
+
+    /**
+     * Adds a long value to the hasher.
+     * The long is written in little-endian format.
+     *
+     * @param value the long value to add
+     */
+    public void addLong(long value) {
+        flushIfRemainingCapacityLowerThan(Long.BYTES);
+        ByteUtils.writeLongLE(value, buffer, pos);
+        pos += Long.BYTES;
+    }
+
+    /**
+     * Adds two long values to the hasher.
+     * Each long is written in little-endian format.
+     *
+     * @param v1 the first long value to add
+     * @param v2 the second long value to add
+     */
+    public void addLongs(long v1, long v2) {
+        flushIfRemainingCapacityLowerThan(Long.BYTES * 2);
+        ByteUtils.writeLongLE(v1, buffer, pos);
+        ByteUtils.writeLongLE(v2, buffer, pos + 8);
+        pos += Long.BYTES * 2;
+    }
+
+    /**
+     * Adds four long values to the hasher.
+     * Each long is written in little-endian format.
+     *
+     * @param v1 the first long value to add
+     * @param v2 the second long value to add
+     * @param v3 the third long value to add
+     * @param v4 the fourth long value to add
+     */
+    public void addLongs(long v1, long v2, long v3, long v4) {
+        flushIfRemainingCapacityLowerThan(Long.BYTES * 4);
+        ByteUtils.writeLongLE(v1, buffer, pos);
+        ByteUtils.writeLongLE(v2, buffer, pos + 8);
+        ByteUtils.writeLongLE(v3, buffer, pos + 16);
+        ByteUtils.writeLongLE(v4, buffer, pos + 24);
+        pos += Long.BYTES * 4;
+    }
+
+    private void ensureCapacity(int requiredBufferLength) {
+        if (buffer.length < requiredBufferLength) {
+            flush();
+            buffer = new byte[requiredBufferLength];
+        }
+    }
+
+    private void flush() {
+        if (pos > 0) {
+            update(buffer, 0, pos);
+            pos = 0;
+        }
+    }
+
+    private void flushIfRemainingCapacityLowerThan(int requiredCapacity) {
+        if (buffer.length - pos < requiredCapacity) {
+            flush();
+        }
+    }
+}

+ 2 - 3
server/src/main/java/org/elasticsearch/common/hash/MurmurHash3.java

@@ -12,7 +12,6 @@ package org.elasticsearch.common.hash;
 import org.elasticsearch.common.util.ByteUtils;
 
 import java.math.BigInteger;
-import java.util.Objects;
 
 /**
  * MurmurHash3 hashing functions.
@@ -56,12 +55,12 @@ public enum MurmurHash3 {
                 return false;
             }
             Hash128 that = (Hash128) other;
-            return Objects.equals(this.h1, that.h1) && Objects.equals(this.h2, that.h2);
+            return this.h1 == that.h1 && this.h2 == that.h2;
         }
 
         @Override
         public int hashCode() {
-            return Objects.hash(h1, h2);
+            return (int) (h1 ^ h2);
         }
 
         @Override

+ 152 - 0
server/src/test/java/org/elasticsearch/common/hash/BufferedMurmur3HasherTests.java

@@ -0,0 +1,152 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+
+package org.elasticsearch.common.hash;
+
+import org.apache.lucene.util.BytesRef;
+import org.elasticsearch.common.util.ByteUtils;
+import org.elasticsearch.test.ESTestCase;
+
+public class BufferedMurmur3HasherTests extends ESTestCase {
+
+    private final BufferedMurmur3Hasher bufferedHasher = new BufferedMurmur3Hasher(0, randomIntBetween(32, 128));
+    private final Murmur3Hasher hasher = new Murmur3Hasher(0);
+
+    public void testAddString() {
+        String testString = randomUnicodeOfLengthBetween(0, 1024);
+        bufferedHasher.addString(testString);
+
+        BytesRef bytesRef = new BytesRef(testString);
+        hasher.update(bytesRef.bytes, bytesRef.offset, bytesRef.length);
+        assertEquals(hasher.digestHash(), bufferedHasher.digestHash());
+    }
+
+    public void testConstructorWithInvalidBufferSize() {
+        IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> new BufferedMurmur3Hasher(0, 31));
+        assertEquals("Buffer size must be at least 32 bytes", exception.getMessage());
+    }
+
+    public void testAddLong() {
+        long value = randomLong();
+        bufferedHasher.addLong(value);
+
+        hasher.update(toBytes(value), 0, Long.BYTES);
+
+        assertEquals(hasher.digestHash(), bufferedHasher.digestHash());
+    }
+
+    public void testAddLongs() {
+        long value1 = randomLong();
+        long value2 = randomLong();
+        long value3 = randomLong();
+        long value4 = randomLong();
+        bufferedHasher.addLong(value1);
+        bufferedHasher.addLongs(value1, value2);
+        bufferedHasher.addLongs(value1, value2, value3, value4);
+
+        hasher.update(toBytes(value1));
+
+        hasher.update(toBytes(value1));
+        hasher.update(toBytes(value2));
+
+        hasher.update(toBytes(value1));
+        hasher.update(toBytes(value2));
+        hasher.update(toBytes(value3));
+        hasher.update(toBytes(value4));
+
+        assertEquals(hasher.digestHash(), bufferedHasher.digestHash());
+    }
+
+    public void testAddTwoLongs() {
+        long value1 = randomLong();
+        long value2 = randomLong();
+
+        bufferedHasher.addLongs(value1, value2);
+
+        hasher.update(toBytes(value1));
+        hasher.update(toBytes(value2));
+
+        assertEquals(hasher.digestHash(), bufferedHasher.digestHash());
+    }
+
+    public void testAddFourLongs() {
+        long value1 = randomLong();
+        long value2 = randomLong();
+        long value3 = randomLong();
+        long value4 = randomLong();
+
+        bufferedHasher.addLongs(value1, value2, value3, value4);
+
+        hasher.update(toBytes(value1));
+        hasher.update(toBytes(value2));
+        hasher.update(toBytes(value3));
+        hasher.update(toBytes(value4));
+
+        assertEquals(hasher.digestHash(), bufferedHasher.digestHash());
+    }
+
+    public void testRandomAdds() {
+        int numAdds = randomIntBetween(128, 1024);
+        for (int i = 0; i < numAdds; i++) {
+            switch (randomIntBetween(0, 4)) {
+                case 0 -> {
+                    String randomString = randomUnicodeOfLengthBetween(0, 64);
+                    bufferedHasher.addString(randomString);
+                    BytesRef bytesRef = new BytesRef(randomString);
+                    hasher.update(bytesRef.bytes, bytesRef.offset, bytesRef.length);
+                }
+                case 1 -> {
+                    String emptyString = "";
+                    bufferedHasher.addString(emptyString);
+                    BytesRef bytesRef = new BytesRef(emptyString);
+                    hasher.update(bytesRef.bytes, bytesRef.offset, bytesRef.length);
+                }
+                case 2 -> {
+                    long randomLong = randomLong();
+                    bufferedHasher.addLong(randomLong);
+                    hasher.update(toBytes(randomLong));
+                }
+                case 3 -> {
+                    long randomLong1 = randomLong();
+                    long randomLong2 = randomLong();
+                    bufferedHasher.addLongs(randomLong1, randomLong2);
+                    hasher.update(toBytes(randomLong1));
+                    hasher.update(toBytes(randomLong2));
+                }
+                case 4 -> {
+                    long randomLong1 = randomLong();
+                    long randomLong2 = randomLong();
+                    long randomLong3 = randomLong();
+                    long randomLong4 = randomLong();
+                    bufferedHasher.addLongs(randomLong1, randomLong2, randomLong3, randomLong4);
+                    hasher.update(toBytes(randomLong1));
+                    hasher.update(toBytes(randomLong2));
+                    hasher.update(toBytes(randomLong3));
+                    hasher.update(toBytes(randomLong4));
+                }
+            }
+        }
+        assertEquals(hasher.digestHash(), bufferedHasher.digestHash());
+    }
+
+    public void testReset() {
+        bufferedHasher.addString(randomUnicodeOfLengthBetween(0, 1024));
+        bufferedHasher.addLong(randomLong());
+        bufferedHasher.addLongs(randomLong(), randomLong());
+        bufferedHasher.addLongs(randomLong(), randomLong(), randomLong(), randomLong());
+        bufferedHasher.reset();
+        assertEquals(new MurmurHash3.Hash128(0, 0), bufferedHasher.digestHash());
+    }
+
+    private byte[] toBytes(long value) {
+        byte[] bytes = new byte[Long.BYTES];
+        ByteUtils.writeLongLE(value, bytes, 0);
+        return bytes;
+    }
+}