Forráskód Böngészése

[ML] add new bucket_correlation aggregation with initial count_correlation function (#72133)

This commit adds a new pipeline aggregation that allows correlation within the aggregation frame work in bucketed values. 

The initial function is a `count_correlation` function. The purpose of which is to correlate the count in a consistent number of buckets with a pre calculated indicator. The indicator and the aggregated buckets should related to the same metrics with in documents. 

Example for correlating terms within a `service.version.keyword` with latency percentiles. The percentiles and provided correlation indicator both refer to the same source data where the indicator was previously calculated.:
```
GET apm-7.12.0-transaction-generated/_search
{
  "size": 0,
  "aggs": {
    "field_terms": {
      "terms": {
        "field": "service.version.keyword",
        "size": 20
      },
      "aggs": {
        "latency_range": {
          "range": {
            "field": "transaction.duration.us",
            "ranges": [<snip>],
            "keyed": true
          }
        },
        "correlation": {
          "bucket_correlation": {
            "buckets_path": "latency_range>_count",
            "count_correlation": {
              "indicator": {
                 "expectations": [<snip>],
                 "doc_count": 20000
               }
            }
          }
        }
      }
    }
  }
}
```
Benjamin Trent 4 éve
szülő
commit
8069e9b233
29 módosított fájl, 1733 hozzáadás és 111 törlés
  1. 41 0
      docs/build.gradle
  2. 319 0
      docs/reference/aggregations/pipeline/bucket-correlation-aggregation.asciidoc
  3. 1 1
      x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/license/MachineLearningLicensingIT.java
  4. 0 13
      x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/AnomalyJobCRUDIT.java
  5. 216 0
      x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/BucketCorrelationAggregationIT.java
  6. 0 11
      x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java
  7. 0 16
      x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/DataFrameAnalyticsCRUDIT.java
  8. 0 16
      x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/DataFrameAnalyticsConfigProviderIT.java
  9. 0 8
      x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/MlAutoUpdateServiceIT.java
  10. 0 13
      x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java
  11. 0 14
      x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java
  12. 9 11
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java
  13. 139 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/correlation/BucketCorrelationAggregationBuilder.java
  14. 77 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/correlation/BucketCorrelationAggregator.java
  15. 21 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/correlation/CorrelationFunction.java
  16. 39 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/correlation/CorrelationNamedContentProvider.java
  17. 179 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/correlation/CountCorrelationFunction.java
  18. 147 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/correlation/CountCorrelationIndicator.java
  19. 15 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/inference/InferencePipelineAggregationBuilder.java
  20. 1 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/inference/InferencePipelineAggregator.java
  21. 1 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/inference/InternalInferenceAggregation.java
  22. 13 2
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlSingleNodeTestCase.java
  23. 97 0
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/correlation/BucketCorrelationAggregationBuilderTests.java
  24. 71 0
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/correlation/CountCorrelationFunctionTests.java
  25. 44 0
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/correlation/CountCorrelationIndicatorTests.java
  26. 2 1
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/inference/InferencePipelineAggregationBuilderTests.java
  27. 1 1
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/inference/InternalInferenceAggregationTests.java
  28. 1 1
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/inference/ParsedInference.java
  29. 299 0
      x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/bucket_correlation_agg.yml

+ 41 - 0
docs/build.gradle

@@ -1522,6 +1522,47 @@ setups['setup-repository'] = '''
           body: |
 #atomic_red_data#
 '''
+  // fake data used by the correlation bucket agg
+  buildRestTests.setups['correlate_latency'] = '''
+  - do:
+        indices.create:
+          index: correlate_latency
+          body:
+            settings:
+              number_of_shards: 1
+              number_of_replicas: 0
+            mappings:
+              properties:
+                latency:
+                  type: double
+                version:
+                  type: keyword
+  - do:
+        bulk:
+          index: correlate_latency
+          refresh: true
+          body: |'''
+
+
+  for (int i = 100; i < 200; i++) {
+    def value = i
+    if (i % 10) {
+      value = i * 10
+    }
+    buildRestTests.setups['correlate_latency'] += """
+            {"index":{}}
+            {"latency": "$value", "version": "1.0"}"""
+  }
+  for (int i = 0; i < 100; i++) {
+    def value = i
+    if (i % 10) {
+      value = i * 10
+    }
+    buildRestTests.setups['correlate_latency'] += """
+            {"index":{}}
+            {"latency": "$value", "version": "2.0"}"""
+  }
+
   /* Load the actual events only if we're going to use them. */
   File atomicRedRegsvr32File = new File("$projectDir/src/test/resources/normalized-T1117-AtomicRed-regsvr32.json")
   inputs.file(atomicRedRegsvr32File)

+ 319 - 0
docs/reference/aggregations/pipeline/bucket-correlation-aggregation.asciidoc

@@ -0,0 +1,319 @@
+[role="xpack"]
+[testenv="basic"]
+[[search-aggregations-bucket-correlation-aggregation]]
+=== Bucket correlation aggregation
+++++
+<titleabbrev>Bucket correlation aggregation</titleabbrev>
+++++
+
+experimental::[]
+
+A sibling pipeline aggregation which executes a correlation function on the
+configured sibling multi-bucket aggregation.
+
+
+[[bucket-correlation-agg-syntax]]
+==== Parameters
+
+`buckets_path`::
+(Required, string)
+Path to the buckets that contain one set of values to correlate.
+For syntax, see <<buckets-path-syntax>>.
+
+`function`::
+(Required, object)
+The correlation function to execute.
++
+.Properties of `function`
+[%collapsible%open]
+====
+`count_correlation`:::
+(Required^*^, object)
+The configuration to calculate a count correlation. This function is designed for
+determining the correlation of a term value and a given metric. Consequently, it
+needs to meet the following requirements.
+
+* The `buckets_path` must point to a `_count` metric.
+* The total count of all the `bucket_path` count values must be less than or equal to `indicator.doc_count`.
+* When utilizing this function, an initial calculation to gather the required `indicator` values is required.
+
+.Properties of `count_correlation`
+[%collapsible%open]
+=====
+`indicator`:::
+(Required, object)
+The indicator with which to correlate the configured `bucket_path` values.
+
+.Properties of `indicator`
+[%collapsible%open]
+=====
+`expectations`:::
+(Required, array)
+An array of numbers with which to correlate the configured `bucket_path` values. The length of this value must always equal
+the number of buckets returned by the `bucket_path`.
+
+`fractions`:::
+(Optional, array)
+An array of fractions to use when averaging and calculating variance. This should be used if the pre-calculated data and the
+`buckets_path` have known gaps. The length of `fractions`, if provided, must equal `expectations`.
+
+`doc_count`:::
+(Required, integer)
+The total number of documents that initially created the `expectations`. It's required to be greater than or equal to the sum
+of all values in the `buckets_path` as this is the originating superset of data to which the term values are correlated.
+=====
+=====
+====
+
+==== Syntax
+
+A `bucket_correlation` aggregation looks like this in isolation:
+
+[source,js]
+--------------------------------------------------
+{
+  "bucket_correlation": {
+    "buckets_path": "range_values>_count", <1>
+    "function": {
+      "count_correlation": { <2>
+        "expectations": [...],
+        "doc_count": 10000
+      }
+    }
+  }
+}
+--------------------------------------------------
+// NOTCONSOLE
+<1> The buckets containing the values to correlate against.
+<2> The correlation function definition.
+
+
+[[bucket-correlation-agg-example]]
+==== Example
+
+The following snippet correlates the individual terms in the field `version` with the `latency` metric. Not shown
+is the pre-calculation of the `latency` indicator values, which was done utilizing the
+<<search-aggregations-metrics-percentile-aggregation,percentiles>> aggregation.
+
+This example is only using the 10s percentiles.
+
+[source,console]
+-------------------------------------------------
+POST correlate_latency/_search?size=0&filter_path=aggregations
+{
+  "aggs": {
+    "buckets": {
+      "terms": {
+        "field": "version",
+        "size": 2
+      },
+      "aggs": {
+        "latency_ranges": {
+          "range": {
+            "field": "latency",
+            "ranges": [
+              { "to": 0.0 },
+              { "from": 0, "to": 105 },
+              { "from": 105, "to": 225 },
+              { "from": 225, "to": 445 },
+              { "from": 445, "to": 665 },
+              { "from": 665, "to": 885 },
+              { "from": 885, "to": 1115 },
+              { "from": 1115, "to": 1335 },
+              { "from": 1335, "to": 1555 },
+              { "from": 1555, "to": 1775 },
+              { "from": 1775 }
+            ]
+          }
+        },
+        "bucket_correlation": {
+          "bucket_correlation": {
+            "buckets_path": "latency_ranges>_count",
+            "function": {
+              "count_correlation": {
+                "indicator": {
+                   "expectations": [0, 52.5, 165, 335, 555, 775, 1000, 1225, 1445, 1665, 1775],
+                   "doc_count": 200
+                }
+              }
+            }
+          }
+        }
+      }
+    }
+  }
+}
+-------------------------------------------------
+// TEST[setup:correlate_latency]
+
+<1> The term buckets containing a range aggregation and the bucket correlation aggregation. Both are utilized to calculate
+    the correlation of the term values with the latency.
+<2> The range aggregation on the latency field. The ranges were created referencing the percentiles of the latency field.
+<3> The bucket correlation aggregation that calculates the correlation of the number of term values within each range
+    and the previously calculated indicator values.
+
+And the following may be the response:
+
+[source,console-result]
+----
+{
+  "aggregations" : {
+    "buckets" : {
+      "doc_count_error_upper_bound" : 0,
+      "sum_other_doc_count" : 0,
+      "buckets" : [
+        {
+          "key" : "1.0",
+          "doc_count" : 100,
+          "latency_ranges" : {
+            "buckets" : [
+              {
+                "key" : "*-0.0",
+                "to" : 0.0,
+                "doc_count" : 0
+              },
+              {
+                "key" : "0.0-105.0",
+                "from" : 0.0,
+                "to" : 105.0,
+                "doc_count" : 1
+              },
+              {
+                "key" : "105.0-225.0",
+                "from" : 105.0,
+                "to" : 225.0,
+                "doc_count" : 9
+              },
+              {
+                "key" : "225.0-445.0",
+                "from" : 225.0,
+                "to" : 445.0,
+                "doc_count" : 0
+              },
+              {
+                "key" : "445.0-665.0",
+                "from" : 445.0,
+                "to" : 665.0,
+                "doc_count" : 0
+              },
+              {
+                "key" : "665.0-885.0",
+                "from" : 665.0,
+                "to" : 885.0,
+                "doc_count" : 0
+              },
+              {
+                "key" : "885.0-1115.0",
+                "from" : 885.0,
+                "to" : 1115.0,
+                "doc_count" : 10
+              },
+              {
+                "key" : "1115.0-1335.0",
+                "from" : 1115.0,
+                "to" : 1335.0,
+                "doc_count" : 20
+              },
+              {
+                "key" : "1335.0-1555.0",
+                "from" : 1335.0,
+                "to" : 1555.0,
+                "doc_count" : 20
+              },
+              {
+                "key" : "1555.0-1775.0",
+                "from" : 1555.0,
+                "to" : 1775.0,
+                "doc_count" : 20
+              },
+              {
+                "key" : "1775.0-*",
+                "from" : 1775.0,
+                "doc_count" : 20
+              }
+            ]
+          },
+          "bucket_correlation" : {
+            "value" : 0.8402398981360937
+          }
+        },
+        {
+          "key" : "2.0",
+          "doc_count" : 100,
+          "latency_ranges" : {
+            "buckets" : [
+              {
+                "key" : "*-0.0",
+                "to" : 0.0,
+                "doc_count" : 0
+              },
+              {
+                "key" : "0.0-105.0",
+                "from" : 0.0,
+                "to" : 105.0,
+                "doc_count" : 19
+              },
+              {
+                "key" : "105.0-225.0",
+                "from" : 105.0,
+                "to" : 225.0,
+                "doc_count" : 11
+              },
+              {
+                "key" : "225.0-445.0",
+                "from" : 225.0,
+                "to" : 445.0,
+                "doc_count" : 20
+              },
+              {
+                "key" : "445.0-665.0",
+                "from" : 445.0,
+                "to" : 665.0,
+                "doc_count" : 20
+              },
+              {
+                "key" : "665.0-885.0",
+                "from" : 665.0,
+                "to" : 885.0,
+                "doc_count" : 20
+              },
+              {
+                "key" : "885.0-1115.0",
+                "from" : 885.0,
+                "to" : 1115.0,
+                "doc_count" : 10
+              },
+              {
+                "key" : "1115.0-1335.0",
+                "from" : 1115.0,
+                "to" : 1335.0,
+                "doc_count" : 0
+              },
+              {
+                "key" : "1335.0-1555.0",
+                "from" : 1335.0,
+                "to" : 1555.0,
+                "doc_count" : 0
+              },
+              {
+                "key" : "1555.0-1775.0",
+                "from" : 1555.0,
+                "to" : 1775.0,
+                "doc_count" : 0
+              },
+              {
+                "key" : "1775.0-*",
+                "from" : 1775.0,
+                "doc_count" : 0
+              }
+            ]
+          },
+          "bucket_correlation" : {
+            "value" : -0.5759855613334943
+          }
+        }
+      ]
+    }
+  }
+}
+----

+ 1 - 1
x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/license/MachineLearningLicensingIT.java

@@ -53,7 +53,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode;
 import org.elasticsearch.xpack.core.ml.job.config.JobState;
-import org.elasticsearch.xpack.ml.inference.aggs.InferencePipelineAggregationBuilder;
+import org.elasticsearch.xpack.ml.aggs.inference.InferencePipelineAggregationBuilder;
 import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
 import org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase;
 import org.junit.Before;

+ 0 - 13
x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/AnomalyJobCRUDIT.java

@@ -6,7 +6,6 @@
  */
 package org.elasticsearch.xpack.ml.integration;
 
-import static java.util.Collections.emptyList;
 import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
 import static org.hamcrest.Matchers.containsString;
 
@@ -27,14 +26,10 @@ import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.cluster.service.MasterService;
 import org.elasticsearch.common.settings.ClusterSettings;
 import org.elasticsearch.common.settings.Settings;
-import org.elasticsearch.common.xcontent.NamedXContentRegistry;
 import org.elasticsearch.common.xcontent.XContentType;
-import org.elasticsearch.search.SearchModule;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.xpack.core.ml.action.PutJobAction;
 import org.elasticsearch.xpack.core.ml.action.UpdateJobAction;
-import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider;
-import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
 import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig;
 import org.elasticsearch.xpack.core.ml.job.config.AnalysisLimits;
 import org.elasticsearch.xpack.core.ml.job.config.DataDescription;
@@ -181,12 +176,4 @@ public class AnomalyJobCRUDIT extends MlSingleNodeTestCase {
         return new AnalysisConfig.Builder(Collections.singletonList(detector.build()));
     }
 
-    @Override
-    public NamedXContentRegistry xContentRegistry() {
-        List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
-        namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers());
-        namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
-        namedXContent.addAll(new SearchModule(Settings.EMPTY, emptyList()).getNamedXContents());
-        return new NamedXContentRegistry(namedXContent);
-    }
 }

+ 216 - 0
x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/BucketCorrelationAggregationIT.java

@@ -0,0 +1,216 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.integration;
+
+import org.elasticsearch.action.DocWriteRequest;
+import org.elasticsearch.action.bulk.BulkItemResponse;
+import org.elasticsearch.action.bulk.BulkRequestBuilder;
+import org.elasticsearch.action.bulk.BulkResponse;
+import org.elasticsearch.action.index.IndexRequest;
+import org.elasticsearch.action.search.SearchResponse;
+import org.elasticsearch.action.support.WriteRequest;
+import org.elasticsearch.common.collect.Tuple;
+import org.elasticsearch.search.aggregations.AggregationBuilders;
+import org.elasticsearch.search.aggregations.bucket.range.RangeAggregationBuilder;
+import org.elasticsearch.search.aggregations.bucket.terms.Terms;
+import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
+import org.elasticsearch.search.aggregations.metrics.Percentiles;
+import org.elasticsearch.search.aggregations.pipeline.MovingFunctions;
+import org.elasticsearch.xpack.ml.MlSingleNodeTestCase;
+import org.elasticsearch.xpack.ml.aggs.correlation.BucketCorrelationAggregationBuilder;
+import org.elasticsearch.xpack.ml.aggs.correlation.CountCorrelationIndicator;
+import org.elasticsearch.xpack.ml.aggs.correlation.CountCorrelationFunction;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.atomic.AtomicLong;
+import java.util.stream.Stream;
+
+import static org.hamcrest.Matchers.closeTo;
+
+public class BucketCorrelationAggregationIT extends MlSingleNodeTestCase {
+
+    public void testCountCorrelation() {
+
+        double[] xs = new double[10000];
+        int[] isCat = new int[10000];
+        int[] isDog = new int[10000];
+
+        client().admin().indices().prepareCreate("data")
+            .setMapping("metric", "type=double", "term", "type=keyword")
+            .get();
+        BulkRequestBuilder bulkRequestBuilder = client().prepareBulk("data");
+        for (int i = 0; i < 5000; i++) {
+            IndexRequest indexRequest = new IndexRequest("data");
+            double x = randomDoubleBetween(100.0, 1000.0, true);
+            xs[i] = x;
+            isCat[i] = 1;
+            isDog[i] = 0;
+            indexRequest.source("metric", x, "term", "cat").opType(DocWriteRequest.OpType.CREATE);
+            bulkRequestBuilder.add(indexRequest);
+        }
+        sendAndMaybeFail(bulkRequestBuilder);
+        bulkRequestBuilder = client().prepareBulk("data");
+
+        for (int i = 5000; i < 10000; i++) {
+            IndexRequest indexRequest = new IndexRequest("data");
+            double x = randomDoubleBetween(0.0, 100.0, true);
+            xs[i] = x;
+            isCat[i] = 0;
+            isDog[i] = 1;
+            indexRequest.source("metric", x, "term", "dog").opType(DocWriteRequest.OpType.CREATE);
+            bulkRequestBuilder.add(indexRequest);
+        }
+        sendAndMaybeFail(bulkRequestBuilder);
+
+        double catCorrelation = pearsonCorrelation(xs, isCat);
+        double dogCorrelation = pearsonCorrelation(xs, isDog);
+
+        AtomicLong counter = new AtomicLong();
+        double[] steps = Stream.generate(() -> counter.getAndAdd(2L)).limit(50).mapToDouble(l -> (double)l).toArray();
+        SearchResponse percentilesSearch = client().prepareSearch("data")
+            .addAggregation(
+                AggregationBuilders
+                    .percentiles("percentiles")
+                    .field("metric")
+                    .percentiles(steps)
+            )
+            .setSize(0)
+            .setTrackTotalHits(true)
+            .get();
+        long totalHits = percentilesSearch.getHits().getTotalHits().value;
+        Percentiles percentiles = percentilesSearch.getAggregations().get("percentiles");
+        Tuple<RangeAggregationBuilder, BucketCorrelationAggregationBuilder> aggs = buildRangeAggAndSetExpectations(
+            percentiles,
+            steps,
+            totalHits,
+            "metric"
+        );
+
+        SearchResponse countCorrelations = client()
+            .prepareSearch("data")
+            .setSize(0)
+            .setTrackTotalHits(false)
+            .addAggregation(AggregationBuilders
+                .terms("buckets")
+                .field("term")
+                .subAggregation(aggs.v1())
+                .subAggregation(aggs.v2())
+            )
+            .get();
+
+        Terms terms = countCorrelations.getAggregations().get("buckets");
+        Terms.Bucket catBucket = terms.getBucketByKey("cat");
+        Terms.Bucket dogBucket = terms.getBucketByKey("dog");
+        NumericMetricsAggregation.SingleValue approxCatCorrelation = catBucket.getAggregations().get("correlates");
+        NumericMetricsAggregation.SingleValue approxDogCorrelation = dogBucket.getAggregations().get("correlates");
+
+        assertThat(approxCatCorrelation.value(), closeTo(catCorrelation, 0.1));
+        assertThat(approxDogCorrelation.value(), closeTo(dogCorrelation, 0.1));
+    }
+
+    private static Tuple<RangeAggregationBuilder, BucketCorrelationAggregationBuilder> buildRangeAggAndSetExpectations(
+        Percentiles raw_percentiles,
+        double[] steps,
+        long totalCount,
+        String indicatorFieldName
+    ) {
+        List<Double> percentiles = new ArrayList<>();
+        List<Double> fractions = new ArrayList<>();
+        RangeAggregationBuilder builder = AggregationBuilders.range("correlation_range").field(indicatorFieldName);
+        double percentile_0 = raw_percentiles.percentile(steps[0]);
+        builder.addUnboundedTo(percentile_0);
+        fractions.add(0.02);
+        percentiles.add(percentile_0);
+        int last_added = 0;
+        for (int i = 1; i < steps.length; i++) {
+            double percentile_l = raw_percentiles.percentile(steps[i - 1]);
+            double percentile_r = raw_percentiles.percentile(steps[i]);
+            if (Double.compare(percentile_l, percentile_r) == 0) {
+                fractions.set(last_added, fractions.get(last_added) + 0.02);
+            } else {
+                last_added = i;
+                fractions.add(0.02);
+                percentiles.add(percentile_r);
+            }
+        }
+        fractions.add(2.0/100);
+        double[] expectations = new double[percentiles.size() + 1];
+        expectations[0] = percentile_0;
+        for (int i = 1; i < percentiles.size(); i++) {
+            double percentile_l = percentiles.get(i - 1);
+            double percentile_r = percentiles.get(i);
+            double fractions_l = fractions.get(i - 1);
+            double fractions_r = fractions.get(i);
+            builder.addRange(percentile_l, percentile_r);
+            expectations[i] = (fractions_l * percentile_l + fractions_r * percentile_r) / (fractions_l + fractions_r);
+        }
+        double percentile_n = percentiles.get(percentiles.size() - 1);
+        builder.addUnboundedFrom(percentile_n);
+        expectations[percentiles.size()] = percentile_n;
+        return Tuple.tuple(
+            builder,
+            new BucketCorrelationAggregationBuilder(
+                "correlates",
+                "correlation_range>_count",
+                new CountCorrelationFunction(
+                    new CountCorrelationIndicator(expectations, fractions.stream().mapToDouble(Double::doubleValue).toArray(), totalCount)
+                )
+            )
+        );
+    }
+
+    private double pearsonCorrelation(double[] xs, int[] ys) {
+        double meanX = MovingFunctions.unweightedAvg(xs);
+        double meanY = sum(ys)/(double)ys.length;
+        double varX = Math.pow(MovingFunctions.stdDev(xs, meanX), 2.0);
+        double varY = 0.0;
+        for (int y : ys) {
+            varY += Math.pow(y - meanY, 2);
+        }
+        varY /= ys.length;
+
+        if (varY == 0 || varX == 0 || Double.isNaN(varX) || Double.isNaN(varY)) {
+            fail("failed to calculate true correlation due to 0 variance in the data");
+        }
+
+        double corXY = 0.0;
+        for (int i = 0; i < xs.length; i++) {
+            corXY += (((xs[i] - meanX)*(ys[i] - meanY))/Math.sqrt(varX*varY));
+        }
+        return corXY/xs.length;
+    }
+
+    private static int sum(int[] xs) {
+        int s = 0;
+        for (int x : xs) {
+            s += x;
+        }
+        return s;
+    }
+
+    private void sendAndMaybeFail(BulkRequestBuilder bulkRequestBuilder) {
+        BulkResponse bulkResponse = bulkRequestBuilder
+            .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
+            .get();
+        if (bulkResponse.hasFailures()) {
+            int failures = 0;
+            for (BulkItemResponse itemResponse : bulkResponse) {
+                if (itemResponse.isFailed()) {
+                    failures++;
+                }
+            }
+            logger.error("Item response failure [{}]", bulkResponse.buildFailureMessage());
+            fail("Bulk response contained " + failures + " failures");
+        }
+    }
+
+
+
+
+}

+ 0 - 11
x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java

@@ -11,7 +11,6 @@ import org.elasticsearch.Version;
 import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.collect.Tuple;
-import org.elasticsearch.common.xcontent.NamedXContentRegistry;
 import org.elasticsearch.license.License;
 import org.elasticsearch.xpack.core.action.util.PageParams;
 import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
@@ -19,7 +18,6 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
 import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
 import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
 import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
-import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
@@ -37,7 +35,6 @@ import org.elasticsearch.xpack.ml.extractor.DocValueField;
 import org.elasticsearch.xpack.ml.extractor.ExtractedField;
 import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
 import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata;
-import org.elasticsearch.xpack.ml.inference.modelsize.MlModelSizeNamedXContentProvider;
 import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfo;
 import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfoTests;
 import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
@@ -159,12 +156,4 @@ public class ChunkedTrainedModelPersisterIT extends MlSingleNodeTestCase {
         return subStrings;
     }
 
-    @Override
-    public NamedXContentRegistry xContentRegistry() {
-        List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
-        namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
-        namedXContent.addAll(new MlModelSizeNamedXContentProvider().getNamedXContentParsers());
-        return new NamedXContentRegistry(namedXContent);
-    }
-
 }

+ 0 - 16
x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/DataFrameAnalyticsCRUDIT.java

@@ -6,7 +6,6 @@
  */
 package org.elasticsearch.xpack.ml.integration;
 
-import static java.util.Collections.emptyList;
 import static java.util.Collections.emptyMap;
 import static org.hamcrest.CoreMatchers.nullValue;
 import static org.hamcrest.Matchers.equalTo;
@@ -14,25 +13,18 @@ import static org.hamcrest.Matchers.is;
 import static org.hamcrest.Matchers.notNullValue;
 import static org.hamcrest.core.IsInstanceOf.instanceOf;
 
-import java.util.ArrayList;
-import java.util.List;
 import java.util.concurrent.atomic.AtomicReference;
 
 import org.elasticsearch.ResourceNotFoundException;
 import org.elasticsearch.client.OriginSettingClient;
 import org.elasticsearch.cluster.service.ClusterService;
-import org.elasticsearch.common.settings.Settings;
-import org.elasticsearch.common.xcontent.NamedXContentRegistry;
 import org.elasticsearch.common.xcontent.XContentType;
 import org.elasticsearch.index.query.QueryBuilders;
-import org.elasticsearch.search.SearchModule;
 import org.elasticsearch.xpack.core.ClientHelper;
 import org.elasticsearch.xpack.core.ml.action.DeleteDataFrameAnalyticsAction;
 import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
 import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests;
-import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider;
 import org.elasticsearch.xpack.core.ml.dataframe.analyses.RegressionTests;
-import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
 import org.elasticsearch.xpack.ml.MlSingleNodeTestCase;
 import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider;
 import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
@@ -117,12 +109,4 @@ public class DataFrameAnalyticsCRUDIT extends MlSingleNodeTestCase {
             .value, equalTo(0L));
     }
 
-    @Override
-    public NamedXContentRegistry xContentRegistry() {
-        List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
-        namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers());
-        namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
-        namedXContent.addAll(new SearchModule(Settings.EMPTY, emptyList()).getNamedXContents());
-        return new NamedXContentRegistry(namedXContent);
-    }
 }

+ 0 - 16
x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/DataFrameAnalyticsConfigProviderIT.java

@@ -14,12 +14,9 @@ import org.elasticsearch.cluster.ClusterName;
 import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.cluster.metadata.Metadata;
 import org.elasticsearch.cluster.service.ClusterService;
-import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.unit.ByteSizeValue;
-import org.elasticsearch.common.xcontent.NamedXContentRegistry;
 import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
 import org.elasticsearch.rest.RestStatus;
-import org.elasticsearch.search.SearchModule;
 import org.elasticsearch.xpack.core.ml.MlTasks;
 import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction;
 import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
@@ -27,20 +24,15 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests;
 import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigUpdate;
 import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
 import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState;
-import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider;
-import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
 import org.elasticsearch.xpack.ml.MlSingleNodeTestCase;
 import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider;
 import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
 import org.junit.Before;
 
-import java.util.ArrayList;
 import java.util.Collections;
-import java.util.List;
 import java.util.Map;
 import java.util.concurrent.atomic.AtomicReference;
 
-import static java.util.Collections.emptyList;
 import static java.util.Collections.emptyMap;
 import static org.hamcrest.CoreMatchers.nullValue;
 import static org.hamcrest.Matchers.equalTo;
@@ -367,12 +359,4 @@ public class DataFrameAnalyticsConfigProviderIT extends MlSingleNodeTestCase {
             .build();
     }
 
-    @Override
-    public NamedXContentRegistry xContentRegistry() {
-        List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
-        namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers());
-        namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
-        namedXContent.addAll(new SearchModule(Settings.EMPTY, emptyList()).getNamedXContents());
-        return new NamedXContentRegistry(namedXContent);
-    }
 }

+ 0 - 8
x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/MlAutoUpdateServiceIT.java

@@ -15,12 +15,9 @@ import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.node.DiscoveryNodeRole;
 import org.elasticsearch.cluster.node.DiscoveryNodes;
-import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.transport.TransportAddress;
-import org.elasticsearch.common.xcontent.NamedXContentRegistry;
 import org.elasticsearch.common.xcontent.XContentType;
 import org.elasticsearch.indices.TestIndexNameExpressionResolver;
-import org.elasticsearch.search.SearchModule;
 import org.elasticsearch.xpack.core.ml.MlConfigIndex;
 import org.elasticsearch.xpack.core.ml.datafeed.DatafeedConfig;
 import org.elasticsearch.xpack.ml.MlAutoUpdateService;
@@ -120,9 +117,4 @@ public class MlAutoUpdateServiceIT extends MlSingleNodeTestCase {
         });
     }
 
-    @Override
-    public NamedXContentRegistry xContentRegistry() {
-        return new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents());
-    }
-
 }

+ 0 - 13
x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java

@@ -8,11 +8,7 @@ package org.elasticsearch.xpack.ml.integration;
 
 import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.Version;
-import org.elasticsearch.common.settings.Settings;
-import org.elasticsearch.common.xcontent.NamedXContentRegistry;
 import org.elasticsearch.license.License;
-import org.elasticsearch.search.SearchModule;
-import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
@@ -413,13 +409,4 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
             .build();
     }
 
-
-    @Override
-    public NamedXContentRegistry xContentRegistry() {
-        List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
-        namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
-        namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents());
-        return new NamedXContentRegistry(namedXContent);
-    }
-
 }

+ 0 - 14
x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java

@@ -11,15 +11,11 @@ import org.elasticsearch.action.admin.indices.refresh.RefreshResponse;
 import org.elasticsearch.action.delete.DeleteRequest;
 import org.elasticsearch.action.index.IndexResponse;
 import org.elasticsearch.action.support.WriteRequest;
-import org.elasticsearch.common.settings.Settings;
-import org.elasticsearch.common.xcontent.NamedXContentRegistry;
 import org.elasticsearch.common.xcontent.ToXContent;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentFactory;
 import org.elasticsearch.license.License;
-import org.elasticsearch.search.SearchModule;
 import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
-import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelInputTests;
@@ -32,7 +28,6 @@ import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDo
 import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
 import org.junit.Before;
 
-import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
 import java.util.concurrent.atomic.AtomicReference;
@@ -316,13 +311,4 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
         return buildTrainedModelConfigBuilder(modelId).build();
     }
 
-    @Override
-    public NamedXContentRegistry xContentRegistry() {
-        List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
-        namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
-        namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents());
-        return new NamedXContentRegistry(namedXContent);
-
-    }
-
 }

+ 9 - 11
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java

@@ -42,7 +42,6 @@ import org.elasticsearch.common.settings.SettingsModule;
 import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.common.unit.TimeValue;
 import org.elasticsearch.common.util.concurrent.EsExecutors;
-import org.elasticsearch.common.xcontent.ContextParser;
 import org.elasticsearch.common.xcontent.NamedXContentRegistry;
 import org.elasticsearch.env.Environment;
 import org.elasticsearch.env.NodeEnvironment;
@@ -69,7 +68,6 @@ import org.elasticsearch.repositories.RepositoriesService;
 import org.elasticsearch.rest.RestController;
 import org.elasticsearch.rest.RestHandler;
 import org.elasticsearch.script.ScriptService;
-import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
 import org.elasticsearch.threadpool.ExecutorBuilder;
 import org.elasticsearch.threadpool.ScalingExecutorBuilder;
 import org.elasticsearch.threadpool.ThreadPool;
@@ -237,6 +235,8 @@ import org.elasticsearch.xpack.ml.action.TransportUpdateProcessAction;
 import org.elasticsearch.xpack.ml.action.TransportUpgradeJobModelSnapshotAction;
 import org.elasticsearch.xpack.ml.action.TransportValidateDetectorAction;
 import org.elasticsearch.xpack.ml.action.TransportValidateJobConfigAction;
+import org.elasticsearch.xpack.ml.aggs.correlation.BucketCorrelationAggregationBuilder;
+import org.elasticsearch.xpack.ml.aggs.correlation.CorrelationNamedContentProvider;
 import org.elasticsearch.xpack.ml.annotations.AnnotationPersister;
 import org.elasticsearch.xpack.ml.autoscaling.MlAutoscalingDeciderService;
 import org.elasticsearch.xpack.ml.autoscaling.MlAutoscalingNamedWritableProvider;
@@ -256,8 +256,7 @@ import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
 import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult;
 import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata;
 import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService;
-import org.elasticsearch.xpack.ml.inference.aggs.InferencePipelineAggregationBuilder;
-import org.elasticsearch.xpack.ml.inference.aggs.InternalInferenceAggregation;
+import org.elasticsearch.xpack.ml.aggs.inference.InferencePipelineAggregationBuilder;
 import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
 import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
 import org.elasticsearch.xpack.ml.inference.modelsize.MlModelSizeNamedXContentProvider;
@@ -1086,13 +1085,10 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
 
     @Override
     public List<PipelineAggregationSpec> getPipelineAggregations() {
-        PipelineAggregationSpec spec = new PipelineAggregationSpec(InferencePipelineAggregationBuilder.NAME,
-            in -> new InferencePipelineAggregationBuilder(in, getLicenseState(), modelLoadingService),
-            (ContextParser<String, ? extends PipelineAggregationBuilder>)
-                (parser, name) -> InferencePipelineAggregationBuilder.parse(modelLoadingService, getLicenseState(), name, parser));
-        spec.addResultReader(InternalInferenceAggregation::new);
-
-        return Collections.singletonList(spec);
+        return Arrays.asList(
+            InferencePipelineAggregationBuilder.buildSpec(modelLoadingService, getLicenseState()),
+            BucketCorrelationAggregationBuilder.buildSpec()
+        );
     }
 
     @Override
@@ -1149,6 +1145,7 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
                 ModelAliasMetadata::fromXContent
             )
         );
+        namedXContent.addAll(new CorrelationNamedContentProvider().getNamedXContentParsers());
         return namedXContent;
     }
 
@@ -1186,6 +1183,7 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
         namedWriteables.addAll(MlEvaluationNamedXContentProvider.getNamedWriteables());
         namedWriteables.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables());
         namedWriteables.addAll(MlAutoscalingNamedWritableProvider.getNamedWriteables());
+        namedWriteables.addAll(new CorrelationNamedContentProvider().getNamedWriteables());
         return namedWriteables;
     }
 

+ 139 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/correlation/BucketCorrelationAggregationBuilder.java

@@ -0,0 +1,139 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.aggs.correlation;
+
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.plugins.SearchPlugin;
+import org.elasticsearch.search.aggregations.AggregationBuilder;
+import org.elasticsearch.search.aggregations.pipeline.AbstractPipelineAggregationBuilder;
+import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
+import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
+
+import java.io.IOException;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Optional;
+
+public class BucketCorrelationAggregationBuilder extends AbstractPipelineAggregationBuilder<BucketCorrelationAggregationBuilder> {
+
+    public static final ParseField NAME = new ParseField("bucket_correlation");
+    private static final ParseField FUNCTION = new ParseField("function");
+
+    @SuppressWarnings("unchecked")
+    public static final ConstructingObjectParser<BucketCorrelationAggregationBuilder, String> PARSER = new ConstructingObjectParser<>(
+        NAME.getPreferredName(),
+        false,
+        (args, context) -> new BucketCorrelationAggregationBuilder(
+            context,
+            (String)args[0],
+            (CorrelationFunction)args[1]
+        )
+    );
+    static {
+        PARSER.declareString(ConstructingObjectParser.constructorArg(), BUCKETS_PATH_FIELD);
+        PARSER.declareNamedObject(
+            ConstructingObjectParser.constructorArg(),
+            (p, c, n) -> p.namedObject(CorrelationFunction.class, n, null),
+            FUNCTION
+        );
+    }
+
+    public static SearchPlugin.PipelineAggregationSpec buildSpec() {
+        return new SearchPlugin.PipelineAggregationSpec(
+            NAME,
+            BucketCorrelationAggregationBuilder::new,
+            BucketCorrelationAggregationBuilder.PARSER
+        );
+    }
+
+    private final CorrelationFunction correlationFunction;
+
+    public BucketCorrelationAggregationBuilder(
+        String name,
+        String bucketsPath,
+        CorrelationFunction correlationFunction
+    ) {
+        super(
+            name,
+            NAME.getPreferredName(),
+            new String[] {bucketsPath}
+        );
+        this.correlationFunction = correlationFunction;
+    }
+
+    public BucketCorrelationAggregationBuilder(StreamInput in) throws IOException {
+        super(in, NAME.getPreferredName());
+        this.correlationFunction = in.readNamedWriteable(CorrelationFunction.class);
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME.getPreferredName();
+    }
+
+    @Override
+    protected void doWriteTo(StreamOutput out) throws IOException {
+        out.writeNamedWriteable(correlationFunction);
+    }
+
+    @Override
+    protected PipelineAggregator createInternal(Map<String, Object> metadata) {
+        return new BucketCorrelationAggregator(name, correlationFunction, bucketsPaths[0], metadata);
+    }
+
+    @Override
+    protected boolean overrideBucketsPath() {
+        return true;
+    }
+
+    @Override
+    protected XContentBuilder internalXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.field(BUCKETS_PATH_FIELD.getPreferredName(), bucketsPaths[0]);
+        NamedXContentObjectHelper.writeNamedObject(builder, params, FUNCTION.getPreferredName(), correlationFunction);
+        return builder;
+    }
+
+    @Override
+    protected void validate(ValidationContext context) {
+
+        final String firstAgg = bucketsPaths[0].split("[>\\.]")[0];
+        Optional<AggregationBuilder> aggBuilder = context.getSiblingAggregations().stream()
+            .filter(builder -> builder.getName().equals(firstAgg))
+            .findAny();
+        if (aggBuilder.isEmpty()) {
+            context.addBucketPathValidationError("aggregation does not exist for aggregation [" + name + "]: " + bucketsPaths[0]);
+            return;
+        }
+        AggregationBuilder aggregationBuilder = aggBuilder.get();
+        if (aggregationBuilder.bucketCardinality() != AggregationBuilder.BucketCardinality.MANY) {
+            context.addValidationError("The first aggregation in " + PipelineAggregator.Parser.BUCKETS_PATH.getPreferredName()
+                + " must be a multi-bucket aggregation for aggregation [" + name + "] found :"
+                + aggBuilder.get().getClass().getName() + " for buckets path: " + bucketsPaths[0]);
+            return;
+        }
+        correlationFunction.validate(context, bucketsPaths[0]);
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        if (super.equals(o) == false) return false;
+        BucketCorrelationAggregationBuilder that = (BucketCorrelationAggregationBuilder) o;
+        return Objects.equals(correlationFunction, that.correlationFunction);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(super.hashCode(), correlationFunction);
+    }
+}

+ 77 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/correlation/BucketCorrelationAggregator.java

@@ -0,0 +1,77 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.aggs.correlation;
+
+import org.elasticsearch.search.DocValueFormat;
+import org.elasticsearch.search.aggregations.Aggregation;
+import org.elasticsearch.search.aggregations.AggregationExecutionException;
+import org.elasticsearch.search.aggregations.Aggregations;
+import org.elasticsearch.search.aggregations.InternalAggregation;
+import org.elasticsearch.search.aggregations.InternalMultiBucketAggregation;
+import org.elasticsearch.search.aggregations.pipeline.BucketHelpers;
+import org.elasticsearch.search.aggregations.pipeline.InternalSimpleValue;
+import org.elasticsearch.search.aggregations.pipeline.SiblingPipelineAggregator;
+import org.elasticsearch.search.aggregations.support.AggregationPath;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+
+public class BucketCorrelationAggregator extends SiblingPipelineAggregator {
+
+    private final CorrelationFunction correlationFunction;
+
+    public BucketCorrelationAggregator(String name,
+                                       CorrelationFunction correlationFunction,
+                                       String bucketsPath,
+                                       Map<String, Object> metadata) {
+        super(name, new String[]{ bucketsPath }, metadata);
+        this.correlationFunction = correlationFunction;
+    }
+
+    @Override
+    public InternalAggregation doReduce(Aggregations aggregations, InternalAggregation.ReduceContext context) {
+        CountCorrelationIndicator bucketPathValue = null;
+        List<String> parsedPath = AggregationPath.parse(bucketsPaths()[0]).getPathElementsAsStringList();
+        for (Aggregation aggregation : aggregations) {
+            if (aggregation.getName().equals(parsedPath.get(0))) {
+                List<String> sublistedPath = parsedPath.subList(1, parsedPath.size());
+                InternalMultiBucketAggregation<?, ?> multiBucketsAgg = (InternalMultiBucketAggregation<?, ?>) aggregation;
+                List<? extends InternalMultiBucketAggregation.InternalBucket> buckets = multiBucketsAgg.getBuckets();
+                List<Double> values = new ArrayList<>(buckets.size());
+                long docCount = 0;
+                for (InternalMultiBucketAggregation.InternalBucket bucket : buckets) {
+                    Double bucketValue = BucketHelpers.resolveBucketValue(
+                        multiBucketsAgg,
+                        bucket,
+                        sublistedPath,
+                        BucketHelpers.GapPolicy.INSERT_ZEROS
+                    );
+                    if (bucketValue != null && Double.isNaN(bucketValue) == false) {
+                        values.add(bucketValue);
+                    }
+                    docCount += bucket.getDocCount();
+                }
+                bucketPathValue = new CountCorrelationIndicator(
+                    values.stream().mapToDouble(Double::doubleValue).toArray(),
+                    null,
+                    docCount
+                );
+                break;
+            }
+        }
+        if (bucketPathValue == null) {
+            throw new AggregationExecutionException(
+                "unable to find valid bucket values in path [" + bucketsPaths()[0] + "] for agg [" + name() + "]"
+            );
+        }
+
+        return new InternalSimpleValue(name(), correlationFunction.execute(bucketPathValue), DocValueFormat.RAW, metadata());
+    }
+
+}

+ 21 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/correlation/CorrelationFunction.java

@@ -0,0 +1,21 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.aggs.correlation;
+
+import org.elasticsearch.common.io.stream.NamedWriteable;
+import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
+import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
+
+
+public interface CorrelationFunction extends NamedWriteable, NamedXContentObject {
+
+    double execute(CountCorrelationIndicator y);
+
+    void validate(PipelineAggregationBuilder.ValidationContext context, String bucketPath);
+
+}

+ 39 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/correlation/CorrelationNamedContentProvider.java

@@ -0,0 +1,39 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.aggs.correlation;
+
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.plugins.spi.NamedXContentProvider;
+
+import java.util.Arrays;
+import java.util.List;
+
+public final class CorrelationNamedContentProvider implements NamedXContentProvider {
+
+    @Override
+    public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
+        return Arrays.asList(
+            new NamedXContentRegistry.Entry(
+                CorrelationFunction.class,
+                CountCorrelationFunction.NAME,
+                CountCorrelationFunction::fromXContent
+            )
+        );
+    }
+
+    public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
+        return Arrays.asList(
+            new NamedWriteableRegistry.Entry(
+                CorrelationFunction.class,
+                CountCorrelationFunction.NAME.getPreferredName(),
+                CountCorrelationFunction::new
+            )
+        );
+    }
+}

+ 179 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/correlation/CountCorrelationFunction.java

@@ -0,0 +1,179 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.aggs.correlation;
+
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.search.aggregations.AggregationExecutionException;
+import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
+import org.elasticsearch.search.aggregations.pipeline.MovingFunctions;
+
+import java.io.IOException;
+import java.util.Objects;
+
+public class CountCorrelationFunction implements CorrelationFunction {
+
+    public static final ParseField NAME = new ParseField("count_correlation");
+    public static final ParseField INDICATOR = new ParseField("indicator");
+
+    private static final ConstructingObjectParser<CountCorrelationFunction, Void> PARSER = new ConstructingObjectParser<>(
+        "count_correlation_function",
+        false,
+        a -> new CountCorrelationFunction((CountCorrelationIndicator)a[0])
+    );
+
+    static {
+        PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> CountCorrelationIndicator.fromXContent(p), INDICATOR);
+    }
+
+    private final CountCorrelationIndicator indicator;
+
+    public CountCorrelationFunction(CountCorrelationIndicator indicator) {
+        this.indicator = indicator;
+    }
+
+    public CountCorrelationFunction(StreamInput in) throws IOException {
+        this.indicator = new CountCorrelationIndicator(in);
+    }
+
+    public static CountCorrelationFunction fromXContent(XContentParser parser) {
+        return PARSER.apply(parser, null);
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        builder.field(INDICATOR.getPreferredName(), indicator);
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME.getPreferredName();
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        indicator.writeTo(out);
+    }
+
+    @Override
+    public String getName() {
+        return NAME.getPreferredName();
+    }
+
+    @Override
+    public int hashCode() {
+        return NAME.getPreferredName().hashCode();
+    }
+
+    @Override
+    public boolean equals(Object obj) {
+        if (this == obj) return true;
+        if (obj == null || getClass() != obj.getClass()) return false;
+        CountCorrelationFunction other = (CountCorrelationFunction) obj;
+        return Objects.equals(indicator, other.indicator);
+    }
+
+    /**
+     * This does an approximate Pearson's correlation with the stored indicator with the passed value `y`.
+     *
+     * This approximation makes many assumptions about the data distribution:
+     *
+     *  - That both the stored `indicator` and `y` are from the same distribution
+     *  - That `y` is effectively a queried subset of the `indicator`
+     *  - That the document count of `y` is always less than or equal to the document count of the `indicator`
+     * @param y the value with which to calculate correlation
+     * @return The correlation
+     */
+    @Override
+    public double execute(CountCorrelationIndicator y) {
+        if (indicator.getExpectations().length != y.getExpectations().length) {
+            throw new AggregationExecutionException(
+                "value lengths do not match; indicator.expectations ["
+                    + indicator.getExpectations().length
+                    + "] and number of buckets ["
+                    + y.getExpectations().length
+                    + "]. Unable to calculate correlation"
+            );
+        }
+        final double xMean;
+        final double xVar;
+        if (indicator.getFractions() == null) {
+            xMean = MovingFunctions.unweightedAvg(indicator.getExpectations());
+            if (Double.isNaN(xMean)) {
+                return Double.NaN;
+            }
+            double stdDev = MovingFunctions.stdDev(indicator.getExpectations(), xMean);
+            if (Double.isNaN(stdDev)) {
+                return Double.NaN;
+            }
+            xVar = Math.pow(stdDev, 2.0);
+        } else {
+            double mean = 0;
+            for (int i = 0; i < indicator.getExpectations().length; i++) {
+                mean += indicator.getExpectations()[i] * indicator.getFractions()[i];
+            }
+            if (Double.isNaN(mean)) {
+                return Double.NaN;
+            }
+            xMean = mean;
+            double var = 0;
+            for (int i = 0; i < indicator.getExpectations().length; i++) {
+                var += Math.pow(indicator.getExpectations()[i] - xMean, 2) * indicator.getFractions()[i];
+            }
+            xVar = var;
+        }
+        final double weight = MovingFunctions.sum(y.getExpectations())/indicator.getDocCount();
+        if (weight > 1.0) {
+            throw new AggregationExecutionException(
+                "doc_count of indicator must be larger than the total count of the correlating values indicator count ["
+                    + indicator.getDocCount()
+                    + "] correlating value total count ["
+                    + MovingFunctions.sum(y.getExpectations())
+                    + "]"
+            );
+        }
+        final double yMean = weight;
+        final double yVar = (1 - weight) * yMean * yMean + weight * (1 - yMean) * (1 - yMean);
+        double xyCov = 0;
+        if (indicator.getFractions() == null) {
+            final double fraction = 1.0 / indicator.getExpectations().length;
+            for (int i = 0; i < indicator.getExpectations().length; i++) {
+                final double xVal = indicator.getExpectations()[i];
+                final double nX = y.getExpectations()[i];
+                xyCov = xyCov
+                    - (indicator.getDocCount() * fraction - nX) * (xVal - xMean) * yMean
+                    + nX * (xVal - xMean) * (1 - yMean);
+            }
+        } else {
+            for (int i = 0; i < indicator.getExpectations().length; i++) {
+                final double fraction = indicator.getFractions()[i];
+                final double xVal = indicator.getExpectations()[i];
+                final double nX = y.getExpectations()[i];
+                xyCov = xyCov
+                    - (indicator.getDocCount() * fraction - nX) * (xVal - xMean) * yMean
+                    + nX * (xVal - xMean) * (1 - yMean);
+            }
+        }
+        xyCov /= indicator.getDocCount();
+        return (xVar * yVar == 0) ? Double.NaN : xyCov / Math.sqrt(xVar * yVar);
+    }
+
+    @Override
+    public void validate(PipelineAggregationBuilder.ValidationContext context, String bucketPath) {
+        if (bucketPath.endsWith("_count") == false) {
+            context.addBucketPathValidationError("count correlation requires that bucket_path points to bucket [_count]");
+        }
+    }
+}

+ 147 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/correlation/CountCorrelationIndicator.java

@@ -0,0 +1,147 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.aggs.correlation;
+
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ToXContentObject;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Objects;
+
+/**
+ * This contains values necessary for calculating the specific count correlation function.
+ */
+public class CountCorrelationIndicator implements Writeable, ToXContentObject {
+
+    private static final ParseField EXPECTATIONS = new ParseField("expectations");
+    private static final ParseField FRACTIONS = new ParseField("fractions");
+    private static final ParseField DOC_COUNT = new ParseField("doc_count");
+
+    @SuppressWarnings("unchecked")
+    private static final ConstructingObjectParser<CountCorrelationIndicator, Void> PARSER =
+        new ConstructingObjectParser<>(
+            "correlative_value",
+            a -> new CountCorrelationIndicator((List<Double>) a[0], (List<Double>) a[2], (Long) a[1])
+        );
+    static {
+        PARSER.declareDoubleArray(ConstructingObjectParser.constructorArg(), EXPECTATIONS);
+        PARSER.declareLong(ConstructingObjectParser.constructorArg(), DOC_COUNT);
+        PARSER.declareDoubleArray(ConstructingObjectParser.optionalConstructorArg(), FRACTIONS);
+    }
+
+    private final double[] expectations;
+    private final double[] fractions;
+    private final long docCount;
+    private CountCorrelationIndicator(List<Double> values, List<Double> fractions, long docCount) {
+        this(
+            values.stream().mapToDouble(Double::doubleValue).toArray(),
+            fractions == null ? null : fractions.stream().mapToDouble(Double::doubleValue).toArray(),
+            docCount
+        );
+    }
+
+    public CountCorrelationIndicator(double[] values, double[] fractions, long docCount) {
+        Objects.requireNonNull(values);
+        if (fractions != null) {
+            if (values.length != fractions.length) {
+                throw new IllegalArgumentException("[expectations] and [fractions] must have the same length");
+            }
+        }
+        if (docCount <= 0) {
+            throw new IllegalArgumentException("[doc_count] must be a positive value");
+        }
+        if (values.length < 2) {
+            throw new IllegalArgumentException("[expectations] must have a length of at least 2");
+        }
+        this.expectations = values;
+        this.fractions = fractions;
+        this.docCount = docCount;
+    }
+
+    public CountCorrelationIndicator(StreamInput in) throws IOException {
+        this.expectations = in.readDoubleArray();
+        this.fractions = in.readBoolean() ? in.readDoubleArray() : null;
+        this.docCount = in.readVLong();
+    }
+
+    public static CountCorrelationIndicator fromXContent(XContentParser parser) {
+        return PARSER.apply(parser, null);
+    }
+
+    /**
+     * @return The expectations with which to correlate
+     */
+    public double[] getExpectations() {
+        return expectations;
+    }
+
+    /**
+     * @return The fractions related to each specific expectation.
+     *         Useful for when there are gaps in the data and one expectation should be weighted higher than others
+     */
+    public double[] getFractions() {
+        return fractions;
+    }
+
+    /**
+     * @return The total doc_count contained in this indicator. Usually simply a sum of the expectations
+     */
+    public long getDocCount() {
+        return docCount;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        CountCorrelationIndicator that =
+            (CountCorrelationIndicator) o;
+        return docCount == that.docCount && Arrays.equals(expectations, that.expectations) && Arrays.equals(fractions, that.fractions);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(docCount, Arrays.hashCode(expectations), Arrays.hashCode(fractions));
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        builder.field(EXPECTATIONS.getPreferredName(), expectations);
+        if (fractions != null) {
+            builder.field(FRACTIONS.getPreferredName(), fractions);
+        }
+        builder.field(DOC_COUNT.getPreferredName(), docCount);
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeDoubleArray(expectations);
+        out.writeBoolean(fractions != null);
+        if (fractions != null) {
+            out.writeDoubleArray(fractions);
+        }
+        out.writeVLong(docCount);
+    }
+
+    @Override
+    public String toString() {
+        return Strings.toString(this, true, true);
+    }
+}

+ 15 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/aggs/InferencePipelineAggregationBuilder.java → x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/inference/InferencePipelineAggregationBuilder.java

@@ -5,7 +5,7 @@
  * 2.0.
  */
 
-package org.elasticsearch.xpack.ml.inference.aggs;
+package org.elasticsearch.xpack.ml.aggs.inference;
 
 import org.apache.lucene.util.SetOnce;
 import org.elasticsearch.action.ActionListener;
@@ -16,11 +16,14 @@ import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ContextParser;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.index.query.QueryRewriteContext;
 import org.elasticsearch.license.LicenseUtils;
 import org.elasticsearch.license.XPackLicenseState;
+import org.elasticsearch.plugins.SearchPlugin;
+import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
 import org.elasticsearch.search.aggregations.pipeline.AbstractPipelineAggregationBuilder;
 import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
 import org.elasticsearch.xpack.core.XPackField;
@@ -71,6 +74,17 @@ public class InferencePipelineAggregationBuilder extends AbstractPipelineAggrega
             (p, c, n) -> p.namedObject(InferenceConfigUpdate.class, n, c), INFERENCE_CONFIG);
     }
 
+    public static SearchPlugin.PipelineAggregationSpec buildSpec(SetOnce<ModelLoadingService> modelLoadingService,
+                                                                 XPackLicenseState xPackLicenseState) {
+        SearchPlugin.PipelineAggregationSpec spec = new SearchPlugin.PipelineAggregationSpec(InferencePipelineAggregationBuilder.NAME,
+            in -> new InferencePipelineAggregationBuilder(in, xPackLicenseState, modelLoadingService),
+            (ContextParser<String, ? extends PipelineAggregationBuilder>)
+                (parser, name) -> InferencePipelineAggregationBuilder.parse(modelLoadingService, xPackLicenseState, name, parser)
+        );
+        spec.addResultReader(InternalInferenceAggregation::new);
+        return spec;
+    }
+
     private final Map<String, String> bucketPathMap;
     private String modelId;
     private InferenceConfigUpdate inferenceConfig;

+ 1 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/aggs/InferencePipelineAggregator.java → x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/inference/InferencePipelineAggregator.java

@@ -5,7 +5,7 @@
  * 2.0.
  */
 
-package org.elasticsearch.xpack.ml.inference.aggs;
+package org.elasticsearch.xpack.ml.aggs.inference;
 
 import org.elasticsearch.search.aggregations.AggregationExecutionException;
 import org.elasticsearch.search.aggregations.InternalAggregation;

+ 1 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/aggs/InternalInferenceAggregation.java → x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/inference/InternalInferenceAggregation.java

@@ -5,7 +5,7 @@
  * 2.0.
  */
 
-package org.elasticsearch.xpack.ml.inference.aggs;
+package org.elasticsearch.xpack.ml.aggs.inference;
 
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;

+ 13 - 2
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlSingleNodeTestCase.java

@@ -31,11 +31,17 @@ import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.xpack.core.XPackSettings;
 import org.elasticsearch.xpack.core.ilm.LifecycleSettings;
 import org.elasticsearch.xpack.core.ml.MachineLearningField;
+import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider;
+import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
 import org.elasticsearch.xpack.datastreams.DataStreamsPlugin;
 import org.elasticsearch.xpack.ilm.IndexLifecycle;
+import org.elasticsearch.xpack.ml.aggs.correlation.CorrelationNamedContentProvider;
+import org.elasticsearch.xpack.ml.inference.modelsize.MlModelSizeNamedXContentProvider;
 
+import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Collections;
+import java.util.List;
 import java.util.Map;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.ExecutorService;
@@ -72,8 +78,13 @@ public abstract class MlSingleNodeTestCase extends ESSingleNodeTestCase {
 
     @Override
     protected NamedXContentRegistry xContentRegistry() {
-        SearchModule searchModule = new SearchModule(Settings.EMPTY, Collections.emptyList());
-        return new NamedXContentRegistry(searchModule.getNamedXContents());
+        List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
+        namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents());
+        namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers());
+        namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
+        namedXContent.addAll(new MlModelSizeNamedXContentProvider().getNamedXContentParsers());
+        namedXContent.addAll(new CorrelationNamedContentProvider().getNamedXContentParsers());
+        return new NamedXContentRegistry(namedXContent);
     }
 
     @Override

+ 97 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/correlation/BucketCorrelationAggregationBuilderTests.java

@@ -0,0 +1,97 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.aggs.correlation;
+
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.plugins.SearchPlugin;
+import org.elasticsearch.search.aggregations.AggregationBuilder;
+import org.elasticsearch.search.aggregations.BasePipelineAggregationTestCase;
+import org.elasticsearch.search.aggregations.bucket.global.GlobalAggregationBuilder;
+import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
+import org.elasticsearch.search.aggregations.support.ValueType;
+import org.elasticsearch.xpack.ml.MachineLearning;
+
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+import static org.hamcrest.Matchers.containsString;
+
+public class BucketCorrelationAggregationBuilderTests extends BasePipelineAggregationTestCase<BucketCorrelationAggregationBuilder> {
+
+    private static final String NAME = "correlation-agg";
+
+    @Override
+    protected List<SearchPlugin> plugins() {
+        return Collections.singletonList(new MachineLearning(Settings.EMPTY, null));
+    }
+
+    @Override
+    protected List<NamedXContentRegistry.Entry> additionalNamedContents() {
+        return new CorrelationNamedContentProvider().getNamedXContentParsers();
+    }
+
+    @Override
+    protected List<NamedWriteableRegistry.Entry> additionalNamedWriteables() {
+        return new CorrelationNamedContentProvider().getNamedWriteables();
+    }
+
+    @Override
+    protected BucketCorrelationAggregationBuilder createTestAggregatorFactory() {
+        List<String> bucketPaths = Stream.generate(() -> randomAlphaOfLength(8))
+            .limit(2)
+            .collect(Collectors.toList());
+
+        CorrelationFunction function = new CountCorrelationFunction(CountCorrelationIndicatorTests.randomInstance());
+        return new BucketCorrelationAggregationBuilder(
+            NAME,
+            randomAlphaOfLength(8),
+            function
+        );
+    }
+
+    public void testValidate() {
+        AggregationBuilder singleBucketAgg = new GlobalAggregationBuilder("global");
+        AggregationBuilder multiBucketAgg = new TermsAggregationBuilder("terms").userValueTypeHint(ValueType.STRING);
+        final Set<AggregationBuilder> aggBuilders = new HashSet<>();
+        aggBuilders.add(singleBucketAgg);
+        aggBuilders.add(multiBucketAgg);
+
+        // First try to point to a non-existent agg
+        assertThat(
+            validate(
+                aggBuilders,
+                new BucketCorrelationAggregationBuilder(
+                    NAME,
+                    "missing>metric",
+                    new CountCorrelationFunction(CountCorrelationIndicatorTests.randomInstance())
+                )
+            ),
+            containsString("aggregation does not exist for aggregation")
+        );
+
+        // Now validate with a single bucket agg
+        assertThat(
+            validate(
+                aggBuilders,
+                new BucketCorrelationAggregationBuilder(
+                    NAME,
+                    "global>metric",
+                    new CountCorrelationFunction(CountCorrelationIndicatorTests.randomInstance())
+                )
+            ),
+            containsString("must be a multi-bucket aggregation for aggregation")
+        );
+    }
+
+}

+ 71 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/correlation/CountCorrelationFunctionTests.java

@@ -0,0 +1,71 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.aggs.correlation;
+
+import org.elasticsearch.search.aggregations.AggregationBuilder;
+import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
+import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
+import org.elasticsearch.search.aggregations.support.ValueType;
+import org.elasticsearch.test.ESTestCase;
+
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.Set;
+import java.util.concurrent.atomic.AtomicLong;
+import java.util.stream.Stream;
+
+import static org.hamcrest.Matchers.allOf;
+import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.greaterThan;
+import static org.hamcrest.Matchers.lessThan;
+
+public class CountCorrelationFunctionTests extends ESTestCase {
+
+    public void testExecute() {
+        AtomicLong xs = new AtomicLong(1);
+        CountCorrelationIndicator x = new CountCorrelationIndicator(
+            Stream.generate(xs::incrementAndGet)
+                .limit(100)
+                .mapToDouble(l -> (double)l).toArray(),
+            null,
+            1000
+        );
+        CountCorrelationFunction countCorrelationFunction = new CountCorrelationFunction(x);
+        AtomicLong ys = new AtomicLong(0);
+        CountCorrelationIndicator yValues = new CountCorrelationIndicator(
+            Stream.generate(() -> Math.min(ys.incrementAndGet(), 10)).limit(100).mapToDouble(l -> (double)l).toArray(),
+            x.getFractions(),
+           1000
+        );
+        double value = countCorrelationFunction.execute(yValues);
+        assertThat(value, greaterThan(0.0));
+
+        AtomicLong otherYs = new AtomicLong(0);
+        CountCorrelationIndicator lesserYValues = new CountCorrelationIndicator(
+            Stream.generate(() -> Math.min(otherYs.incrementAndGet(), 5)).limit(100).mapToDouble(l -> (double)l).toArray(),
+            x.getFractions(),
+            1000
+        );
+        assertThat(countCorrelationFunction.execute(lesserYValues), allOf(lessThan(value), greaterThan(0.0)));
+    }
+
+    public void testValidation() {
+        AggregationBuilder multiBucketAgg = new TermsAggregationBuilder("terms").userValueTypeHint(ValueType.STRING);
+        final Set<AggregationBuilder> aggBuilders = new HashSet<>();
+        aggBuilders.add(multiBucketAgg);
+        CountCorrelationFunction function = new CountCorrelationFunction(CountCorrelationIndicatorTests.randomInstance());
+        PipelineAggregationBuilder.ValidationContext validationContext =
+            PipelineAggregationBuilder.ValidationContext.forTreeRoot(aggBuilders, Collections.emptyList(), null);
+        function.validate(validationContext, "terms>metric_agg");
+
+        assertThat(
+            validationContext.getValidationException().getMessage(),
+            containsString("count correlation requires that bucket_path points to bucket [_count]")
+        );
+    }
+}

+ 44 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/correlation/CountCorrelationIndicatorTests.java

@@ -0,0 +1,44 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.aggs.correlation;
+
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.test.AbstractSerializingTestCase;
+import org.elasticsearch.test.ESTestCase;
+
+import java.io.IOException;
+import java.util.stream.Stream;
+
+public class CountCorrelationIndicatorTests extends AbstractSerializingTestCase<CountCorrelationIndicator> {
+
+    public static CountCorrelationIndicator randomInstance() {
+        double[] expectations = Stream.generate(ESTestCase::randomDouble)
+            .limit(randomIntBetween(5, 100))
+            .mapToDouble(Double::doubleValue).toArray();
+        double[] fractions = Stream.generate(ESTestCase::randomDouble)
+            .limit(expectations.length)
+            .mapToDouble(Double::doubleValue).toArray();
+        return new CountCorrelationIndicator(expectations, randomBoolean() ? null : fractions, randomLongBetween(1, Long.MAX_VALUE - 1));
+    }
+
+    @Override
+    protected CountCorrelationIndicator doParseInstance(XContentParser parser) throws IOException {
+        return CountCorrelationIndicator.fromXContent(parser);
+    }
+
+    @Override
+    protected Writeable.Reader<CountCorrelationIndicator> instanceReader() {
+        return CountCorrelationIndicator::new;
+    }
+
+    @Override
+    protected CountCorrelationIndicator createTestInstance() {
+        return randomInstance();
+    }
+}

+ 2 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/aggs/InferencePipelineAggregationBuilderTests.java → x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/inference/InferencePipelineAggregationBuilderTests.java

@@ -5,7 +5,7 @@
  * 2.0.
  */
 
-package org.elasticsearch.xpack.ml.inference.aggs;
+package org.elasticsearch.xpack.ml.aggs.inference;
 
 import org.apache.lucene.util.SetOnce;
 import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
@@ -24,6 +24,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUp
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdateTests;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ResultsFieldUpdate;
 import org.elasticsearch.xpack.ml.MachineLearning;
+import org.elasticsearch.xpack.ml.aggs.inference.InferencePipelineAggregationBuilder;
 import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
 
 import java.util.Collections;

+ 1 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/aggs/InternalInferenceAggregationTests.java → x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/inference/InternalInferenceAggregationTests.java

@@ -5,7 +5,7 @@
  * 2.0.
  */
 
-package org.elasticsearch.xpack.ml.inference.aggs;
+package org.elasticsearch.xpack.ml.aggs.inference;
 
 import org.elasticsearch.common.ParseField;
 import org.elasticsearch.common.settings.Settings;

+ 1 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/aggs/ParsedInference.java → x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/inference/ParsedInference.java

@@ -5,7 +5,7 @@
  * 2.0.
  */
 
-package org.elasticsearch.xpack.ml.inference.aggs;
+package org.elasticsearch.xpack.ml.aggs.inference;
 
 import org.elasticsearch.common.ParseField;
 import org.elasticsearch.common.xcontent.ConstructingObjectParser;

+ 299 - 0
x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/bucket_correlation_agg.yml

@@ -0,0 +1,299 @@
+setup:
+  - skip:
+      features: headers
+  - do:
+      headers:
+        Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
+      indices.create:
+        index: store
+        body:
+          mappings:
+            properties:
+              product:
+                type: keyword
+              cost:
+                type: integer
+              time:
+                type: date
+
+  - do:
+      headers:
+        Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
+        Content-Type: application/json
+      bulk:
+        index: store
+        refresh: true
+        body: |
+          { "index": {} }
+          { "product": "TV", "cost": 200, "time": 1587501233000 }
+          { "index": {} }
+          { "product": "TV", "cost": 400, "time": 1587501233000}
+          { "index": {} }
+          { "product": "TV", "cost": 600, "time": 1587501233000}
+          { "index": {} }
+          { "product": "VCR", "cost": 150, "time": 1587501233000 }
+          { "index": {} }
+          { "product": "VCR", "cost": 350, "time": 1587501233000 }
+          { "index": {} }
+          { "product": "VCR", "cost": 580, "time": 1587501233000 }
+          { "index": {} }
+          { "product": "Laptop", "cost": 100, "time": 1587501233000 }
+          { "index": {} }
+          { "product": "Laptop", "cost": 300, "time": 1587501233000 }
+          { "index": {} }
+          { "product": "Laptop", "cost": 500, "time": 1587501233000 }
+
+---
+"Test correlation bucket agg simple":
+
+  - do:
+      search:
+        index: store
+        body: >
+          {
+            "size": 0,
+            "aggs": {
+              "good": {
+                "terms": {
+                  "field": "product",
+                  "size": 10
+                },
+                "aggs": {
+                  "ranged_cost": {
+                    "range": {
+                      "field": "cost",
+                      "ranges": [
+                        {"from": 200},
+                        {"from": 200, "to": 300},
+                        {"from": 300}
+                      ]
+                    }
+                  },
+                  "bucket_correlation": {
+                    "bucket_correlation": {
+                      "buckets_path": "ranged_cost>_count",
+                      "function": { "count_correlation": {
+                          "indicator": {
+                             "expectations": [3, 4, 2],
+                             "doc_count": 9
+                          }
+                        }
+                      }
+                    }
+                  }
+                }
+              }
+            }
+          }
+  - is_true: aggregations.good.buckets.0.bucket_correlation.value
+  - is_true: aggregations.good.buckets.1.bucket_correlation.value
+  - is_true: aggregations.good.buckets.2.bucket_correlation.value
+
+---
+"Test correlation with missing buckets_path":
+
+  - do:
+      catch: /Required \[buckets_path\]/
+      search:
+        index: store
+        body: >
+          {
+            "size": 0,
+            "aggs": {
+              "good": {
+                "terms": {
+                  "field": "product",
+                  "size": 10
+                },
+                "aggs": {
+                  "ranged_cost": {
+                    "range": {
+                      "field": "cost",
+                      "ranges": [
+                        {"from": 200},
+                        {"from": 200, "to": 300},
+                        {"from": 300}
+                      ]
+                    }
+                  },
+                  "bucket_correlation": {
+                    "bucket_correlation": {
+                      "function": { "count_correlation": {
+                          "indicator": {
+                             "expectations": [3, 4, 2],
+                             "doc_count": 9
+                          }
+                        }
+                      }
+                    }
+                  }
+                }
+              }
+            }
+          }
+
+---
+"Test correlation with missing function":
+
+  - do:
+      catch: /Required \[function\]/
+      search:
+        index: store
+        body: >
+          {
+            "size": 0,
+            "aggs": {
+              "good": {
+                "terms": {
+                  "field": "product",
+                  "size": 10
+                },
+                "aggs": {
+                  "ranged_cost": {
+                    "range": {
+                      "field": "cost",
+                      "ranges": [
+                        {"from": 200},
+                        {"from": 200, "to": 300},
+                        {"from": 300}
+                      ]
+                    }
+                  },
+                  "bucket_correlation": {
+                    "bucket_correlation": {
+                      "buckets_path": "ranged_cost>_count"
+                    }
+                  }
+                }
+              }
+            }
+          }
+
+
+---
+"Test correlation with pointing to missing agg":
+  - do:
+      catch: /No aggregation found for path \[missing>_count\]/
+      search:
+        index: store
+        body: >
+          {
+            "size": 0,
+            "aggs": {
+              "good": {
+                "terms": {
+                  "field": "product",
+                  "size": 10
+                },
+                "aggs": {
+                  "ranged_cost": {
+                    "range": {
+                      "field": "cost",
+                      "ranges": [
+                        {"from": 200},
+                        {"from": 200, "to": 300},
+                        {"from": 300}
+                      ]
+                    }
+                  },
+                  "bucket_correlation": {
+                    "bucket_correlation": {
+                      "buckets_path": "missing>_count",
+                      "function": { "count_correlation": {
+                          "indicator": {
+                            "expectations": [3, 4, 2],
+                            "doc_count": 9
+                          }
+                        }
+                      }
+                    }
+                  }
+                }
+              }
+            }
+          }
+
+---
+"Test correlation with pointing to mismatched lengths":
+  - do:
+      catch: /value lengths do not match; indicator.expectations \[4\] and number of buckets \[3\]. Unable to calculate correlation/
+      search:
+        index: store
+        body: >
+          {
+            "size": 0,
+            "aggs": {
+              "good": {
+                "terms": {
+                  "field": "product",
+                  "size": 10
+                },
+                "aggs": {
+                  "ranged_cost": {
+                    "range": {
+                      "field": "cost",
+                      "ranges": [
+                        {"from": 200},
+                        {"from": 200, "to": 300},
+                        {"from": 300}
+                      ]
+                    }
+                  },
+                  "bucket_correlation": {
+                    "bucket_correlation": {
+                      "buckets_path": "ranged_cost>_count",
+                      "function": { "count_correlation": {
+                          "indicator": {
+                            "expectations": [3, 4, 2, 10],
+                            "doc_count": 9
+                          }
+                        }
+                      }
+                    }
+                  }
+                }
+              }
+            }
+          }
+
+  - do:
+      catch: /value lengths do not match; indicator.expectations \[3\] and number of buckets \[4\]. Unable to calculate correlation/
+      search:
+        index: store
+        body: >
+          {
+            "size": 0,
+            "aggs": {
+              "good": {
+                "terms": {
+                  "field": "product",
+                  "size": 10
+                },
+                "aggs": {
+                  "ranged_cost": {
+                    "range": {
+                      "field": "cost",
+                      "ranges": [
+                        {"from": 200},
+                        {"from": 200, "to": 300},
+                        {"from": 300, "to": 400},
+                        {"from": 400}
+                      ]
+                    }
+                  },
+                  "bucket_correlation": {
+                    "bucket_correlation": {
+                      "buckets_path": "ranged_cost>_count",
+                      "function": { "count_correlation": {
+                          "indicator": {
+                            "expectations": [3, 4, 2],
+                            "doc_count": 9
+                          }
+                        }
+                      }
+                    }
+                  }
+                }
+              }
+            }
+          }