瀏覽代碼

Murmur3 wrapper for sequential hashing of multiple byte arrays (#69185)

Dan Hermann 4 年之前
父節點
當前提交
47b8ea53cc

+ 107 - 0
server/src/main/java/org/elasticsearch/common/hash/Murmur3Hasher.java

@@ -0,0 +1,107 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.common.hash;
+
+import org.apache.lucene.util.BytesRef;
+import org.elasticsearch.common.Numbers;
+
+/**
+ * Wraps {@link MurmurHash3} to provide an interface similar to {@link java.security.MessageDigest} that
+ * allows hashing of byte arrays passed through multiple calls to {@link #update(byte[])}. Like
+ * {@link java.security.MessageDigest}, this class maintains internal state during the calculation of the
+ * hash and is <b>not</b> threadsafe. If concurrent hashes are to be computed, each must be done on a
+ * separate instance.
+ */
+public class Murmur3Hasher {
+
+    private final long seed;
+    private final byte[] remainder = new byte[16];
+    private int remainderLength = 0;
+    private int length;
+    private long h1, h2;
+
+    public Murmur3Hasher(long seed) {
+        this.seed = seed;
+        h1 = h2 = seed;
+    }
+
+    /**
+     * Supplies some or all of the bytes to be hashed. Multiple calls to this method may
+     * be made to sequentially supply the bytes for hashing. Once all bytes have been supplied, the
+     * {@link #digest()} method should be called to complete the hash calculation.
+     */
+    public void update(byte[] inputBytes) {
+        int totalLength = remainderLength + inputBytes.length;
+        if (totalLength >= 16) {
+            // hash as many bytes as available in integer multiples of 16
+            int numBytesToHash = totalLength & 0xFFFFFFF0;
+            byte[] bytesToHash;
+            if (remainderLength > 0) {
+                bytesToHash = new byte[numBytesToHash];
+                System.arraycopy(remainder, 0, bytesToHash, 0, remainderLength);
+                System.arraycopy(inputBytes, 0, bytesToHash, remainderLength, numBytesToHash - remainderLength);
+            } else {
+                bytesToHash = inputBytes;
+            }
+
+            MurmurHash3.IntermediateResult result = MurmurHash3.intermediateHash(bytesToHash, 0, numBytesToHash, h1, h2);
+            h1 = result.h1;
+            h2 = result.h2;
+            this.length += numBytesToHash;
+
+            // save the remaining bytes, if any
+            if (totalLength > numBytesToHash) {
+                System.arraycopy(inputBytes, numBytesToHash - remainderLength, remainder, 0, totalLength - numBytesToHash);
+                remainderLength = totalLength - numBytesToHash;
+            } else {
+                remainderLength = 0;
+            }
+        } else {
+            System.arraycopy(inputBytes, 0, remainder, remainderLength, inputBytes.length);
+            remainderLength += inputBytes.length;
+        }
+    }
+
+    /**
+     * Clears all bytes previously passed to {@link #update(byte[])} and prepares for the calculation
+     * of a new hash.
+     */
+    public void reset() {
+        length = 0;
+        remainderLength = 0;
+        h1 = h2 = seed;
+    }
+
+    /**
+     * Completes the hash of all bytes previously passed to {@link #update(byte[])}.
+     */
+    public byte[] digest() {
+        length += remainderLength;
+        MurmurHash3.Hash128 h = MurmurHash3.finalizeHash(new MurmurHash3.Hash128(), remainder, 0, length, h1, h2);
+        byte[] hash = new byte[16];
+        System.arraycopy(Numbers.longToBytes(h.h1), 0, hash, 0, 8);
+        System.arraycopy(Numbers.longToBytes(h.h2), 0, hash, 8, 8);
+        return hash;
+    }
+
+    public String getAlgorithm() {
+        return "MurmurHash3_x64_128";
+    }
+
+    /**
+     * Converts the 128-bit byte array returned by {@link #digest()} to a
+     * {@link org.elasticsearch.common.hash.MurmurHash3.Hash128}
+     */
+    public static MurmurHash3.Hash128 toHash128(byte[] doubleLongBytes) {
+        MurmurHash3.Hash128 hash128 = new MurmurHash3.Hash128();
+        hash128.h1 = Numbers.bytesToLong(new BytesRef(doubleLongBytes, 0, 8));
+        hash128.h2 = Numbers.bytesToLong(new BytesRef(doubleLongBytes, 8, 8));
+        return hash128;
+    }
+}

+ 76 - 40
server/src/main/java/org/elasticsearch/common/hash/MurmurHash3.java

@@ -8,8 +8,10 @@
 
 package org.elasticsearch.common.hash;
 
+import org.elasticsearch.common.Numbers;
 import org.elasticsearch.common.util.ByteUtils;
 
+import java.math.BigInteger;
 import java.util.Objects;
 
 
@@ -45,6 +47,27 @@ public enum MurmurHash3 {
         public int hashCode() {
             return Objects.hash(h1, h2);
         }
+
+        @Override
+        public String toString() {
+            byte[] longBytes = new byte[17];
+            System.arraycopy(Numbers.longToBytes(h1), 0, longBytes, 1, 8);
+            System.arraycopy(Numbers.longToBytes(h2), 0, longBytes, 9, 8);
+            BigInteger bi = new BigInteger(longBytes);
+            return "0x" + bi.toString(16);
+        }
+    }
+
+    static class IntermediateResult {
+        int offset;
+        long h1;
+        long h2;
+
+        IntermediateResult(int offset, long h1, long h2) {
+            this.offset = offset;
+            this.h1 = h1;
+            this.h2 = h2;
+        }
     }
 
     private static long C1 = 0x87c37b91114253d5L;
@@ -77,75 +100,88 @@ public enum MurmurHash3 {
         long h2 = seed;
 
         if (length >= 16) {
+            IntermediateResult result = intermediateHash(key, offset, length, h1, h2);
+            h1 = result.h1;
+            h2 = result.h2;
+            offset = result.offset;
+        }
 
-            final int len16 = length & 0xFFFFFFF0; // higher multiple of 16 that is lower than or equal to length
-            final int end = offset + len16;
-            for (int i = offset; i < end; i += 16) {
-                long k1 = ByteUtils.readLongLE(key, i);
-                long k2 = ByteUtils.readLongLE(key, i + 8);
-
-                k1 *= C1;
-                k1 = Long.rotateLeft(k1, 31);
-                k1 *= C2;
-                h1 ^= k1;
-
-                h1 = Long.rotateLeft(h1, 27);
-                h1 += h2;
-                h1 = h1 * 5 + 0x52dce729;
+        return finalizeHash(hash, key, offset, length, h1, h2);
+    }
 
-                k2 *= C2;
-                k2 = Long.rotateLeft(k2, 33);
-                k2 *= C1;
-                h2 ^= k2;
+    static IntermediateResult intermediateHash(byte[] key, int offset, int length, long h1, long h2) {
+        final int len16 = length & 0xFFFFFFF0; // higher multiple of 16 that is lower than or equal to length
+        final int end = offset + len16;
+        for (int i = offset; i < end; i += 16) {
+            long k1 = ByteUtils.readLongLE(key, i);
+            long k2 = ByteUtils.readLongLE(key, i + 8);
+
+            k1 *= C1;
+            k1 = Long.rotateLeft(k1, 31);
+            k1 *= C2;
+            h1 ^= k1;
+
+            h1 = Long.rotateLeft(h1, 27);
+            h1 += h2;
+            h1 = h1 * 5 + 0x52dce729;
+
+            k2 *= C2;
+            k2 = Long.rotateLeft(k2, 33);
+            k2 *= C1;
+            h2 ^= k2;
+
+            h2 = Long.rotateLeft(h2, 31);
+            h2 += h1;
+            h2 = h2 * 5 + 0x38495ab5;
+        }
 
-                h2 = Long.rotateLeft(h2, 31);
-                h2 += h1;
-                h2 = h2 * 5 + 0x38495ab5;
-            }
+        // Advance offset to the unprocessed tail of the data.
+        offset = end;
 
-            // Advance offset to the unprocessed tail of the data.
-            offset = end;
-        }
+        return new IntermediateResult(offset, h1, h2);
+    }
 
+    @SuppressWarnings("fallthrough") // Intentionally uses fallthrough to implement a well known hashing algorithm
+    static Hash128 finalizeHash(Hash128 hash, byte[] remainder, int offset, int length, long h1, long h2) {
         long k1 = 0;
         long k2 = 0;
 
         switch (length & 15) {
             case 15:
-                k2 ^= (key[offset + 14] & 0xFFL) << 48;
+                k2 ^= (remainder[offset + 14] & 0xFFL) << 48;
             case 14:
-                k2 ^= (key[offset + 13] & 0xFFL) << 40;
+                k2 ^= (remainder[offset + 13] & 0xFFL) << 40;
             case 13:
-                k2 ^= (key[offset + 12] & 0xFFL) << 32;
+                k2 ^= (remainder[offset + 12] & 0xFFL) << 32;
             case 12:
-                k2 ^= (key[offset + 11] & 0xFFL) << 24;
+                k2 ^= (remainder[offset + 11] & 0xFFL) << 24;
             case 11:
-                k2 ^= (key[offset + 10] & 0xFFL) << 16;
+                k2 ^= (remainder[offset + 10] & 0xFFL) << 16;
             case 10:
-                k2 ^= (key[offset + 9] & 0xFFL) << 8;
+                k2 ^= (remainder[offset + 9] & 0xFFL) << 8;
             case 9:
-                k2 ^= (key[offset + 8] & 0xFFL) << 0;
+                k2 ^= (remainder[offset + 8] & 0xFFL) << 0;
                 k2 *= C2;
                 k2 = Long.rotateLeft(k2, 33);
                 k2 *= C1;
                 h2 ^= k2;
 
             case 8:
-                k1 ^= (key[offset + 7] & 0xFFL) << 56;
+                k1 ^= (remainder[offset + 7] & 0xFFL) << 56;
             case 7:
-                k1 ^= (key[offset + 6] & 0xFFL) << 48;
+                k1 ^= (remainder[offset + 6] & 0xFFL) << 48;
             case 6:
-                k1 ^= (key[offset + 5] & 0xFFL) << 40;
+                k1 ^= (remainder[offset + 5] & 0xFFL) << 40;
             case 5:
-                k1 ^= (key[offset + 4] & 0xFFL) << 32;
+                k1 ^= (remainder[offset + 4] & 0xFFL) << 32;
             case 4:
-                k1 ^= (key[offset + 3] & 0xFFL) << 24;
+                k1 ^= (remainder[offset + 3] & 0xFFL) << 24;
             case 3:
-                k1 ^= (key[offset + 2] & 0xFFL) << 16;
+                k1 ^= (remainder[offset + 2] & 0xFFL) << 16;
             case 2:
-                k1 ^= (key[offset + 1] & 0xFFL) << 8;
+                k1 ^= (remainder[offset + 1] & 0xFFL) << 8;
             case 1:
-                k1 ^= (key[offset] & 0xFFL);
+                k1 ^= (remainder[offset] & 0xFFL);
                 k1 *= C1;
                 k1 = Long.rotateLeft(k1, 31);
                 k1 *= C2;

+ 91 - 0
server/src/test/java/org/elasticsearch/common/hashing/Murmur3HasherTests.java

@@ -0,0 +1,91 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.common.hashing;
+
+import org.elasticsearch.common.hash.Murmur3Hasher;
+import org.elasticsearch.common.hash.MurmurHash3;
+import org.elasticsearch.test.ESTestCase;
+
+import java.nio.charset.StandardCharsets;
+
+import static org.hamcrest.Matchers.equalTo;
+
+public class Murmur3HasherTests extends ESTestCase {
+
+    public void testKnownValues() {
+        assertHash(0x629942693e10f867L, 0x92db0b82baeb5347L, "hell", 0);
+        assertHash(0xa78ddff5adae8d10L, 0x128900ef20900135L, "hello", 1);
+        assertHash(0x8a486b23f422e826L, 0xf962a2c58947765fL, "hello ", 2);
+        assertHash(0x2ea59f466f6bed8cL, 0xc610990acc428a17L, "hello w", 3);
+        assertHash(0x79f6305a386c572cL, 0x46305aed3483b94eL, "hello wo", 4);
+        assertHash(0xc2219d213ec1f1b5L, 0xa1d8e2e0a52785bdL, "hello wor", 5);
+        assertHash(0xe34bbc7bbc071b6cL, 0x7a433ca9c49a9347L, "The quick brown fox jumps over the lazy dog", 0);
+        assertHash(0x658ca970ff85269aL, 0x43fee3eaa68e5c3eL, "The quick brown fox jumps over the lazy cog", 0);
+    }
+
+    private static void assertHash(long lower, long upper, String inputString, long seed) {
+        MurmurHash3.Hash128 expected = new MurmurHash3.Hash128();
+        expected.h1 = lower;
+        expected.h2 = upper;
+
+        byte[] bytes = inputString.getBytes(StandardCharsets.UTF_8);
+        Murmur3Hasher mh = new Murmur3Hasher(seed);
+        mh.update(bytes);
+        MurmurHash3.Hash128 actual = Murmur3Hasher.toHash128(mh.digest());
+        assertHash(expected, actual);
+    }
+
+    private static void assertHash(MurmurHash3.Hash128 expected, MurmurHash3.Hash128 actual) {
+        assertEquals(expected.h1, actual.h1);
+        assertEquals(expected.h2, actual.h2);
+    }
+
+    public void testSingleVsSequentialMurmur3() {
+        final String inputString = randomAlphaOfLengthBetween(2000, 3000);
+        final int numSplits = randomIntBetween(2, 100); // should produce a good number of byte arrays both longer and shorter than 16
+        final int[] splits = new int[numSplits];
+        int totalLength = 0;
+        for (int k = 0; k < numSplits - 1; k++) {
+            splits[k] = randomIntBetween(1, Math.max(2, inputString.length() / numSplits * 2));
+            totalLength += splits[k];
+        }
+        splits[numSplits - 1] = inputString.length() - totalLength;
+        totalLength = 0;
+        byte[][] splitBytes = new byte[numSplits][];
+        for (int k = 0; k < numSplits - 1; k++) {
+            int end = Math.min(totalLength + splits[k], inputString.length());
+            if (totalLength < end) {
+                splitBytes[k] = inputString.substring(totalLength, end).getBytes(StandardCharsets.UTF_8);
+            } else {
+                splitBytes[k] = new byte[0];
+            }
+            totalLength += splits[k];
+        }
+        if (totalLength < inputString.length()) {
+            splitBytes[numSplits - 1] = inputString.substring(totalLength).getBytes(StandardCharsets.UTF_8);
+        } else {
+            splitBytes[numSplits - 1] = new byte[0];
+        }
+
+        final long seed = randomLong();
+        final byte[] allBytes = inputString.getBytes(StandardCharsets.UTF_8);
+        final MurmurHash3.Hash128 singleHash = MurmurHash3.hash128(allBytes, 0, allBytes.length, seed, new MurmurHash3.Hash128());
+
+        Murmur3Hasher mh = new Murmur3Hasher(seed);
+        totalLength = 0;
+        for (int k = 0; k < numSplits; k++) {
+            totalLength += splitBytes[k].length;
+            if (totalLength <= inputString.length()) {
+                mh.update(splitBytes[k]);
+            }
+        }
+        MurmurHash3.Hash128 sequentialHash = Murmur3Hasher.toHash128(mh.digest());
+        assertThat(singleHash, equalTo(sequentialHash));
+    }
+}