Browse Source

Add more flexibility to MovingFunction window alignment (#44360)

Introduce shift field to MovingFunction aggregation.

By default, shift = 0. Behavior, in this case, is the same as before.
Increasing shift by 1 moves starting window position by 1 to the right.

    To simply include current bucket to the window, use shift = 1
    For center alignment (n/2 values before and after the current bucket), use shift = window / 2
    For right alignment (n values after the current bucket), use shift = window.
Nikita Glashenko 6 years ago
parent
commit
ead4eb5209

+ 15 - 2
docs/reference/aggregations/pipeline/movfn-aggregation.asciidoc

@@ -24,14 +24,15 @@ A `moving_fn` aggregation looks like this in isolation:
 --------------------------------------------------
 // NOTCONSOLE
 
-[[moving-avg-params]]
-.`moving_avg` Parameters
+[[moving-fn-params]]
+.`moving_fn` Parameters
 [options="header"]
 |===
 |Parameter Name |Description |Required |Default Value
 |`buckets_path` |Path to the metric of interest (see <<buckets-path-syntax, `buckets_path` Syntax>> for more details |Required |
 |`window` |The size of window to "slide" across the histogram. |Required |
 |`script` |The script that should be executed on each window of data |Required |
+|`shift` |<<shift-parameter, Shift>> of window position. |Optional | 0
 |===
 
 `moving_fn` aggregations must be embedded inside of a `histogram` or `date_histogram` aggregation.  They can be
@@ -169,6 +170,18 @@ POST /_search
 // CONSOLE
 // TEST[setup:sales]
 
+[[shift-parameter]]
+==== shift parameter
+
+By default (with `shift = 0`), the window that is offered for calculation is the last `n` values excluding the current bucket.
+Increasing `shift` by 1 moves starting window position by `1` to the right.
+
+- To include current bucket to the window, use `shift = 1`.
+- For center alignment (`n / 2` values before and after the current bucket), use `shift = window / 2`.
+- For right alignment (`n` values after the current bucket), use `shift = window`.
+
+If either of window edges moves outside the borders of data series, the window shrinks to include available values only.
+
 ==== Pre-built Functions
 
 For convenience, a number of functions have been prebuilt and are available inside the `moving_fn` script context:

+ 22 - 4
server/src/main/java/org/elasticsearch/search/aggregations/pipeline/MovFnPipelineAggregationBuilder.java

@@ -19,6 +19,7 @@
 
 package org.elasticsearch.search.aggregations.pipeline;
 
+import org.elasticsearch.Version;
 import org.elasticsearch.common.ParseField;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.io.stream.StreamInput;
@@ -48,12 +49,14 @@ import static org.elasticsearch.search.aggregations.pipeline.PipelineAggregator.
 public class MovFnPipelineAggregationBuilder extends AbstractPipelineAggregationBuilder<MovFnPipelineAggregationBuilder> {
     public static final String NAME = "moving_fn";
     private static final ParseField WINDOW = new ParseField("window");
+    private static final ParseField SHIFT = new ParseField("shift");
 
     private final Script script;
     private final String bucketsPathString;
     private String format = null;
     private GapPolicy gapPolicy = GapPolicy.SKIP;
     private int window;
+    private int shift;
 
     private static final Function<String, ConstructingObjectParser<MovFnPipelineAggregationBuilder, Void>> PARSER
         = name -> {
@@ -68,6 +71,7 @@ public class MovFnPipelineAggregationBuilder extends AbstractPipelineAggregation
             (p, c) -> Script.parse(p), Script.SCRIPT_PARSE_FIELD, ObjectParser.ValueType.OBJECT_OR_STRING);
         parser.declareInt(ConstructingObjectParser.constructorArg(), WINDOW);
 
+        parser.declareInt(MovFnPipelineAggregationBuilder::setShift, SHIFT);
         parser.declareString(MovFnPipelineAggregationBuilder::format, FORMAT);
         parser.declareField(MovFnPipelineAggregationBuilder::gapPolicy, p -> {
             if (p.currentToken() == XContentParser.Token.VALUE_STRING) {
@@ -97,6 +101,11 @@ public class MovFnPipelineAggregationBuilder extends AbstractPipelineAggregation
         format = in.readOptionalString();
         gapPolicy = GapPolicy.readFrom(in);
         window = in.readInt();
+        if (in.getVersion().onOrAfter(Version.V_8_0_0)) { // TODO change this after backport
+            shift = in.readInt();
+        } else {
+            shift = 0;
+        }
     }
 
     @Override
@@ -106,6 +115,9 @@ public class MovFnPipelineAggregationBuilder extends AbstractPipelineAggregation
         out.writeOptionalString(format);
         gapPolicy.writeTo(out);
         out.writeInt(window);
+        if (out.getVersion().onOrAfter(Version.V_8_0_0)) { // TODO change this after backport
+            out.writeInt(shift);
+        }
     }
 
     /**
@@ -168,9 +180,13 @@ public class MovFnPipelineAggregationBuilder extends AbstractPipelineAggregation
         this.window = window;
     }
 
+    public void setShift(int shift) {
+        this.shift = shift;
+    }
+
     @Override
     public void doValidate(AggregatorFactory parent, Collection<AggregationBuilder> aggFactories,
-                           Collection<PipelineAggregationBuilder> pipelineAggregatoractories) {
+                           Collection<PipelineAggregationBuilder> pipelineAggregatorFactories) {
         if (window <= 0) {
             throw new IllegalArgumentException("[" + WINDOW.getPreferredName() + "] must be a positive, non-zero integer.");
         }
@@ -180,7 +196,7 @@ public class MovFnPipelineAggregationBuilder extends AbstractPipelineAggregation
 
     @Override
     protected PipelineAggregator createInternal(Map<String, Object> metaData) {
-        return new MovFnPipelineAggregator(name, bucketsPathString, script, window, formatter(), gapPolicy, metaData);
+        return new MovFnPipelineAggregator(name, bucketsPathString, script, window, shift, formatter(), gapPolicy, metaData);
     }
 
     @Override
@@ -192,6 +208,7 @@ public class MovFnPipelineAggregationBuilder extends AbstractPipelineAggregation
         }
         builder.field(GAP_POLICY.getPreferredName(), gapPolicy.getName());
         builder.field(WINDOW.getPreferredName(), window);
+        builder.field(SHIFT.getPreferredName(), shift);
         return builder;
     }
 
@@ -225,7 +242,7 @@ public class MovFnPipelineAggregationBuilder extends AbstractPipelineAggregation
 
     @Override
     public int hashCode() {
-        return Objects.hash(super.hashCode(), bucketsPathString, script, format, gapPolicy, window);
+        return Objects.hash(super.hashCode(), bucketsPathString, script, format, gapPolicy, window, shift);
     }
 
     @Override
@@ -238,7 +255,8 @@ public class MovFnPipelineAggregationBuilder extends AbstractPipelineAggregation
             && Objects.equals(script, other.script)
             && Objects.equals(format, other.format)
             && Objects.equals(gapPolicy, other.gapPolicy)
-            && Objects.equals(window, other.window);
+            && Objects.equals(window, other.window)
+            && Objects.equals(shift, other.shift);
     }
 
     @Override

+ 38 - 6
server/src/main/java/org/elasticsearch/search/aggregations/pipeline/MovFnPipelineAggregator.java

@@ -19,7 +19,7 @@
 
 package org.elasticsearch.search.aggregations.pipeline;
 
-import org.elasticsearch.common.collect.EvictingQueue;
+import org.elasticsearch.Version;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.script.Script;
@@ -63,8 +63,9 @@ public class MovFnPipelineAggregator extends PipelineAggregator {
     private final Script script;
     private final String bucketsPath;
     private final int window;
+    private final int shift;
 
-    MovFnPipelineAggregator(String name, String bucketsPath, Script script, int window, DocValueFormat formatter,
+    MovFnPipelineAggregator(String name, String bucketsPath, Script script, int window, int shift, DocValueFormat formatter,
                             BucketHelpers.GapPolicy gapPolicy, Map<String, Object> metadata) {
         super(name, new String[]{bucketsPath}, metadata);
         this.bucketsPath = bucketsPath;
@@ -72,6 +73,7 @@ public class MovFnPipelineAggregator extends PipelineAggregator {
         this.formatter = formatter;
         this.gapPolicy = gapPolicy;
         this.window = window;
+        this.shift = shift;
     }
 
     public MovFnPipelineAggregator(StreamInput in) throws IOException {
@@ -81,6 +83,11 @@ public class MovFnPipelineAggregator extends PipelineAggregator {
         gapPolicy = BucketHelpers.GapPolicy.readFrom(in);
         bucketsPath = in.readString();
         window = in.readInt();
+        if (in.getVersion().onOrAfter(Version.V_8_0_0)) { // TODO change this after backport
+            shift = in.readInt();
+        } else {
+            shift = 0;
+        }
     }
 
     @Override
@@ -90,6 +97,9 @@ public class MovFnPipelineAggregator extends PipelineAggregator {
         gapPolicy.writeTo(out);
         out.writeString(bucketsPath);
         out.writeInt(window);
+        if (out.getVersion().onOrAfter(Version.V_8_0_0)) { // TODO change this after backport
+            out.writeInt(shift);
+        }
     }
 
     @Override
@@ -106,7 +116,6 @@ public class MovFnPipelineAggregator extends PipelineAggregator {
         HistogramFactory factory = (HistogramFactory) histo;
 
         List<MultiBucketsAggregation.Bucket> newBuckets = new ArrayList<>();
-        EvictingQueue<Double> values = new EvictingQueue<>(this.window);
 
         // Initialize the script
         MovingFunctionScript.Factory scriptFactory = reduceContext.scriptService().compile(script, MovingFunctionScript.CONTEXT);
@@ -117,6 +126,12 @@ public class MovFnPipelineAggregator extends PipelineAggregator {
 
         MovingFunctionScript executableScript = scriptFactory.newInstance();
 
+        List<Double> values = buckets.stream()
+            .map(b -> resolveBucketValue(histo, b, bucketsPaths()[0], gapPolicy))
+            .filter(v -> v != null && v.isNaN() == false)
+            .collect(Collectors.toList());
+
+        int index = 0;
         for (InternalMultiBucketAggregation.InternalBucket bucket : buckets) {
             Double thisBucketValue = resolveBucketValue(histo, bucket, bucketsPaths()[0], gapPolicy);
 
@@ -124,11 +139,18 @@ public class MovFnPipelineAggregator extends PipelineAggregator {
             // since we only change newBucket if we can add to it
             MultiBucketsAggregation.Bucket newBucket = bucket;
 
-            if (thisBucketValue != null && thisBucketValue.equals(Double.NaN) == false) {
+            if (thisBucketValue != null && thisBucketValue.isNaN() == false) {
 
                 // The custom context mandates that the script returns a double (not Double) so we
                 // don't need null checks, etc.
-                double movavg = executableScript.execute(vars, values.stream().mapToDouble(Double::doubleValue).toArray());
+                int fromIndex = clamp(index - window + shift, values);
+                int toIndex = clamp(index + shift, values);
+                double movavg = executableScript.execute(
+                    vars,
+                    values.subList(fromIndex, toIndex).stream()
+                        .mapToDouble(Double::doubleValue)
+                        .toArray()
+                );
 
                 List<InternalAggregation> aggs = StreamSupport
                     .stream(bucket.getAggregations().spliterator(), false)
@@ -136,11 +158,21 @@ public class MovFnPipelineAggregator extends PipelineAggregator {
                     .collect(Collectors.toList());
                 aggs.add(new InternalSimpleValue(name(), movavg, formatter, new ArrayList<>(), metaData()));
                 newBucket = factory.createBucket(factory.getKey(bucket), bucket.getDocCount(), new InternalAggregations(aggs));
-                values.offer(thisBucketValue);
+                index++;
             }
             newBuckets.add(newBucket);
         }
 
         return factory.createAggregation(newBuckets);
     }
+
+    private int clamp(int index, List<Double> list) {
+        if (index < 0) {
+            return 0;
+        }
+        if (index > list.size()) {
+            return list.size();
+        }
+        return index;
+    }
 }

+ 8 - 2
server/src/test/java/org/elasticsearch/search/aggregations/pipeline/MovFnPipelineAggregationBuilderSerializationTests.java

@@ -22,7 +22,6 @@ package org.elasticsearch.search.aggregations.pipeline;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.script.Script;
-import org.elasticsearch.search.aggregations.pipeline.MovFnPipelineAggregationBuilder;
 import org.elasticsearch.test.AbstractSerializingTestCase;
 
 import java.io.IOException;
@@ -31,7 +30,14 @@ public class MovFnPipelineAggregationBuilderSerializationTests extends AbstractS
 
     @Override
     protected MovFnPipelineAggregationBuilder createTestInstance() {
-        return new MovFnPipelineAggregationBuilder(randomAlphaOfLength(10), "foo", new Script("foo"), randomIntBetween(1, 10));
+        MovFnPipelineAggregationBuilder builder = new MovFnPipelineAggregationBuilder(
+            randomAlphaOfLength(10),
+            "foo",
+            new Script("foo"),
+            randomIntBetween(1, 10)
+        );
+        builder.setShift(randomIntBetween(1, 10));
+        return builder;
     }
 
     @Override

+ 29 - 11
server/src/test/java/org/elasticsearch/search/aggregations/pipeline/MovFnUnitTests.java

@@ -53,6 +53,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.function.Consumer;
+import java.util.stream.Collectors;
 
 import static org.hamcrest.Matchers.equalTo;
 import static org.mockito.Mockito.mock;
@@ -79,25 +80,42 @@ public class MovFnUnitTests extends AggregatorTestCase {
     private static final List<Integer> datasetValues = Arrays.asList(1,2,3,4,5,6,7,8,9,10);
 
     public void testMatchAllDocs() throws IOException {
-        Query query = new MatchAllDocsQuery();
+        check(0, List.of(Double.NaN, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0));
+    }
+
+    public void testShift() throws IOException {
+        check(1, List.of(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0));
+        check(5, List.of(5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 10.0, 10.0, Double.NaN, Double.NaN));
+        check(-5, List.of(Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN, 1.0, 2.0, 3.0, 4.0));
+    }
+
+    public void testWideWindow() throws IOException {
         Script script = new Script(Script.DEFAULT_SCRIPT_TYPE, "painless", "test", Collections.emptyMap());
+        MovFnPipelineAggregationBuilder builder = new MovFnPipelineAggregationBuilder("mov_fn", "avg", script, 100);
+        builder.setShift(50);
+        check(builder, script, List.of(10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0));
+    }
 
+    private void check(int shift, List<Double> expected) throws IOException {
+        Script script = new Script(Script.DEFAULT_SCRIPT_TYPE, "painless", "test", Collections.emptyMap());
+        MovFnPipelineAggregationBuilder builder = new MovFnPipelineAggregationBuilder("mov_fn", "avg", script, 3);
+        builder.setShift(shift);
+        check(builder, script, expected);
+    }
+
+    private void check(MovFnPipelineAggregationBuilder builder, Script script, List<Double> expected) throws IOException {
+        Query query = new MatchAllDocsQuery();
         DateHistogramAggregationBuilder aggBuilder = new DateHistogramAggregationBuilder("histo");
         aggBuilder.calendarInterval(DateHistogramInterval.DAY).field(DATE_FIELD);
         aggBuilder.subAggregation(new AvgAggregationBuilder("avg").field(VALUE_FIELD));
-        aggBuilder.subAggregation(new MovFnPipelineAggregationBuilder("mov_fn", "avg", script, 3));
+        aggBuilder.subAggregation(builder);
 
         executeTestCase(query, aggBuilder, histogram -> {
-                assertEquals(10, histogram.getBuckets().size());
                 List<? extends Histogram.Bucket> buckets = histogram.getBuckets();
-                for (int i = 0; i < buckets.size(); i++) {
-                    if (i == 0) {
-                        assertThat(((InternalSimpleValue)(buckets.get(i).getAggregations().get("mov_fn"))).value(), equalTo(Double.NaN));
-                    } else {
-                        assertThat(((InternalSimpleValue)(buckets.get(i).getAggregations().get("mov_fn"))).value(), equalTo(((double) i)));
-                    }
-
-                }
+                List<Double> actual = buckets.stream()
+                    .map(bucket -> ((InternalSimpleValue) (bucket.getAggregations().get("mov_fn"))).value())
+                    .collect(Collectors.toList());
+                assertThat(actual, equalTo(expected));
             }, 1000, script);
     }