Prechádzať zdrojové kódy

Add method to check if object is generically writeable in stream (#54936)

When calling scripts in metric aggregation, the returned metric state is
passed along to the coordinating node to do the final reduce. However,
it is possible the object could contain nested state which is unknown to
StreamOutput/StreamInput. This would then result in the node crashing as
exceptions are not expected in the middle of serialization.

This commit adds a method to StreamOutput that can determine if an
object is writeable by the stream. It uses the same logic
writeGenericValue, special casing each of the supported collection types
to recursively determine if each contained value is itself writeable.

relates #54708
Ryan Ernst 5 rokov pred
rodič
commit
62b4964f3b

+ 50 - 16
server/src/main/java/org/elasticsearch/common/io/stream/StreamOutput.java

@@ -805,6 +805,23 @@ public abstract class StreamOutput extends OutputStream {
                     }
             ));
 
+    private static Class<?> getGenericType(Object value) {
+        if (value instanceof List) {
+            return List.class;
+        } else if (value instanceof Object[]) {
+            return Object[].class;
+        } else if (value instanceof Map) {
+            return Map.class;
+        } else if (value instanceof Set) {
+            return Set.class;
+        } else if (value instanceof ReadableInstant) {
+            return ReadableInstant.class;
+        } else if (value instanceof BytesReference) {
+            return BytesReference.class;
+        } else {
+            return value.getClass();
+        }
+    }
     /**
      * Notice: when serialization a map, the stream out map with the stream in map maybe have the
      * different key-value orders, they will maybe have different stream order.
@@ -816,22 +833,7 @@ public abstract class StreamOutput extends OutputStream {
             writeByte((byte) -1);
             return;
         }
-        final Class type;
-        if (value instanceof List) {
-            type = List.class;
-        } else if (value instanceof Object[]) {
-            type = Object[].class;
-        } else if (value instanceof Map) {
-            type = Map.class;
-        } else if (value instanceof Set) {
-            type = Set.class;
-        } else if (value instanceof ReadableInstant) {
-            type = ReadableInstant.class;
-        } else if (value instanceof BytesReference) {
-            type = BytesReference.class;
-        } else {
-            type = value.getClass();
-        }
+        final Class<?> type = getGenericType(value);
         final Writer writer = WRITERS.get(type);
         if (writer != null) {
             writer.write(this, value);
@@ -840,6 +842,38 @@ public abstract class StreamOutput extends OutputStream {
         }
     }
 
+    public static void checkWriteable(@Nullable Object value) throws IllegalArgumentException {
+        if (value == null) {
+            return;
+        }
+        final Class<?> type = getGenericType(value);
+
+        if (type == List.class) {
+            @SuppressWarnings("unchecked") List<Object> list = (List<Object>) value;
+            for (Object v : list) {
+                checkWriteable(v);
+            }
+        } else if (type == Object[].class) {
+            Object[] array = (Object[]) value;
+            for (Object v : array) {
+                checkWriteable(v);
+            }
+        } else if (type == Map.class) {
+            @SuppressWarnings("unchecked") Map<String, Object> map = (Map<String, Object>) value;
+            for (Map.Entry<String, Object> entry : map.entrySet()) {
+                checkWriteable(entry.getKey());
+                checkWriteable(entry.getValue());
+            }
+        } else if (type == Set.class) {
+            @SuppressWarnings("unchecked") Set<Object> set = (Set<Object>) value;
+            for (Object v : set) {
+                checkWriteable(v);
+            }
+        } else if (WRITERS.containsKey(type) == false) {
+            throw new IllegalArgumentException("Cannot write type [" + type.getCanonicalName() + "] to stream");
+        }
+    }
+
     public void writeIntArray(int[] values) throws IOException {
         writeVInt(values.length);
         for (int value : values) {

+ 2 - 0
server/src/main/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregator.java

@@ -22,6 +22,7 @@ package org.elasticsearch.search.aggregations.metrics;
 import org.apache.lucene.index.LeafReaderContext;
 import org.apache.lucene.search.Scorable;
 import org.apache.lucene.search.ScoreMode;
+import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.util.CollectionUtils;
 import org.elasticsearch.script.Script;
 import org.elasticsearch.script.ScriptedMetricAggContexts;
@@ -90,6 +91,7 @@ class ScriptedMetricAggregator extends MetricsAggregator {
         } else {
             aggregation = aggState;
         }
+        StreamOutput.checkWriteable(aggregation);
         return new InternalScriptedMetric(name, aggregation, reduceScript, metadata());
     }
 

+ 31 - 0
server/src/test/java/org/elasticsearch/common/io/stream/StreamTests.java

@@ -445,6 +445,37 @@ public class StreamTests extends ESTestCase {
         assertGenericRoundtrip(new LinkedHashSet<>(list));
     }
 
+    private static class Unwriteable {}
+
+    private void assertNotWriteable(Object o, Class<?> type) {
+        IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> StreamOutput.checkWriteable(o));
+        assertThat(e.getMessage(), equalTo("Cannot write type [" + type.getCanonicalName() + "] to stream"));
+    }
+
+    public void testIsWriteable() throws IOException {
+        assertNotWriteable(new Unwriteable(), Unwriteable.class);
+    }
+
+    public void testSetIsWriteable() throws IOException {
+        StreamOutput.checkWriteable(Set.of("a", "b"));
+        assertNotWriteable(Set.of(new Unwriteable()), Unwriteable.class);
+    }
+
+    public void testListIsWriteable() throws IOException {
+        StreamOutput.checkWriteable(List.of("a", "b"));
+        assertNotWriteable(List.of(new Unwriteable()), Unwriteable.class);
+    }
+
+    public void testMapIsWriteable() throws IOException {
+        StreamOutput.checkWriteable(Map.of("a", "b", "c", "d"));
+        assertNotWriteable(Map.of("a", new Unwriteable()), Unwriteable.class);
+    }
+
+    public void testObjectArrayIsWriteable() throws IOException {
+        StreamOutput.checkWriteable(new Object[] {"a", "b"});
+        assertNotWriteable(new Object[] {new Unwriteable()}, Unwriteable.class);
+    }
+
     private void assertSerialization(CheckedConsumer<StreamOutput, IOException> outputAssertions,
                                      CheckedConsumer<StreamInput, IOException> inputAssertions) throws IOException {
         try (BytesStreamOutput output = new BytesStreamOutput()) {