Browse Source

Introduce StreamInput.readSlicedBytesReference (#105262)

This is mainly added as a prerequisite to slicing doc sources out of
bulk index requests. The few usages for it added in this PR have limited
performance impact but demonstrate correct functioning of the
implementations.

Co-authored-by: David Turner <david.turner@elastic.co>
Armin Braun 1 year ago
parent
commit
cf27a501aa
21 changed files with 158 additions and 49 deletions
  1. 6 2
      libs/core/src/main/java/org/elasticsearch/core/CharArrays.java
  2. 10 0
      server/src/main/java/org/elasticsearch/common/bytes/BytesReferenceStreamInput.java
  3. 14 0
      server/src/main/java/org/elasticsearch/common/io/stream/ByteBufferStreamInput.java
  4. 6 0
      server/src/main/java/org/elasticsearch/common/io/stream/FilterStreamInput.java
  5. 18 6
      server/src/main/java/org/elasticsearch/common/io/stream/StreamInput.java
  6. 4 4
      server/src/main/java/org/elasticsearch/index/mapper/TimeSeriesIdFieldMapper.java
  7. 7 0
      server/src/main/java/org/elasticsearch/index/translog/BufferedChecksumStreamInput.java
  8. 1 1
      server/src/main/java/org/elasticsearch/index/translog/Translog.java
  9. 2 1
      server/src/main/java/org/elasticsearch/search/dfs/AggregatedDfs.java
  10. 1 1
      server/src/main/java/org/elasticsearch/search/dfs/DfsSearchResult.java
  11. 1 1
      server/src/main/java/org/elasticsearch/transport/TransportHandshaker.java
  12. 5 0
      server/src/test/java/org/elasticsearch/common/bytes/BytesArrayTests.java
  13. 12 0
      server/src/test/java/org/elasticsearch/common/bytes/CompositeBytesReferenceTests.java
  14. 7 0
      server/src/test/java/org/elasticsearch/common/bytes/PagedBytesReferenceTests.java
  15. 20 21
      server/src/test/java/org/elasticsearch/common/bytes/ReleasableBytesReferenceTests.java
  16. 28 0
      test/framework/src/main/java/org/elasticsearch/common/bytes/AbstractBytesReferenceTestCase.java
  17. 11 0
      test/framework/src/main/java/org/elasticsearch/common/bytes/ZeroBytesReferenceTests.java
  18. 2 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/action/DelegatePkiAuthenticationRequest.java
  19. 1 9
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/action/apikey/CreateApiKeyResponse.java
  20. 1 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/action/user/ChangePasswordRequest.java
  21. 1 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/action/user/PutUserRequest.java

+ 6 - 2
libs/core/src/main/java/org/elasticsearch/core/CharArrays.java

@@ -21,13 +21,17 @@ public final class CharArrays {
 
     private CharArrays() {}
 
+    public static char[] utf8BytesToChars(byte[] utf8Bytes) {
+        return utf8BytesToChars(utf8Bytes, 0, utf8Bytes.length);
+    }
+
     /**
      * Decodes the provided byte[] to a UTF-8 char[]. This is done while avoiding
      * conversions to String. The provided byte[] is not modified by this method, so
      * the caller needs to take care of clearing the value if it is sensitive.
      */
-    public static char[] utf8BytesToChars(byte[] utf8Bytes) {
-        final ByteBuffer byteBuffer = ByteBuffer.wrap(utf8Bytes);
+    public static char[] utf8BytesToChars(byte[] utf8Bytes, int offset, int len) {
+        final ByteBuffer byteBuffer = ByteBuffer.wrap(utf8Bytes, offset, len);
         final CharBuffer charBuffer = StandardCharsets.UTF_8.decode(byteBuffer);
         final char[] chars;
         if (charBuffer.hasArray()) {

+ 10 - 0
server/src/main/java/org/elasticsearch/common/bytes/BytesReferenceStreamInput.java

@@ -229,6 +229,16 @@ class BytesReferenceStreamInput extends StreamInput {
         }
     }
 
+    @Override
+    public BytesReference readSlicedBytesReference() throws IOException {
+        int len = readVInt();
+        int pos = offset();
+        if (len != skip(len)) {
+            throw new EOFException();
+        }
+        return bytesReference.slice(pos, len);
+    }
+
     @Override
     public boolean markSupported() {
         return true;

+ 14 - 0
server/src/main/java/org/elasticsearch/common/io/stream/ByteBufferStreamInput.java

@@ -7,6 +7,9 @@
  */
 package org.elasticsearch.common.io.stream;
 
+import org.elasticsearch.common.bytes.BytesArray;
+import org.elasticsearch.common.bytes.BytesReference;
+
 import java.io.EOFException;
 import java.io.IOException;
 import java.nio.BufferUnderflowException;
@@ -234,6 +237,17 @@ public class ByteBufferStreamInput extends StreamInput {
         }
     }
 
+    @Override
+    public BytesReference readSlicedBytesReference() throws IOException {
+        if (buffer.hasArray()) {
+            int len = readVInt();
+            var res = new BytesArray(buffer.array(), buffer.arrayOffset() + buffer.position(), len);
+            skip(len);
+            return res;
+        }
+        return super.readSlicedBytesReference();
+    }
+
     @Override
     public void mark(int readlimit) {
         buffer.mark();

+ 6 - 0
server/src/main/java/org/elasticsearch/common/io/stream/FilterStreamInput.java

@@ -9,6 +9,7 @@
 package org.elasticsearch.common.io.stream;
 
 import org.elasticsearch.TransportVersion;
+import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.common.bytes.ReleasableBytesReference;
 
 import java.io.EOFException;
@@ -51,6 +52,11 @@ public abstract class FilterStreamInput extends StreamInput {
         return delegate.readAllToReleasableBytesReference();
     }
 
+    @Override
+    public BytesReference readSlicedBytesReference() throws IOException {
+        return delegate.readSlicedBytesReference();
+    }
+
     @Override
     public short readShort() throws IOException {
         return delegate.readShort();

+ 18 - 6
server/src/main/java/org/elasticsearch/common/io/stream/StreamInput.java

@@ -119,6 +119,17 @@ public abstract class StreamInput extends InputStream {
         return ReleasableBytesReference.wrap(readBytesReference());
     }
 
+    /**
+     * Reads the same bytes returned by {@link #readReleasableBytesReference()} but does not retain a reference to these bytes.
+     * The returned {@link BytesReference} thus only contains valid content as long as the underlying buffer has not been released.
+     * This method should be preferred over {@link #readReleasableBytesReference()} when the returned reference is known to not be used
+     * past the lifetime of the underlying buffer as it requires fewer allocations and does not require a potentially costly reference
+     * count change.
+     */
+    public BytesReference readSlicedBytesReference() throws IOException {
+        return readBytesReference();
+    }
+
     /**
      * Checks if this {@link InputStream} supports {@link #readAllToReleasableBytesReference()}.
      */
@@ -525,13 +536,14 @@ public abstract class StreamInput extends InputStream {
     }
 
     public SecureString readSecureString() throws IOException {
-        BytesReference bytesRef = readBytesReference();
-        byte[] bytes = BytesReference.toBytes(bytesRef);
-        try {
-            return new SecureString(CharArrays.utf8BytesToChars(bytes));
-        } finally {
-            Arrays.fill(bytes, (byte) 0);
+        BytesReference bytesRef = readSlicedBytesReference();
+        final char[] chars;
+        if (bytesRef.hasArray()) {
+            chars = CharArrays.utf8BytesToChars(bytesRef.array(), bytesRef.arrayOffset(), bytesRef.length());
+        } else {
+            chars = CharArrays.utf8BytesToChars(BytesReference.toBytes(bytesRef));
         }
+        return new SecureString(chars);
     }
 
     public final float readFloat() throws IOException {

+ 4 - 4
server/src/main/java/org/elasticsearch/index/mapper/TimeSeriesIdFieldMapper.java

@@ -161,7 +161,7 @@ public class TimeSeriesIdFieldMapper extends MetadataFieldMapper {
      */
     public static Object encodeTsid(StreamInput in) {
         try {
-            return base64Encode(in.readBytesRef());
+            return base64Encode(in.readSlicedBytesReference().toBytesRef());
         } catch (IOException e) {
             throw new IllegalArgumentException("Unable to read tsid");
         }
@@ -359,7 +359,7 @@ public class TimeSeriesIdFieldMapper extends MetadataFieldMapper {
 
     private static String base64Encode(final BytesRef bytesRef) {
         byte[] bytes = new byte[bytesRef.length];
-        System.arraycopy(bytesRef.bytes, 0, bytes, 0, bytesRef.length);
+        System.arraycopy(bytesRef.bytes, bytesRef.offset, bytes, 0, bytesRef.length);
         return BASE64_ENCODER.encodeToString(bytes);
     }
 
@@ -379,7 +379,7 @@ public class TimeSeriesIdFieldMapper extends MetadataFieldMapper {
             for (int i = 0; i < size; i++) {
                 String name = null;
                 try {
-                    name = in.readBytesRef().utf8ToString();
+                    name = in.readSlicedBytesReference().utf8ToString();
                 } catch (AssertionError ae) {
                     throw new IllegalArgumentException("Error parsing keyword dimension: " + ae.getMessage(), ae);
                 }
@@ -389,7 +389,7 @@ public class TimeSeriesIdFieldMapper extends MetadataFieldMapper {
                     case (byte) 's' -> {
                         // parse a string
                         try {
-                            result.put(name, in.readBytesRef().utf8ToString());
+                            result.put(name, in.readSlicedBytesReference().utf8ToString());
                         } catch (AssertionError ae) {
                             throw new IllegalArgumentException("Error parsing keyword dimension: " + ae.getMessage(), ae);
                         }

+ 7 - 0
server/src/main/java/org/elasticsearch/index/translog/BufferedChecksumStreamInput.java

@@ -10,6 +10,7 @@ package org.elasticsearch.index.translog;
 
 import org.apache.lucene.store.BufferedChecksum;
 import org.elasticsearch.common.Numbers;
+import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.common.io.stream.FilterStreamInput;
 import org.elasticsearch.common.io.stream.StreamInput;
 
@@ -109,6 +110,12 @@ public final class BufferedChecksumStreamInput extends FilterStreamInput {
         return b;
     }
 
+    @Override
+    public BytesReference readSlicedBytesReference() throws IOException {
+        // TODO: support slicing here as well
+        return readBytesReference();
+    }
+
     @Override
     public boolean markSupported() {
         return delegate.markSupported();

+ 1 - 1
server/src/main/java/org/elasticsearch/index/translog/Translog.java

@@ -1315,7 +1315,7 @@ public class Translog extends AbstractIndexShardComponent implements IndexShardC
             if (format < FORMAT_NO_DOC_TYPE) {
                 final String docType = in.readString();
                 assert docType.equals(IdFieldMapper.NAME) : docType + " != " + IdFieldMapper.NAME;
-                in.readBytesRef(); // uid
+                in.readSlicedBytesReference(); // uid
             }
             long version = in.readLong();
             if (format < FORMAT_NO_VERSION_TYPE) {

+ 2 - 1
server/src/main/java/org/elasticsearch/search/dfs/AggregatedDfs.java

@@ -29,7 +29,8 @@ public class AggregatedDfs implements Writeable {
         int size = in.readVInt();
         termStatistics = new HashMap<>(size);
         for (int i = 0; i < size; i++) {
-            Term term = new Term(in.readString(), in.readBytesRef());
+            // term constructor copies the bytes so we can work with a slice
+            Term term = new Term(in.readString(), in.readSlicedBytesReference().toBytesRef());
             TermStatistics stats = new TermStatistics(in.readBytesRef(), in.readVLong(), DfsSearchResult.subOne(in.readVLong()));
             termStatistics.put(term, stats);
         }

+ 1 - 1
server/src/main/java/org/elasticsearch/search/dfs/DfsSearchResult.java

@@ -46,7 +46,7 @@ public final class DfsSearchResult extends SearchPhaseResult {
         } else {
             terms = new Term[termsSize];
             for (int i = 0; i < terms.length; i++) {
-                terms[i] = new Term(in.readString(), in.readBytesRef());
+                terms[i] = new Term(in.readString(), in.readSlicedBytesReference().toBytesRef());
             }
         }
         this.termStatistics = readTermStats(in, terms);

+ 1 - 1
server/src/main/java/org/elasticsearch/transport/TransportHandshaker.java

@@ -277,7 +277,7 @@ final class TransportHandshaker {
             super(streamInput);
             BytesReference remainingMessage;
             try {
-                remainingMessage = streamInput.readBytesReference();
+                remainingMessage = streamInput.readSlicedBytesReference();
             } catch (EOFException e) {
                 remainingMessage = null;
             }

+ 5 - 0
server/src/test/java/org/elasticsearch/common/bytes/BytesArrayTests.java

@@ -26,6 +26,11 @@ public class BytesArrayTests extends AbstractBytesReferenceTestCase {
         return newBytesReference(length, 0);
     }
 
+    @Override
+    protected BytesReference newBytesReference(byte[] content) {
+        return new BytesArray(content);
+    }
+
     private BytesReference newBytesReference(int length, int offset) throws IOException {
         // we know bytes stream output always creates a paged bytes reference, we use it to create randomized content
         final BytesStreamOutput out = new BytesStreamOutput(length + offset);

+ 12 - 0
server/src/test/java/org/elasticsearch/common/bytes/CompositeBytesReferenceTests.java

@@ -38,6 +38,18 @@ public class CompositeBytesReferenceTests extends AbstractBytesReferenceTestCase
         return ref;
     }
 
+    @Override
+    protected BytesReference newBytesReference(byte[] content) {
+        if (content.length > 1) {
+            int splitOffset = randomIntBetween(1, content.length - 1);
+            return CompositeBytesReference.of(
+                new BytesArray(content, 0, splitOffset),
+                new BytesArray(content, splitOffset, content.length - splitOffset)
+            );
+        }
+        return CompositeBytesReference.of(new BytesArray(content));
+    }
+
     private List<BytesReference> newRefList(int length) {
         int emptySlices = between(0, 10);
         List<BytesReference> referenceList = new ArrayList<>();

+ 7 - 0
server/src/test/java/org/elasticsearch/common/bytes/PagedBytesReferenceTests.java

@@ -38,6 +38,13 @@ public class PagedBytesReferenceTests extends AbstractBytesReferenceTestCase {
         return ref;
     }
 
+    @Override
+    protected BytesReference newBytesReference(byte[] content) {
+        ByteArray byteArray = bigarrays.newByteArray(content.length);
+        byteArray.set(0, content, 0, content.length);
+        return BytesReference.fromByteArray(byteArray, content.length);
+    }
+
     public void testToBytesRefMaterializedPages() throws IOException {
         // we need a length != (n * pagesize) to avoid page sharing at boundaries
         int length = 0;

+ 20 - 21
server/src/test/java/org/elasticsearch/common/bytes/ReleasableBytesReferenceTests.java

@@ -28,46 +28,45 @@ public class ReleasableBytesReferenceTests extends AbstractBytesReferenceTestCas
 
     @Override
     protected BytesReference newBytesReferenceWithOffsetOfZero(int length) throws IOException {
+        return newBytesReference(randomByteArrayOfLength(length));
+    }
+
+    @Override
+    protected BytesReference newBytesReference(byte[] content) throws IOException {
         BytesReference delegate;
         String composite = "composite";
         String paged = "paged";
         String array = "array";
         String type = randomFrom(composite, paged, array);
         if (array.equals(type)) {
-            final BytesStreamOutput out = new BytesStreamOutput(length);
-            for (int i = 0; i < length; i++) {
-                out.writeByte((byte) random().nextInt(1 << 8));
-            }
-            assertThat(length, equalTo(out.size()));
-            BytesArray ref = new BytesArray(out.bytes().toBytesRef().bytes, 0, length);
-            assertThat(length, equalTo(ref.length()));
-            assertThat(ref.length(), Matchers.equalTo(length));
+            final BytesStreamOutput out = new BytesStreamOutput(content.length);
+            out.writeBytes(content, 0, content.length);
+            assertThat(content.length, equalTo(out.size()));
+            BytesArray ref = new BytesArray(out.bytes().toBytesRef().bytes, 0, content.length);
+            assertThat(content.length, equalTo(ref.length()));
+            assertThat(ref.length(), Matchers.equalTo(content.length));
             delegate = ref;
         } else if (paged.equals(type)) {
-            ByteArray byteArray = bigarrays.newByteArray(length);
-            for (int i = 0; i < length; i++) {
-                byteArray.set(i, (byte) random().nextInt(1 << 8));
-            }
-            assertThat(byteArray.size(), Matchers.equalTo((long) length));
-            BytesReference ref = BytesReference.fromByteArray(byteArray, length);
-            assertThat(ref.length(), Matchers.equalTo(length));
+            ByteArray byteArray = bigarrays.newByteArray(content.length);
+            byteArray.set(0, content, 0, content.length);
+            assertThat(byteArray.size(), Matchers.equalTo((long) content.length));
+            BytesReference ref = BytesReference.fromByteArray(byteArray, content.length);
+            assertThat(ref.length(), Matchers.equalTo(content.length));
             delegate = ref;
         } else {
             assert composite.equals(type);
             List<BytesReference> referenceList = new ArrayList<>();
-            for (int i = 0; i < length;) {
-                int remaining = length - i;
+            for (int i = 0; i < content.length;) {
+                int remaining = content.length - i;
                 int sliceLength = randomIntBetween(1, remaining);
                 ReleasableBytesStreamOutput out = new ReleasableBytesStreamOutput(sliceLength, bigarrays);
-                for (int j = 0; j < sliceLength; j++) {
-                    out.writeByte((byte) random().nextInt(1 << 8));
-                }
+                out.writeBytes(content, content.length - remaining, sliceLength);
                 assertThat(sliceLength, equalTo(out.size()));
                 referenceList.add(out.bytes());
                 i += sliceLength;
             }
             BytesReference ref = CompositeBytesReference.of(referenceList.toArray(new BytesReference[0]));
-            assertThat(length, equalTo(ref.length()));
+            assertThat(content.length, equalTo(ref.length()));
             delegate = ref;
         }
         return ReleasableBytesReference.wrap(delegate);

+ 28 - 0
test/framework/src/main/java/org/elasticsearch/common/bytes/AbstractBytesReferenceTestCase.java

@@ -540,6 +540,8 @@ public abstract class AbstractBytesReferenceTestCase extends ESTestCase {
 
     protected abstract BytesReference newBytesReferenceWithOffsetOfZero(int length) throws IOException;
 
+    protected abstract BytesReference newBytesReference(byte[] content) throws IOException;
+
     public void testCompareTo() throws IOException {
         final int iters = randomIntBetween(5, 10);
         for (int i = 0; i < iters; i++) {
@@ -682,4 +684,30 @@ public abstract class AbstractBytesReferenceTestCase extends ESTestCase {
         }
         assertArrayEquals(bytes, BytesReference.toBytes(bytesReference));
     }
+
+    public void testReadSlices() throws IOException {
+        final int refs = randomIntBetween(1, 1024);
+        final BytesReference bytesReference;
+        try (BytesStreamOutput out = new BytesStreamOutput()) {
+            for (int i = 0; i < refs; i++) {
+                out.writeBytesReference(newBytesReference(randomIntBetween(1, 1024)));
+            }
+            bytesReference = newBytesReference(out.copyBytes().array());
+        }
+        try (StreamInput input1 = bytesReference.streamInput(); StreamInput input2 = bytesReference.streamInput()) {
+            for (int i = 0; i < refs; i++) {
+                boolean sliceLeft = randomBoolean();
+                BytesReference left = sliceLeft ? input1.readSlicedBytesReference() : input1.readBytesReference();
+                if (sliceLeft && bytesReference.hasArray()) {
+                    assertSame(left.array(), bytesReference.array());
+                }
+                boolean sliceRight = randomBoolean();
+                BytesReference right = sliceRight ? input2.readSlicedBytesReference() : input2.readBytesReference();
+                assertEquals(left, right);
+                if (sliceRight && bytesReference.hasArray()) {
+                    assertSame(right.array(), right.array());
+                }
+            }
+        }
+    }
 }

+ 11 - 0
test/framework/src/main/java/org/elasticsearch/common/bytes/ZeroBytesReferenceTests.java

@@ -8,6 +8,8 @@
 
 package org.elasticsearch.common.bytes;
 
+import java.io.IOException;
+
 import static org.hamcrest.Matchers.containsString;
 
 public class ZeroBytesReferenceTests extends AbstractBytesReferenceTestCase {
@@ -22,6 +24,11 @@ public class ZeroBytesReferenceTests extends AbstractBytesReferenceTestCase {
         return new ZeroBytesReference(length);
     }
 
+    @Override
+    protected BytesReference newBytesReference(byte[] content) {
+        throw new AssertionError("can't build a zero bytes reference with arbitrary content");
+    }
+
     @Override
     public void testToBytesRefSharedPage() {
         // ZeroBytesReference doesn't share pages
@@ -44,4 +51,8 @@ public class ZeroBytesReferenceTests extends AbstractBytesReferenceTestCase {
         );
     }
 
+    @Override
+    public void testReadSlices() throws IOException {
+        // irrelevant for zero bytes reference
+    }
 }

+ 2 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/action/DelegatePkiAuthenticationRequest.java

@@ -21,6 +21,7 @@ import org.elasticsearch.xpack.core.ssl.CertParsingUtils;
 
 import java.io.ByteArrayInputStream;
 import java.io.IOException;
+import java.io.InputStream;
 import java.security.cert.CertificateEncodingException;
 import java.security.cert.CertificateException;
 import java.security.cert.CertificateFactory;
@@ -75,7 +76,7 @@ public final class DelegatePkiAuthenticationRequest extends ActionRequest implem
         try {
             final CertificateFactory certificateFactory = CertificateFactory.getInstance("X.509");
             certificateChain = input.readCollectionAsImmutableList(in -> {
-                try (ByteArrayInputStream bis = new ByteArrayInputStream(in.readByteArray())) {
+                try (InputStream bis = in.readSlicedBytesReference().streamInput()) {
                     return (X509Certificate) certificateFactory.generateCertificate(bis);
                 } catch (CertificateException e) {
                     throw new IOException(e);

+ 1 - 9
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/action/apikey/CreateApiKeyResponse.java

@@ -70,15 +70,7 @@ public final class CreateApiKeyResponse extends ActionResponse implements ToXCon
         super(in);
         this.name = in.readString();
         this.id = in.readString();
-        byte[] bytes = null;
-        try {
-            bytes = in.readByteArray();
-            this.key = new SecureString(CharArrays.utf8BytesToChars(bytes));
-        } finally {
-            if (bytes != null) {
-                Arrays.fill(bytes, (byte) 0);
-            }
-        }
+        this.key = in.readSecureString();
         this.expiration = in.readOptionalInstant();
     }
 

+ 1 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/action/user/ChangePasswordRequest.java

@@ -33,7 +33,7 @@ public class ChangePasswordRequest extends ActionRequest implements UserRequest,
     public ChangePasswordRequest(StreamInput in) throws IOException {
         super(in);
         username = in.readString();
-        passwordHash = CharArrays.utf8BytesToChars(BytesReference.toBytes(in.readBytesReference()));
+        passwordHash = CharArrays.utf8BytesToChars(BytesReference.toBytes(in.readSlicedBytesReference()));
         refreshPolicy = RefreshPolicy.readFrom(in);
     }
 

+ 1 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/action/user/PutUserRequest.java

@@ -163,7 +163,7 @@ public class PutUserRequest extends ActionRequest implements UserRequest, WriteR
     }
 
     private static char[] readCharArrayFromStream(StreamInput in) throws IOException {
-        BytesReference charBytesRef = in.readBytesReference();
+        BytesReference charBytesRef = in.readSlicedBytesReference();
         if (charBytesRef == BytesArray.EMPTY) {
             return null;
         } else {