Răsfoiți Sursa

Throw an exception if Writeable.Reader reads null

If a Writeable.Reader returns null it is always a bug, probably one that
will cause corruption in the StreamInput it was trying to read from. This
commit adds a check that attempts to catch these errors quickly including
the name of the reader.
Nik Everett 9 ani în urmă
părinte
comite
5e8656aff0

+ 7 - 1
core/src/main/java/org/elasticsearch/common/io/stream/NamedWriteableAwareStreamInput.java

@@ -36,6 +36,12 @@ public class NamedWriteableAwareStreamInput extends FilterStreamInput {
     @Override
     <C> C readNamedWriteable(Class<C> categoryClass) throws IOException {
         String name = readString();
-        return namedWriteableRegistry.getReader(categoryClass, name).read(this);
+        Writeable.Reader<? extends C> reader = namedWriteableRegistry.getReader(categoryClass, name);
+        C c = reader.read(this);
+        if (c == null) {
+            throw new IOException(
+                    "Writeable.Reader [" + reader + "] returned null which is not allowed and probably means it screwed up the stream.");
+        }
+        return c;
     }
 }

+ 7 - 2
core/src/main/java/org/elasticsearch/common/io/stream/StreamInput.java

@@ -566,9 +566,14 @@ public abstract class StreamInput extends InputStream {
         }
     }
 
-    public <T extends Writeable> T readOptionalWriteable(Writeable.Reader<T> provider) throws IOException {
+    public <T extends Writeable> T readOptionalWriteable(Writeable.Reader<T> reader) throws IOException {
         if (readBoolean()) {
-            return provider.read(this);
+            T t = reader.read(this);
+            if (t == null) {
+                throw new IOException("Writeable.Reader [" + reader
+                        + "] returned null which is not allowed and probably means it screwed up the stream.");
+            }
+            return t;
         } else {
             return null;
         }

+ 2 - 1
core/src/main/java/org/elasticsearch/common/io/stream/Writeable.java

@@ -51,7 +51,8 @@ public interface Writeable<T> extends StreamableReader<T> { // TODO remove exten
 
     /**
      * Reference to a method that can read some object from a stream. By convention this is a constructor that takes
-     * {@linkplain StreamInput} as an argument for most classes and a static method for things like enums.
+     * {@linkplain StreamInput} as an argument for most classes and a static method for things like enums. Returning null from one of these
+     * is always wrong - for that we use methods like {@link StreamInput#readOptionalWriteable(Reader)}.
      */
     @FunctionalInterface
     interface Reader<R> {

+ 22 - 0
core/src/test/java/org/elasticsearch/common/io/stream/BytesStreamsTests.java

@@ -29,6 +29,7 @@ import java.io.IOException;
 import java.util.Objects;
 
 import static org.hamcrest.Matchers.closeTo;
+import static org.hamcrest.Matchers.endsWith;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.is;
 import static org.hamcrest.Matchers.startsWith;
@@ -373,6 +374,27 @@ public class BytesStreamsTests extends ESTestCase {
         }
     }
 
+    public void testNamedWriteableReaderReturnsNull() throws IOException {
+        BytesStreamOutput out = new BytesStreamOutput();
+        NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry();
+        namedWriteableRegistry.register(BaseNamedWriteable.class, TestNamedWriteable.NAME, (StreamInput in) -> null);
+        TestNamedWriteable namedWriteableIn = new TestNamedWriteable(randomAsciiOfLengthBetween(1, 10), randomAsciiOfLengthBetween(1, 10));
+        out.writeNamedWriteable(namedWriteableIn);
+        byte[] bytes = out.bytes().toBytes();
+        StreamInput in = new NamedWriteableAwareStreamInput(StreamInput.wrap(bytes), namedWriteableRegistry);
+        assertEquals(in.available(), bytes.length);
+        IOException e = expectThrows(IOException.class, () -> in.readNamedWriteable(BaseNamedWriteable.class));
+        assertThat(e.getMessage(), endsWith("] returned null which is not allowed and probably means it screwed up the stream."));
+    }
+
+    public void testOptionalWriteableReaderReturnsNull() throws IOException {
+        BytesStreamOutput out = new BytesStreamOutput();
+        out.writeOptionalWriteable(new TestNamedWriteable(randomAsciiOfLengthBetween(1, 10), randomAsciiOfLengthBetween(1, 10)));
+        StreamInput in = StreamInput.wrap(out.bytes().toBytes());
+        IOException e = expectThrows(IOException.class, () -> in.readOptionalWriteable((StreamInput ignored) -> null));
+        assertThat(e.getMessage(), endsWith("] returned null which is not allowed and probably means it screwed up the stream."));
+    }
+
     private static abstract class BaseNamedWriteable<T> implements NamedWriteable<T> {
 
     }