瀏覽代碼

Merge pull request #11465 from polyfractal/bugfix/movavg_double_predict

Aggregations: Fix bug where moving_avg prediction keys are appended to previous prediction
Zachary Tong 10 年之前
父節點
當前提交
2dfcefb02c

+ 45 - 24
src/main/java/org/elasticsearch/search/aggregations/pipeline/movavg/MovAvgPipelineAggregator.java

@@ -47,9 +47,7 @@ import org.elasticsearch.search.aggregations.support.format.ValueFormatterStream
 import org.joda.time.DateTime;
 
 import java.io.IOException;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Map;
+import java.util.*;
 
 import static org.elasticsearch.search.aggregations.pipeline.BucketHelpers.resolveBucketValue;
 
@@ -110,12 +108,12 @@ public class MovAvgPipelineAggregator extends PipelineAggregator {
         List newBuckets = new ArrayList<>();
         EvictingQueue<Double> values = EvictingQueue.create(this.window);
 
-        long lastKey = 0;
-        Object currentKey;
+        long lastValidKey = 0;
+        int lastValidPosition = 0;
+        int counter = 0;
 
         for (InternalHistogram.Bucket bucket : buckets) {
             Double thisBucketValue = resolveBucketValue(histo, bucket, bucketsPaths()[0], gapPolicy);
-            currentKey = bucket.getKey();
 
             // Default is to reuse existing bucket.  Simplifies the rest of the logic,
             // since we only change newBucket if we can add to it
@@ -130,22 +128,23 @@ public class MovAvgPipelineAggregator extends PipelineAggregator {
 
                     List<InternalAggregation> aggs = new ArrayList<>(Lists.transform(bucket.getAggregations().asList(), AGGREGATION_TRANFORM_FUNCTION));
                     aggs.add(new InternalSimpleValue(name(), movavg, formatter, new ArrayList<PipelineAggregator>(), metaData()));
-                    newBucket = factory.createBucket(currentKey, bucket.getDocCount(), new InternalAggregations(
+                    newBucket = factory.createBucket(bucket.getKey(), bucket.getDocCount(), new InternalAggregations(
                             aggs), bucket.getKeyed(), bucket.getFormatter());
                 }
-            }
-
-            newBuckets.add(newBucket);
 
-            if (predict > 0) {
-                if (currentKey instanceof Number) {
-                    lastKey  = ((Number) bucket.getKey()).longValue();
-                } else if (currentKey instanceof DateTime) {
-                    lastKey = ((DateTime) bucket.getKey()).getMillis();
-                } else {
-                    throw new AggregationExecutionException("Expected key of type Number or DateTime but got [" + currentKey + "]");
+                if (predict > 0) {
+                    if (bucket.getKey() instanceof Number) {
+                        lastValidKey  = ((Number) bucket.getKey()).longValue();
+                    } else if (bucket.getKey() instanceof DateTime) {
+                        lastValidKey = ((DateTime) bucket.getKey()).getMillis();
+                    } else {
+                        throw new AggregationExecutionException("Expected key of type Number or DateTime but got [" + lastValidKey + "]");
+                    }
+                    lastValidPosition = counter;
                 }
             }
+            counter += 1;
+            newBuckets.add(newBucket);
 
         }
 
@@ -158,13 +157,35 @@ public class MovAvgPipelineAggregator extends PipelineAggregator {
 
             double[] predictions = model.predict(values, predict);
             for (int i = 0; i < predictions.length; i++) {
-                List<InternalAggregation> aggs = new ArrayList<>();
-                aggs.add(new InternalSimpleValue(name(), predictions[i], formatter, new ArrayList<PipelineAggregator>(), metaData()));
-                long newKey = histo.getRounding().nextRoundingValue(lastKey);
-                InternalHistogram.Bucket newBucket = factory.createBucket(newKey, 0, new InternalAggregations(
-                        aggs), keyed, formatter);
-                newBuckets.add(newBucket);
-                lastKey = newKey;
+
+                List<InternalAggregation> aggs;
+                long newKey = histo.getRounding().nextRoundingValue(lastValidKey);
+
+                if (lastValidPosition + i + 1 < newBuckets.size()) {
+                    InternalHistogram.Bucket bucket = (InternalHistogram.Bucket) newBuckets.get(lastValidPosition + i + 1);
+
+                    // Get the existing aggs in the bucket so we don't clobber data
+                    aggs = new ArrayList<>(Lists.transform(bucket.getAggregations().asList(), AGGREGATION_TRANFORM_FUNCTION));
+                    aggs.add(new InternalSimpleValue(name(), predictions[i], formatter, new ArrayList<PipelineAggregator>(), metaData()));
+
+                    InternalHistogram.Bucket newBucket = factory.createBucket(newKey, 0, new InternalAggregations(
+                            aggs), keyed, formatter);
+
+                    // Overwrite the existing bucket with the new version
+                    newBuckets.set(lastValidPosition + i + 1, newBucket);
+
+                } else {
+                    // Not seen before, create fresh
+                    aggs = new ArrayList<>();
+                    aggs.add(new InternalSimpleValue(name(), predictions[i], formatter, new ArrayList<PipelineAggregator>(), metaData()));
+
+                    InternalHistogram.Bucket newBucket = factory.createBucket(newKey, 0, new InternalAggregations(
+                            aggs), keyed, formatter);
+
+                    // Since this is a new bucket, simply append it
+                    newBuckets.add(newBucket);
+                }
+                lastValidKey = newKey;
             }
         }
 

+ 131 - 20
src/test/java/org/elasticsearch/search/aggregations/pipeline/moving/avg/MovAvgTests.java

@@ -35,6 +35,7 @@ import org.elasticsearch.search.aggregations.metrics.avg.Avg;
 import org.elasticsearch.search.aggregations.pipeline.BucketHelpers;
 import org.elasticsearch.search.aggregations.pipeline.PipelineAggregationHelperTests;
 import org.elasticsearch.search.aggregations.pipeline.SimpleValue;
+import org.elasticsearch.search.aggregations.pipeline.derivative.Derivative;
 import org.elasticsearch.search.aggregations.pipeline.movavg.models.*;
 import org.elasticsearch.test.ElasticsearchIntegrationTest;
 import org.hamcrest.Matchers;
@@ -49,6 +50,7 @@ import static org.elasticsearch.search.aggregations.AggregationBuilders.histogra
 import static org.elasticsearch.search.aggregations.AggregationBuilders.max;
 import static org.elasticsearch.search.aggregations.AggregationBuilders.min;
 import static org.elasticsearch.search.aggregations.AggregationBuilders.range;
+import static org.elasticsearch.search.aggregations.pipeline.PipelineAggregatorBuilders.derivative;
 import static org.elasticsearch.search.aggregations.pipeline.PipelineAggregatorBuilders.movingAvg;
 import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertSearchResponse;
 import static org.hamcrest.Matchers.closeTo;
@@ -160,6 +162,11 @@ public class MovAvgTests extends ElasticsearchIntegrationTest {
                     jsonBuilder().startObject().field(INTERVAL_FIELD, i).field(VALUE_FIELD, 10).endObject()));
         }
 
+        for (int i = 0; i < 12; i++) {
+            builders.add(client().prepareIndex("double_predict", "type").setSource(
+                    jsonBuilder().startObject().field(INTERVAL_FIELD, i).field(VALUE_FIELD, 10).endObject()));
+        }
+
         indexRandom(true, builders);
         ensureSearchable();
     }
@@ -957,8 +964,10 @@ public class MovAvgTests extends ElasticsearchIntegrationTest {
         assertThat(histo, notNullValue());
         assertThat(histo.getName(), equalTo("histo"));
         List<? extends Bucket> buckets = histo.getBuckets();
+
         assertThat("Size of buckets array is not correct.", buckets.size(), equalTo(50 + numPredictions));
 
+
         double lastValue = ((SimpleValue)(buckets.get(0).getAggregations().get("movavg_values"))).value();
         assertThat(Double.compare(lastValue, 0.0d), greaterThanOrEqualTo(0));
 
@@ -1073,8 +1082,10 @@ public class MovAvgTests extends ElasticsearchIntegrationTest {
         assertThat(histo, notNullValue());
         assertThat(histo.getName(), equalTo("histo"));
         List<? extends Bucket> buckets = histo.getBuckets();
+
         assertThat("Size of buckets array is not correct.", buckets.size(), equalTo(50 + numPredictions));
 
+
         double lastValue = 0;
 
         double currentValue;
@@ -1099,8 +1110,7 @@ public class MovAvgTests extends ElasticsearchIntegrationTest {
 
     /**
      * This test filters the "gap" data so that the last doc is excluded.  This leaves a long stretch of empty
-     * buckets after the first bucket.  The moving avg should be one at the beginning, then zero for the rest
-     * regardless of mov avg type or gap policy.
+     * buckets after the first bucket.
      */
     @Test
     public void testRightGap() {
@@ -1176,32 +1186,39 @@ public class MovAvgTests extends ElasticsearchIntegrationTest {
         assertThat(histo, notNullValue());
         assertThat(histo.getName(), equalTo("histo"));
         List<? extends Bucket> buckets = histo.getBuckets();
-        assertThat("Size of buckets array is not correct.", buckets.size(), equalTo(50 + numPredictions));
 
+        // If we are skipping, there will only be predictions at the very beginning and won't append any new buckets
+        if (gapPolicy.equals(BucketHelpers.GapPolicy.SKIP)) {
+            assertThat("Size of buckets array is not correct.", buckets.size(), equalTo(50));
+        } else {
+            assertThat("Size of buckets array is not correct.", buckets.size(), equalTo(50 + numPredictions));
+        }
 
+        // Unlike left-gap tests, we cannot check the slope of prediction for right-gap. E.g. linear will
+        // converge on zero, but holt-linear may trend upwards based on the first value
+        // Just check for non-nullness
         SimpleValue current = buckets.get(0).getAggregations().get("movavg_values");
         assertThat(current, notNullValue());
 
-        double lastValue = current.value();
-
-        double currentValue;
-        for (int i = 1; i < 50; i++) {
-            current = buckets.get(i).getAggregations().get("movavg_values");
-            if (current != null) {
-                currentValue = current.value();
-
-                assertThat(Double.compare(lastValue, currentValue), greaterThanOrEqualTo(0));
-                lastValue = currentValue;
+        // If we are skipping, there will only be predictions at the very beginning and won't append any new buckets
+        if (gapPolicy.equals(BucketHelpers.GapPolicy.SKIP)) {
+            // Now check predictions
+            for (int i = 1; i < 1 + numPredictions; i++) {
+                // Unclear at this point which direction the predictions will go, just verify they are
+                // not null
+                assertThat(buckets.get(i).getDocCount(), equalTo(0L));
+                assertThat((buckets.get(i).getAggregations().get("movavg_values")), notNullValue());
+            }
+        } else {
+            // Otherwise we'll have some predictions at the end
+            for (int i = 50; i < 50 + numPredictions; i++) {
+                // Unclear at this point which direction the predictions will go, just verify they are
+                // not null
+                assertThat(buckets.get(i).getDocCount(), equalTo(0L));
+                assertThat((buckets.get(i).getAggregations().get("movavg_values")), notNullValue());
             }
         }
 
-        // Now check predictions
-        for (int i = 50; i < 50 + numPredictions; i++) {
-            // Unclear at this point which direction the predictions will go, just verify they are
-            // not null, and that we don't have the_metric anymore
-            assertThat((buckets.get(i).getAggregations().get("movavg_values")), notNullValue());
-            assertThat((buckets.get(i).getAggregations().get("the_metric")), nullValue());
-        }
     }
 
     @Test
@@ -1232,6 +1249,100 @@ public class MovAvgTests extends ElasticsearchIntegrationTest {
 
     }
 
+    @Test
+    public void testTwoMovAvgsWithPredictions() {
+
+        SearchResponse response = client()
+                .prepareSearch("double_predict")
+                .setTypes("type")
+                .addAggregation(
+                        histogram("histo")
+                                .field(INTERVAL_FIELD)
+                                .interval(1)
+                                .subAggregation(avg("avg").field(VALUE_FIELD))
+                                .subAggregation(derivative("deriv")
+                                        .setBucketsPaths("avg").gapPolicy(gapPolicy))
+                                .subAggregation(
+                                        movingAvg("avg_movavg").window(windowSize).modelBuilder(new SimpleModel.SimpleModelBuilder())
+                                                .gapPolicy(gapPolicy).predict(12).setBucketsPaths("avg"))
+                                .subAggregation(
+                                        movingAvg("deriv_movavg").window(windowSize).modelBuilder(new SimpleModel.SimpleModelBuilder())
+                                                .gapPolicy(gapPolicy).predict(12).setBucketsPaths("deriv"))
+                ).execute().actionGet();
+
+        assertSearchResponse(response);
+
+        InternalHistogram<Bucket> histo = response.getAggregations().get("histo");
+        assertThat(histo, notNullValue());
+        assertThat(histo.getName(), equalTo("histo"));
+        List<? extends Bucket> buckets = histo.getBuckets();
+        assertThat("Size of buckets array is not correct.", buckets.size(), equalTo(24));
+
+        Bucket bucket = buckets.get(0);
+        assertThat(bucket, notNullValue());
+        assertThat((long) bucket.getKey(), equalTo((long) 0));
+        assertThat(bucket.getDocCount(), equalTo(1l));
+
+        Avg avgAgg = bucket.getAggregations().get("avg");
+        assertThat(avgAgg, notNullValue());
+        assertThat(avgAgg.value(), equalTo(10d));
+
+        SimpleValue movAvgAgg = bucket.getAggregations().get("avg_movavg");
+        assertThat(movAvgAgg, notNullValue());
+        assertThat(movAvgAgg.value(), equalTo(10d));
+
+        Derivative deriv = bucket.getAggregations().get("deriv");
+        assertThat(deriv, nullValue());
+
+        SimpleValue derivMovAvg = bucket.getAggregations().get("deriv_movavg");
+        assertThat(derivMovAvg, nullValue());
+
+        for (int i = 1; i < 12; i++) {
+            bucket = buckets.get(i);
+            assertThat(bucket, notNullValue());
+            assertThat((long) bucket.getKey(), equalTo((long) i));
+            assertThat(bucket.getDocCount(), equalTo(1l));
+
+            avgAgg = bucket.getAggregations().get("avg");
+            assertThat(avgAgg, notNullValue());
+            assertThat(avgAgg.value(), equalTo(10d));
+
+            deriv = bucket.getAggregations().get("deriv");
+            assertThat(deriv, notNullValue());
+            assertThat(deriv.value(), equalTo(0d));
+
+            movAvgAgg = bucket.getAggregations().get("avg_movavg");
+            assertThat(movAvgAgg, notNullValue());
+            assertThat(movAvgAgg.value(), equalTo(10d));
+
+            derivMovAvg = bucket.getAggregations().get("deriv_movavg");
+            assertThat(derivMovAvg, notNullValue());
+            assertThat(derivMovAvg.value(), equalTo(0d));
+        }
+
+        // Predictions
+        for (int i = 12; i < 24; i++) {
+            bucket = buckets.get(i);
+            assertThat(bucket, notNullValue());
+            assertThat((long) bucket.getKey(), equalTo((long) i));
+            assertThat(bucket.getDocCount(), equalTo(0l));
+
+            avgAgg = bucket.getAggregations().get("avg");
+            assertThat(avgAgg, nullValue());
+
+            deriv = bucket.getAggregations().get("deriv");
+            assertThat(deriv, nullValue());
+
+            movAvgAgg = bucket.getAggregations().get("avg_movavg");
+            assertThat(movAvgAgg, notNullValue());
+            assertThat(movAvgAgg.value(), equalTo(10d));
+
+            derivMovAvg = bucket.getAggregations().get("deriv_movavg");
+            assertThat(derivMovAvg, notNullValue());
+            assertThat(derivMovAvg.value(), equalTo(0d));
+        }
+    }
+
     @Test
     public void testBadModelParams() {
         try {