Browse Source

Allow pipeline aggs to select specific buckets from multi-bucket aggs (#44179)

This adjusts the `buckets_path` parser so that pipeline aggs can
select specific buckets (via their bucket keys) instead of fetching
the entire set of buckets.  This is useful for bucket_script in
particular, which might want specific buckets for calculations.

It's possible to workaround this with `filter` aggs, but the workaround
is hacky and probably less performant.

- Adjusts documentation
- Adds a barebones AggregatorTestCase for bucket_script
- Tweaks AggTestCase to use getMockScriptService() for reductions and
pipelines.  Previously pipelines could just pass in a script service
for testing, but this didnt work for regular aggs.  The new
getMockScriptService() method fixes that issue, but needs to be used
for pipelines too.  This had a knock-on effect of touching MovFn,
AvgBucket and ScriptedMetric
Zachary Tong 6 years ago
parent
commit
ae7c071ec7

+ 50 - 3
docs/reference/aggregations/pipeline.asciidoc

@@ -35,11 +35,12 @@ parameter, which follows a specific format:
 // https://en.wikipedia.org/wiki/Extended_Backus%E2%80%93Naur_Form
 [source,ebnf]
 --------------------------------------------------
-AGG_SEPARATOR       =  '>' ;
-METRIC_SEPARATOR    =  '.' ;
+AGG_SEPARATOR       =  `>` ;
+METRIC_SEPARATOR    =  `.` ;
 AGG_NAME            =  <the name of the aggregation> ;
 METRIC              =  <the name of the metric (in case of multi-value metrics aggregation)> ;
-PATH                =  <AGG_NAME> [ <AGG_SEPARATOR>, <AGG_NAME> ]* [ <METRIC_SEPARATOR>, <METRIC> ] ;
+MULTIBUCKET_KEY     =  `[<KEY_NAME>]`
+PATH                =  <AGG_NAME><MULTIBUCKET_KEY>? (<AGG_SEPARATOR>, <AGG_NAME> )* ( <METRIC_SEPARATOR>, <METRIC> ) ;
 --------------------------------------------------
 
 For example, the path `"my_bucket>my_stats.avg"` will path to the `avg` value in the `"my_stats"` metric, which is
@@ -110,6 +111,52 @@ POST /_search
 <1> `buckets_path` instructs this max_bucket aggregation that we want the maximum value of the `sales` aggregation in the
 `sales_per_month` date histogram.
 
+If a Sibling pipeline agg references a multi-bucket aggregation, such as a `terms` agg, it also has the option to
+select specific keys from the multi-bucket.  For example, a `bucket_script` could select two specific buckets (via
+their bucket keys) to perform the calculation:
+
+[source,js]
+--------------------------------------------------
+POST /_search
+{
+    "aggs" : {
+        "sales_per_month" : {
+            "date_histogram" : {
+                "field" : "date",
+                "calendar_interval" : "month"
+            },
+            "aggs": {
+                "sale_type": {
+                    "terms": {
+                        "field": "type"
+                    },
+                    "aggs": {
+                        "sales": {
+                            "sum": {
+                                "field": "price"
+                            }
+                        }
+                    }
+                },
+                "hat_vs_bag_ratio": {
+                    "bucket_script": {
+                        "buckets_path": {
+                            "hats": "sale_type['hat']>sales", <1>
+                            "bags": "sale_type['bag']>sales"  <1>
+                        },
+                        "script": "params.hats / params.bags"
+                    }
+                }
+            }
+        }
+    }
+}
+--------------------------------------------------
+// CONSOLE
+// TEST[setup:sales]
+<1> `buckets_path` selects the hats and bags buckets (via `['hat']`/`['bag']``) to use in the script specifically,
+instead of fetching all the buckets from `sale_type` aggregation
+
 [float]
 === Special Paths
 

+ 38 - 0
modules/lang-painless/src/test/resources/rest-api-spec/test/painless/100_terms_agg.yml

@@ -102,3 +102,41 @@ setup:
   - is_false: aggregations.double_terms.buckets.1.key_as_string
   - match: { aggregations.double_terms.buckets.1.doc_count: 1 }
 
+---
+"Bucket script with keys":
+
+  - do:
+      search:
+        rest_total_hits_as_int: true
+        body:
+          size: 0
+          aggs:
+            placeholder:
+              filters:
+                filters:
+                 - match_all: {}
+              aggs:
+                str_terms:
+                  terms:
+                    field: "str"
+                  aggs:
+                    the_avg:
+                      avg:
+                        field: "number"
+                the_bucket_script:
+                  bucket_script:
+                    buckets_path:
+                      foo: "str_terms['bcd']>the_avg.value"
+                    script: "params.foo"
+
+  - match: { hits.total: 3 }
+
+  - length: { aggregations.placeholder.buckets.0.str_terms.buckets: 2 }
+  - match: { aggregations.placeholder.buckets.0.str_terms.buckets.0.key: "abc" }
+  - is_false: aggregations.placeholder.buckets.0.str_terms.buckets.0.key_as_string
+  - match: { aggregations.placeholder.buckets.0.str_terms.buckets.0.doc_count: 2 }
+  - match: { aggregations.placeholder.buckets.0.str_terms.buckets.1.key: "bcd" }
+  - is_false: aggregations.placeholder.buckets.0.str_terms.buckets.1.key_as_string
+  - match: { aggregations.placeholder.buckets.0.str_terms.buckets.1.doc_count: 1 }
+  - match: { aggregations.placeholder.buckets.0.the_bucket_script.value: 2.0 }
+

+ 25 - 8
server/src/main/java/org/elasticsearch/search/aggregations/InternalMultiBucketAggregation.java

@@ -73,16 +73,33 @@ public abstract class InternalMultiBucketAggregation<A extends InternalMultiBuck
     public Object getProperty(List<String> path) {
         if (path.isEmpty()) {
             return this;
-        } else if (path.get(0).equals("_bucket_count")) {
-            return getBuckets().size();
-        } else {
-            List<? extends InternalBucket> buckets = getBuckets();
-            Object[] propertyArray = new Object[buckets.size()];
-            for (int i = 0; i < buckets.size(); i++) {
-                propertyArray[i] = buckets.get(i).getProperty(getName(), path);
+        }
+        return resolvePropertyFromPath(path, getBuckets(), getName());
+    }
+
+    static Object resolvePropertyFromPath(List<String> path, List<? extends InternalBucket> buckets, String name) {
+        String aggName = path.get(0);
+        if (aggName.equals("_bucket_count")) {
+            return buckets.size();
+        }
+
+        // This is a bucket key, look through our buckets and see if we can find a match
+        if (aggName.startsWith("'") && aggName.endsWith("'")) {
+            for (InternalBucket bucket : buckets) {
+                if (bucket.getKeyAsString().equals(aggName.substring(1, aggName.length() - 1))) {
+                    return bucket.getProperty(name, path.subList(1, path.size()));
+                }
             }
-            return propertyArray;
+            // No key match, time to give up
+            throw new InvalidAggregationPathException("Cannot find an key [" + aggName + "] in [" + name + "]");
+        }
+
+        Object[] propertyArray = new Object[buckets.size()];
+        for (int i = 0; i < buckets.size(); i++) {
+            propertyArray[i] = buckets.get(i).getProperty(name, path);
         }
+        return propertyArray;
+
     }
 
     /**

+ 183 - 0
server/src/test/java/org/elasticsearch/search/aggregations/InternalMultiBucketAggregationTests.java

@@ -0,0 +1,183 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.search.aggregations;
+
+import org.apache.lucene.util.BytesRef;
+import org.elasticsearch.search.DocValueFormat;
+import org.elasticsearch.search.aggregations.bucket.terms.InternalTerms;
+import org.elasticsearch.search.aggregations.bucket.terms.LongTerms;
+import org.elasticsearch.search.aggregations.bucket.terms.StringTerms;
+import org.elasticsearch.search.aggregations.metrics.InternalAvg;
+import org.elasticsearch.search.aggregations.support.AggregationPath;
+import org.elasticsearch.test.ESTestCase;
+
+import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+
+import static org.elasticsearch.search.aggregations.InternalMultiBucketAggregation.resolvePropertyFromPath;
+import static org.hamcrest.Matchers.equalTo;
+
+public class InternalMultiBucketAggregationTests extends ESTestCase {
+
+    public void testResolveToAgg() {
+        AggregationPath path = AggregationPath.parse("the_avg");
+        List<LongTerms.Bucket> buckets = new ArrayList<>();
+        InternalAggregation agg = new InternalAvg("the_avg", 2, 1,
+            DocValueFormat.RAW, Collections.emptyList(), Collections.emptyMap());
+        InternalAggregations internalAggregations = new InternalAggregations(Collections.singletonList(agg));
+
+        LongTerms.Bucket bucket = new LongTerms.Bucket(1, 1, internalAggregations, false, 0, DocValueFormat.RAW);
+        buckets.add(bucket);
+
+        Object[] value = (Object[]) resolvePropertyFromPath(path.getPathElementsAsStringList(), buckets, "the_long_terms");
+        assertThat(value[0], equalTo(agg));
+    }
+
+    public void testResolveToAggValue() {
+        AggregationPath path = AggregationPath.parse("the_avg.value");
+        List<LongTerms.Bucket> buckets = new ArrayList<>();
+        InternalAggregation agg = new InternalAvg("the_avg", 2, 1,
+            DocValueFormat.RAW, Collections.emptyList(), Collections.emptyMap());
+        InternalAggregations internalAggregations = new InternalAggregations(Collections.singletonList(agg));
+
+        LongTerms.Bucket bucket = new LongTerms.Bucket(1, 1, internalAggregations, false, 0, DocValueFormat.RAW);
+        buckets.add(bucket);
+
+        Object[] value = (Object[]) resolvePropertyFromPath(path.getPathElementsAsStringList(), buckets, "the_long_terms");
+        assertThat(value[0], equalTo(2.0));
+    }
+
+    public void testResolveToNothing() {
+        AggregationPath path = AggregationPath.parse("foo.value");
+        List<LongTerms.Bucket> buckets = new ArrayList<>();
+        InternalAggregation agg = new InternalAvg("the_avg", 2, 1,
+            DocValueFormat.RAW, Collections.emptyList(), Collections.emptyMap());
+        InternalAggregations internalAggregations = new InternalAggregations(Collections.singletonList(agg));
+
+        LongTerms.Bucket bucket = new LongTerms.Bucket(1, 1, internalAggregations, false, 0, DocValueFormat.RAW);
+        buckets.add(bucket);
+
+        InvalidAggregationPathException e = expectThrows(InvalidAggregationPathException.class,
+            () -> resolvePropertyFromPath(path.getPathElementsAsStringList(), buckets, "the_long_terms"));
+        assertThat(e.getMessage(), equalTo("Cannot find an aggregation named [foo] in [the_long_terms]"));
+    }
+
+    public void testResolveToUnknown() {
+        AggregationPath path = AggregationPath.parse("the_avg.unknown");
+        List<LongTerms.Bucket> buckets = new ArrayList<>();
+        InternalAggregation agg = new InternalAvg("the_avg", 2, 1,
+            DocValueFormat.RAW, Collections.emptyList(), Collections.emptyMap());
+        InternalAggregations internalAggregations = new InternalAggregations(Collections.singletonList(agg));
+
+        LongTerms.Bucket bucket = new LongTerms.Bucket(1, 1, internalAggregations, false, 0, DocValueFormat.RAW);
+        buckets.add(bucket);
+
+        IllegalArgumentException e = expectThrows(IllegalArgumentException.class,
+            () -> resolvePropertyFromPath(path.getPathElementsAsStringList(), buckets, "the_long_terms"));
+        assertThat(e.getMessage(), equalTo("path not supported for [the_avg]: [unknown]"));
+    }
+
+    public void testResolveToBucketCount() {
+        AggregationPath path = AggregationPath.parse("_bucket_count");
+        List<LongTerms.Bucket> buckets = new ArrayList<>();
+        InternalAggregation agg = new InternalAvg("the_avg", 2, 1,
+            DocValueFormat.RAW, Collections.emptyList(), Collections.emptyMap());
+        InternalAggregations internalAggregations = new InternalAggregations(Collections.singletonList(agg));
+
+        LongTerms.Bucket bucket = new LongTerms.Bucket(1, 1, internalAggregations, false, 0, DocValueFormat.RAW);
+        buckets.add(bucket);
+
+        Object value = resolvePropertyFromPath(path.getPathElementsAsStringList(), buckets, "the_long_terms");
+        assertThat(value, equalTo(1));
+    }
+
+    public void testResolveToCount() {
+        AggregationPath path = AggregationPath.parse("_count");
+        List<LongTerms.Bucket> buckets = new ArrayList<>();
+        InternalAggregation agg = new InternalAvg("the_avg", 2, 1,
+            DocValueFormat.RAW, Collections.emptyList(), Collections.emptyMap());
+        InternalAggregations internalAggregations = new InternalAggregations(Collections.singletonList(agg));
+
+        LongTerms.Bucket bucket = new LongTerms.Bucket(1, 1, internalAggregations, false, 0, DocValueFormat.RAW);
+        buckets.add(bucket);
+
+        Object[] value = (Object[]) resolvePropertyFromPath(path.getPathElementsAsStringList(), buckets, "the_long_terms");
+        assertThat(value[0], equalTo(1L));
+    }
+
+    public void testResolveToKey() {
+        AggregationPath path = AggregationPath.parse("_key");
+        List<LongTerms.Bucket> buckets = new ArrayList<>();
+        InternalAggregation agg = new InternalAvg("the_avg", 2, 1,
+            DocValueFormat.RAW, Collections.emptyList(), Collections.emptyMap());
+        InternalAggregations internalAggregations = new InternalAggregations(Collections.singletonList(agg));
+
+        LongTerms.Bucket bucket = new LongTerms.Bucket(19, 1, internalAggregations, false, 0, DocValueFormat.RAW);
+        buckets.add(bucket);
+
+        Object[] value = (Object[]) resolvePropertyFromPath(path.getPathElementsAsStringList(), buckets, "the_long_terms");
+        assertThat(value[0], equalTo(19L));
+    }
+
+    public void testResolveToSpecificBucket() {
+        AggregationPath path = AggregationPath.parse("string_terms['foo']>the_avg.value");
+
+        List<LongTerms.Bucket> buckets = new ArrayList<>();
+        InternalAggregation agg = new InternalAvg("the_avg", 2, 1,
+            DocValueFormat.RAW, Collections.emptyList(), Collections.emptyMap());
+        InternalAggregations internalStringAggs = new InternalAggregations(Collections.singletonList(agg));
+        List<StringTerms.Bucket> stringBuckets = Collections.singletonList(new StringTerms.Bucket(
+            new BytesRef("foo".getBytes(StandardCharsets.UTF_8), 0, "foo".getBytes(StandardCharsets.UTF_8).length), 1,
+            internalStringAggs, false, 0, DocValueFormat.RAW));
+
+        InternalTerms termsAgg = new StringTerms("string_terms", BucketOrder.count(false), 1, 0, Collections.emptyList(),
+            Collections.emptyMap(), DocValueFormat.RAW, 1, false, 0, stringBuckets, 0);
+        InternalAggregations internalAggregations = new InternalAggregations(Collections.singletonList(termsAgg));
+        LongTerms.Bucket bucket = new LongTerms.Bucket(19, 1, internalAggregations, false, 0, DocValueFormat.RAW);
+        buckets.add(bucket);
+
+        Object[] value = (Object[]) resolvePropertyFromPath(path.getPathElementsAsStringList(), buckets, "the_long_terms");
+        assertThat(value[0], equalTo(2.0));
+    }
+
+    public void testResolveToMissingSpecificBucket() {
+        AggregationPath path = AggregationPath.parse("string_terms['bar']>the_avg.value");
+
+        List<LongTerms.Bucket> buckets = new ArrayList<>();
+        InternalAggregation agg = new InternalAvg("the_avg", 2, 1,
+            DocValueFormat.RAW, Collections.emptyList(), Collections.emptyMap());
+        InternalAggregations internalStringAggs = new InternalAggregations(Collections.singletonList(agg));
+        List<StringTerms.Bucket> stringBuckets = Collections.singletonList(new StringTerms.Bucket(
+            new BytesRef("foo".getBytes(StandardCharsets.UTF_8), 0, "foo".getBytes(StandardCharsets.UTF_8).length), 1,
+            internalStringAggs, false, 0, DocValueFormat.RAW));
+
+        InternalTerms termsAgg = new StringTerms("string_terms", BucketOrder.count(false), 1, 0, Collections.emptyList(),
+            Collections.emptyMap(), DocValueFormat.RAW, 1, false, 0, stringBuckets, 0);
+        InternalAggregations internalAggregations = new InternalAggregations(Collections.singletonList(termsAgg));
+        LongTerms.Bucket bucket = new LongTerms.Bucket(19, 1, internalAggregations, false, 0, DocValueFormat.RAW);
+        buckets.add(bucket);
+
+        InvalidAggregationPathException e = expectThrows(InvalidAggregationPathException.class,
+            () -> resolvePropertyFromPath(path.getPathElementsAsStringList(), buckets, "the_long_terms"));
+        assertThat(e.getMessage(), equalTo("Cannot find an key ['bar'] in [string_terms]"));
+    }
+}

+ 12 - 1
server/src/test/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregatorTests.java

@@ -173,6 +173,17 @@ public class ScriptedMetricAggregatorTests extends AggregatorTestCase {
         });
     }
 
+    @Override
+    protected ScriptService getMockScriptService() {
+        MockScriptEngine scriptEngine = new MockScriptEngine(MockScriptEngine.NAME,
+            SCRIPTS,
+            Collections.emptyMap());
+        Map<String, ScriptEngine> engines = Collections.singletonMap(scriptEngine.getType(), scriptEngine);
+
+        return new ScriptService(Settings.EMPTY, engines, ScriptModule.CORE_CONTEXTS);
+    }
+
+
     @SuppressWarnings("unchecked")
     public void testNoDocs() throws IOException {
         try (Directory directory = newDirectory()) {
@@ -311,7 +322,7 @@ public class ScriptedMetricAggregatorTests extends AggregatorTestCase {
                         .initScript(INIT_SCRIPT_PARAMS).mapScript(MAP_SCRIPT_PARAMS)
                         .combineScript(COMBINE_SCRIPT_PARAMS).reduceScript(REDUCE_SCRIPT_PARAMS);
                 ScriptedMetric scriptedMetric = searchAndReduce(
-                        newSearcher(indexReader, true, true), new MatchAllDocsQuery(), aggregationBuilder, 0, scriptService);
+                        newSearcher(indexReader, true, true), new MatchAllDocsQuery(), aggregationBuilder, 0);
 
                 // The result value depends on the script params.
                 assertEquals(4803, scriptedMetric.aggregation());

+ 2 - 2
server/src/test/java/org/elasticsearch/search/aggregations/pipeline/AvgBucketAggregatorTests.java

@@ -120,9 +120,9 @@ public class AvgBucketAggregatorTests extends AggregatorTestCase {
                 valueFieldType.setName(VALUE_FIELD);
                 valueFieldType.setHasDocValues(true);
 
-                avgResult = searchAndReduce(indexSearcher, query, avgBuilder, 10000, null,
+                avgResult = searchAndReduce(indexSearcher, query, avgBuilder, 10000,
                     new MappedFieldType[]{fieldType, valueFieldType});
-                histogramResult = searchAndReduce(indexSearcher, query, histo, 10000, null,
+                histogramResult = searchAndReduce(indexSearcher, query, histo, 10000,
                     new MappedFieldType[]{fieldType, valueFieldType});
             }
 

+ 122 - 0
server/src/test/java/org/elasticsearch/search/aggregations/pipeline/BucketScriptAggregatorTests.java

@@ -0,0 +1,122 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.search.aggregations.pipeline;
+
+import org.apache.lucene.document.Document;
+import org.apache.lucene.document.SortedNumericDocValuesField;
+import org.apache.lucene.document.SortedSetDocValuesField;
+import org.apache.lucene.index.DirectoryReader;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.RandomIndexWriter;
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.MatchAllDocsQuery;
+import org.apache.lucene.search.Query;
+import org.apache.lucene.store.Directory;
+import org.apache.lucene.util.BytesRef;
+import org.elasticsearch.common.CheckedConsumer;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.index.mapper.KeywordFieldMapper;
+import org.elasticsearch.index.mapper.MappedFieldType;
+import org.elasticsearch.index.mapper.NumberFieldMapper;
+import org.elasticsearch.index.query.MatchAllQueryBuilder;
+import org.elasticsearch.script.MockScriptEngine;
+import org.elasticsearch.script.Script;
+import org.elasticsearch.script.ScriptEngine;
+import org.elasticsearch.script.ScriptModule;
+import org.elasticsearch.script.ScriptService;
+import org.elasticsearch.script.ScriptType;
+import org.elasticsearch.search.aggregations.AggregatorTestCase;
+import org.elasticsearch.search.aggregations.bucket.filter.FiltersAggregationBuilder;
+import org.elasticsearch.search.aggregations.bucket.filter.InternalFilters;
+import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
+import org.elasticsearch.search.aggregations.metrics.AvgAggregationBuilder;
+import org.elasticsearch.search.aggregations.support.ValueType;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.Map;
+import java.util.function.Consumer;
+
+import static org.hamcrest.CoreMatchers.equalTo;
+
+public class BucketScriptAggregatorTests extends AggregatorTestCase {
+    private final String SCRIPT_NAME = "script_name";
+
+    @Override
+    protected ScriptService getMockScriptService() {
+        MockScriptEngine scriptEngine = new MockScriptEngine(MockScriptEngine.NAME,
+            Collections.singletonMap(SCRIPT_NAME, script -> script.get("the_avg")),
+            Collections.emptyMap());
+        Map<String, ScriptEngine> engines = Collections.singletonMap(scriptEngine.getType(), scriptEngine);
+
+        return new ScriptService(Settings.EMPTY, engines, ScriptModule.CORE_CONTEXTS);
+    }
+
+    public void testScript() throws IOException {
+        MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.INTEGER);
+        fieldType.setName("number_field");
+        fieldType.setHasDocValues(true);
+        MappedFieldType fieldType1 = new KeywordFieldMapper.KeywordFieldType();
+        fieldType1.setName("the_field");
+        fieldType1.setHasDocValues(true);
+
+        FiltersAggregationBuilder filters = new FiltersAggregationBuilder("placeholder", new MatchAllQueryBuilder())
+            .subAggregation(new TermsAggregationBuilder("the_terms", ValueType.STRING).field("the_field")
+                .subAggregation(new AvgAggregationBuilder("the_avg").field("number_field")))
+            .subAggregation(new BucketScriptPipelineAggregationBuilder("bucket_script",
+                Collections.singletonMap("the_avg", "the_terms['test1']>the_avg.value"),
+                new Script(ScriptType.INLINE, MockScriptEngine.NAME, SCRIPT_NAME, Collections.emptyMap())));
+
+
+        testCase(filters, new MatchAllDocsQuery(), iw -> {
+            Document doc = new Document();
+            doc.add(new SortedSetDocValuesField("the_field", new BytesRef("test1")));
+            doc.add(new SortedNumericDocValuesField("number_field", 19));
+            iw.addDocument(doc);
+
+            doc = new Document();
+            doc.add(new SortedSetDocValuesField("the_field", new BytesRef("test2")));
+            doc.add(new SortedNumericDocValuesField("number_field", 55));
+            iw.addDocument(doc);
+        }, f -> {
+           assertThat(((InternalSimpleValue)(f.getBuckets().get(0).getAggregations().get("bucket_script"))).value,
+               equalTo(19.0));
+        }, fieldType, fieldType1);
+    }
+
+    private void testCase(FiltersAggregationBuilder aggregationBuilder, Query query,
+                          CheckedConsumer<RandomIndexWriter, IOException> buildIndex,
+                          Consumer<InternalFilters> verify, MappedFieldType... fieldType) throws IOException {
+
+        try (Directory directory = newDirectory()) {
+            RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory);
+            buildIndex.accept(indexWriter);
+            indexWriter.close();
+
+            try (IndexReader indexReader = DirectoryReader.open(directory)) {
+                IndexSearcher indexSearcher = newSearcher(indexReader, true, true);
+
+                InternalFilters filters;
+                filters = searchAndReduce(indexSearcher, query, aggregationBuilder, fieldType);
+                verify.accept(filters);
+            }
+        }
+    }
+}

+ 31 - 40
server/src/test/java/org/elasticsearch/search/aggregations/pipeline/MovFnUnitTests.java

@@ -30,12 +30,17 @@ import org.apache.lucene.search.IndexSearcher;
 import org.apache.lucene.search.MatchAllDocsQuery;
 import org.apache.lucene.search.Query;
 import org.apache.lucene.store.Directory;
+import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.time.DateFormatters;
 import org.elasticsearch.index.mapper.DateFieldMapper;
 import org.elasticsearch.index.mapper.MappedFieldType;
 import org.elasticsearch.index.mapper.NumberFieldMapper;
+import org.elasticsearch.script.MockScriptEngine;
 import org.elasticsearch.script.Script;
+import org.elasticsearch.script.ScriptEngine;
+import org.elasticsearch.script.ScriptModule;
 import org.elasticsearch.script.ScriptService;
+import org.elasticsearch.script.ScriptType;
 import org.elasticsearch.search.aggregations.AggregatorTestCase;
 import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
 import org.elasticsearch.search.aggregations.TestAggregatorFactory;
@@ -56,8 +61,6 @@ import java.util.function.Consumer;
 import java.util.stream.Collectors;
 
 import static org.hamcrest.Matchers.equalTo;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.when;
 
 public class MovFnUnitTests extends AggregatorTestCase {
 
@@ -79,31 +82,35 @@ public class MovFnUnitTests extends AggregatorTestCase {
 
     private static final List<Integer> datasetValues = Arrays.asList(1,2,3,4,5,6,7,8,9,10);
 
+    @Override
+    protected ScriptService getMockScriptService() {
+        MockScriptEngine scriptEngine = new MockScriptEngine(MockScriptEngine.NAME,
+            Collections.singletonMap("test", script -> MovingFunctions.max((double[]) script.get("_values"))),
+            Collections.emptyMap());
+        Map<String, ScriptEngine> engines = Collections.singletonMap(scriptEngine.getType(), scriptEngine);
+
+        return new ScriptService(Settings.EMPTY, engines, ScriptModule.CORE_CONTEXTS);
+    }
+
     public void testMatchAllDocs() throws IOException {
-        check(0, List.of(Double.NaN, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0));
+        check(0, 3, List.of(Double.NaN, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0));
     }
 
     public void testShift() throws IOException {
-        check(1, List.of(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0));
-        check(5, List.of(5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 10.0, 10.0, Double.NaN, Double.NaN));
-        check(-5, List.of(Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN, 1.0, 2.0, 3.0, 4.0));
+        check(1, 3, List.of(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0));
+        check(5, 3, List.of(5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 10.0, 10.0, Double.NaN, Double.NaN));
+        check(-5, 3, List.of(Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN, 1.0, 2.0, 3.0, 4.0));
     }
 
     public void testWideWindow() throws IOException {
-        Script script = new Script(Script.DEFAULT_SCRIPT_TYPE, "painless", "test", Collections.emptyMap());
-        MovFnPipelineAggregationBuilder builder = new MovFnPipelineAggregationBuilder("mov_fn", "avg", script, 100);
-        builder.setShift(50);
-        check(builder, script, List.of(10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0));
+        check(50, 100, List.of(10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0));
     }
 
-    private void check(int shift, List<Double> expected) throws IOException {
-        Script script = new Script(Script.DEFAULT_SCRIPT_TYPE, "painless", "test", Collections.emptyMap());
-        MovFnPipelineAggregationBuilder builder = new MovFnPipelineAggregationBuilder("mov_fn", "avg", script, 3);
+    private void check(int shift, int window, List<Double> expected) throws IOException {
+        Script script = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "test", Collections.emptyMap());
+        MovFnPipelineAggregationBuilder builder = new MovFnPipelineAggregationBuilder("mov_fn", "avg", script, window);
         builder.setShift(shift);
-        check(builder, script, expected);
-    }
 
-    private void check(MovFnPipelineAggregationBuilder builder, Script script, List<Double> expected) throws IOException {
         Query query = new MatchAllDocsQuery();
         DateHistogramAggregationBuilder aggBuilder = new DateHistogramAggregationBuilder("histo");
         aggBuilder.calendarInterval(DateHistogramInterval.DAY).field(DATE_FIELD);
@@ -111,19 +118,17 @@ public class MovFnUnitTests extends AggregatorTestCase {
         aggBuilder.subAggregation(builder);
 
         executeTestCase(query, aggBuilder, histogram -> {
-                List<? extends Histogram.Bucket> buckets = histogram.getBuckets();
-                List<Double> actual = buckets.stream()
-                    .map(bucket -> ((InternalSimpleValue) (bucket.getAggregations().get("mov_fn"))).value())
-                    .collect(Collectors.toList());
-                assertThat(actual, equalTo(expected));
-            }, 1000, script);
+            List<? extends Histogram.Bucket> buckets = histogram.getBuckets();
+            List<Double> actual = buckets.stream()
+                .map(bucket -> ((InternalSimpleValue) (bucket.getAggregations().get("mov_fn"))).value())
+                .collect(Collectors.toList());
+            assertThat(actual, equalTo(expected));
+        });
     }
 
-
     private void executeTestCase(Query query,
                                  DateHistogramAggregationBuilder aggBuilder,
-                                 Consumer<Histogram> verify,
-                                 int maxBucket, Script script) throws IOException {
+                                 Consumer<Histogram> verify) throws IOException {
 
         try (Directory directory = newDirectory()) {
             try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) {
@@ -144,20 +149,6 @@ public class MovFnUnitTests extends AggregatorTestCase {
                 }
             }
 
-            ScriptService scriptService = mock(ScriptService.class);
-            MovingFunctionScript.Factory factory = mock(MovingFunctionScript.Factory.class);
-            when(scriptService.compile(script, MovingFunctionScript.CONTEXT)).thenReturn(factory);
-
-            MovingFunctionScript scriptInstance = new MovingFunctionScript() {
-                @Override
-                public double execute(Map<String, Object> params, double[] values) {
-                    assertNotNull(values);
-                    return MovingFunctions.max(values);
-                }
-            };
-
-            when(factory.newInstance()).thenReturn(scriptInstance);
-
             try (IndexReader indexReader = DirectoryReader.open(directory)) {
                 IndexSearcher indexSearcher = newSearcher(indexReader, true, true);
 
@@ -171,7 +162,7 @@ public class MovFnUnitTests extends AggregatorTestCase {
                 valueFieldType.setName("value_field");
 
                 InternalDateHistogram histogram;
-                histogram = searchAndReduce(indexSearcher, query, aggBuilder, maxBucket, scriptService,
+                histogram = searchAndReduce(indexSearcher, query, aggBuilder, 1000,
                     new MappedFieldType[]{fieldType, valueFieldType});
                 verify.accept(histogram);
             }

+ 7 - 13
test/framework/src/main/java/org/elasticsearch/script/MockScriptEngine.java

@@ -27,7 +27,6 @@ import org.elasticsearch.index.similarity.ScriptedSimilarity.Field;
 import org.elasticsearch.index.similarity.ScriptedSimilarity.Query;
 import org.elasticsearch.index.similarity.ScriptedSimilarity.Term;
 import org.elasticsearch.search.aggregations.pipeline.MovingFunctionScript;
-import org.elasticsearch.search.aggregations.pipeline.MovingFunctions;
 import org.elasticsearch.search.lookup.LeafSearchLookup;
 import org.elasticsearch.search.lookup.SearchLookup;
 
@@ -271,7 +270,13 @@ public class MockScriptEngine implements ScriptEngine {
             SimilarityWeightScript.Factory factory = mockCompiled::createSimilarityWeightScript;
             return context.factoryClazz.cast(factory);
         } else if (context.instanceClazz.equals(MovingFunctionScript.class)) {
-            MovingFunctionScript.Factory factory = mockCompiled::createMovingFunctionScript;
+            MovingFunctionScript.Factory factory = () -> new MovingFunctionScript() {
+                @Override
+                public double execute(Map<String, Object> params1, double[] values) {
+                    params1.put("_values", values);
+                    return (double) script.apply(params1);
+                }
+            };
             return context.factoryClazz.cast(factory);
         } else if (context.instanceClazz.equals(ScoreScript.class)) {
             ScoreScript.Factory factory = new MockScoreScript(script);
@@ -335,10 +340,6 @@ public class MockScriptEngine implements ScriptEngine {
             return new MockSimilarityWeightScript(script != null ? script : ctx -> 42d);
         }
 
-        public MovingFunctionScript createMovingFunctionScript() {
-            return new MockMovingFunctionScript();
-        }
-
         public ScriptedMetricAggContexts.InitScript createMetricAggInitScript(Map<String, Object> params, Map<String, Object> state) {
             return new MockMetricAggInitScript(params, state, script != null ? script : ctx -> 42d);
         }
@@ -544,13 +545,6 @@ public class MockScriptEngine implements ScriptEngine {
         return new Script(ScriptType.INLINE, "mock", script, emptyMap());
     }
 
-    public class MockMovingFunctionScript extends MovingFunctionScript {
-        @Override
-        public double execute(Map<String, Object> params, double[] values) {
-            return MovingFunctions.unweightedAvg(values);
-        }
-    }
-
     public class MockScoreScript implements ScoreScript.Factory {
 
         private final Function<Map<String, Object>, Object> script;

+ 3 - 4
test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java

@@ -332,7 +332,7 @@ public abstract class AggregatorTestCase extends ESTestCase {
                                                                                       Query query,
                                                                                       AggregationBuilder builder,
                                                                                       MappedFieldType... fieldTypes) throws IOException {
-        return searchAndReduce(searcher, query, builder, DEFAULT_MAX_BUCKETS, null, fieldTypes);
+        return searchAndReduce(searcher, query, builder, DEFAULT_MAX_BUCKETS, fieldTypes);
     }
 
     /**
@@ -344,7 +344,6 @@ public abstract class AggregatorTestCase extends ESTestCase {
                                                                                       Query query,
                                                                                       AggregationBuilder builder,
                                                                                       int maxBucket,
-                                                                                      ScriptService scriptService,
                                                                                       MappedFieldType... fieldTypes) throws IOException {
         final IndexReaderContext ctx = searcher.getTopReaderContext();
 
@@ -389,7 +388,7 @@ public abstract class AggregatorTestCase extends ESTestCase {
                 List<InternalAggregation> toReduce = aggs.subList(0, r);
                 MultiBucketConsumer reduceBucketConsumer = new MultiBucketConsumer(maxBucket);
                 InternalAggregation.ReduceContext context =
-                    new InternalAggregation.ReduceContext(root.context().bigArrays(), null,
+                    new InternalAggregation.ReduceContext(root.context().bigArrays(), getMockScriptService(),
                         reduceBucketConsumer, false);
                 A reduced = (A) aggs.get(0).doReduce(toReduce, context);
                 doAssertReducedMultiBucketConsumer(reduced, reduceBucketConsumer);
@@ -399,7 +398,7 @@ public abstract class AggregatorTestCase extends ESTestCase {
             // now do the final reduce
             MultiBucketConsumer reduceBucketConsumer = new MultiBucketConsumer(maxBucket);
             InternalAggregation.ReduceContext context =
-                new InternalAggregation.ReduceContext(root.context().bigArrays(), scriptService, reduceBucketConsumer, true);
+                new InternalAggregation.ReduceContext(root.context().bigArrays(), getMockScriptService(), reduceBucketConsumer, true);
 
             @SuppressWarnings("unchecked")
             A internalAgg = (A) aggs.get(0).doReduce(aggs, context);