Bläddra i källkod

Fix Incorrect Uncompressed Error Handling in InboundMessage (#44317)

* Fix Incorrect Uncompressed Error Handling in InboundMessage

* CompressorFactory.compressor does not throw uncompressed exception on uncompressed bytes, it merely returns `null` in this case if the bytes are at least XContent so the current catch and re-throw logic is dead code
* Made it work again by throwing on a `null` return so we get a real error message instead of an NPE
Armin Braun 6 år sedan
förälder
incheckning
9004b468e3

+ 10 - 6
server/src/main/java/org/elasticsearch/transport/InboundMessage.java

@@ -19,10 +19,10 @@
 package org.elasticsearch.transport;
 
 import org.elasticsearch.Version;
+import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.common.compress.Compressor;
 import org.elasticsearch.common.compress.CompressorFactory;
-import org.elasticsearch.common.compress.NotCompressedException;
 import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput;
 import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.io.stream.StreamInput;
@@ -75,11 +75,8 @@ public abstract class InboundMessage extends NetworkMessage implements Closeable
                 final boolean isHandshake = TransportStatus.isHandshake(status);
                 ensureVersionCompatibility(remoteVersion, version, isHandshake);
                 if (TransportStatus.isCompress(status) && hasMessageBytesToRead && streamInput.available() > 0) {
-                    Compressor compressor;
-                    try {
-                        final int bytesConsumed = TcpHeader.REQUEST_ID_SIZE + TcpHeader.STATUS_SIZE + TcpHeader.VERSION_ID_SIZE;
-                        compressor = CompressorFactory.compressor(reference.slice(bytesConsumed, reference.length() - bytesConsumed));
-                    } catch (NotCompressedException ex) {
+                    Compressor compressor = getCompressor(reference);
+                    if (compressor == null) {
                         int maxToRead = Math.min(reference.length(), 10);
                         StringBuilder sb = new StringBuilder("stream marked as compressed, but no compressor found, first [")
                             .append(maxToRead).append("] content bytes out of [").append(reference.length())
@@ -121,6 +118,13 @@ public abstract class InboundMessage extends NetworkMessage implements Closeable
         }
     }
 
+    @Nullable
+    static Compressor getCompressor(BytesReference message) {
+        final int offset = TcpHeader.REQUEST_ID_SIZE + TcpHeader.STATUS_SIZE + TcpHeader.VERSION_ID_SIZE;
+        return CompressorFactory.COMPRESSOR.isCompressed(message.slice(offset, message.length() - offset))
+            ? CompressorFactory.COMPRESSOR : null;
+    }
+
     @Override
     public void close() throws IOException {
         streamInput.close();

+ 4 - 7
server/src/main/java/org/elasticsearch/transport/TransportLogger.java

@@ -18,12 +18,11 @@
  */
 package org.elasticsearch.transport;
 
-import org.apache.logging.log4j.Logger;
 import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
 import org.elasticsearch.Version;
 import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.common.compress.Compressor;
-import org.elasticsearch.common.compress.CompressorFactory;
 import org.elasticsearch.common.compress.NotCompressedException;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.settings.Settings;
@@ -88,11 +87,9 @@ public final class TransportLogger {
                 if (isRequest) {
                     if (TransportStatus.isCompress(status)) {
                         Compressor compressor;
-                        try {
-                            final int bytesConsumed = TcpHeader.REQUEST_ID_SIZE + TcpHeader.STATUS_SIZE + TcpHeader.VERSION_ID_SIZE;
-                            compressor = CompressorFactory.compressor(message.slice(bytesConsumed, message.length() - bytesConsumed));
-                        } catch (NotCompressedException ex) {
-                            throw new IllegalStateException(ex);
+                        compressor = InboundMessage.getCompressor(message);
+                        if (compressor == null) {
+                            throw new IllegalStateException(new NotCompressedException());
                         }
                         streamInput = compressor.streamInput(streamInput);
                     }

+ 20 - 0
server/src/test/java/org/elasticsearch/transport/InboundMessageTests.java

@@ -19,6 +19,7 @@
 package org.elasticsearch.transport;
 
 import org.elasticsearch.Version;
+import org.elasticsearch.common.bytes.BytesArray;
 import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.common.io.stream.BytesStreamOutput;
 import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
@@ -28,6 +29,7 @@ import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.VersionUtils;
+import org.hamcrest.Matchers;
 
 import java.io.IOException;
 import java.util.Arrays;
@@ -181,6 +183,24 @@ public class InboundMessageTests extends ESTestCase {
             + version.minimumCompatibilityVersion() + "]", ise.getMessage());
     }
 
+    public void testThrowOnNotCompressed() throws Exception {
+        OutboundMessage.Response request = new OutboundMessage.Response(
+            threadContext, Collections.emptySet(), new Message(randomAlphaOfLength(10)), Version.CURRENT, randomLong(), false, false);
+        BytesReference reference;
+        try (BytesStreamOutput streamOutput = new BytesStreamOutput()) {
+            reference = request.serialize(streamOutput);
+        }
+        final byte[] serialized = BytesReference.toBytes(reference);
+        final int statusPosition = TcpHeader.HEADER_SIZE - TcpHeader.VERSION_ID_SIZE - 1;
+        // force status byte to signal compressed on the otherwise uncompressed message
+        serialized[statusPosition] = TransportStatus.setCompress(serialized[statusPosition]);
+        reference = new BytesArray(serialized);
+        InboundMessage.Reader reader = new InboundMessage.Reader(Version.CURRENT, registry, threadContext);
+        BytesReference sliced = reference.slice(6, reference.length() - 6);
+        final IllegalStateException iste = expectThrows(IllegalStateException.class, () -> reader.deserialize(sliced));
+        assertThat(iste.getMessage(), Matchers.startsWith("stream marked as compressed, but no compressor found,"));
+    }
+
     private void testVersionIncompatibility(Version version, Version currentVersion, boolean isHandshake) throws IOException {
         String[] features = {};
         String value = randomAlphaOfLength(10);