Ver código fonte

Support BucketScript paths of type string and array. (#44694)

Ignacio Vera 6 anos atrás
pai
commit
02ff060605

+ 26 - 1
server/src/main/java/org/elasticsearch/search/aggregations/pipeline/BucketScriptPipelineAggregationBuilder.java

@@ -30,6 +30,7 @@ import org.elasticsearch.search.DocValueFormat;
 import org.elasticsearch.search.aggregations.pipeline.BucketHelpers.GapPolicy;
 
 import java.io.IOException;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.Locale;
 import java.util.Map;
@@ -59,7 +60,10 @@ public class BucketScriptPipelineAggregationBuilder extends AbstractPipelineAggr
             false,
             o -> new BucketScriptPipelineAggregationBuilder(name, (Map<String, String>) o[0], (Script) o[1]));
 
-        parser.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> p.mapStrings(), BUCKETS_PATH_FIELD);
+        parser.declareField(ConstructingObjectParser.constructorArg()
+            , BucketScriptPipelineAggregationBuilder::extractBucketPath
+            , BUCKETS_PATH_FIELD
+            , ObjectParser.ValueType.OBJECT_ARRAY_OR_STRING);
         parser.declareField(ConstructingObjectParser.constructorArg(),
             (p, c) -> Script.parse(p), Script.SCRIPT_PARSE_FIELD, ObjectParser.ValueType.OBJECT_OR_STRING);
 
@@ -112,6 +116,27 @@ public class BucketScriptPipelineAggregationBuilder extends AbstractPipelineAggr
         gapPolicy.writeTo(out);
     }
 
+    private static Map<String, String> extractBucketPath(XContentParser parser) throws IOException {
+        XContentParser.Token token = parser.currentToken();
+       if (token == XContentParser.Token.VALUE_STRING) {
+           // input is a string, name of the path set to '_value'.
+           // This is a bit odd as there is not constructor for it
+           return Collections.singletonMap("_value", parser.text());
+       } else if (token == XContentParser.Token.START_ARRAY) {
+           // input is an array, name of the path set to '_value' + position
+           Map<String, String> bucketsPathsMap = new HashMap<>();
+           int i =0;
+           while ((parser.nextToken()) != XContentParser.Token.END_ARRAY) {
+               String path = parser.text();
+               bucketsPathsMap.put("_value" + i++, path);
+           }
+           return bucketsPathsMap;
+       } else  {
+           // input is an object, it should contain name / value pairs
+           return parser.mapStrings();
+       }
+    }
+
     private static Map<String, String> convertToBucketsPathMap(String[] bucketsPaths) {
         Map<String, String> bucketsPathsMap = new HashMap<>();
         for (int i = 0; i < bucketsPaths.length; i++) {

+ 161 - 0
server/src/test/java/org/elasticsearch/search/aggregations/pipeline/BucketScriptIT.java

@@ -23,6 +23,7 @@ import org.elasticsearch.action.index.IndexRequestBuilder;
 import org.elasticsearch.action.search.SearchResponse;
 import org.elasticsearch.common.bytes.BytesArray;
 import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentFactory;
 import org.elasticsearch.common.xcontent.XContentType;
 import org.elasticsearch.plugins.Plugin;
 import org.elasticsearch.script.MockScriptPlugin;
@@ -117,6 +118,11 @@ public class BucketScriptIT extends ESIntegTestCase {
                 return value0 + value1 + value2;
             });
 
+            scripts.put("single_input", vars -> {
+                double value = (double) vars.get("_value");
+                return value;
+            });
+
             scripts.put("return null", vars -> null);
 
             return scripts;
@@ -628,4 +634,159 @@ public class BucketScriptIT extends ESIntegTestCase {
             }
         }
     }
+
+    public void testSingleBucketPathAgg() throws Exception {
+        XContentBuilder content = XContentFactory.jsonBuilder()
+            .startObject()
+            .field("buckets_path", "field2Sum")
+            .startObject("script")
+            .field("source", "single_input")
+            .field("lang", CustomScriptPlugin.NAME)
+            .endObject()
+            .endObject();
+        BucketScriptPipelineAggregationBuilder bucketScriptAgg =
+            BucketScriptPipelineAggregationBuilder.parse("seriesArithmetic", createParser(content));
+
+        SearchResponse response = client()
+            .prepareSearch("idx", "idx_unmapped")
+            .addAggregation(
+                histogram("histo")
+                    .field(FIELD_1_NAME)
+                    .interval(interval)
+                    .subAggregation(sum("field2Sum").field(FIELD_2_NAME))
+                    .subAggregation(bucketScriptAgg)).get();
+
+        assertSearchResponse(response);
+
+        Histogram histo = response.getAggregations().get("histo");
+        assertThat(histo, notNullValue());
+        assertThat(histo.getName(), equalTo("histo"));
+        List<? extends Histogram.Bucket> buckets = histo.getBuckets();
+
+        for (int i = 0; i < buckets.size(); ++i) {
+            Histogram.Bucket bucket = buckets.get(i);
+            if (bucket.getDocCount() == 0) {
+                SimpleValue seriesArithmetic = bucket.getAggregations().get("seriesArithmetic");
+                assertThat(seriesArithmetic, nullValue());
+            } else {
+                Sum field2Sum = bucket.getAggregations().get("field2Sum");
+                assertThat(field2Sum, notNullValue());
+                double field2SumValue = field2Sum.getValue();
+                SimpleValue seriesArithmetic = bucket.getAggregations().get("seriesArithmetic");
+                assertThat(seriesArithmetic, notNullValue());
+                double seriesArithmeticValue = seriesArithmetic.value();
+                assertThat(seriesArithmeticValue, equalTo(field2SumValue));
+            }
+        }
+    }
+
+    public void testArrayBucketPathAgg() throws Exception {
+        XContentBuilder content = XContentFactory.jsonBuilder()
+            .startObject()
+            .array("buckets_path", "field2Sum", "field3Sum", "field4Sum")
+            .startObject("script")
+            .field("source", "_value0 + _value1 + _value2")
+            .field("lang", CustomScriptPlugin.NAME)
+            .endObject()
+            .endObject();
+        BucketScriptPipelineAggregationBuilder bucketScriptAgg =
+            BucketScriptPipelineAggregationBuilder.parse("seriesArithmetic", createParser(content));
+
+        SearchResponse response = client()
+            .prepareSearch("idx", "idx_unmapped")
+            .addAggregation(
+                histogram("histo")
+                    .field(FIELD_1_NAME)
+                    .interval(interval)
+                    .subAggregation(sum("field2Sum").field(FIELD_2_NAME))
+                    .subAggregation(sum("field3Sum").field(FIELD_3_NAME))
+                    .subAggregation(sum("field4Sum").field(FIELD_4_NAME))
+                    .subAggregation(bucketScriptAgg)).get();
+
+        assertSearchResponse(response);
+
+        Histogram histo = response.getAggregations().get("histo");
+        assertThat(histo, notNullValue());
+        assertThat(histo.getName(), equalTo("histo"));
+        List<? extends Histogram.Bucket> buckets = histo.getBuckets();
+
+        for (int i = 0; i < buckets.size(); ++i) {
+            Histogram.Bucket bucket = buckets.get(i);
+            if (bucket.getDocCount() == 0) {
+                SimpleValue seriesArithmetic = bucket.getAggregations().get("seriesArithmetic");
+                assertThat(seriesArithmetic, nullValue());
+            } else {
+                Sum field2Sum = bucket.getAggregations().get("field2Sum");
+                assertThat(field2Sum, notNullValue());
+                double field2SumValue = field2Sum.getValue();
+                Sum field3Sum = bucket.getAggregations().get("field3Sum");
+                assertThat(field3Sum, notNullValue());
+                double field3SumValue = field3Sum.getValue();
+                Sum field4Sum = bucket.getAggregations().get("field4Sum");
+                assertThat(field4Sum, notNullValue());
+                double field4SumValue = field4Sum.getValue();
+                SimpleValue seriesArithmetic = bucket.getAggregations().get("seriesArithmetic");
+                assertThat(seriesArithmetic, notNullValue());
+                double seriesArithmeticValue = seriesArithmetic.value();
+                assertThat(seriesArithmeticValue, equalTo(field2SumValue + field3SumValue + field4SumValue));
+            }
+        }
+    }
+
+    public void testObjectBucketPathAgg() throws Exception {
+        XContentBuilder content = XContentFactory.jsonBuilder()
+            .startObject()
+            .startObject("buckets_path")
+               .field("_value0", "field2Sum")
+               .field("_value1", "field3Sum")
+               .field("_value2", "field4Sum")
+            .endObject()
+            .startObject("script")
+            .field("source", "_value0 + _value1 + _value2")
+            .field("lang", CustomScriptPlugin.NAME)
+            .endObject()
+            .endObject();
+        BucketScriptPipelineAggregationBuilder bucketScriptAgg =
+            BucketScriptPipelineAggregationBuilder.parse("seriesArithmetic", createParser(content));
+
+        SearchResponse response = client()
+            .prepareSearch("idx", "idx_unmapped")
+            .addAggregation(
+                histogram("histo")
+                    .field(FIELD_1_NAME)
+                    .interval(interval)
+                    .subAggregation(sum("field2Sum").field(FIELD_2_NAME))
+                    .subAggregation(sum("field3Sum").field(FIELD_3_NAME))
+                    .subAggregation(sum("field4Sum").field(FIELD_4_NAME))
+                    .subAggregation(bucketScriptAgg)).get();
+
+        assertSearchResponse(response);
+
+        Histogram histo = response.getAggregations().get("histo");
+        assertThat(histo, notNullValue());
+        assertThat(histo.getName(), equalTo("histo"));
+        List<? extends Histogram.Bucket> buckets = histo.getBuckets();
+
+        for (int i = 0; i < buckets.size(); ++i) {
+            Histogram.Bucket bucket = buckets.get(i);
+            if (bucket.getDocCount() == 0) {
+                SimpleValue seriesArithmetic = bucket.getAggregations().get("seriesArithmetic");
+                assertThat(seriesArithmetic, nullValue());
+            } else {
+                Sum field2Sum = bucket.getAggregations().get("field2Sum");
+                assertThat(field2Sum, notNullValue());
+                double field2SumValue = field2Sum.getValue();
+                Sum field3Sum = bucket.getAggregations().get("field3Sum");
+                assertThat(field3Sum, notNullValue());
+                double field3SumValue = field3Sum.getValue();
+                Sum field4Sum = bucket.getAggregations().get("field4Sum");
+                assertThat(field4Sum, notNullValue());
+                double field4SumValue = field4Sum.getValue();
+                SimpleValue seriesArithmetic = bucket.getAggregations().get("seriesArithmetic");
+                assertThat(seriesArithmetic, notNullValue());
+                double seriesArithmeticValue = seriesArithmetic.value();
+                assertThat(seriesArithmeticValue, equalTo(field2SumValue + field3SumValue + field4SumValue));
+            }
+        }
+    }
 }

+ 46 - 0
server/src/test/java/org/elasticsearch/search/aggregations/pipeline/BucketScriptTests.java

@@ -19,11 +19,14 @@
 
 package org.elasticsearch.search.aggregations.pipeline;
 
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentFactory;
 import org.elasticsearch.script.Script;
 import org.elasticsearch.script.ScriptType;
 import org.elasticsearch.search.aggregations.BasePipelineAggregationTestCase;
 import org.elasticsearch.search.aggregations.pipeline.BucketHelpers.GapPolicy;
 
+import java.io.IOException;
 import java.util.HashMap;
 import java.util.Map;
 
@@ -59,4 +62,47 @@ public class BucketScriptTests extends BasePipelineAggregationTestCase<BucketScr
         return factory;
     }
 
+    public void testParseBucketPath() throws IOException  {
+        XContentBuilder content = XContentFactory.jsonBuilder()
+            .startObject()
+              .field("buckets_path", "_count")
+              .startObject("script")
+                   .field("source", "value")
+                   .field("lang", "expression")
+              .endObject()
+            .endObject();
+        BucketScriptPipelineAggregationBuilder builder1 = BucketScriptPipelineAggregationBuilder.parse("count", createParser(content));
+        assertEquals(builder1.getBucketsPaths().length , 1);
+        assertEquals(builder1.getBucketsPaths()[0], "_count");
+
+        content = XContentFactory.jsonBuilder()
+            .startObject()
+              .startObject("buckets_path")
+                .field("path1", "_count1")
+                .field("path2", "_count2")
+              .endObject()
+              .startObject("script")
+                .field("source", "value")
+                .field("lang", "expression")
+              .endObject()
+            .endObject();
+        BucketScriptPipelineAggregationBuilder builder2 = BucketScriptPipelineAggregationBuilder.parse("count", createParser(content));
+        assertEquals(builder2.getBucketsPaths().length , 2);
+        assertEquals(builder2.getBucketsPaths()[0], "_count1");
+        assertEquals(builder2.getBucketsPaths()[1], "_count2");
+
+        content = XContentFactory.jsonBuilder()
+            .startObject()
+              .array("buckets_path","_count1", "_count2")
+              .startObject("script")
+                .field("source", "value")
+                .field("lang", "expression")
+               .endObject()
+            .endObject();
+        BucketScriptPipelineAggregationBuilder builder3 = BucketScriptPipelineAggregationBuilder.parse("count", createParser(content));
+        assertEquals(builder3.getBucketsPaths().length , 2);
+        assertEquals(builder3.getBucketsPaths()[0], "_count1");
+        assertEquals(builder3.getBucketsPaths()[1], "_count2");
+    }
+
 }