Преглед на файлове

Allow evaluation to consist of multiple steps. (#46653)

This is groundwork for introducing classification evaluation which actually needs multistep evaluation.
Przemysław Witek преди 6 години
родител
ревизия
41d82f658e
променени са 21 файла, в които са добавени 335 реда и са изтрити 175 реда
  1. 15 4
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameAction.java
  2. 48 7
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/Evaluation.java
  3. 28 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetric.java
  4. 2 2
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetricResult.java
  5. 17 11
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredError.java
  6. 17 11
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquared.java
  7. 17 25
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java
  8. 6 13
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionMetric.java
  9. 22 4
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AbstractConfusionMatrixMetric.java
  10. 22 8
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AucRoc.java
  11. 23 29
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java
  12. 2 2
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ConfusionMatrix.java
  13. 1 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Precision.java
  14. 2 2
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Recall.java
  15. 1 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ScoreByThresholdResult.java
  16. 6 13
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/SoftClassificationMetric.java
  17. 5 2
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredErrorTests.java
  18. 25 11
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquaredTests.java
  19. 10 10
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java
  20. 8 2
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassificationTests.java
  21. 58 17
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportEvaluateDataFrameAction.java

+ 15 - 4
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameAction.java

@@ -105,28 +105,31 @@ public class EvaluateDataFrameAction extends ActionType<EvaluateDataFrameAction.
             return indices;
         }
 
-        public final void setIndices(List<String> indices) {
+        public final Request setIndices(List<String> indices) {
             ExceptionsHelper.requireNonNull(indices, INDEX);
             if (indices.isEmpty()) {
                 throw ExceptionsHelper.badRequestException("At least one index must be specified");
             }
             this.indices = indices.toArray(new String[indices.size()]);
+            return this;
         }
 
         public QueryBuilder getParsedQuery() {
             return Optional.ofNullable(queryProvider).orElseGet(QueryProvider::defaultQuery).getParsedQuery();
         }
 
-        public final void setQueryProvider(QueryProvider queryProvider) {
+        public final Request setQueryProvider(QueryProvider queryProvider) {
             this.queryProvider = queryProvider;
+            return this;
         }
 
         public Evaluation getEvaluation() {
             return evaluation;
         }
 
-        public final void setEvaluation(Evaluation evaluation) {
+        public final Request setEvaluation(Evaluation evaluation) {
             this.evaluation = ExceptionsHelper.requireNonNull(evaluation, EVALUATION);
+            return this;
         }
 
         @Override
@@ -203,6 +206,14 @@ public class EvaluateDataFrameAction extends ActionType<EvaluateDataFrameAction.
             this.metrics = Objects.requireNonNull(metrics);
         }
 
+        public String getEvaluationName() {
+            return evaluationName;
+        }
+
+        public List<EvaluationMetricResult> getMetrics() {
+            return metrics;
+        }
+
         @Override
         public void writeTo(StreamOutput out) throws IOException {
             out.writeString(evaluationName);
@@ -214,7 +225,7 @@ public class EvaluateDataFrameAction extends ActionType<EvaluateDataFrameAction.
             builder.startObject();
             builder.startObject(evaluationName);
             for (EvaluationMetricResult metric : metrics) {
-                builder.field(metric.getName(), metric);
+                builder.field(metric.getMetricName(), metric);
             }
             builder.endObject();
             builder.endObject();

+ 48 - 7
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/Evaluation.java

@@ -5,14 +5,17 @@
  */
 package org.elasticsearch.xpack.core.ml.dataframe.evaluation;
 
-import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.search.SearchResponse;
 import org.elasticsearch.common.io.stream.NamedWriteable;
 import org.elasticsearch.common.xcontent.ToXContentObject;
+import org.elasticsearch.index.query.BoolQueryBuilder;
 import org.elasticsearch.index.query.QueryBuilder;
+import org.elasticsearch.index.query.QueryBuilders;
 import org.elasticsearch.search.builder.SearchSourceBuilder;
 
 import java.util.List;
+import java.util.Optional;
+import java.util.stream.Collectors;
 
 /**
  * Defines an evaluation
@@ -24,16 +27,54 @@ public interface Evaluation extends ToXContentObject, NamedWriteable {
      */
     String getName();
 
+    /**
+     * Returns the list of metrics to evaluate
+     * @return list of metrics to evaluate
+     */
+    List<? extends EvaluationMetric> getMetrics();
+
     /**
      * Builds the search required to collect data to compute the evaluation result
-     * @param queryBuilder User-provided query that must be respected when collecting data
+     * @param userProvidedQueryBuilder User-provided query that must be respected when collecting data
+     */
+    SearchSourceBuilder buildSearch(QueryBuilder userProvidedQueryBuilder);
+
+    /**
+     * Builds the search that verifies existence of required fields and applies user-provided query
+     * @param requiredFields fields that must exist
+     * @param userProvidedQueryBuilder user-provided query
+     */
+    default SearchSourceBuilder newSearchSourceBuilder(List<String> requiredFields, QueryBuilder userProvidedQueryBuilder) {
+        BoolQueryBuilder boolQuery = QueryBuilders.boolQuery();
+        for (String requiredField : requiredFields) {
+            boolQuery.filter(QueryBuilders.existsQuery(requiredField));
+        }
+        boolQuery.filter(userProvidedQueryBuilder);
+        return new SearchSourceBuilder().size(0).query(boolQuery);
+    }
+
+    /**
+     * Processes {@link SearchResponse} from the search action
+     * @param searchResponse response from the search action
+     */
+    void process(SearchResponse searchResponse);
+
+    /**
+     * @return true iff all the metrics have their results computed
      */
-    SearchSourceBuilder buildSearch(QueryBuilder queryBuilder);
+    default boolean hasAllResults() {
+        return getMetrics().stream().map(EvaluationMetric::getResult).allMatch(Optional::isPresent);
+    }
 
     /**
-     * Computes the evaluation result
-     * @param searchResponse The search response required to compute the result
-     * @param listener A listener of the results
+     * Returns the list of evaluation results
+     * @return list of evaluation results
      */
-    void evaluate(SearchResponse searchResponse, ActionListener<List<EvaluationMetricResult>> listener);
+    default List<EvaluationMetricResult> getResults() {
+        return getMetrics().stream()
+            .map(EvaluationMetric::getResult)
+            .filter(Optional::isPresent)
+            .map(Optional::get)
+            .collect(Collectors.toList());
+    }
 }

+ 28 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetric.java

@@ -0,0 +1,28 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.dataframe.evaluation;
+
+import org.elasticsearch.common.io.stream.NamedWriteable;
+import org.elasticsearch.common.xcontent.ToXContentObject;
+
+import java.util.Optional;
+
+/**
+ * {@link EvaluationMetric} class represents a metric to evaluate.
+ */
+public interface EvaluationMetric extends ToXContentObject, NamedWriteable {
+
+    /**
+     * Returns the name of the metric (which may differ to the writeable name)
+     */
+    String getName();
+
+    /**
+     * Gets the evaluation result for this metric.
+     * @return {@code Optional.empty()} if the result is not available yet, {@code Optional.of(result)} otherwise
+     */
+    Optional<EvaluationMetricResult> getResult();
+}

+ 2 - 2
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetricResult.java

@@ -14,7 +14,7 @@ import org.elasticsearch.common.xcontent.ToXContentObject;
 public interface EvaluationMetricResult extends ToXContentObject, NamedWriteable {
 
     /**
-     * Returns the name of the metric
+     * Returns the name of the metric (which may differ to the writeable name)
      */
-    String getName();
+    String getMetricName();
 }

+ 17 - 11
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredError.java

@@ -20,10 +20,10 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResu
 
 import java.io.IOException;
 import java.text.MessageFormat;
-import java.util.Collections;
 import java.util.List;
 import java.util.Locale;
 import java.util.Objects;
+import java.util.Optional;
 
 /**
  * Calculates the mean squared error between two known numerical fields.
@@ -48,28 +48,34 @@ public class MeanSquaredError implements RegressionMetric {
         return PARSER.apply(parser, null);
     }
 
-    public MeanSquaredError(StreamInput in) {
+    private EvaluationMetricResult result;
 
-    }
-
-    public MeanSquaredError() {
+    public MeanSquaredError(StreamInput in) {}
 
-    }
+    public MeanSquaredError() {}
 
     @Override
-    public String getMetricName() {
+    public String getName() {
         return NAME.getPreferredName();
     }
 
     @Override
     public List<AggregationBuilder> aggs(String actualField, String predictedField) {
-        return Collections.singletonList(AggregationBuilders.avg(AGG_NAME).script(new Script(buildScript(actualField, predictedField))));
+        if (result != null) {
+            return List.of();
+        }
+        return List.of(AggregationBuilders.avg(AGG_NAME).script(new Script(buildScript(actualField, predictedField))));
     }
 
     @Override
-    public EvaluationMetricResult evaluate(Aggregations aggs) {
+    public void process(Aggregations aggs) {
         NumericMetricsAggregation.SingleValue value = aggs.get(AGG_NAME);
-        return value == null ? new Result(0.0) : new Result(value.value());
+        result = value == null ? new Result(0.0) : new Result(value.value());
+    }
+
+    @Override
+    public Optional<EvaluationMetricResult> getResult() {
+        return Optional.ofNullable(result);
     }
 
     @Override
@@ -121,7 +127,7 @@ public class MeanSquaredError implements RegressionMetric {
         }
 
         @Override
-        public String getName() {
+        public String getMetricName() {
             return NAME.getPreferredName();
         }
 

+ 17 - 11
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquared.java

@@ -22,10 +22,10 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResu
 
 import java.io.IOException;
 import java.text.MessageFormat;
-import java.util.Arrays;
 import java.util.List;
 import java.util.Locale;
 import java.util.Objects;
+import java.util.Optional;
 
 /**
  * Calculates R-Squared between two known numerical fields.
@@ -53,36 +53,42 @@ public class RSquared implements RegressionMetric {
         return PARSER.apply(parser, null);
     }
 
-    public RSquared(StreamInput in) {
+    private EvaluationMetricResult result;
 
-    }
-
-    public RSquared() {
+    public RSquared(StreamInput in) {}
 
-    }
+    public RSquared() {}
 
     @Override
-    public String getMetricName() {
+    public String getName() {
         return NAME.getPreferredName();
     }
 
     @Override
     public List<AggregationBuilder> aggs(String actualField, String predictedField) {
-        return Arrays.asList(
+        if (result != null) {
+            return List.of();
+        }
+        return List.of(
             AggregationBuilders.sum(SS_RES).script(new Script(buildScript(actualField, predictedField))),
             AggregationBuilders.extendedStats(ExtendedStatsAggregationBuilder.NAME + "_actual").field(actualField));
     }
 
     @Override
-    public EvaluationMetricResult evaluate(Aggregations aggs) {
+    public void process(Aggregations aggs) {
         NumericMetricsAggregation.SingleValue residualSumOfSquares = aggs.get(SS_RES);
         ExtendedStats extendedStats = aggs.get(ExtendedStatsAggregationBuilder.NAME + "_actual");
         // extendedStats.getVariance() is the statistical sumOfSquares divided by count
-        return residualSumOfSquares == null || extendedStats == null || extendedStats.getCount() == 0 ?
+        result = residualSumOfSquares == null || extendedStats == null || extendedStats.getCount() == 0 ?
             new Result(0.0) :
             new Result(1 - (residualSumOfSquares.value() / (extendedStats.getVariance() * extendedStats.getCount())));
     }
 
+    @Override
+    public Optional<EvaluationMetricResult> getResult() {
+        return Optional.ofNullable(result);
+    }
+
     @Override
     public String getWriteableName() {
         return NAME.getPreferredName();
@@ -132,7 +138,7 @@ public class RSquared implements RegressionMetric {
         }
 
         @Override
-        public String getName() {
+        public String getMetricName() {
             return NAME.getPreferredName();
         }
 

+ 17 - 25
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java

@@ -5,7 +5,6 @@
  */
 package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression;
 
-import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.search.SearchResponse;
 import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.ParseField;
@@ -14,17 +13,15 @@ import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.xcontent.ConstructingObjectParser;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentParser;
-import org.elasticsearch.index.query.BoolQueryBuilder;
 import org.elasticsearch.index.query.QueryBuilder;
-import org.elasticsearch.index.query.QueryBuilders;
 import org.elasticsearch.search.aggregations.AggregationBuilder;
 import org.elasticsearch.search.builder.SearchSourceBuilder;
-import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
-import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 
 import java.io.IOException;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collections;
 import java.util.Comparator;
 import java.util.List;
@@ -86,19 +83,16 @@ public class Regression implements Evaluation {
     }
 
     private static List<RegressionMetric> initMetrics(@Nullable List<RegressionMetric> parsedMetrics) {
-        List<RegressionMetric> metrics = parsedMetrics == null ? defaultMetrics() : parsedMetrics;
+        List<RegressionMetric> metrics = parsedMetrics == null ? defaultMetrics() : new ArrayList<>(parsedMetrics);
         if (metrics.isEmpty()) {
             throw ExceptionsHelper.badRequestException("[{}] must have one or more metrics", NAME.getPreferredName());
         }
-        Collections.sort(metrics, Comparator.comparing(RegressionMetric::getMetricName));
+        Collections.sort(metrics, Comparator.comparing(RegressionMetric::getName));
         return metrics;
     }
 
     private static List<RegressionMetric> defaultMetrics() {
-        List<RegressionMetric> defaultMetrics = new ArrayList<>(2);
-        defaultMetrics.add(new MeanSquaredError());
-        defaultMetrics.add(new RSquared());
-        return defaultMetrics;
+        return Arrays.asList(new MeanSquaredError(), new RSquared());
     }
 
     @Override
@@ -107,12 +101,14 @@ public class Regression implements Evaluation {
     }
 
     @Override
-    public SearchSourceBuilder buildSearch(QueryBuilder queryBuilder) {
-        BoolQueryBuilder boolQuery = QueryBuilders.boolQuery()
-            .filter(QueryBuilders.existsQuery(actualField))
-            .filter(QueryBuilders.existsQuery(predictedField))
-            .filter(queryBuilder);
-        SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).query(boolQuery);
+    public List<RegressionMetric> getMetrics() {
+        return metrics;
+    }
+
+    @Override
+    public SearchSourceBuilder buildSearch(QueryBuilder userProvidedQueryBuilder) {
+        ExceptionsHelper.requireNonNull(userProvidedQueryBuilder, "userProvidedQueryBuilder");
+        SearchSourceBuilder searchSourceBuilder = newSearchSourceBuilder(List.of(actualField, predictedField), userProvidedQueryBuilder);
         for (RegressionMetric metric : metrics) {
             List<AggregationBuilder> aggs = metric.aggs(actualField, predictedField);
             aggs.forEach(searchSourceBuilder::aggregation);
@@ -121,18 +117,14 @@ public class Regression implements Evaluation {
     }
 
     @Override
-    public void evaluate(SearchResponse searchResponse, ActionListener<List<EvaluationMetricResult>> listener) {
-        List<EvaluationMetricResult> results = new ArrayList<>(metrics.size());
+    public void process(SearchResponse searchResponse) {
+        ExceptionsHelper.requireNonNull(searchResponse, "searchResponse");
         if (searchResponse.getHits().getTotalHits().value == 0) {
-            listener.onFailure(ExceptionsHelper.badRequestException("No documents found containing both [{}, {}] fields",
-                actualField,
-                predictedField));
-            return;
+            throw ExceptionsHelper.badRequestException("No documents found containing both [{}, {}] fields", actualField, predictedField);
         }
         for (RegressionMetric metric : metrics) {
-            results.add(metric.evaluate(searchResponse.getAggregations()));
+            metric.process(searchResponse.getAggregations());
         }
-        listener.onResponse(results);
     }
 
     @Override

+ 6 - 13
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionMetric.java

@@ -5,20 +5,14 @@
  */
 package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression;
 
-import org.elasticsearch.common.io.stream.NamedWriteable;
-import org.elasticsearch.common.xcontent.ToXContentObject;
+import org.elasticsearch.action.search.SearchResponse;
 import org.elasticsearch.search.aggregations.AggregationBuilder;
 import org.elasticsearch.search.aggregations.Aggregations;
-import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
 
 import java.util.List;
 
-public interface RegressionMetric extends ToXContentObject, NamedWriteable {
-
-    /**
-     * Returns the name of the metric (which may differ to the writeable name)
-     */
-    String getMetricName();
+public interface RegressionMetric extends EvaluationMetric {
 
     /**
      * Builds the aggregation that collect required data to compute the metric
@@ -29,9 +23,8 @@ public interface RegressionMetric extends ToXContentObject, NamedWriteable {
     List<AggregationBuilder> aggs(String actualField, String predictedField);
 
     /**
-     * Calculates the metric result
-     * @param aggs the aggregations
-     * @return the metric result
+     * Processes given aggregations as a step towards computing result
+     * @param aggs aggregations from {@link SearchResponse}
      */
-    EvaluationMetricResult evaluate(Aggregations aggs);
+    void process(Aggregations aggs);
 }

+ 22 - 4
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AbstractConfusionMatrixMetric.java

@@ -13,27 +13,30 @@ import org.elasticsearch.index.query.BoolQueryBuilder;
 import org.elasticsearch.index.query.QueryBuilders;
 import org.elasticsearch.search.aggregations.AggregationBuilder;
 import org.elasticsearch.search.aggregations.AggregationBuilders;
+import org.elasticsearch.search.aggregations.Aggregations;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.List;
+import java.util.Optional;
 
 abstract class AbstractConfusionMatrixMetric implements SoftClassificationMetric {
 
     public static final ParseField AT = new ParseField("at");
 
     protected final double[] thresholds;
+    private EvaluationMetricResult result;
 
     protected AbstractConfusionMatrixMetric(double[] thresholds) {
         this.thresholds = ExceptionsHelper.requireNonNull(thresholds, AT);
         if (thresholds.length == 0) {
-            throw ExceptionsHelper.badRequestException("[" + getMetricName() + "." + AT.getPreferredName()
-                + "] must have at least one value");
+            throw ExceptionsHelper.badRequestException("[" + getName() + "." + AT.getPreferredName() + "] must have at least one value");
         }
         for (double threshold : thresholds) {
             if (threshold < 0 || threshold > 1.0) {
-                throw ExceptionsHelper.badRequestException("[" + getMetricName() + "." + AT.getPreferredName()
+                throw ExceptionsHelper.badRequestException("[" + getName() + "." + AT.getPreferredName()
                     + "] values must be in [0.0, 1.0]");
             }
         }
@@ -58,6 +61,9 @@ abstract class AbstractConfusionMatrixMetric implements SoftClassificationMetric
 
     @Override
     public final List<AggregationBuilder> aggs(String actualField, List<ClassInfo> classInfos) {
+        if (result != null) {
+            return List.of();
+        }
         List<AggregationBuilder> aggs = new ArrayList<>();
         for (double threshold : thresholds) {
             aggs.addAll(aggsAt(actualField, classInfos, threshold));
@@ -65,14 +71,26 @@ abstract class AbstractConfusionMatrixMetric implements SoftClassificationMetric
         return aggs;
     }
 
+    @Override
+    public void process(ClassInfo classInfo, Aggregations aggs) {
+        result = evaluate(classInfo, aggs);
+    }
+
+    @Override
+    public Optional<EvaluationMetricResult> getResult() {
+        return Optional.ofNullable(result);
+    }
+
     protected abstract List<AggregationBuilder> aggsAt(String labelField, List<ClassInfo> classInfos, double threshold);
 
+    protected abstract EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs);
+
     protected enum Condition {
         TP, FP, TN, FN;
     }
 
     protected String aggName(ClassInfo classInfo, double threshold, Condition condition) {
-        return getMetricName() + "_" + classInfo.getName() + "_at_" + threshold + "_" + condition.name();
+        return getName() + "_" + classInfo.getName() + "_at_" + threshold + "_" + condition.name();
     }
 
     protected AggregationBuilder buildAgg(ClassInfo classInfo, double threshold, Condition condition) {

+ 22 - 8
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AucRoc.java

@@ -30,6 +30,7 @@ import java.util.Collections;
 import java.util.Comparator;
 import java.util.List;
 import java.util.Objects;
+import java.util.Optional;
 import java.util.stream.IntStream;
 
 /**
@@ -70,6 +71,7 @@ public class AucRoc implements SoftClassificationMetric {
     }
 
     private final boolean includeCurve;
+    private EvaluationMetricResult result;
 
     public AucRoc(Boolean includeCurve) {
         this.includeCurve = includeCurve == null ? false : includeCurve;
@@ -98,7 +100,7 @@ public class AucRoc implements SoftClassificationMetric {
     }
 
     @Override
-    public String getMetricName() {
+    public String getName() {
         return NAME.getPreferredName();
     }
 
@@ -117,6 +119,9 @@ public class AucRoc implements SoftClassificationMetric {
 
     @Override
     public List<AggregationBuilder> aggs(String actualField, List<ClassInfo> classInfos) {
+        if (result != null) {
+            return List.of();
+        }
         double[] percentiles = IntStream.range(1, 100).mapToDouble(v -> (double) v).toArray();
         List<AggregationBuilder> aggs = new ArrayList<>();
         for (ClassInfo classInfo : classInfos) {
@@ -134,22 +139,31 @@ public class AucRoc implements SoftClassificationMetric {
         return aggs;
     }
 
+    @Override
+    public void process(ClassInfo classInfo, Aggregations aggs) {
+        result = evaluate(classInfo, aggs);
+    }
+
+    @Override
+    public Optional<EvaluationMetricResult> getResult() {
+        return Optional.ofNullable(result);
+    }
+
     private String evaluatedLabelAggName(ClassInfo classInfo) {
-        return getMetricName() + "_" + classInfo.getName();
+        return getName() + "_" + classInfo.getName();
     }
 
     private String restLabelsAggName(ClassInfo classInfo) {
-        return getMetricName() + "_non_" + classInfo.getName();
+        return getName() + "_non_" + classInfo.getName();
     }
 
-    @Override
-    public EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs) {
+    private EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs) {
         Filter classAgg = aggs.get(evaluatedLabelAggName(classInfo));
         Filter restAgg = aggs.get(restLabelsAggName(classInfo));
         double[] tpPercentiles = percentilesArray(classAgg.getAggregations().get(PERCENTILES),
-            "[" + getMetricName() + "] requires at least one actual_field to have the value [" + classInfo.getName() + "]");
+            "[" + getName() + "] requires at least one actual_field to have the value [" + classInfo.getName() + "]");
         double[] fpPercentiles = percentilesArray(restAgg.getAggregations().get(PERCENTILES),
-            "[" + getMetricName() + "] requires at least one actual_field to have a different value than [" + classInfo.getName() + "]");
+            "[" + getName() + "] requires at least one actual_field to have a different value than [" + classInfo.getName() + "]");
         List<AucRocPoint> aucRocCurve = buildAucRocCurve(tpPercentiles, fpPercentiles);
         double aucRocScore = calculateAucScore(aucRocCurve);
         return new Result(aucRocScore, includeCurve ? aucRocCurve : Collections.emptyList());
@@ -326,7 +340,7 @@ public class AucRoc implements SoftClassificationMetric {
         }
 
         @Override
-        public String getName() {
+        public String getMetricName() {
             return NAME.getPreferredName();
         }
 

+ 23 - 29
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java

@@ -5,7 +5,6 @@
  */
 package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification;
 
-import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.search.SearchResponse;
 import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.ParseField;
@@ -14,18 +13,14 @@ import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.xcontent.ConstructingObjectParser;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentParser;
-import org.elasticsearch.index.query.BoolQueryBuilder;
 import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.index.query.QueryBuilders;
 import org.elasticsearch.search.aggregations.AggregationBuilder;
-import org.elasticsearch.search.aggregations.Aggregations;
 import org.elasticsearch.search.builder.SearchSourceBuilder;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
-import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 
 import java.io.IOException;
-import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.Comparator;
@@ -87,17 +82,16 @@ public class BinarySoftClassification implements Evaluation {
         if (metrics.isEmpty()) {
             throw ExceptionsHelper.badRequestException("[{}] must have one or more metrics", NAME.getPreferredName());
         }
-        Collections.sort(metrics, Comparator.comparing(SoftClassificationMetric::getMetricName));
+        Collections.sort(metrics, Comparator.comparing(SoftClassificationMetric::getName));
         return metrics;
     }
 
     private static List<SoftClassificationMetric> defaultMetrics() {
-        List<SoftClassificationMetric> defaultMetrics = new ArrayList<>(4);
-        defaultMetrics.add(new AucRoc(false));
-        defaultMetrics.add(new Precision(Arrays.asList(0.25, 0.5, 0.75)));
-        defaultMetrics.add(new Recall(Arrays.asList(0.25, 0.5, 0.75)));
-        defaultMetrics.add(new ConfusionMatrix(Arrays.asList(0.25, 0.5, 0.75)));
-        return defaultMetrics;
+        return Arrays.asList(
+            new AucRoc(false),
+            new Precision(Arrays.asList(0.25, 0.5, 0.75)),
+            new Recall(Arrays.asList(0.25, 0.5, 0.75)),
+            new ConfusionMatrix(Arrays.asList(0.25, 0.5, 0.75)));
     }
 
     public BinarySoftClassification(StreamInput in) throws IOException {
@@ -126,7 +120,7 @@ public class BinarySoftClassification implements Evaluation {
 
         builder.startObject(METRICS.getPreferredName());
         for (SoftClassificationMetric metric : metrics) {
-            builder.field(metric.getMetricName(), metric);
+            builder.field(metric.getName(), metric);
         }
         builder.endObject();
 
@@ -155,34 +149,34 @@ public class BinarySoftClassification implements Evaluation {
     }
 
     @Override
-    public SearchSourceBuilder buildSearch(QueryBuilder queryBuilder) {
-        BoolQueryBuilder boolQuery = QueryBuilders.boolQuery()
-            .filter(QueryBuilders.existsQuery(actualField))
-            .filter(QueryBuilders.existsQuery(predictedProbabilityField))
-            .filter(queryBuilder);
-        SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).query(boolQuery);
+    public List<SoftClassificationMetric> getMetrics() {
+        return metrics;
+    }
+
+    @Override
+    public SearchSourceBuilder buildSearch(QueryBuilder userProvidedQueryBuilder) {
+        ExceptionsHelper.requireNonNull(userProvidedQueryBuilder, "userProvidedQueryBuilder");
+        SearchSourceBuilder searchSourceBuilder =
+            newSearchSourceBuilder(List.of(actualField, predictedProbabilityField), userProvidedQueryBuilder);
+        BinaryClassInfo binaryClassInfo = new BinaryClassInfo();
         for (SoftClassificationMetric metric : metrics) {
-            List<AggregationBuilder> aggs = metric.aggs(actualField, Collections.singletonList(new BinaryClassInfo()));
+            List<AggregationBuilder> aggs = metric.aggs(actualField, Collections.singletonList(binaryClassInfo));
             aggs.forEach(searchSourceBuilder::aggregation);
         }
         return searchSourceBuilder;
     }
 
     @Override
-    public void evaluate(SearchResponse searchResponse, ActionListener<List<EvaluationMetricResult>> listener) {
+    public void process(SearchResponse searchResponse) {
+        ExceptionsHelper.requireNonNull(searchResponse, "searchResponse");
         if (searchResponse.getHits().getTotalHits().value == 0) {
-            listener.onFailure(ExceptionsHelper.badRequestException("No documents found containing both [{}, {}] fields", actualField,
-                predictedProbabilityField));
-            return;
+            throw ExceptionsHelper.badRequestException(
+                "No documents found containing both [{}, {}] fields", actualField, predictedProbabilityField);
         }
-
-        List<EvaluationMetricResult> results = new ArrayList<>();
-        Aggregations aggs = searchResponse.getAggregations();
         BinaryClassInfo binaryClassInfo = new BinaryClassInfo();
         for (SoftClassificationMetric metric : metrics) {
-            results.add(metric.evaluate(binaryClassInfo, aggs));
+            metric.process(binaryClassInfo, searchResponse.getAggregations());
         }
-        listener.onResponse(results);
     }
 
     private class BinaryClassInfo implements SoftClassificationMetric.ClassInfo {

+ 2 - 2
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ConfusionMatrix.java

@@ -50,7 +50,7 @@ public class ConfusionMatrix extends AbstractConfusionMatrixMetric {
     }
 
     @Override
-    public String getMetricName() {
+    public String getName() {
         return NAME.getPreferredName();
     }
 
@@ -132,7 +132,7 @@ public class ConfusionMatrix extends AbstractConfusionMatrixMetric {
         }
 
         @Override
-        public String getName() {
+        public String getMetricName() {
             return NAME.getPreferredName();
         }
 

+ 1 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Precision.java

@@ -48,7 +48,7 @@ public class Precision extends AbstractConfusionMatrixMetric {
     }
 
     @Override
-    public String getMetricName() {
+    public String getName() {
         return NAME.getPreferredName();
     }
 

+ 2 - 2
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Recall.java

@@ -48,7 +48,7 @@ public class Recall extends AbstractConfusionMatrixMetric {
     }
 
     @Override
-    public String getMetricName() {
+    public String getName() {
         return NAME.getPreferredName();
     }
 
@@ -68,7 +68,7 @@ public class Recall extends AbstractConfusionMatrixMetric {
     @Override
     protected List<AggregationBuilder> aggsAt(String actualField, List<ClassInfo> classInfos, double threshold) {
         List<AggregationBuilder> aggs = new ArrayList<>();
-        for (ClassInfo classInfo: classInfos) {
+        for (ClassInfo classInfo : classInfos) {
             aggs.add(buildAgg(classInfo, threshold, Condition.TP));
             aggs.add(buildAgg(classInfo, threshold, Condition.FN));
         }

+ 1 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ScoreByThresholdResult.java

@@ -40,7 +40,7 @@ public class ScoreByThresholdResult implements EvaluationMetricResult {
     }
 
     @Override
-    public String getName() {
+    public String getMetricName() {
         return name;
     }
 

+ 6 - 13
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/SoftClassificationMetric.java

@@ -5,16 +5,15 @@
  */
 package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification;
 
-import org.elasticsearch.common.io.stream.NamedWriteable;
-import org.elasticsearch.common.xcontent.ToXContentObject;
+import org.elasticsearch.action.search.SearchResponse;
 import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.search.aggregations.AggregationBuilder;
 import org.elasticsearch.search.aggregations.Aggregations;
-import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
 
 import java.util.List;
 
-public interface SoftClassificationMetric extends ToXContentObject, NamedWriteable {
+public interface SoftClassificationMetric extends EvaluationMetric {
 
     /**
      * The information of a specific class
@@ -37,11 +36,6 @@ public interface SoftClassificationMetric extends ToXContentObject, NamedWriteab
         String getProbabilityField();
     }
 
-    /**
-     * Returns the name of the metric (which may differ to the writeable name)
-     */
-    String getMetricName();
-
     /**
      * Builds the aggregation that collect required data to compute the metric
      * @param actualField the field that stores the actual class
@@ -51,10 +45,9 @@ public interface SoftClassificationMetric extends ToXContentObject, NamedWriteab
     List<AggregationBuilder> aggs(String actualField, List<ClassInfo> classInfos);
 
     /**
-     * Calculates the metric result for a given class
+     * Processes given aggregations as a step towards computing result
      * @param classInfo the class to calculate the metric for
-     * @param aggs the aggregations
-     * @return the metric result
+     * @param aggs aggregations from {@link SearchResponse}
      */
-    EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs);
+    void process(ClassInfo classInfo, Aggregations aggs);
 }

+ 5 - 2
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredErrorTests.java

@@ -49,8 +49,9 @@ public class MeanSquaredErrorTests extends AbstractSerializingTestCase<MeanSquar
         ));
 
         MeanSquaredError mse = new MeanSquaredError();
-        EvaluationMetricResult result = mse.evaluate(aggs);
+        mse.process(aggs);
 
+        EvaluationMetricResult result = mse.getResult().get();
         String expected = "{\"error\":0.8123}";
         assertThat(Strings.toString(result), equalTo(expected));
     }
@@ -61,7 +62,9 @@ public class MeanSquaredErrorTests extends AbstractSerializingTestCase<MeanSquar
         ));
 
         MeanSquaredError mse = new MeanSquaredError();
-        EvaluationMetricResult result = mse.evaluate(aggs);
+        mse.process(aggs);
+
+        EvaluationMetricResult result = mse.getResult().get();
         assertThat(result, equalTo(new MeanSquaredError.Result(0.0)));
     }
 

+ 25 - 11
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquaredTests.java

@@ -52,8 +52,9 @@ public class RSquaredTests extends AbstractSerializingTestCase<RSquared> {
         ));
 
         RSquared rSquared = new RSquared();
-        EvaluationMetricResult result = rSquared.evaluate(aggs);
+        rSquared.process(aggs);
 
+        EvaluationMetricResult result = rSquared.getResult().get();
         String expected = "{\"value\":0.9348643947690524}";
         assertThat(Strings.toString(result), equalTo(expected));
     }
@@ -67,35 +68,48 @@ public class RSquaredTests extends AbstractSerializingTestCase<RSquared> {
         ));
 
         RSquared rSquared = new RSquared();
-        EvaluationMetricResult result = rSquared.evaluate(aggs);
+        rSquared.process(aggs);
+
+        EvaluationMetricResult result = rSquared.getResult().get();
         assertThat(result, equalTo(new RSquared.Result(0.0)));
     }
 
     public void testEvaluate_GivenMissingAggs() {
-        EvaluationMetricResult zeroResult = new RSquared.Result(0.0);
         Aggregations aggs = new Aggregations(Collections.singletonList(
             createSingleMetricAgg("some_other_single_metric_agg", 0.2377)
         ));
 
         RSquared rSquared = new RSquared();
-        EvaluationMetricResult result = rSquared.evaluate(aggs);
-        assertThat(result, equalTo(zeroResult));
+        rSquared.process(aggs);
+
+        EvaluationMetricResult result = rSquared.getResult().get();
+        assertThat(result, equalTo(new RSquared.Result(0.0)));
+    }
 
-        aggs = new Aggregations(Arrays.asList(
+    public void testEvaluate_GivenMissingExtendedStatsAgg() {
+        Aggregations aggs = new Aggregations(Arrays.asList(
             createSingleMetricAgg("some_other_single_metric_agg", 0.2377),
             createSingleMetricAgg("residual_sum_of_squares", 0.2377)
         ));
 
-        result = rSquared.evaluate(aggs);
-        assertThat(result, equalTo(zeroResult));
+        RSquared rSquared = new RSquared();
+        rSquared.process(aggs);
 
-        aggs = new Aggregations(Arrays.asList(
+        EvaluationMetricResult result = rSquared.getResult().get();
+        assertThat(result, equalTo(new RSquared.Result(0.0)));
+    }
+
+    public void testEvaluate_GivenMissingResidualSumOfSquaresAgg() {
+        Aggregations aggs = new Aggregations(Arrays.asList(
             createSingleMetricAgg("some_other_single_metric_agg", 0.2377),
             createExtendedStatsAgg("extended_stats_actual",100, 50)
         ));
 
-        result = rSquared.evaluate(aggs);
-        assertThat(result, equalTo(zeroResult));
+        RSquared rSquared = new RSquared();
+        rSquared.process(aggs);
+
+        EvaluationMetricResult result = rSquared.getResult().get();
+        assertThat(result, equalTo(new RSquared.Result(0.0)));
     }
 
     private static NumericMetricsAggregation.SingleValue createSingleMetricAgg(String name, double value) {

+ 10 - 10
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java

@@ -12,6 +12,7 @@ import org.elasticsearch.common.xcontent.NamedXContentRegistry;
 import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.index.query.QueryBuilders;
+import org.elasticsearch.search.builder.SearchSourceBuilder;
 import org.elasticsearch.test.AbstractSerializingTestCase;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
 
@@ -22,6 +23,7 @@ import java.util.Collections;
 import java.util.List;
 
 import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.greaterThan;
 
 public class RegressionTests extends AbstractSerializingTestCase<Regression> {
 
@@ -43,13 +45,7 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
         if (randomBoolean()) {
             metrics.add(RSquaredTests.createRandom());
         }
-        return new Regression(randomAlphaOfLength(10),
-            randomAlphaOfLength(10),
-            randomBoolean() ?
-                null :
-                metrics.isEmpty() ?
-                    null :
-                    metrics);
+        return new Regression(randomAlphaOfLength(10), randomAlphaOfLength(10), metrics.isEmpty() ? null : metrics);
     }
 
     @Override
@@ -74,7 +70,6 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
     }
 
     public void testBuildSearch() {
-        Regression evaluation = new Regression("act", "prob", Arrays.asList(new MeanSquaredError()));
         QueryBuilder userProvidedQuery =
             QueryBuilders.boolQuery()
                 .filter(QueryBuilders.termQuery("field_A", "some-value"))
@@ -82,10 +77,15 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
         QueryBuilder expectedSearchQuery =
             QueryBuilders.boolQuery()
                 .filter(QueryBuilders.existsQuery("act"))
-                .filter(QueryBuilders.existsQuery("prob"))
+                .filter(QueryBuilders.existsQuery("pred"))
                 .filter(QueryBuilders.boolQuery()
                     .filter(QueryBuilders.termQuery("field_A", "some-value"))
                     .filter(QueryBuilders.termQuery("field_B", "some-other-value")));
-        assertThat(evaluation.buildSearch(userProvidedQuery).query(), equalTo(expectedSearchQuery));
+
+        Regression evaluation = new Regression("act", "pred", Arrays.asList(new MeanSquaredError()));
+
+        SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(userProvidedQuery);
+        assertThat(searchSourceBuilder.query(), equalTo(expectedSearchQuery));
+        assertThat(searchSourceBuilder.aggregations().count(), greaterThan(0));
     }
 }

+ 8 - 2
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassificationTests.java

@@ -12,6 +12,7 @@ import org.elasticsearch.common.xcontent.NamedXContentRegistry;
 import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.index.query.QueryBuilders;
+import org.elasticsearch.search.builder.SearchSourceBuilder;
 import org.elasticsearch.test.AbstractSerializingTestCase;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
 
@@ -22,6 +23,7 @@ import java.util.Collections;
 import java.util.List;
 
 import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.greaterThan;
 
 public class BinarySoftClassificationTests extends AbstractSerializingTestCase<BinarySoftClassification> {
 
@@ -81,7 +83,6 @@ public class BinarySoftClassificationTests extends AbstractSerializingTestCase<B
     }
 
     public void testBuildSearch() {
-        BinarySoftClassification evaluation = new BinarySoftClassification("act", "prob", Arrays.asList(new Precision(Arrays.asList(0.7))));
         QueryBuilder userProvidedQuery =
             QueryBuilders.boolQuery()
                 .filter(QueryBuilders.termQuery("field_A", "some-value"))
@@ -93,6 +94,11 @@ public class BinarySoftClassificationTests extends AbstractSerializingTestCase<B
                 .filter(QueryBuilders.boolQuery()
                     .filter(QueryBuilders.termQuery("field_A", "some-value"))
                     .filter(QueryBuilders.termQuery("field_B", "some-other-value")));
-        assertThat(evaluation.buildSearch(userProvidedQuery).query(), equalTo(expectedSearchQuery));
+
+        BinarySoftClassification evaluation = new BinarySoftClassification("act", "prob", Arrays.asList(new Precision(Arrays.asList(0.7))));
+
+        SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(userProvidedQuery);
+        assertThat(searchSourceBuilder.query(), equalTo(expectedSearchQuery));
+        assertThat(searchSourceBuilder.aggregations().count(), greaterThan(0));
     }
 }

+ 58 - 17
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportEvaluateDataFrameAction.java

@@ -12,12 +12,13 @@ import org.elasticsearch.action.support.ActionFilters;
 import org.elasticsearch.action.support.HandledTransportAction;
 import org.elasticsearch.client.Client;
 import org.elasticsearch.common.inject.Inject;
+import org.elasticsearch.search.builder.SearchSourceBuilder;
 import org.elasticsearch.tasks.Task;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.TransportService;
 import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
-import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
+import org.elasticsearch.xpack.ml.utils.TypedChainTaskExecutor;
 
 import java.util.List;
 
@@ -38,24 +39,64 @@ public class TransportEvaluateDataFrameAction extends HandledTransportAction<Eva
     @Override
     protected void doExecute(Task task, EvaluateDataFrameAction.Request request,
                              ActionListener<EvaluateDataFrameAction.Response> listener) {
-        Evaluation evaluation = request.getEvaluation();
-        SearchRequest searchRequest = new SearchRequest(request.getIndices());
-        searchRequest.source(evaluation.buildSearch(request.getParsedQuery()));
-
-        ActionListener<List<EvaluationMetricResult>> resultsListener = ActionListener.wrap(
-            results -> listener.onResponse(new EvaluateDataFrameAction.Response(evaluation.getName(), results)),
+        ActionListener<List<Void>> resultsListener = ActionListener.wrap(
+            unused -> {
+                EvaluateDataFrameAction.Response response =
+                    new EvaluateDataFrameAction.Response(request.getEvaluation().getName(), request.getEvaluation().getResults());
+                listener.onResponse(response);
+            },
             listener::onFailure
         );
 
-        client.execute(SearchAction.INSTANCE, searchRequest, ActionListener.wrap(
-            searchResponse -> threadPool.generic().execute(() -> {
-                try {
-                    evaluation.evaluate(searchResponse, resultsListener);
-                } catch (Exception e) {
-                    listener.onFailure(e);
-                };
-            }),
-            listener::onFailure
-        ));
+        EvaluationExecutor evaluationExecutor = new EvaluationExecutor(threadPool, client, request);
+        evaluationExecutor.execute(resultsListener);
+    }
+
+    /**
+     * {@link EvaluationExecutor} class allows for serial execution of evaluation steps.
+     *
+     * Each step consists of the following phases:
+     *  1. build search request with aggs requested by individual metrics
+     *  2. execute search action with the request built in (1.)
+     *  3. make all individual metrics process the search response obtained in (2.)
+     *  4. check if all the metrics have their results computed
+     *      a) If so, call the final listener and finish
+     *      b) Otherwise, add another step to the queue
+     *
+     * To avoid infinite loop it is essential that every metric *does* compute its result at some point.
+     * */
+    private static final class EvaluationExecutor extends TypedChainTaskExecutor<Void> {
+
+        private final Client client;
+        private final EvaluateDataFrameAction.Request request;
+        private final Evaluation evaluation;
+
+        EvaluationExecutor(ThreadPool threadPool, Client client, EvaluateDataFrameAction.Request request) {
+            super(threadPool.generic(), unused -> true, unused -> true);
+            this.client = client;
+            this.request = request;
+            this.evaluation = request.getEvaluation();
+            // Add one task only. Other tasks will be added as needed by the nextTask method itself.
+            add(nextTask());
+        }
+
+        private TypedChainTaskExecutor.ChainTask<Void> nextTask() {
+            return listener -> {
+                SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(request.getParsedQuery());
+                SearchRequest searchRequest = new SearchRequest(request.getIndices()).source(searchSourceBuilder);
+                client.execute(
+                    SearchAction.INSTANCE,
+                    searchRequest,
+                    ActionListener.wrap(
+                        searchResponse -> {
+                            evaluation.process(searchResponse);
+                            if (evaluation.hasAllResults() == false) {
+                                add(nextTask());
+                            }
+                            listener.onResponse(null);
+                        },
+                        listener::onFailure));
+            };
+        }
     }
 }