1
0
Эх сурвалжийг харах

Improving parsing of sigma param for Extended Stats Bucket Aggregatio

Improving parsing of sigma param for Extended Stats Bucket Aggregation
Colin Goodheart-Smithe 9 жил өмнө
parent
commit
319ca82510

+ 23 - 22
core/src/main/java/org/elasticsearch/search/aggregations/pipeline/bucketmetrics/BucketMetricsParser.java

@@ -27,7 +27,6 @@ import org.elasticsearch.search.aggregations.pipeline.BucketHelpers.GapPolicy;
 import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
 
 import java.io.IOException;
-import java.text.ParseException;
 import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.List;
@@ -53,7 +52,7 @@ public abstract class BucketMetricsParser implements PipelineAggregator.Parser {
         String[] bucketsPaths = null;
         String format = null;
         GapPolicy gapPolicy = null;
-        Map<String, Object> leftover = new HashMap<>(5);
+        Map<String, Object> params = new HashMap<>(5);
 
         while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
             if (token == XContentParser.Token.FIELD_NAME) {
@@ -66,7 +65,7 @@ public abstract class BucketMetricsParser implements PipelineAggregator.Parser {
                 } else if (context.getParseFieldMatcher().match(currentFieldName, GAP_POLICY)) {
                     gapPolicy = GapPolicy.parse(context, parser.text(), parser.getTokenLocation());
                 } else {
-                    leftover.put(currentFieldName, parser.text());
+                    parseToken(pipelineAggregatorName, parser, context, currentFieldName, token, params);
                 }
             } else if (token == XContentParser.Token.START_ARRAY) {
                 if (context.getParseFieldMatcher().match(currentFieldName, BUCKETS_PATH)) {
@@ -77,10 +76,10 @@ public abstract class BucketMetricsParser implements PipelineAggregator.Parser {
                     }
                     bucketsPaths = paths.toArray(new String[paths.size()]);
                 } else {
-                    leftover.put(currentFieldName, parser.list());
+                    parseToken(pipelineAggregatorName, parser, context, currentFieldName, token, params);
                 }
             } else {
-                leftover.put(currentFieldName, parser.objectText());
+                parseToken(pipelineAggregatorName, parser, context, currentFieldName, token, params);
             }
         }
 
@@ -89,30 +88,32 @@ public abstract class BucketMetricsParser implements PipelineAggregator.Parser {
                     "Missing required field [" + BUCKETS_PATH.getPreferredName() + "] for aggregation [" + pipelineAggregatorName + "]");
         }
 
-        BucketMetricsPipelineAggregatorBuilder<?> factory = null;
-        try {
-            factory = buildFactory(pipelineAggregatorName, bucketsPaths[0], leftover);
-            if (format != null) {
-                factory.format(format);
-            }
-            if (gapPolicy != null) {
-                factory.gapPolicy(gapPolicy);
-            }
-        } catch (ParseException exception) {
-            throw new ParsingException(parser.getTokenLocation(),
-                    "Could not parse settings for aggregation [" + pipelineAggregatorName + "].", exception);
+        BucketMetricsPipelineAggregatorBuilder<?> factory  = buildFactory(pipelineAggregatorName, bucketsPaths[0], params);
+        if (format != null) {
+            factory.format(format);
         }
-
-        if (leftover.size() > 0) {
-            throw new ParsingException(parser.getTokenLocation(),
-                    "Unexpected tokens " + leftover.keySet() + " in [" + pipelineAggregatorName + "].");
+        if (gapPolicy != null) {
+            factory.gapPolicy(gapPolicy);
         }
+
         assert(factory != null);
 
         return factory;
     }
 
     protected abstract BucketMetricsPipelineAggregatorBuilder<?> buildFactory(String pipelineAggregatorName, String bucketsPaths,
-            Map<String, Object> unparsedParams) throws ParseException;
+                                                                              Map<String, Object> params);
 
+    protected boolean token(XContentParser parser, QueryParseContext context, String field,
+                            XContentParser.Token token, Map<String, Object> params) throws IOException {
+        return false;
+    }
+
+    private void parseToken(String aggregationName, XContentParser parser, QueryParseContext context, String currentFieldName,
+                       XContentParser.Token currentToken, Map<String, Object> params) throws IOException {
+        if (token(parser, context, currentFieldName, currentToken, params) == false) {
+            throw new ParsingException(parser.getTokenLocation(),
+                "Unexpected token " + currentToken + " [" + currentFieldName + "] in [" + aggregationName + "]");
+        }
+    }
 }

+ 2 - 2
core/src/main/java/org/elasticsearch/search/aggregations/pipeline/bucketmetrics/avg/AvgBucketPipelineAggregatorBuilder.java

@@ -75,7 +75,7 @@ public class AvgBucketPipelineAggregatorBuilder extends BucketMetricsPipelineAgg
     public static final PipelineAggregator.Parser PARSER = new BucketMetricsParser() {
         @Override
         protected AvgBucketPipelineAggregatorBuilder buildFactory(String pipelineAggregatorName,
-                String bucketsPath, Map<String, Object> unparsedParams) {
+                String bucketsPath, Map<String, Object> params) {
             return new AvgBucketPipelineAggregatorBuilder(pipelineAggregatorName, bucketsPath);
         }
     };
@@ -94,4 +94,4 @@ public class AvgBucketPipelineAggregatorBuilder extends BucketMetricsPipelineAgg
     public String getWriteableName() {
         return NAME;
     }
-}
+}

+ 2 - 2
core/src/main/java/org/elasticsearch/search/aggregations/pipeline/bucketmetrics/max/MaxBucketPipelineAggregatorBuilder.java

@@ -75,7 +75,7 @@ public class MaxBucketPipelineAggregatorBuilder extends BucketMetricsPipelineAgg
     public static final PipelineAggregator.Parser PARSER = new BucketMetricsParser() {
         @Override
         protected MaxBucketPipelineAggregatorBuilder buildFactory(String pipelineAggregatorName,
-                String bucketsPath, Map<String, Object> unparsedParams) {
+                String bucketsPath, Map<String, Object> params) {
             return new MaxBucketPipelineAggregatorBuilder(pipelineAggregatorName, bucketsPath);
         }
     };
@@ -94,4 +94,4 @@ public class MaxBucketPipelineAggregatorBuilder extends BucketMetricsPipelineAgg
     public String getWriteableName() {
         return NAME;
     }
-}
+}

+ 2 - 2
core/src/main/java/org/elasticsearch/search/aggregations/pipeline/bucketmetrics/min/MinBucketPipelineAggregatorBuilder.java

@@ -75,7 +75,7 @@ public class MinBucketPipelineAggregatorBuilder extends BucketMetricsPipelineAgg
     public static final PipelineAggregator.Parser PARSER = new BucketMetricsParser() {
         @Override
         protected MinBucketPipelineAggregatorBuilder buildFactory(String pipelineAggregatorName,
-                String bucketsPath, Map<String, Object> unparsedParams) {
+                String bucketsPath, Map<String, Object> params) {
             return new MinBucketPipelineAggregatorBuilder(pipelineAggregatorName, bucketsPath);
         }
     };
@@ -94,4 +94,4 @@ public class MinBucketPipelineAggregatorBuilder extends BucketMetricsPipelineAgg
     public String getWriteableName() {
         return NAME;
     }
-}
+}

+ 25 - 27
core/src/main/java/org/elasticsearch/search/aggregations/pipeline/bucketmetrics/percentile/PercentilesBucketPipelineAggregatorBuilder.java

@@ -19,10 +19,13 @@
 
 package org.elasticsearch.search.aggregations.pipeline.bucketmetrics.percentile;
 
+import com.carrotsearch.hppc.DoubleArrayList;
 import org.elasticsearch.common.ParseField;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.index.query.QueryParseContext;
 import org.elasticsearch.search.aggregations.AggregatorFactory;
 import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
 import org.elasticsearch.search.aggregations.pipeline.PipelineAggregatorBuilder;
@@ -117,41 +120,36 @@ public class PercentilesBucketPipelineAggregatorBuilder
     }
 
     public static final PipelineAggregator.Parser PARSER = new BucketMetricsParser() {
+
         @Override
         protected PercentilesBucketPipelineAggregatorBuilder buildFactory(String pipelineAggregatorName,
-                String bucketsPath, Map<String, Object> unparsedParams) throws ParseException {
-            double[] percents = null;
-            int counter = 0;
-            Object percentParam = unparsedParams.get(PERCENTS_FIELD.getPreferredName());
-
-            if (percentParam != null) {
-                if (percentParam instanceof List) {
-                    percents = new double[((List<?>) percentParam).size()];
-                    for (Object p : (List<?>) percentParam) {
-                        if (p instanceof Double) {
-                            percents[counter] = (Double) p;
-                            counter += 1;
-                        } else {
-                            throw new ParseException(
-                                    "Parameter [" + PERCENTS_FIELD.getPreferredName() + "] must be an array of doubles, type `"
-                                            + percentParam.getClass().getSimpleName() + "` provided instead",
-                                    0);
-                        }
-                    }
-                    unparsedParams.remove(PERCENTS_FIELD.getPreferredName());
-                } else {
-                    throw new ParseException("Parameter [" + PERCENTS_FIELD.getPreferredName() + "] must be an array of doubles, type `"
-                            + percentParam.getClass().getSimpleName() + "` provided instead", 0);
-                }
-            }
+                                                                          String bucketsPath, Map<String, Object> params) {
 
             PercentilesBucketPipelineAggregatorBuilder factory = new
-                    PercentilesBucketPipelineAggregatorBuilder(pipelineAggregatorName, bucketsPath);
+                PercentilesBucketPipelineAggregatorBuilder(pipelineAggregatorName, bucketsPath);
+
+            double[] percents = (double[]) params.get(PERCENTS_FIELD.getPreferredName());
             if (percents != null) {
                 factory.percents(percents);
             }
+
             return factory;
         }
+
+        @Override
+        protected boolean token(XContentParser parser, QueryParseContext context, String field,
+                                XContentParser.Token token, Map<String, Object> params) throws IOException {
+            if (context.getParseFieldMatcher().match(field, PERCENTS_FIELD) && token == XContentParser.Token.START_ARRAY) {
+                DoubleArrayList percents = new DoubleArrayList(10);
+                while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
+                    percents.add(parser.doubleValue());
+                }
+                params.put(PERCENTS_FIELD.getPreferredName(), percents.toArray());
+                return true;
+            }
+            return false;
+        }
+
     };
 
     @Override
@@ -169,4 +167,4 @@ public class PercentilesBucketPipelineAggregatorBuilder
     public String getWriteableName() {
         return NAME;
     }
-}
+}

+ 2 - 2
core/src/main/java/org/elasticsearch/search/aggregations/pipeline/bucketmetrics/stats/StatsBucketPipelineAggregatorBuilder.java

@@ -77,7 +77,7 @@ public class StatsBucketPipelineAggregatorBuilder extends BucketMetricsPipelineA
     public static final PipelineAggregator.Parser PARSER = new BucketMetricsParser() {
         @Override
         protected StatsBucketPipelineAggregatorBuilder buildFactory(String pipelineAggregatorName,
-                String bucketsPath, Map<String, Object> unparsedParams) {
+                String bucketsPath, Map<String, Object> params) {
             return new StatsBucketPipelineAggregatorBuilder(pipelineAggregatorName, bucketsPath);
         }
     };
@@ -96,4 +96,4 @@ public class StatsBucketPipelineAggregatorBuilder extends BucketMetricsPipelineA
     public String getWriteableName() {
         return NAME;
     }
-}
+}

+ 17 - 16
core/src/main/java/org/elasticsearch/search/aggregations/pipeline/bucketmetrics/stats/extended/ExtendedStatsBucketParser.java

@@ -20,9 +20,11 @@
 package org.elasticsearch.search.aggregations.pipeline.bucketmetrics.stats.extended;
 
 import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.index.query.QueryParseContext;
 import org.elasticsearch.search.aggregations.pipeline.bucketmetrics.BucketMetricsParser;
 
-import java.text.ParseException;
+import java.io.IOException;
 import java.util.Map;
 
 public class ExtendedStatsBucketParser extends BucketMetricsParser {
@@ -30,25 +32,24 @@ public class ExtendedStatsBucketParser extends BucketMetricsParser {
 
     @Override
     protected ExtendedStatsBucketPipelineAggregatorBuilder buildFactory(String pipelineAggregatorName,
-            String bucketsPath, Map<String, Object> unparsedParams) throws ParseException {
-
-        Double sigma = null;
-        Object param = unparsedParams.get(SIGMA.getPreferredName());
-
-        if (param != null) {
-            if (param instanceof Double) {
-                sigma = (Double) param;
-                unparsedParams.remove(SIGMA.getPreferredName());
-            } else {
-                throw new ParseException("Parameter [" + SIGMA.getPreferredName() + "] must be a Double, type `"
-                        + param.getClass().getSimpleName() + "` provided instead", 0);
-            }
-        }
+            String bucketsPath, Map<String, Object> params) {
         ExtendedStatsBucketPipelineAggregatorBuilder factory =
-                new ExtendedStatsBucketPipelineAggregatorBuilder(pipelineAggregatorName, bucketsPath);
+            new ExtendedStatsBucketPipelineAggregatorBuilder(pipelineAggregatorName, bucketsPath);
+        Double sigma = (Double) params.get(SIGMA.getPreferredName());
         if (sigma != null) {
             factory.sigma(sigma);
         }
+
         return factory;
     }
+
+    @Override
+    protected boolean token(XContentParser parser, QueryParseContext context, String field,
+                            XContentParser.Token token, Map<String, Object> params) throws IOException {
+        if (context.getParseFieldMatcher().match(field, SIGMA) && token == XContentParser.Token.VALUE_NUMBER) {
+            params.put(SIGMA.getPreferredName(), parser.doubleValue());
+            return true;
+        }
+        return false;
+    }
 }

+ 2 - 2
core/src/main/java/org/elasticsearch/search/aggregations/pipeline/bucketmetrics/sum/SumBucketPipelineAggregatorBuilder.java

@@ -75,7 +75,7 @@ public class SumBucketPipelineAggregatorBuilder extends BucketMetricsPipelineAgg
     public static final PipelineAggregator.Parser PARSER = new BucketMetricsParser() {
         @Override
         protected SumBucketPipelineAggregatorBuilder buildFactory(String pipelineAggregatorName,
-                String bucketsPath, Map<String, Object> unparsedParams) {
+                String bucketsPath, Map<String, Object> params) {
             return new SumBucketPipelineAggregatorBuilder(pipelineAggregatorName, bucketsPath);
         }
     };
@@ -94,4 +94,4 @@ public class SumBucketPipelineAggregatorBuilder extends BucketMetricsPipelineAgg
     public String getWriteableName() {
         return NAME;
     }
-}
+}

+ 23 - 0
core/src/test/java/org/elasticsearch/search/aggregations/pipeline/bucketmetrics/ExtendedStatsBucketTests.java

@@ -19,8 +19,14 @@
 
 package org.elasticsearch.search.aggregations.pipeline.bucketmetrics;
 
+import org.elasticsearch.common.xcontent.XContentFactory;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.index.query.QueryParseContext;
+import org.elasticsearch.search.aggregations.pipeline.bucketmetrics.stats.extended.ExtendedStatsBucketPipelineAggregator;
 import org.elasticsearch.search.aggregations.pipeline.bucketmetrics.stats.extended.ExtendedStatsBucketPipelineAggregatorBuilder;
 
+import static org.hamcrest.Matchers.equalTo;
+
 public class ExtendedStatsBucketTests extends AbstractBucketMetricsTestCase<ExtendedStatsBucketPipelineAggregatorBuilder> {
 
     @Override
@@ -32,5 +38,22 @@ public class ExtendedStatsBucketTests extends AbstractBucketMetricsTestCase<Exte
         return factory;
     }
 
+    public void testSigmaFromInt() throws Exception {
+        String content = XContentFactory.jsonBuilder()
+            .startObject()
+                .field("sigma", 5)
+                .field("buckets_path", "test")
+            .endObject()
+            .string();
+
+        XContentParser parser = XContentFactory.xContent(content).createParser(content);
+        QueryParseContext parseContext = new QueryParseContext(queriesRegistry, parser, parseFieldMatcher);
+        parser.nextToken(); // skip object start
 
+        ExtendedStatsBucketPipelineAggregatorBuilder builder = (ExtendedStatsBucketPipelineAggregatorBuilder) aggParsers
+            .pipelineParser(ExtendedStatsBucketPipelineAggregator.TYPE.name(), parseFieldMatcher)
+            .parse("test", parseContext);
+
+        assertThat(builder.sigma(), equalTo(5.0));
+    }
 }

+ 23 - 0
core/src/test/java/org/elasticsearch/search/aggregations/pipeline/bucketmetrics/PercentilesBucketTests.java

@@ -19,8 +19,14 @@
 
 package org.elasticsearch.search.aggregations.pipeline.bucketmetrics;
 
+import org.elasticsearch.common.xcontent.XContentFactory;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.index.query.QueryParseContext;
+import org.elasticsearch.search.aggregations.pipeline.bucketmetrics.percentile.PercentilesBucketPipelineAggregator;
 import org.elasticsearch.search.aggregations.pipeline.bucketmetrics.percentile.PercentilesBucketPipelineAggregatorBuilder;
 
+import static org.hamcrest.Matchers.equalTo;
+
 public class PercentilesBucketTests extends AbstractBucketMetricsTestCase<PercentilesBucketPipelineAggregatorBuilder> {
 
     @Override
@@ -37,5 +43,22 @@ public class PercentilesBucketTests extends AbstractBucketMetricsTestCase<Percen
         return factory;
     }
 
+    public void testPercentsFromMixedArray() throws Exception {
+        String content = XContentFactory.jsonBuilder()
+            .startObject()
+                .field("buckets_path", "test")
+                .array("percents", 0, 20.0, 50, 75.99)
+            .endObject()
+            .string();
+
+        XContentParser parser = XContentFactory.xContent(content).createParser(content);
+        QueryParseContext parseContext = new QueryParseContext(queriesRegistry, parser, parseFieldMatcher);
+        parser.nextToken(); // skip object start
 
+        PercentilesBucketPipelineAggregatorBuilder builder = (PercentilesBucketPipelineAggregatorBuilder) aggParsers
+            .pipelineParser(PercentilesBucketPipelineAggregator.TYPE.name(), parseFieldMatcher)
+            .parse("test", parseContext);
+
+        assertThat(builder.percents(), equalTo(new double[]{0.0, 20.0, 50.0, 75.99}));
+    }
 }