1
0
Эх сурвалжийг харах

Add int indicating size of transport header (#48884)

Currently we do not know the size of the transport header (map of
request response headers, features array, and action name). This means
that we must read the entire transport message to dependably act on the
headers. This commit adds an int indicating the size of the transport
headers. With this addition we can act upon the headers prior to reading
the entire message.
Tim Brooks 5 жил өмнө
parent
commit
8c2dda90c0

+ 34 - 28
server/src/main/java/org/elasticsearch/transport/InboundMessage.java

@@ -19,9 +19,7 @@
 package org.elasticsearch.transport;
 
 import org.elasticsearch.Version;
-import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.bytes.BytesReference;
-import org.elasticsearch.common.compress.Compressor;
 import org.elasticsearch.common.compress.CompressorFactory;
 import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput;
 import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
@@ -58,10 +56,6 @@ public abstract class InboundMessage extends NetworkMessage implements Closeable
         }
 
         InboundMessage deserialize(BytesReference reference) throws IOException {
-            int messageLengthBytes = reference.length();
-            final int totalMessageSize = messageLengthBytes + TcpHeader.MARKER_BYTES_SIZE + TcpHeader.MESSAGE_LENGTH_SIZE;
-            // we have additional bytes to read, outside of the header
-            boolean hasMessageBytesToRead = (totalMessageSize - TcpHeader.HEADER_SIZE) > 0;
             StreamInput streamInput = reference.streamInput();
             boolean success = false;
             try (ThreadContext.StoredContext existing = threadContext.stashContext()) {
@@ -70,23 +64,13 @@ public abstract class InboundMessage extends NetworkMessage implements Closeable
                 Version remoteVersion = Version.fromId(streamInput.readInt());
                 final boolean isHandshake = TransportStatus.isHandshake(status);
                 ensureVersionCompatibility(remoteVersion, version, isHandshake);
-                if (TransportStatus.isCompress(status) && hasMessageBytesToRead && streamInput.available() > 0) {
-                    Compressor compressor = getCompressor(reference);
-                    if (compressor == null) {
-                        int maxToRead = Math.min(reference.length(), 10);
-                        StringBuilder sb = new StringBuilder("stream marked as compressed, but no compressor found, first [")
-                            .append(maxToRead).append("] content bytes out of [").append(reference.length())
-                            .append("] readable bytes with message size [").append(messageLengthBytes).append("] ").append("] are [");
-                        for (int i = 0; i < maxToRead; i++) {
-                            sb.append(reference.get(i)).append(",");
-                        }
-                        sb.append("]");
-                        throw new IllegalStateException(sb.toString());
-                    }
-                    streamInput = compressor.streamInput(streamInput);
+
+                if (remoteVersion.onOrAfter(TcpHeader.VERSION_WITH_HEADER_SIZE)) {
+                    // Consume the variable header size
+                    streamInput.readInt();
+                } else {
+                    streamInput = decompressingStream(status, remoteVersion, streamInput);
                 }
-                streamInput = new NamedWriteableAwareStreamInput(streamInput, namedWriteableRegistry);
-                streamInput.setVersion(remoteVersion);
 
                 threadContext.readHeaders(streamInput);
 
@@ -97,8 +81,17 @@ public abstract class InboundMessage extends NetworkMessage implements Closeable
                         streamInput.readStringArray();
                     }
                     final String action = streamInput.readString();
+
+                    if (remoteVersion.onOrAfter(TcpHeader.VERSION_WITH_HEADER_SIZE)) {
+                        streamInput = decompressingStream(status, remoteVersion, streamInput);
+                    }
+                    streamInput = namedWriteableStream(streamInput, remoteVersion);
                     message = new Request(threadContext, remoteVersion, status, requestId, action, streamInput);
                 } else {
+                    if (remoteVersion.onOrAfter(TcpHeader.VERSION_WITH_HEADER_SIZE)) {
+                        streamInput = decompressingStream(status, remoteVersion, streamInput);
+                    }
+                    streamInput = namedWriteableStream(streamInput, remoteVersion);
                     message = new Response(threadContext, remoteVersion, status, requestId, streamInput);
                 }
                 success = true;
@@ -109,13 +102,26 @@ public abstract class InboundMessage extends NetworkMessage implements Closeable
                 }
             }
         }
-    }
 
-    @Nullable
-    static Compressor getCompressor(BytesReference message) {
-        final int offset = TcpHeader.REQUEST_ID_SIZE + TcpHeader.STATUS_SIZE + TcpHeader.VERSION_ID_SIZE;
-        return CompressorFactory.COMPRESSOR.isCompressed(message.slice(offset, message.length() - offset))
-            ? CompressorFactory.COMPRESSOR : null;
+        static StreamInput decompressingStream(byte status, Version remoteVersion, StreamInput streamInput) throws IOException {
+            if (TransportStatus.isCompress(status) && streamInput.available() > 0) {
+                try {
+                    StreamInput decompressor = CompressorFactory.COMPRESSOR.streamInput(streamInput);
+                    decompressor.setVersion(remoteVersion);
+                    return decompressor;
+                } catch (IllegalArgumentException e) {
+                    throw new IllegalStateException("stream marked as compressed, but is missing deflate header");
+                }
+            } else {
+                return streamInput;
+            }
+        }
+
+        private StreamInput namedWriteableStream(StreamInput delegate, Version remoteVersion) {
+            NamedWriteableAwareStreamInput streamInput = new NamedWriteableAwareStreamInput(delegate, namedWriteableRegistry);
+            streamInput.setVersion(remoteVersion);
+            return streamInput;
+        }
     }
 
     @Override

+ 24 - 7
server/src/main/java/org/elasticsearch/transport/OutboundMessage.java

@@ -24,6 +24,7 @@ import org.elasticsearch.common.bytes.BytesArray;
 import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.common.bytes.CompositeBytesReference;
 import org.elasticsearch.common.io.stream.BytesStreamOutput;
+import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
 
@@ -41,20 +42,36 @@ abstract class OutboundMessage extends NetworkMessage {
     BytesReference serialize(BytesStreamOutput bytesStream) throws IOException {
         storedContext.restore();
         bytesStream.setVersion(version);
-        bytesStream.skip(TcpHeader.HEADER_SIZE);
+        bytesStream.skip(TcpHeader.headerSize(version));
 
         // The compressible bytes stream will not close the underlying bytes stream
         BytesReference reference;
+        int variableHeaderLength = -1;
+        final long preHeaderPosition = bytesStream.position();
+
+        if (version.onOrAfter(TcpHeader.VERSION_WITH_HEADER_SIZE)) {
+            writeVariableHeader(bytesStream);
+            variableHeaderLength = Math.toIntExact(bytesStream.position() - preHeaderPosition);
+        }
+
         try (CompressibleBytesOutputStream stream = new CompressibleBytesOutputStream(bytesStream, TransportStatus.isCompress(status))) {
             stream.setVersion(version);
-            threadContext.writeTo(stream);
+            if (variableHeaderLength == -1) {
+                writeVariableHeader(stream);
+            }
             reference = writeMessage(stream);
         }
+
         bytesStream.seek(0);
-        TcpHeader.writeHeader(bytesStream, requestId, status, version, reference.length() - TcpHeader.HEADER_SIZE);
+        final int contentSize = reference.length() - TcpHeader.headerSize(version);
+        TcpHeader.writeHeader(bytesStream, requestId, status, version, contentSize, variableHeaderLength);
         return reference;
     }
 
+    protected void writeVariableHeader(StreamOutput stream) throws IOException {
+        threadContext.writeTo(stream);
+    }
+
     protected BytesReference writeMessage(CompressibleBytesOutputStream stream) throws IOException {
         final BytesReference zeroCopyBuffer;
         if (message instanceof BytesTransportRequest) {
@@ -92,13 +109,13 @@ abstract class OutboundMessage extends NetworkMessage {
         }
 
         @Override
-        protected BytesReference writeMessage(CompressibleBytesOutputStream out) throws IOException {
+        protected void writeVariableHeader(StreamOutput stream) throws IOException {
+            super.writeVariableHeader(stream);
             if (version.before(Version.V_8_0_0)) {
                 // empty features array
-                out.writeStringArray(Strings.EMPTY_ARRAY);
+                stream.writeStringArray(Strings.EMPTY_ARRAY);
             }
-            out.writeString(action);
-            return super.writeMessage(out);
+            stream.writeString(action);
         }
 
         private static byte setStatus(boolean compress, boolean isHandshake, Writeable message) {

+ 29 - 4
server/src/main/java/org/elasticsearch/transport/TcpHeader.java

@@ -25,7 +25,11 @@ import org.elasticsearch.common.io.stream.StreamOutput;
 import java.io.IOException;
 
 public class TcpHeader {
-    public static final int MARKER_BYTES_SIZE = 2 * 1;
+
+    // TODO: Change to 7.6 after backport
+    public static final Version VERSION_WITH_HEADER_SIZE = Version.V_8_0_0;
+
+    public static final int MARKER_BYTES_SIZE = 2;
 
     public static final int MESSAGE_LENGTH_SIZE = 4;
 
@@ -35,15 +39,36 @@ public class TcpHeader {
 
     public static final int VERSION_ID_SIZE = 4;
 
-    public static final int HEADER_SIZE = MARKER_BYTES_SIZE + MESSAGE_LENGTH_SIZE + REQUEST_ID_SIZE + STATUS_SIZE + VERSION_ID_SIZE;
+    public static final int VARIABLE_HEADER_SIZE = 4;
+
+    private static final int PRE_76_HEADER_SIZE = MARKER_BYTES_SIZE + MESSAGE_LENGTH_SIZE + REQUEST_ID_SIZE + STATUS_SIZE + VERSION_ID_SIZE;
+
+    private static final int HEADER_SIZE = PRE_76_HEADER_SIZE + VARIABLE_HEADER_SIZE;
+
+    public static int headerSize(Version version) {
+        if (version.onOrAfter(VERSION_WITH_HEADER_SIZE)) {
+            return HEADER_SIZE;
+        } else {
+            return PRE_76_HEADER_SIZE;
+        }
+    }
 
-    public static void writeHeader(StreamOutput output, long requestId, byte status, Version version, int messageSize) throws IOException {
+    public static void writeHeader(StreamOutput output, long requestId, byte status, Version version, int contentSize,
+                                   int variableHeaderSize) throws IOException {
         output.writeByte((byte)'E');
         output.writeByte((byte)'S');
         // write the size, the size indicates the remaining message size, not including the size int
-        output.writeInt(messageSize + REQUEST_ID_SIZE + STATUS_SIZE + VERSION_ID_SIZE);
+        if (version.onOrAfter(VERSION_WITH_HEADER_SIZE)) {
+            output.writeInt(contentSize + REQUEST_ID_SIZE + STATUS_SIZE + VERSION_ID_SIZE + VARIABLE_HEADER_SIZE);
+        } else {
+            output.writeInt(contentSize + REQUEST_ID_SIZE + STATUS_SIZE + VERSION_ID_SIZE);
+        }
         output.writeLong(requestId);
         output.writeByte(status);
         output.writeInt(version.id);
+        if (version.onOrAfter(VERSION_WITH_HEADER_SIZE)) {
+            assert variableHeaderSize != -1 : "Variable header size not set";
+            output.writeInt(variableHeaderSize);
+        }
     }
 }

+ 12 - 16
server/src/main/java/org/elasticsearch/transport/TransportLogger.java

@@ -22,8 +22,6 @@ import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 import org.elasticsearch.Version;
 import org.elasticsearch.common.bytes.BytesReference;
-import org.elasticsearch.common.compress.Compressor;
-import org.elasticsearch.common.compress.NotCompressedException;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
 import org.elasticsearch.core.internal.io.IOUtils;
@@ -77,26 +75,24 @@ public final class TransportLogger {
                 final byte status = streamInput.readByte();
                 final boolean isRequest = TransportStatus.isRequest(status);
                 final String type = isRequest ? "request" : "response";
-                final String version = Version.fromId(streamInput.readInt()).toString();
+                Version version = Version.fromId(streamInput.readInt());
                 sb.append(" [length: ").append(messageLengthWithHeader);
                 sb.append(", request id: ").append(requestId);
                 sb.append(", type: ").append(type);
                 sb.append(", version: ").append(version);
 
-                if (isRequest) {
-                    if (TransportStatus.isCompress(status)) {
-                        Compressor compressor;
-                        compressor = InboundMessage.getCompressor(message);
-                        if (compressor == null) {
-                            throw new IllegalStateException(new NotCompressedException());
-                        }
-                        streamInput = compressor.streamInput(streamInput);
-                    }
+                if (version.onOrAfter(TcpHeader.VERSION_WITH_HEADER_SIZE)) {
+                    sb.append(", header size: ").append(streamInput.readInt()).append('B');
+                } else {
+                    streamInput = InboundMessage.Reader.decompressingStream(status, version, streamInput);
+                }
+
+                // read and discard headers
+                ThreadContext.readHeadersFromStream(streamInput);
 
-                    // read and discard headers
-                    ThreadContext.readHeadersFromStream(streamInput);
-                    if (streamInput.getVersion().before(Version.V_8_0_0)) {
-                        // discard the features
+                if (isRequest) {
+                    if (version.before(Version.V_8_0_0)) {
+                        // discard features
                         streamInput.readStringArray();
                     }
                     sb.append(", action: ").append(streamInput.readString());

+ 2 - 8
server/src/test/java/org/elasticsearch/transport/InboundMessageTests.java

@@ -32,9 +32,7 @@ import org.elasticsearch.test.VersionUtils;
 import org.hamcrest.Matchers;
 
 import java.io.IOException;
-import java.util.Arrays;
 import java.util.Collections;
-import java.util.HashSet;
 
 public class InboundMessageTests extends ESTestCase {
 
@@ -42,7 +40,6 @@ public class InboundMessageTests extends ESTestCase {
     private final NamedWriteableRegistry registry = new NamedWriteableRegistry(Collections.emptyList());
 
     public void testReadRequest() throws IOException {
-        String[] features = {"feature1", "feature2"};
         String value = randomAlphaOfLength(10);
         Message message = new Message(value);
         String action = randomAlphaOfLength(10);
@@ -81,7 +78,6 @@ public class InboundMessageTests extends ESTestCase {
     }
 
     public void testReadResponse() throws IOException {
-        HashSet<String> features = new HashSet<>(Arrays.asList("feature1", "feature2"));
         String value = randomAlphaOfLength(10);
         Message message = new Message(value);
         long requestId = randomLong();
@@ -118,7 +114,6 @@ public class InboundMessageTests extends ESTestCase {
     }
 
     public void testReadErrorResponse() throws IOException {
-        HashSet<String> features = new HashSet<>(Arrays.asList("feature1", "feature2"));
         RemoteTransportException exception = new RemoteTransportException("error", new IOException());
         long requestId = randomLong();
         boolean isHandshake = randomBoolean();
@@ -190,18 +185,17 @@ public class InboundMessageTests extends ESTestCase {
             reference = request.serialize(streamOutput);
         }
         final byte[] serialized = BytesReference.toBytes(reference);
-        final int statusPosition = TcpHeader.HEADER_SIZE - TcpHeader.VERSION_ID_SIZE - 1;
+        final int statusPosition = TcpHeader.headerSize(Version.CURRENT) - TcpHeader.VERSION_ID_SIZE - TcpHeader.VARIABLE_HEADER_SIZE - 1;
         // force status byte to signal compressed on the otherwise uncompressed message
         serialized[statusPosition] = TransportStatus.setCompress(serialized[statusPosition]);
         reference = new BytesArray(serialized);
         InboundMessage.Reader reader = new InboundMessage.Reader(Version.CURRENT, registry, threadContext);
         BytesReference sliced = reference.slice(6, reference.length() - 6);
         final IllegalStateException iste = expectThrows(IllegalStateException.class, () -> reader.deserialize(sliced));
-        assertThat(iste.getMessage(), Matchers.startsWith("stream marked as compressed, but no compressor found,"));
+        assertThat(iste.getMessage(), Matchers.equalTo("stream marked as compressed, but is missing deflate header"));
     }
 
     private void testVersionIncompatibility(Version version, Version currentVersion, boolean isHandshake) throws IOException {
-        String[] features = {};
         String value = randomAlphaOfLength(10);
         Message message = new Message(value);
         String action = randomAlphaOfLength(10);

+ 7 - 21
server/src/test/java/org/elasticsearch/transport/TransportLoggerTests.java

@@ -24,7 +24,6 @@ import org.elasticsearch.Version;
 import org.elasticsearch.action.admin.cluster.stats.ClusterStatsAction;
 import org.elasticsearch.action.admin.cluster.stats.ClusterStatsRequest;
 import org.elasticsearch.common.bytes.BytesReference;
-import org.elasticsearch.common.bytes.CompositeBytesReference;
 import org.elasticsearch.common.io.stream.BytesStreamOutput;
 import org.elasticsearch.common.logging.Loggers;
 import org.elasticsearch.common.settings.Settings;
@@ -61,6 +60,7 @@ public class TransportLoggerTests extends ESTestCase {
                 ", request id: \\d+" +
                 ", type: request" +
                 ", version: .*" +
+                ", header size: \\d+B" +
                 ", action: cluster:monitor/stats]" +
                 " WRITE: \\d+B";
         final MockLogAppender.LoggingExpectation writeExpectation =
@@ -72,6 +72,7 @@ public class TransportLoggerTests extends ESTestCase {
                 ", request id: \\d+" +
                 ", type: request" +
                 ", version: .*" +
+                ", header size: \\d+B" +
                 ", action: cluster:monitor/stats]" +
                 " READ: \\d+B";
 
@@ -88,26 +89,11 @@ public class TransportLoggerTests extends ESTestCase {
     }
 
     private BytesReference buildRequest() throws IOException {
-        try (BytesStreamOutput messageOutput = new BytesStreamOutput()) {
-            messageOutput.setVersion(Version.CURRENT);
-            ThreadContext context = new ThreadContext(Settings.EMPTY);
-            context.writeTo(messageOutput);
-            messageOutput.writeString(ClusterStatsAction.NAME);
-            new ClusterStatsRequest().writeTo(messageOutput);
-            BytesReference messageBody = messageOutput.bytes();
-            final BytesReference header = buildHeader(randomInt(30), messageBody.length());
-            return new CompositeBytesReference(header, messageBody);
-        }
-    }
-
-    private BytesReference buildHeader(long requestId, int length) throws IOException {
-        try (BytesStreamOutput headerOutput = new BytesStreamOutput(TcpHeader.HEADER_SIZE)) {
-            headerOutput.setVersion(Version.CURRENT);
-            TcpHeader.writeHeader(headerOutput, requestId, TransportStatus.setRequest((byte) 0), Version.CURRENT, length);
-            final BytesReference bytes = headerOutput.bytes();
-            assert bytes.length() == TcpHeader.HEADER_SIZE : "header size mismatch expected: " + TcpHeader.HEADER_SIZE + " but was: "
-                + bytes.length();
-            return bytes;
+        boolean compress = randomBoolean();
+        try (BytesStreamOutput bytesStreamOutput = new BytesStreamOutput()) {
+            OutboundMessage.Request request = new OutboundMessage.Request(new ThreadContext(Settings.EMPTY), new ClusterStatsRequest(),
+                Version.CURRENT, ClusterStatsAction.NAME, randomInt(30), false, compress);
+            return request.serialize(bytesStreamOutput);
         }
     }
 }