Browse Source

Fix scripted metric in ccs (#54776)

`scripted_metric` did not work with cross cluster search because it
assumed that you'd never perform a partial reduction, serialize the
results, and then perform a final reduction. That
serialized-after-partial-reduction step was broken.

This is also required to support #54758.
Nik Everett 5 years ago
parent
commit
171464626c

+ 38 - 0
qa/multi-cluster-search/src/test/resources/rest-api-spec/test/multi_cluster/10_basic.yml

@@ -196,6 +196,44 @@
   - match: { aggregations.cluster.buckets.1.animal.buckets.1.s.value: 0 }
   - match: { aggregations.cluster.buckets.1.average_sum.value: 1 }
 
+  # scripted_metric
+  - do:
+      search:
+        index: test_index,my_remote_cluster:test_index
+        body:
+          seq_no_primary_term: true
+          aggs:
+            cluster:
+              terms:
+                field: f1.keyword
+              aggs:
+                animal_length:
+                  scripted_metric:
+                    init_script: |
+                      state.sum = 0
+                    map_script: |
+                      state.sum += doc['animal.keyword'].value.length()
+                    combine_script: |
+                      state.sum
+                    reduce_script: |
+                      long sum = 0;
+                      for (s in states) {
+                        sum += s;
+                      }
+                      return sum
+  - match: { num_reduce_phases: 3 }
+  - match: {_clusters.total: 2}
+  - match: {_clusters.successful: 2}
+  - match: {_clusters.skipped: 0}
+  - match: { _shards.total: 5 }
+  - match: { hits.total.value: 11 }
+  - length: { aggregations.cluster.buckets: 2 }
+  - match: { aggregations.cluster.buckets.0.key: "remote_cluster" }
+  - match: { aggregations.cluster.buckets.0.doc_count: 6 }
+  - match: { aggregations.cluster.buckets.0.animal_length.value: 34 }
+  - match: { aggregations.cluster.buckets.1.key: "local_cluster" }
+  - match: { aggregations.cluster.buckets.1.animal_length.value: 15 }
+
 ---
 "Add transient remote cluster based on the preset cluster":
   - do:

+ 32 - 13
server/src/main/java/org/elasticsearch/search/aggregations/metrics/InternalScriptedMetric.java

@@ -19,12 +19,13 @@
 
 package org.elasticsearch.search.aggregations.metrics;
 
+import org.elasticsearch.Version;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.util.CollectionUtils;
 import org.elasticsearch.common.xcontent.XContentBuilder;
-import org.elasticsearch.script.ScriptedMetricAggContexts;
 import org.elasticsearch.script.Script;
+import org.elasticsearch.script.ScriptedMetricAggContexts;
 import org.elasticsearch.search.aggregations.InternalAggregation;
 import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
 
@@ -36,19 +37,21 @@ import java.util.List;
 import java.util.Map;
 import java.util.Objects;
 
+import static java.util.Collections.singletonList;
+
 public class InternalScriptedMetric extends InternalAggregation implements ScriptedMetric {
     final Script reduceScript;
-    private final List<Object> aggregation;
+    private final List<Object> aggregations;
 
     InternalScriptedMetric(String name, Object aggregation, Script reduceScript, List<PipelineAggregator> pipelineAggregators,
                                   Map<String, Object> metadata) {
         this(name, Collections.singletonList(aggregation), reduceScript, pipelineAggregators, metadata);
     }
 
-    private InternalScriptedMetric(String name, List<Object> aggregation, Script reduceScript, List<PipelineAggregator> pipelineAggregators,
-            Map<String, Object> metadata) {
+    private InternalScriptedMetric(String name, List<Object> aggregations, Script reduceScript,
+                                   List<PipelineAggregator> pipelineAggregators, Map<String, Object> metadata) {
         super(name, pipelineAggregators, metadata);
-        this.aggregation = aggregation;
+        this.aggregations = aggregations;
         this.reduceScript = reduceScript;
     }
 
@@ -58,13 +61,29 @@ public class InternalScriptedMetric extends InternalAggregation implements Scrip
     public InternalScriptedMetric(StreamInput in) throws IOException {
         super(in);
         reduceScript = in.readOptionalWriteable(Script::new);
-        aggregation = Collections.singletonList(in.readGenericValue());
+        if (in.getVersion().before(Version.V_7_8_0)) {
+            aggregations = singletonList(in.readGenericValue());
+        } else {
+            aggregations = in.readList(StreamInput::readGenericValue);
+        }
     }
 
     @Override
     protected void doWriteTo(StreamOutput out) throws IOException {
         out.writeOptionalWriteable(reduceScript);
-        out.writeGenericValue(aggregation());
+        if (out.getVersion().before(Version.V_7_8_0)) {
+            if (aggregations.size() > 0) {
+                /*
+                 * I *believe* that this situation can only happen in cross
+                 * cluster search right now. Thus the message. But computers
+                 * are hard. 
+                 */
+                throw new IllegalArgumentException("scripted_metric doesn't support cross cluster search until 7.8.0");
+            }
+            out.writeGenericValue(aggregations.get(0));
+        } else {
+            out.writeCollection(aggregations, StreamOutput::writeGenericValue);
+        }
     }
 
     @Override
@@ -74,14 +93,14 @@ public class InternalScriptedMetric extends InternalAggregation implements Scrip
 
     @Override
     public Object aggregation() {
-        if (aggregation.size() != 1) {
+        if (aggregations.size() != 1) {
             throw new IllegalStateException("aggregation was not reduced");
         }
-        return aggregation.get(0);
+        return aggregations.get(0);
     }
 
     List<Object> getAggregation() {
-        return aggregation;
+        return aggregations;
     }
 
     @Override
@@ -89,7 +108,7 @@ public class InternalScriptedMetric extends InternalAggregation implements Scrip
         List<Object> aggregationObjects = new ArrayList<>();
         for (InternalAggregation aggregation : aggregations) {
             InternalScriptedMetric mapReduceAggregation = (InternalScriptedMetric) aggregation;
-            aggregationObjects.addAll(mapReduceAggregation.aggregation);
+            aggregationObjects.addAll(mapReduceAggregation.aggregations);
         }
         InternalScriptedMetric firstAggregation = ((InternalScriptedMetric) aggregations.get(0));
         List<Object> aggregation;
@@ -142,12 +161,12 @@ public class InternalScriptedMetric extends InternalAggregation implements Scrip
 
         InternalScriptedMetric other = (InternalScriptedMetric) obj;
         return Objects.equals(reduceScript, other.reduceScript) &&
-                Objects.equals(aggregation, other.aggregation);
+                Objects.equals(aggregations, other.aggregations);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(super.hashCode(), reduceScript, aggregation);
+        return Objects.hash(super.hashCode(), reduceScript, aggregations);
     }
 
 }

+ 1 - 1
server/src/test/java/org/elasticsearch/search/aggregations/metrics/InternalScriptedMetricTests.java

@@ -132,7 +132,7 @@ public class InternalScriptedMetricTests extends InternalAggregationTestCase<Int
         if (hasReduceScript) {
             assertEquals(inputs.size(), reduced.aggregation());
         } else {
-            assertEquals(inputs.size(), ((List<Object>) reduced.aggregation()).size());
+            assertEquals(inputs.size(), ((List<?>) reduced.aggregation()).size());
         }
     }
 

+ 10 - 2
test/framework/src/main/java/org/elasticsearch/test/InternalAggregationTestCase.java

@@ -281,7 +281,7 @@ public abstract class InternalAggregationTestCase<T extends InternalAggregation>
         return createTestInstance(name, metadata);
     }
 
-    public void testReduceRandom() {
+    public void testReduceRandom() throws IOException {
         String name = randomAlphaOfLength(5);
         List<T> inputs = new ArrayList<>();
         List<InternalAggregation> toReduce = new ArrayList<>();
@@ -296,7 +296,7 @@ public abstract class InternalAggregationTestCase<T extends InternalAggregation>
         ScriptService mockScriptService = mockScriptService();
         MockBigArrays bigArrays = new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService());
         if (randomBoolean() && toReduce.size() > 1) {
-            // sometimes do an incremental reduce
+            // sometimes do a partial reduce
             Collections.shuffle(toReduce, random());
             int r = randomIntBetween(1, toReduceSize);
             List<InternalAggregation> internalAggregations = toReduce.subList(0, r);
@@ -311,6 +311,14 @@ public abstract class InternalAggregationTestCase<T extends InternalAggregation>
             int reducedBucketCount = countInnerBucket(reduced);
             //check that non final reduction never adds buckets
             assertThat(reducedBucketCount, lessThanOrEqualTo(initialBucketCount));
+            /*
+             * Sometimes serializing and deserializing the partially reduced
+             * result to simulate the compaction that we attempt after a
+             * partial reduce. And to simulate cross cluster search.
+             */
+            if (randomBoolean()) {
+                reduced = copyInstance(reduced);
+            }
             toReduce = new ArrayList<>(toReduce.subList(r, toReduceSize));
             toReduce.add(reduced);
         }

+ 10 - 2
x-pack/plugin/analytics/src/test/java/org/elasticsearch/xpack/analytics/stringstats/InternalStringStatsTests.java

@@ -42,8 +42,16 @@ public class InternalStringStatsTests extends InternalAggregationTestCase<Intern
         if (randomBoolean()) {
             return new InternalStringStats(name, 0, 0, 0, 0, emptyMap(), randomBoolean(), DocValueFormat.RAW, emptyList(), metadata);
         }
-        return new InternalStringStats(name, randomLongBetween(1, Long.MAX_VALUE),
-                randomNonNegativeLong(), between(0, Integer.MAX_VALUE), between(0, Integer.MAX_VALUE), randomCharOccurrences(),
+        /*
+         * Pick random count and length that are *much* less than
+         * Long.MAX_VALUE because reduction adds them together and sometimes
+         * serializes them and that serialization would fail if the sum has
+         * wrapped to a negative number.
+         */
+        long count = randomLongBetween(1, Integer.MAX_VALUE);
+        long totalLength = randomLongBetween(0, count * 10);
+        return new InternalStringStats(name, count, totalLength,
+                between(0, Integer.MAX_VALUE), between(0, Integer.MAX_VALUE), randomCharOccurrences(),
                 randomBoolean(), DocValueFormat.RAW,
                 emptyList(), metadata);
     };