Browse Source

[ML] slight refactor of the bucket_correlation aggregation (#72924)

This PR moves the bucket_correlation aggregation to inherit from pipeline bucket metrics.

This allows the validation code to be reused from the parent object.
Benjamin Trent 4 years ago
parent
commit
a4e722e139

+ 8 - 2
server/src/main/java/org/elasticsearch/search/aggregations/pipeline/BucketMetricsPipelineAggregationBuilder.java

@@ -23,11 +23,17 @@ import java.util.Optional;
 public abstract class BucketMetricsPipelineAggregationBuilder<AF extends BucketMetricsPipelineAggregationBuilder<AF>>
         extends AbstractPipelineAggregationBuilder<AF> {
 
-    private String format = null;
-    private GapPolicy gapPolicy = GapPolicy.SKIP;
+    private String format;
+    private GapPolicy gapPolicy;
 
     protected BucketMetricsPipelineAggregationBuilder(String name, String type, String[] bucketsPaths) {
+        this(name, type, bucketsPaths, null, GapPolicy.SKIP);
+    }
+
+    protected BucketMetricsPipelineAggregationBuilder(String name, String type, String[] bucketsPaths, String format, GapPolicy gapPolicy) {
         super(name, type, bucketsPaths);
+        this.format = format;
+        this.gapPolicy = gapPolicy;
     }
 
     /**

+ 34 - 28
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/correlation/BucketCorrelationAggregationBuilder.java

@@ -11,19 +11,23 @@ import org.elasticsearch.common.ParseField;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ObjectParser;
 import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.plugins.SearchPlugin;
-import org.elasticsearch.search.aggregations.AggregationBuilder;
-import org.elasticsearch.search.aggregations.pipeline.AbstractPipelineAggregationBuilder;
+import org.elasticsearch.search.aggregations.pipeline.BucketHelpers;
+import org.elasticsearch.search.aggregations.pipeline.BucketMetricsPipelineAggregationBuilder;
 import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
 import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
 
 import java.io.IOException;
+import java.util.Locale;
 import java.util.Map;
 import java.util.Objects;
-import java.util.Optional;
 
-public class BucketCorrelationAggregationBuilder extends AbstractPipelineAggregationBuilder<BucketCorrelationAggregationBuilder> {
+import static org.elasticsearch.search.aggregations.pipeline.PipelineAggregator.Parser.GAP_POLICY;
+
+public class BucketCorrelationAggregationBuilder extends BucketMetricsPipelineAggregationBuilder<BucketCorrelationAggregationBuilder> {
 
     public static final ParseField NAME = new ParseField("bucket_correlation");
     private static final ParseField FUNCTION = new ParseField("function");
@@ -35,7 +39,8 @@ public class BucketCorrelationAggregationBuilder extends AbstractPipelineAggrega
         (args, context) -> new BucketCorrelationAggregationBuilder(
             context,
             (String)args[0],
-            (CorrelationFunction)args[1]
+            (CorrelationFunction)args[1],
+            (BucketHelpers.GapPolicy)args[2]
         )
     );
     static {
@@ -45,6 +50,12 @@ public class BucketCorrelationAggregationBuilder extends AbstractPipelineAggrega
             (p, c, n) -> p.namedObject(CorrelationFunction.class, n, null),
             FUNCTION
         );
+        PARSER.declareField(ConstructingObjectParser.optionalConstructorArg(), p -> {
+            if (p.currentToken() == XContentParser.Token.VALUE_STRING) {
+                return BucketHelpers.GapPolicy.parse(p.text().toLowerCase(Locale.ROOT), p.getTokenLocation());
+            }
+            throw new IllegalArgumentException("Unsupported token [" + p.currentToken() + "]");
+        }, GAP_POLICY, ObjectParser.ValueType.STRING);
     }
 
     public static SearchPlugin.PipelineAggregationSpec buildSpec() {
@@ -57,17 +68,29 @@ public class BucketCorrelationAggregationBuilder extends AbstractPipelineAggrega
 
     private final CorrelationFunction correlationFunction;
 
-    public BucketCorrelationAggregationBuilder(
+    public BucketCorrelationAggregationBuilder(String name, String bucketsPath, CorrelationFunction correlationFunction) {
+        this(name, bucketsPath, correlationFunction, BucketHelpers.GapPolicy.INSERT_ZEROS);
+    }
+
+    private BucketCorrelationAggregationBuilder(
         String name,
         String bucketsPath,
-        CorrelationFunction correlationFunction
+        CorrelationFunction correlationFunction,
+        BucketHelpers.GapPolicy gapPolicy
     ) {
         super(
             name,
             NAME.getPreferredName(),
-            new String[] {bucketsPath}
+            new String[] {bucketsPath},
+            null,
+            gapPolicy == null ? BucketHelpers.GapPolicy.INSERT_ZEROS : gapPolicy
         );
         this.correlationFunction = correlationFunction;
+        if (gapPolicy != null && gapPolicy.equals(BucketHelpers.GapPolicy.INSERT_ZEROS) == false) {
+            throw new IllegalArgumentException(
+                "only [gap_policy] of [" + BucketHelpers.GapPolicy.INSERT_ZEROS.getName() + "] is supported"
+            );
+        }
     }
 
     public BucketCorrelationAggregationBuilder(StreamInput in) throws IOException {
@@ -81,7 +104,7 @@ public class BucketCorrelationAggregationBuilder extends AbstractPipelineAggrega
     }
 
     @Override
-    protected void doWriteTo(StreamOutput out) throws IOException {
+    protected void innerWriteTo(StreamOutput out) throws IOException {
         out.writeNamedWriteable(correlationFunction);
     }
 
@@ -96,7 +119,7 @@ public class BucketCorrelationAggregationBuilder extends AbstractPipelineAggrega
     }
 
     @Override
-    protected XContentBuilder internalXContent(XContentBuilder builder, Params params) throws IOException {
+    protected XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException {
         builder.field(BUCKETS_PATH_FIELD.getPreferredName(), bucketsPaths[0]);
         NamedXContentObjectHelper.writeNamedObject(builder, params, FUNCTION.getPreferredName(), correlationFunction);
         return builder;
@@ -104,29 +127,12 @@ public class BucketCorrelationAggregationBuilder extends AbstractPipelineAggrega
 
     @Override
     protected void validate(ValidationContext context) {
-
-        final String firstAgg = bucketsPaths[0].split("[>\\.]")[0];
-        Optional<AggregationBuilder> aggBuilder = context.getSiblingAggregations().stream()
-            .filter(builder -> builder.getName().equals(firstAgg))
-            .findAny();
-        if (aggBuilder.isEmpty()) {
-            context.addBucketPathValidationError("aggregation does not exist for aggregation [" + name + "]: " + bucketsPaths[0]);
-            return;
-        }
-        AggregationBuilder aggregationBuilder = aggBuilder.get();
-        if (aggregationBuilder.bucketCardinality() != AggregationBuilder.BucketCardinality.MANY) {
-            context.addValidationError("The first aggregation in " + PipelineAggregator.Parser.BUCKETS_PATH.getPreferredName()
-                + " must be a multi-bucket aggregation for aggregation [" + name + "] found :"
-                + aggBuilder.get().getClass().getName() + " for buckets path: " + bucketsPaths[0]);
-            return;
-        }
+        super.validate(context);
         correlationFunction.validate(context, bucketsPaths[0]);
     }
 
     @Override
     public boolean equals(Object o) {
-        if (this == o) return true;
-        if (o == null || getClass() != o.getClass()) return false;
         if (super.equals(o) == false) return false;
         BucketCorrelationAggregationBuilder that = (BucketCorrelationAggregationBuilder) o;
         return Objects.equals(correlationFunction, that.correlationFunction);

+ 0 - 6
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/correlation/BucketCorrelationAggregationBuilderTests.java

@@ -22,8 +22,6 @@ import java.util.Collections;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Set;
-import java.util.stream.Collectors;
-import java.util.stream.Stream;
 
 import static org.hamcrest.Matchers.containsString;
 
@@ -48,10 +46,6 @@ public class BucketCorrelationAggregationBuilderTests extends BasePipelineAggreg
 
     @Override
     protected BucketCorrelationAggregationBuilder createTestAggregatorFactory() {
-        List<String> bucketPaths = Stream.generate(() -> randomAlphaOfLength(8))
-            .limit(2)
-            .collect(Collectors.toList());
-
         CorrelationFunction function = new CountCorrelationFunction(CountCorrelationIndicatorTests.randomInstance());
         return new BucketCorrelationAggregationBuilder(
             NAME,