瀏覽代碼

Allow any map key type when serialising hash maps (#88686)

Before this change only maps with String keys were supported.
There is no reason why we should not support non-String keys
too, provided that they are serialisable. With this PR we include
support for other map key types.
Salvatore Campagna 2 年之前
父節點
當前提交
8f8f55efc3

+ 6 - 0
docs/changelog/88686.yaml

@@ -0,0 +1,6 @@
+pr: 88686
+summary: "Fix: do not allow map key types other than String"
+area: Aggregations
+type: bug
+issues:
+ - 66057

+ 51 - 0
modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/130_metric_agg.yml

@@ -9,6 +9,8 @@ setup:
             properties:
               double:
                 type: double
+              keyword:
+                type: keyword
 
   - do:
       cluster.health:
@@ -20,6 +22,7 @@ setup:
         id: "1"
         body:
           double: 1.0
+          keyword: "first"
 
   - do:
       index:
@@ -27,6 +30,7 @@ setup:
         id: "2"
         body:
           double: 1.0
+          keyword: "second"
 
   - do:
       index:
@@ -34,6 +38,7 @@ setup:
         id: "3"
         body:
           double: 2.0
+          keyword: "third"
 
   - do:
       indices.refresh: {}
@@ -61,3 +66,49 @@ setup:
   - match: { hits.total: 3 }
   - match: { aggregations.total.value: 4.0 }
 
+---
+"Scripted Metric Agg Non String Map Key":
+
+  - do:
+      search:
+        rest_total_hits_as_int: true
+        body: {
+          "size": 0,
+          "aggs": {
+            "total": {
+              "scripted_metric": {
+                "init_script": "state.transactions = [:]",
+                "map_script": "state.transactions[doc['double'].value] = doc['double'].value",
+                "combine_script": "return state.transactions",
+                "reduce_script": "double sum = 0; for (transactions in states) { for (entry in transactions.entrySet()) { sum += entry.getKey() - entry.getValue() } } return sum"
+              }
+            }
+          }
+        }
+
+  - match: { hits.total: 3 }
+  - match: { aggregations.total.value: 0.0 }
+
+---
+"Scripted Metric Agg String Map Key":
+
+  - do:
+      search:
+        rest_total_hits_as_int: true
+        body: {
+          "size": 0,
+          "aggs": {
+            "total": {
+              "scripted_metric": {
+                "init_script": "state.transactions = [:]",
+                "map_script": "state.transactions[doc['keyword'].value] = doc['double'].value",
+                "combine_script": "return state.transactions",
+                "reduce_script": "double sum = 0; for (transactions in states) { for (entry in transactions.entrySet()) { sum += entry.getValue() } } return sum"
+              }
+            }
+          }
+        }
+
+  - match: { hits.total: 3 }
+  - match: { aggregations.total.value: 4.0 }
+

+ 6 - 26
server/src/main/java/org/elasticsearch/common/io/stream/StreamInput.java

@@ -738,8 +738,12 @@ public abstract class StreamInput extends InputStream {
             case 6 -> readByteArray();
             case 7 -> readArrayList();
             case 8 -> readArray();
-            case 9 -> readLinkedHashMap();
-            case 10 -> readHashMap();
+            case 9 -> getVersion().onOrAfter(Version.V_8_7_0)
+                ? readOrderedMap(StreamInput::readGenericValue, StreamInput::readGenericValue)
+                : readOrderedMap(StreamInput::readString, StreamInput::readGenericValue);
+            case 10 -> getVersion().onOrAfter(Version.V_8_7_0)
+                ? readMap(StreamInput::readGenericValue, StreamInput::readGenericValue)
+                : readMap(StreamInput::readString, StreamInput::readGenericValue);
             case 11 -> readByte();
             case 12 -> readDate();
             case 13 ->
@@ -817,30 +821,6 @@ public abstract class StreamInput extends InputStream {
         return list8;
     }
 
-    private Map<String, Object> readLinkedHashMap() throws IOException {
-        int size9 = readArraySize();
-        if (size9 == 0) {
-            return Collections.emptyMap();
-        }
-        Map<String, Object> map9 = Maps.newLinkedHashMapWithExpectedSize(size9);
-        for (int i = 0; i < size9; i++) {
-            map9.put(readString(), readGenericValue());
-        }
-        return map9;
-    }
-
-    private Map<String, Object> readHashMap() throws IOException {
-        int size10 = readArraySize();
-        if (size10 == 0) {
-            return Collections.emptyMap();
-        }
-        Map<String, Object> map10 = Maps.newMapWithExpectedSize(size10);
-        for (int i = 0; i < size10; i++) {
-            map10.put(readString(), readGenericValue());
-        }
-        return map10;
-    }
-
     private Date readDate() throws IOException {
         return new Date(readLong());
     }

+ 12 - 7
server/src/main/java/org/elasticsearch/common/io/stream/StreamOutput.java

@@ -561,7 +561,11 @@ public abstract class StreamOutput extends OutputStream {
             .iterator();
         while (iterator.hasNext()) {
             Map.Entry<String, ?> next = iterator.next();
-            this.writeString(next.getKey());
+            if (this.getVersion().onOrAfter(Version.V_8_7_0)) {
+                this.writeGenericValue(next.getKey());
+            } else {
+                this.writeString(next.getKey());
+            }
             this.writeGenericValue(next.getValue());
         }
     }
@@ -688,12 +692,13 @@ public abstract class StreamOutput extends OutputStream {
             } else {
                 o.writeByte((byte) 10);
             }
-            @SuppressWarnings("unchecked")
-            final Map<String, Object> map = (Map<String, Object>) v;
-            o.writeVInt(map.size());
-            for (Map.Entry<String, Object> entry : map.entrySet()) {
-                o.writeString(entry.getKey());
-                o.writeGenericValue(entry.getValue());
+            if (o.getVersion().onOrAfter(Version.V_8_7_0)) {
+                final Map<?, ?> map = (Map<?, ?>) v;
+                o.writeMap(map, StreamOutput::writeGenericValue, StreamOutput::writeGenericValue);
+            } else {
+                @SuppressWarnings("unchecked")
+                final Map<String, ?> map = (Map<String, ?>) v;
+                o.writeMap(map, StreamOutput::writeString, StreamOutput::writeGenericValue);
             }
         }),
         entry(Byte.class, (o, v) -> {