瀏覽代碼

Optimize Reading vInt and vLong from BytesReference (#71522)

Same optimization as in #71181 (also used by buffering Lucene DataInput implementations) but for the variable length encodings.
Benchmarks show a ~50% speedup for the benchmarked mix of values for `vLong`. Generally this change helps the most with large values but shows a slight speedup even for the 1 byte length case by avoiding some indirection and bounds checking.
Armin Braun 4 年之前
父節點
當前提交
2540c7489b

+ 1 - 1
benchmarks/src/main/java/org/elasticsearch/common/bytes/PagedBytesReferenceBenchmark.java → benchmarks/src/main/java/org/elasticsearch/common/bytes/PagedBytesReferenceReadLongBenchmark.java

@@ -32,7 +32,7 @@ import java.util.concurrent.TimeUnit;
 @OutputTimeUnit(TimeUnit.MILLISECONDS)
 @State(Scope.Thread)
 @Fork(value = 1)
-public class PagedBytesReferenceBenchmark {
+public class PagedBytesReferenceReadLongBenchmark {
 
     @Param(value = { "1" })
     private int dataMb;

+ 65 - 0
benchmarks/src/main/java/org/elasticsearch/common/bytes/PagedBytesReferenceReadVIntBenchmark.java

@@ -0,0 +1,65 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+package org.elasticsearch.common.bytes;
+
+import org.elasticsearch.common.io.stream.BytesStreamOutput;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.openjdk.jmh.annotations.Benchmark;
+import org.openjdk.jmh.annotations.BenchmarkMode;
+import org.openjdk.jmh.annotations.Fork;
+import org.openjdk.jmh.annotations.Measurement;
+import org.openjdk.jmh.annotations.Mode;
+import org.openjdk.jmh.annotations.OutputTimeUnit;
+import org.openjdk.jmh.annotations.Param;
+import org.openjdk.jmh.annotations.Scope;
+import org.openjdk.jmh.annotations.Setup;
+import org.openjdk.jmh.annotations.State;
+import org.openjdk.jmh.annotations.Warmup;
+
+import java.io.IOException;
+import java.util.concurrent.TimeUnit;
+
+@Warmup(iterations = 5)
+@Measurement(iterations = 7)
+@BenchmarkMode(Mode.AverageTime)
+@OutputTimeUnit(TimeUnit.MILLISECONDS)
+@State(Scope.Thread)
+@Fork(value = 1)
+public class PagedBytesReferenceReadVIntBenchmark {
+
+    private BytesReference pagedBytes;
+
+    @Param(value = { "10000000" })
+    int entries;
+
+    @Setup
+    public void initResults() throws IOException {
+        final BytesStreamOutput tmp = new BytesStreamOutput();
+        for (int i = 0; i < entries / 2; i++) {
+            tmp.writeVInt(i);
+        }
+        for (int i = 0; i < entries / 2; i++) {
+            tmp.writeVInt(Integer.MAX_VALUE - i);
+        }
+        pagedBytes = tmp.bytes();
+        if (pagedBytes instanceof PagedBytesReference == false) {
+            throw new AssertionError("expected paged PagedBytesReference but saw [" + pagedBytes.getClass() + "]");
+        }
+    }
+
+    @Benchmark
+    public int readVInt() throws IOException {
+        int res = 0;
+        try (StreamInput streamInput = pagedBytes.streamInput()) {
+            for (int i = 0; i < entries; i++) {
+                res = res ^ streamInput.readVInt();
+            }
+        }
+        return res;
+    }
+}

+ 65 - 0
benchmarks/src/main/java/org/elasticsearch/common/bytes/PagedBytesReferenceReadVLongBenchmark.java

@@ -0,0 +1,65 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+package org.elasticsearch.common.bytes;
+
+import org.elasticsearch.common.io.stream.BytesStreamOutput;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.openjdk.jmh.annotations.Benchmark;
+import org.openjdk.jmh.annotations.BenchmarkMode;
+import org.openjdk.jmh.annotations.Fork;
+import org.openjdk.jmh.annotations.Measurement;
+import org.openjdk.jmh.annotations.Mode;
+import org.openjdk.jmh.annotations.OutputTimeUnit;
+import org.openjdk.jmh.annotations.Param;
+import org.openjdk.jmh.annotations.Scope;
+import org.openjdk.jmh.annotations.Setup;
+import org.openjdk.jmh.annotations.State;
+import org.openjdk.jmh.annotations.Warmup;
+
+import java.io.IOException;
+import java.util.concurrent.TimeUnit;
+
+@Warmup(iterations = 5)
+@Measurement(iterations = 7)
+@BenchmarkMode(Mode.AverageTime)
+@OutputTimeUnit(TimeUnit.MILLISECONDS)
+@State(Scope.Thread)
+@Fork(value = 1)
+public class PagedBytesReferenceReadVLongBenchmark {
+
+    private BytesReference pagedBytes;
+
+    @Param(value = { "10000000" })
+    int entries;
+
+    @Setup
+    public void initResults() throws IOException {
+        final BytesStreamOutput tmp = new BytesStreamOutput();
+        for (int i = 0; i < entries / 2; i++) {
+            tmp.writeVLong(i);
+        }
+        for (int i = 0; i < entries / 2; i++) {
+            tmp.writeVLong(Long.MAX_VALUE - i);
+        }
+        pagedBytes = tmp.bytes();
+        if (pagedBytes instanceof PagedBytesReference == false) {
+            throw new AssertionError("expected paged PagedBytesReference but saw [" + pagedBytes.getClass() + "]");
+        }
+    }
+
+    @Benchmark
+    public long readVLong() throws IOException {
+        long res = 0;
+        try (StreamInput streamInput = pagedBytes.streamInput()) {
+            for (int i = 0; i < entries; i++) {
+                res = res ^ streamInput.readVLong();
+            }
+        }
+        return res;
+    }
+}

+ 96 - 0
server/src/main/java/org/elasticsearch/common/bytes/BytesReferenceStreamInput.java

@@ -78,6 +78,102 @@ class BytesReferenceStreamInput extends StreamInput {
         }
     }
 
+    @Override
+    public int readVInt() throws IOException {
+        if (slice.length - sliceIndex >= 5) {
+            final byte[] buf = slice.bytes;
+            final int offset = slice.offset;
+            byte b = buf[offset + sliceIndex++];
+            if (b >= 0) {
+                return b;
+            }
+            int i = b & 0x7F;
+            b = buf[offset + sliceIndex++];
+            i |= (b & 0x7F) << 7;
+            if (b >= 0) {
+                return i;
+            }
+            b = buf[offset + sliceIndex++];
+            i |= (b & 0x7F) << 14;
+            if (b >= 0) {
+                return i;
+            }
+            b = buf[offset + sliceIndex++];
+            i |= (b & 0x7F) << 21;
+            if (b >= 0) {
+                return i;
+            }
+            b = buf[offset + sliceIndex++];
+            i |= (b & 0x0F) << 28;
+            if ((b & 0xF0) == 0) {
+                return i;
+            }
+            throwOnBrokenVInt(b, i);
+        }
+        return super.readVInt();
+    }
+
+    @Override
+    public long readVLong() throws IOException {
+        if (slice.length - sliceIndex >= 10) {
+            final byte[] buf = slice.bytes;
+            final int offset = slice.offset;
+            byte b = buf[offset + sliceIndex++];
+            long i = b & 0x7FL;
+            if ((b & 0x80) == 0) {
+                return i;
+            }
+            b = buf[offset + sliceIndex++];
+            i |= (b & 0x7FL) << 7;
+            if ((b & 0x80) == 0) {
+                return i;
+            }
+            b = buf[offset + sliceIndex++];
+            i |= (b & 0x7FL) << 14;
+            if ((b & 0x80) == 0) {
+                return i;
+            }
+            b = buf[offset + sliceIndex++];
+            i |= (b & 0x7FL) << 21;
+            if ((b & 0x80) == 0) {
+                return i;
+            }
+            b = buf[offset + sliceIndex++];
+            i |= (b & 0x7FL) << 28;
+            if ((b & 0x80) == 0) {
+                return i;
+            }
+            b = buf[offset + sliceIndex++];
+            i |= (b & 0x7FL) << 35;
+            if ((b & 0x80) == 0) {
+                return i;
+            }
+            b = buf[offset + sliceIndex++];
+            i |= (b & 0x7FL) << 42;
+            if ((b & 0x80) == 0) {
+                return i;
+            }
+            b = buf[offset + sliceIndex++];
+            i |= (b & 0x7FL) << 49;
+            if ((b & 0x80) == 0) {
+                return i;
+            }
+            b = buf[offset + sliceIndex++];
+            i |= ((b & 0x7FL) << 56);
+            if ((b & 0x80) == 0) {
+                return i;
+            }
+            b = buf[offset + sliceIndex++];
+            if (b != 0 && b != 1) {
+                throwOnBrokenVLong(b, i);
+            }
+            i |= ((long) b) << 63;
+            return i;
+        } else {
+            return super.readVLong();
+        }
+    }
+
     protected int offset() {
         return sliceStartOffset + sliceIndex;
     }

+ 10 - 0
server/src/main/java/org/elasticsearch/common/io/stream/FilterStreamInput.java

@@ -55,6 +55,16 @@ public abstract class FilterStreamInput extends StreamInput {
         return delegate.readLong();
     }
 
+    @Override
+    public int readVInt() throws IOException {
+        return delegate.readVInt();
+    }
+
+    @Override
+    public long readVLong() throws IOException {
+        return delegate.readVLong();
+    }
+
     @Override
     public void reset() throws IOException {
         delegate.reset();

+ 18 - 2
server/src/main/java/org/elasticsearch/common/io/stream/StreamInput.java

@@ -210,6 +210,10 @@ public abstract class StreamInput extends InputStream {
      * using {@link #readInt}
      */
     public int readVInt() throws IOException {
+        return readVIntSlow();
+    }
+
+    protected final int readVIntSlow() throws IOException {
         byte b = readByte();
         int i = b & 0x7F;
         if ((b & 0x80) == 0) {
@@ -232,11 +236,15 @@ public abstract class StreamInput extends InputStream {
         }
         b = readByte();
         if ((b & 0x80) != 0) {
-            throw new IOException("Invalid vInt ((" + Integer.toHexString(b) + " & 0x7f) << 28) | " + Integer.toHexString(i));
+            throwOnBrokenVInt(b, i);
         }
         return i | ((b & 0x7F) << 28);
     }
 
+    protected static void throwOnBrokenVInt(byte b, int accumulated) throws IOException {
+        throw new IOException("Invalid vInt ((" + Integer.toHexString(b) + " & 0x7f) << 28) | " + Integer.toHexString(accumulated));
+    }
+
     /**
      * Reads eight bytes and returns a long.
      */
@@ -249,6 +257,10 @@ public abstract class StreamInput extends InputStream {
      * are encoded in ten bytes so prefer {@link #readLong()} or {@link #readZLong()} for negative numbers.
      */
     public long readVLong() throws IOException {
+        return readVLongSlow();
+    }
+
+    protected final long readVLongSlow() throws IOException {
         byte b = readByte();
         long i = b & 0x7FL;
         if ((b & 0x80) == 0) {
@@ -296,12 +308,16 @@ public abstract class StreamInput extends InputStream {
         }
         b = readByte();
         if (b != 0 && b != 1) {
-            throw new IOException("Invalid vlong (" + Integer.toHexString(b) + " << 63) | " + Long.toHexString(i));
+            throwOnBrokenVLong(b, i);
         }
         i |= ((long) b) << 63;
         return i;
     }
 
+    protected static void throwOnBrokenVLong(byte b, long accumulated) throws IOException {
+        throw new IOException("Invalid vlong (" + Integer.toHexString(b) + " << 63) | " + Long.toHexString(accumulated));
+    }
+
     @Nullable
     public Long readOptionalVLong() throws IOException {
         if (readBoolean()) {

+ 10 - 0
server/src/main/java/org/elasticsearch/index/translog/BufferedChecksumStreamInput.java

@@ -83,6 +83,16 @@ public final class BufferedChecksumStreamInput extends FilterStreamInput {
         return Numbers.bytesToLong(buf, 0);
     }
 
+    @Override
+    public int readVInt() throws IOException {
+        return readVIntSlow();
+    }
+
+    @Override
+    public long readVLong() throws IOException {
+        return readVLongSlow();
+    }
+
     @Override
     public void reset() throws IOException {
         delegate.reset();