Răsfoiți Sursa

Remove ClassInfo interface and BinaryClassInfo class. (#49649)

Przemysław Witek 5 ani în urmă
părinte
comite
e248610334
16 a modificat fișierele cu 248 adăugiri și 391 ștergeri
  1. 52 14
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/Evaluation.java
  2. 18 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetric.java
  3. 13 41
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Classification.java
  4. 0 19
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationMetric.java
  5. 11 39
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java
  6. 0 19
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionMetric.java
  7. 39 38
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AbstractConfusionMatrixMetric.java
  8. 29 35
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AucRoc.java
  9. 21 73
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java
  10. 13 12
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ConfusionMatrix.java
  11. 10 9
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Precision.java
  12. 10 9
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Recall.java
  13. 3 39
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/SoftClassificationMetric.java
  14. 9 12
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ConfusionMatrixTests.java
  15. 10 16
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/PrecisionTests.java
  16. 10 16
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/RecallTests.java

+ 52 - 14
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/Evaluation.java

@@ -6,15 +6,23 @@
 package org.elasticsearch.xpack.core.ml.dataframe.evaluation;
 
 import org.elasticsearch.action.search.SearchResponse;
+import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.io.stream.NamedWriteable;
 import org.elasticsearch.common.xcontent.ToXContentObject;
 import org.elasticsearch.index.query.BoolQueryBuilder;
 import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.index.query.QueryBuilders;
+import org.elasticsearch.search.aggregations.AggregationBuilder;
 import org.elasticsearch.search.builder.SearchSourceBuilder;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
 import java.util.List;
+import java.util.Objects;
 import java.util.Optional;
+import java.util.function.Supplier;
 import java.util.stream.Collectors;
 
 /**
@@ -27,37 +35,67 @@ public interface Evaluation extends ToXContentObject, NamedWriteable {
      */
     String getName();
 
+    /**
+     * Returns the field containing the actual value
+     */
+    String getActualField();
+
+    /**
+     * Returns the field containing the predicted value
+     */
+    String getPredictedField();
+
     /**
      * Returns the list of metrics to evaluate
      * @return list of metrics to evaluate
      */
     List<? extends EvaluationMetric> getMetrics();
 
+    default <T extends EvaluationMetric> List<T> initMetrics(@Nullable List<T> parsedMetrics, Supplier<List<T>> defaultMetricsSupplier) {
+        List<T> metrics = parsedMetrics == null ? defaultMetricsSupplier.get() : new ArrayList<>(parsedMetrics);
+        if (metrics.isEmpty()) {
+            throw ExceptionsHelper.badRequestException("[{}] must have one or more metrics", getName());
+        }
+        Collections.sort(metrics, Comparator.comparing(EvaluationMetric::getName));
+        return metrics;
+    }
+
     /**
      * Builds the search required to collect data to compute the evaluation result
      * @param userProvidedQueryBuilder User-provided query that must be respected when collecting data
      */
-    SearchSourceBuilder buildSearch(QueryBuilder userProvidedQueryBuilder);
-
-    /**
-     * Builds the search that verifies existence of required fields and applies user-provided query
-     * @param requiredFields fields that must exist
-     * @param userProvidedQueryBuilder user-provided query
-     */
-    default SearchSourceBuilder newSearchSourceBuilder(List<String> requiredFields, QueryBuilder userProvidedQueryBuilder) {
-        BoolQueryBuilder boolQuery = QueryBuilders.boolQuery();
-        for (String requiredField : requiredFields) {
-            boolQuery.filter(QueryBuilders.existsQuery(requiredField));
+    default SearchSourceBuilder buildSearch(QueryBuilder userProvidedQueryBuilder) {
+        Objects.requireNonNull(userProvidedQueryBuilder);
+        BoolQueryBuilder boolQuery =
+            QueryBuilders.boolQuery()
+                // Verify existence of required fields
+                .filter(QueryBuilders.existsQuery(getActualField()))
+                .filter(QueryBuilders.existsQuery(getPredictedField()))
+                // Apply user-provided query
+                .filter(userProvidedQueryBuilder);
+        SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).query(boolQuery);
+        for (EvaluationMetric metric : getMetrics()) {
+            // Fetch aggregations requested by individual metrics
+            List<AggregationBuilder> aggs = metric.aggs(getActualField(), getPredictedField());
+            aggs.forEach(searchSourceBuilder::aggregation);
         }
-        boolQuery.filter(userProvidedQueryBuilder);
-        return new SearchSourceBuilder().size(0).query(boolQuery);
+        return searchSourceBuilder;
     }
 
     /**
      * Processes {@link SearchResponse} from the search action
      * @param searchResponse response from the search action
      */
-    void process(SearchResponse searchResponse);
+    default void process(SearchResponse searchResponse) {
+        Objects.requireNonNull(searchResponse);
+        if (searchResponse.getHits().getTotalHits().value == 0) {
+            throw ExceptionsHelper.badRequestException(
+                "No documents found containing both [{}, {}] fields", getActualField(), getPredictedField());
+        }
+        for (EvaluationMetric metric : getMetrics()) {
+            metric.process(searchResponse.getAggregations());
+        }
+    }
 
     /**
      * @return true iff all the metrics have their results computed

+ 18 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetric.java

@@ -5,9 +5,13 @@
  */
 package org.elasticsearch.xpack.core.ml.dataframe.evaluation;
 
+import org.elasticsearch.action.search.SearchResponse;
 import org.elasticsearch.common.io.stream.NamedWriteable;
 import org.elasticsearch.common.xcontent.ToXContentObject;
+import org.elasticsearch.search.aggregations.AggregationBuilder;
+import org.elasticsearch.search.aggregations.Aggregations;
 
+import java.util.List;
 import java.util.Optional;
 
 /**
@@ -20,6 +24,20 @@ public interface EvaluationMetric extends ToXContentObject, NamedWriteable {
      */
     String getName();
 
+    /**
+     * Builds the aggregation that collect required data to compute the metric
+     * @param actualField the field that stores the actual value
+     * @param predictedField the field that stores the predicted value (class name or probability)
+     * @return the aggregations required to compute the metric
+     */
+    List<AggregationBuilder> aggs(String actualField, String predictedField);
+
+    /**
+     * Processes given aggregations as a step towards computing result
+     * @param aggs aggregations from {@link SearchResponse}
+     */
+    void process(Aggregations aggs);
+
     /**
      * Gets the evaluation result for this metric.
      * @return {@code Optional.empty()} if the result is not available yet, {@code Optional.of(result)} otherwise

+ 13 - 41
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Classification.java

@@ -5,7 +5,6 @@
  */
 package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
 
-import org.elasticsearch.action.search.SearchResponse;
 import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.ParseField;
 import org.elasticsearch.common.io.stream.StreamInput;
@@ -13,17 +12,11 @@ import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.xcontent.ConstructingObjectParser;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentParser;
-import org.elasticsearch.index.query.QueryBuilder;
-import org.elasticsearch.search.aggregations.AggregationBuilder;
-import org.elasticsearch.search.builder.SearchSourceBuilder;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 
 import java.io.IOException;
-import java.util.ArrayList;
 import java.util.Arrays;
-import java.util.Collections;
-import java.util.Comparator;
 import java.util.List;
 import java.util.Objects;
 
@@ -55,13 +48,13 @@ public class Classification implements Evaluation {
 
     /**
      * The field containing the actual value
-     * The value of this field is assumed to be numeric
+     * The value of this field is assumed to be categorical
      */
     private final String actualField;
 
     /**
      * The field containing the predicted value
-     * The value of this field is assumed to be numeric
+     * The value of this field is assumed to be categorical
      */
     private final String predictedField;
 
@@ -73,7 +66,11 @@ public class Classification implements Evaluation {
     public Classification(String actualField, String predictedField, @Nullable List<ClassificationMetric> metrics) {
         this.actualField = ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD);
         this.predictedField = ExceptionsHelper.requireNonNull(predictedField, PREDICTED_FIELD);
-        this.metrics = initMetrics(metrics);
+        this.metrics = initMetrics(metrics, Classification::defaultMetrics);
+    }
+
+    private static List<ClassificationMetric> defaultMetrics() {
+        return Arrays.asList(new MulticlassConfusionMatrix());
     }
 
     public Classification(StreamInput in) throws IOException {
@@ -82,49 +79,24 @@ public class Classification implements Evaluation {
         this.metrics = in.readNamedWriteableList(ClassificationMetric.class);
     }
 
-    private static List<ClassificationMetric> initMetrics(@Nullable List<ClassificationMetric> parsedMetrics) {
-        List<ClassificationMetric> metrics = parsedMetrics == null ? defaultMetrics() : new ArrayList<>(parsedMetrics);
-        if (metrics.isEmpty()) {
-            throw ExceptionsHelper.badRequestException("[{}] must have one or more metrics", NAME.getPreferredName());
-        }
-        Collections.sort(metrics, Comparator.comparing(ClassificationMetric::getName));
-        return metrics;
-    }
-
-    private static List<ClassificationMetric> defaultMetrics() {
-        return Arrays.asList(new MulticlassConfusionMatrix());
-    }
-
     @Override
     public String getName() {
         return NAME.getPreferredName();
     }
 
     @Override
-    public List<ClassificationMetric> getMetrics() {
-        return metrics;
+    public String getActualField() {
+        return actualField;
     }
 
     @Override
-    public SearchSourceBuilder buildSearch(QueryBuilder userProvidedQueryBuilder) {
-        ExceptionsHelper.requireNonNull(userProvidedQueryBuilder, "userProvidedQueryBuilder");
-        SearchSourceBuilder searchSourceBuilder = newSearchSourceBuilder(List.of(actualField, predictedField), userProvidedQueryBuilder);
-        for (ClassificationMetric metric : metrics) {
-            List<AggregationBuilder> aggs = metric.aggs(actualField, predictedField);
-            aggs.forEach(searchSourceBuilder::aggregation);
-        }
-        return searchSourceBuilder;
+    public String getPredictedField() {
+        return predictedField;
     }
 
     @Override
-    public void process(SearchResponse searchResponse) {
-        ExceptionsHelper.requireNonNull(searchResponse, "searchResponse");
-        if (searchResponse.getHits().getTotalHits().value == 0) {
-            throw ExceptionsHelper.badRequestException("No documents found containing both [{}, {}] fields", actualField, predictedField);
-        }
-        for (ClassificationMetric metric : metrics) {
-            metric.process(searchResponse.getAggregations());
-        }
+    public List<ClassificationMetric> getMetrics() {
+        return metrics;
     }
 
     @Override

+ 0 - 19
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationMetric.java

@@ -5,26 +5,7 @@
  */
 package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
 
-import org.elasticsearch.action.search.SearchResponse;
-import org.elasticsearch.search.aggregations.AggregationBuilder;
-import org.elasticsearch.search.aggregations.Aggregations;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
 
-import java.util.List;
-
 public interface ClassificationMetric extends EvaluationMetric {
-
-    /**
-     * Builds the aggregation that collect required data to compute the metric
-     * @param actualField the field that stores the actual value
-     * @param predictedField the field that stores the predicted value
-     * @return the aggregations required to compute the metric
-     */
-    List<AggregationBuilder> aggs(String actualField, String predictedField);
-
-    /**
-     * Processes given aggregations as a step towards computing result
-     * @param aggs aggregations from {@link SearchResponse}
-     */
-    void process(Aggregations aggs);
 }

+ 11 - 39
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java

@@ -5,7 +5,6 @@
  */
 package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression;
 
-import org.elasticsearch.action.search.SearchResponse;
 import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.ParseField;
 import org.elasticsearch.common.io.stream.StreamInput;
@@ -13,17 +12,11 @@ import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.xcontent.ConstructingObjectParser;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentParser;
-import org.elasticsearch.index.query.QueryBuilder;
-import org.elasticsearch.search.aggregations.AggregationBuilder;
-import org.elasticsearch.search.builder.SearchSourceBuilder;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 
 import java.io.IOException;
-import java.util.ArrayList;
 import java.util.Arrays;
-import java.util.Collections;
-import java.util.Comparator;
 import java.util.List;
 import java.util.Objects;
 
@@ -73,7 +66,11 @@ public class Regression implements Evaluation {
     public Regression(String actualField, String predictedField, @Nullable List<RegressionMetric> metrics) {
         this.actualField = ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD);
         this.predictedField = ExceptionsHelper.requireNonNull(predictedField, PREDICTED_FIELD);
-        this.metrics = initMetrics(metrics);
+        this.metrics = initMetrics(metrics, Regression::defaultMetrics);
+    }
+
+    private static List<RegressionMetric> defaultMetrics() {
+        return Arrays.asList(new MeanSquaredError(), new RSquared());
     }
 
     public Regression(StreamInput in) throws IOException {
@@ -82,49 +79,24 @@ public class Regression implements Evaluation {
         this.metrics = in.readNamedWriteableList(RegressionMetric.class);
     }
 
-    private static List<RegressionMetric> initMetrics(@Nullable List<RegressionMetric> parsedMetrics) {
-        List<RegressionMetric> metrics = parsedMetrics == null ? defaultMetrics() : new ArrayList<>(parsedMetrics);
-        if (metrics.isEmpty()) {
-            throw ExceptionsHelper.badRequestException("[{}] must have one or more metrics", NAME.getPreferredName());
-        }
-        Collections.sort(metrics, Comparator.comparing(RegressionMetric::getName));
-        return metrics;
-    }
-
-    private static List<RegressionMetric> defaultMetrics() {
-        return Arrays.asList(new MeanSquaredError(), new RSquared());
-    }
-
     @Override
     public String getName() {
         return NAME.getPreferredName();
     }
 
     @Override
-    public List<RegressionMetric> getMetrics() {
-        return metrics;
+    public String getActualField() {
+        return actualField;
     }
 
     @Override
-    public SearchSourceBuilder buildSearch(QueryBuilder userProvidedQueryBuilder) {
-        ExceptionsHelper.requireNonNull(userProvidedQueryBuilder, "userProvidedQueryBuilder");
-        SearchSourceBuilder searchSourceBuilder = newSearchSourceBuilder(List.of(actualField, predictedField), userProvidedQueryBuilder);
-        for (RegressionMetric metric : metrics) {
-            List<AggregationBuilder> aggs = metric.aggs(actualField, predictedField);
-            aggs.forEach(searchSourceBuilder::aggregation);
-        }
-        return searchSourceBuilder;
+    public String getPredictedField() {
+        return predictedField;
     }
 
     @Override
-    public void process(SearchResponse searchResponse) {
-        ExceptionsHelper.requireNonNull(searchResponse, "searchResponse");
-        if (searchResponse.getHits().getTotalHits().value == 0) {
-            throw ExceptionsHelper.badRequestException("No documents found containing both [{}, {}] fields", actualField, predictedField);
-        }
-        for (RegressionMetric metric : metrics) {
-            metric.process(searchResponse.getAggregations());
-        }
+    public List<RegressionMetric> getMetrics() {
+        return metrics;
     }
 
     @Override

+ 0 - 19
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionMetric.java

@@ -5,26 +5,7 @@
  */
 package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression;
 
-import org.elasticsearch.action.search.SearchResponse;
-import org.elasticsearch.search.aggregations.AggregationBuilder;
-import org.elasticsearch.search.aggregations.Aggregations;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
 
-import java.util.List;
-
 public interface RegressionMetric extends EvaluationMetric {
-
-    /**
-     * Builds the aggregation that collect required data to compute the metric
-     * @param actualField the field that stores the actual value
-     * @param predictedField the field that stores the predicted value
-     * @return the aggregations required to compute the metric
-     */
-    List<AggregationBuilder> aggs(String actualField, String predictedField);
-
-    /**
-     * Processes given aggregations as a step towards computing result
-     * @param aggs aggregations from {@link SearchResponse}
-     */
-    void process(Aggregations aggs);
 }

+ 39 - 38
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AbstractConfusionMatrixMetric.java

@@ -10,6 +10,7 @@ import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.index.query.BoolQueryBuilder;
+import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.index.query.QueryBuilders;
 import org.elasticsearch.search.aggregations.AggregationBuilder;
 import org.elasticsearch.search.aggregations.AggregationBuilders;
@@ -18,10 +19,11 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResu
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 
 import java.io.IOException;
-import java.util.ArrayList;
 import java.util.List;
 import java.util.Optional;
 
+import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric.actualIsTrueQuery;
+
 abstract class AbstractConfusionMatrixMetric implements SoftClassificationMetric {
 
     public static final ParseField AT = new ParseField("at");
@@ -29,8 +31,8 @@ abstract class AbstractConfusionMatrixMetric implements SoftClassificationMetric
     protected final double[] thresholds;
     private EvaluationMetricResult result;
 
-    protected AbstractConfusionMatrixMetric(double[] thresholds) {
-        this.thresholds = ExceptionsHelper.requireNonNull(thresholds, AT);
+    protected AbstractConfusionMatrixMetric(List<Double> at) {
+        this.thresholds = ExceptionsHelper.requireNonNull(at, AT).stream().mapToDouble(Double::doubleValue).toArray();
         if (thresholds.length == 0) {
             throw ExceptionsHelper.badRequestException("[" + getName() + "." + AT.getPreferredName() + "] must have at least one value");
         }
@@ -60,20 +62,16 @@ abstract class AbstractConfusionMatrixMetric implements SoftClassificationMetric
     }
 
     @Override
-    public final List<AggregationBuilder> aggs(String actualField, List<ClassInfo> classInfos) {
+    public final List<AggregationBuilder> aggs(String actualField, String predictedProbabilityField) {
         if (result != null) {
             return List.of();
         }
-        List<AggregationBuilder> aggs = new ArrayList<>();
-        for (double threshold : thresholds) {
-            aggs.addAll(aggsAt(actualField, classInfos, threshold));
-        }
-        return aggs;
+        return aggsAt(actualField, predictedProbabilityField);
     }
 
     @Override
-    public void process(ClassInfo classInfo, Aggregations aggs) {
-        result = evaluate(classInfo, aggs);
+    public void process(Aggregations aggs) {
+        result = evaluate(aggs);
     }
 
     @Override
@@ -81,40 +79,43 @@ abstract class AbstractConfusionMatrixMetric implements SoftClassificationMetric
         return Optional.ofNullable(result);
     }
 
-    protected abstract List<AggregationBuilder> aggsAt(String labelField, List<ClassInfo> classInfos, double threshold);
+    protected abstract List<AggregationBuilder> aggsAt(String actualField, String predictedProbabilityField);
+
+    protected abstract EvaluationMetricResult evaluate(Aggregations aggs);
 
-    protected abstract EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs);
+    enum Condition {
+        TP(true, true),
+        FP(false, true),
+        TN(false, false),
+        FN(true, false);
 
-    protected enum Condition {
-        TP, FP, TN, FN;
+        final boolean actual;
+        final boolean predicted;
+
+        Condition(boolean actual, boolean predicted) {
+            this.actual = actual;
+            this.predicted = predicted;
+        }
     }
 
-    protected String aggName(ClassInfo classInfo, double threshold, Condition condition) {
-        return getName() + "_" + classInfo.getName() + "_at_" + threshold + "_" + condition.name();
+    protected String aggName(double threshold, Condition condition) {
+        return getName() + "_at_" + threshold + "_" + condition.name();
     }
 
-    protected AggregationBuilder buildAgg(ClassInfo classInfo, double threshold, Condition condition) {
+    protected AggregationBuilder buildAgg(String actualField, String predictedProbabilityField, double threshold, Condition condition) {
         BoolQueryBuilder boolQuery = QueryBuilders.boolQuery();
-        switch (condition) {
-            case TP:
-                boolQuery.must(classInfo.matchingQuery());
-                boolQuery.must(QueryBuilders.rangeQuery(classInfo.getProbabilityField()).gte(threshold));
-                break;
-            case FP:
-                boolQuery.mustNot(classInfo.matchingQuery());
-                boolQuery.must(QueryBuilders.rangeQuery(classInfo.getProbabilityField()).gte(threshold));
-                break;
-            case TN:
-                boolQuery.mustNot(classInfo.matchingQuery());
-                boolQuery.must(QueryBuilders.rangeQuery(classInfo.getProbabilityField()).lt(threshold));
-                break;
-            case FN:
-                boolQuery.must(classInfo.matchingQuery());
-                boolQuery.must(QueryBuilders.rangeQuery(classInfo.getProbabilityField()).lt(threshold));
-                break;
-            default:
-                throw new IllegalArgumentException("Unknown enum value: " + condition);
+        QueryBuilder actualIsTrueQuery = actualIsTrueQuery(actualField);
+        QueryBuilder predictedIsTrueQuery = QueryBuilders.rangeQuery(predictedProbabilityField).gte(threshold);
+        if (condition.actual) {
+            boolQuery.must(actualIsTrueQuery);
+        } else {
+            boolQuery.mustNot(actualIsTrueQuery);
+        }
+        if (condition.predicted) {
+            boolQuery.must(predictedIsTrueQuery);
+        } else {
+            boolQuery.mustNot(predictedIsTrueQuery);
         }
-        return AggregationBuilders.filter(aggName(classInfo, threshold, condition), boolQuery);
+        return AggregationBuilders.filter(aggName(threshold, condition), boolQuery);
     }
 }

+ 29 - 35
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AucRoc.java

@@ -33,6 +33,8 @@ import java.util.Objects;
 import java.util.Optional;
 import java.util.stream.IntStream;
 
+import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric.actualIsTrueQuery;
+
 /**
  * Area under the curve (AUC) of the receiver operating characteristic (ROC).
  * The ROC curve is a plot of the TPR (true positive rate) against
@@ -66,6 +68,9 @@ public class AucRoc implements SoftClassificationMetric {
 
     private static final String PERCENTILES = "percentiles";
 
+    private static final String TRUE_AGG_NAME = NAME.getPreferredName() + "_true";
+    private static final String NON_TRUE_AGG_NAME = NAME.getPreferredName() + "_non_true";
+
     public static AucRoc fromXContent(XContentParser parser) {
         return PARSER.apply(parser, null);
     }
@@ -118,30 +123,39 @@ public class AucRoc implements SoftClassificationMetric {
     }
 
     @Override
-    public List<AggregationBuilder> aggs(String actualField, List<ClassInfo> classInfos) {
+    public List<AggregationBuilder> aggs(String actualField, String predictedProbabilityField) {
         if (result != null) {
             return List.of();
         }
         double[] percentiles = IntStream.range(1, 100).mapToDouble(v -> (double) v).toArray();
-        List<AggregationBuilder> aggs = new ArrayList<>();
-        for (ClassInfo classInfo : classInfos) {
-            AggregationBuilder percentilesForClassValueAgg = AggregationBuilders
-                .filter(evaluatedLabelAggName(classInfo), classInfo.matchingQuery())
+        AggregationBuilder percentilesForClassValueAgg =
+            AggregationBuilders
+                .filter(TRUE_AGG_NAME, actualIsTrueQuery(actualField))
                 .subAggregation(
-                    AggregationBuilders.percentiles(PERCENTILES).field(classInfo.getProbabilityField()).percentiles(percentiles));
-            AggregationBuilder percentilesForRestAgg = AggregationBuilders
-                .filter(restLabelsAggName(classInfo), QueryBuilders.boolQuery().mustNot(classInfo.matchingQuery()))
+                    AggregationBuilders.percentiles(PERCENTILES).field(predictedProbabilityField).percentiles(percentiles));
+        AggregationBuilder percentilesForRestAgg =
+            AggregationBuilders
+                .filter(NON_TRUE_AGG_NAME, QueryBuilders.boolQuery().mustNot(actualIsTrueQuery(actualField)))
                 .subAggregation(
-                    AggregationBuilders.percentiles(PERCENTILES).field(classInfo.getProbabilityField()).percentiles(percentiles));
-            aggs.add(percentilesForClassValueAgg);
-            aggs.add(percentilesForRestAgg);
-        }
-        return aggs;
+                    AggregationBuilders.percentiles(PERCENTILES).field(predictedProbabilityField).percentiles(percentiles));
+        return List.of(percentilesForClassValueAgg, percentilesForRestAgg);
     }
 
     @Override
-    public void process(ClassInfo classInfo, Aggregations aggs) {
-        result = evaluate(classInfo, aggs);
+    public void process(Aggregations aggs) {
+        Filter classAgg = aggs.get(TRUE_AGG_NAME);
+        Filter restAgg = aggs.get(NON_TRUE_AGG_NAME);
+        double[] tpPercentiles =
+            percentilesArray(
+                classAgg.getAggregations().get(PERCENTILES),
+                "[" + getName() + "] requires at least one actual_field to have the value [true]");
+        double[] fpPercentiles =
+            percentilesArray(
+                restAgg.getAggregations().get(PERCENTILES),
+                "[" + getName() + "] requires at least one actual_field to have a different value than [true]");
+        List<AucRocPoint> aucRocCurve = buildAucRocCurve(tpPercentiles, fpPercentiles);
+        double aucRocScore = calculateAucScore(aucRocCurve);
+        result = new Result(aucRocScore, includeCurve ? aucRocCurve : Collections.emptyList());
     }
 
     @Override
@@ -149,26 +163,6 @@ public class AucRoc implements SoftClassificationMetric {
         return Optional.ofNullable(result);
     }
 
-    private String evaluatedLabelAggName(ClassInfo classInfo) {
-        return getName() + "_" + classInfo.getName();
-    }
-
-    private String restLabelsAggName(ClassInfo classInfo) {
-        return getName() + "_non_" + classInfo.getName();
-    }
-
-    private EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs) {
-        Filter classAgg = aggs.get(evaluatedLabelAggName(classInfo));
-        Filter restAgg = aggs.get(restLabelsAggName(classInfo));
-        double[] tpPercentiles = percentilesArray(classAgg.getAggregations().get(PERCENTILES),
-            "[" + getName() + "] requires at least one actual_field to have the value [" + classInfo.getName() + "]");
-        double[] fpPercentiles = percentilesArray(restAgg.getAggregations().get(PERCENTILES),
-            "[" + getName() + "] requires at least one actual_field to have a different value than [" + classInfo.getName() + "]");
-        List<AucRocPoint> aucRocCurve = buildAucRocCurve(tpPercentiles, fpPercentiles);
-        double aucRocScore = calculateAucScore(aucRocCurve);
-        return new Result(aucRocScore, includeCurve ? aucRocCurve : Collections.emptyList());
-    }
-
     private static double[] percentilesArray(Percentiles percentiles, String errorIfUndefined) {
         double[] result = new double[99];
         percentiles.forEach(percentile -> {

+ 21 - 73
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java

@@ -5,7 +5,6 @@
  */
 package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification;
 
-import org.elasticsearch.action.search.SearchResponse;
 import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.ParseField;
 import org.elasticsearch.common.io.stream.StreamInput;
@@ -13,17 +12,11 @@ import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.xcontent.ConstructingObjectParser;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentParser;
-import org.elasticsearch.index.query.QueryBuilder;
-import org.elasticsearch.index.query.QueryBuilders;
-import org.elasticsearch.search.aggregations.AggregationBuilder;
-import org.elasticsearch.search.builder.SearchSourceBuilder;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 
 import java.io.IOException;
 import java.util.Arrays;
-import java.util.Collections;
-import java.util.Comparator;
 import java.util.List;
 import java.util.Objects;
 
@@ -74,16 +67,7 @@ public class BinarySoftClassification implements Evaluation {
                                     @Nullable List<SoftClassificationMetric> metrics) {
         this.actualField = ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD);
         this.predictedProbabilityField = ExceptionsHelper.requireNonNull(predictedProbabilityField, PREDICTED_PROBABILITY_FIELD);
-        this.metrics = initMetrics(metrics);
-    }
-
-    private static List<SoftClassificationMetric> initMetrics(@Nullable List<SoftClassificationMetric> parsedMetrics) {
-        List<SoftClassificationMetric> metrics = parsedMetrics == null ? defaultMetrics() : parsedMetrics;
-        if (metrics.isEmpty()) {
-            throw ExceptionsHelper.badRequestException("[{}] must have one or more metrics", NAME.getPreferredName());
-        }
-        Collections.sort(metrics, Comparator.comparing(SoftClassificationMetric::getName));
-        return metrics;
+        this.metrics = initMetrics(metrics, BinarySoftClassification::defaultMetrics);
     }
 
     private static List<SoftClassificationMetric> defaultMetrics() {
@@ -100,6 +84,26 @@ public class BinarySoftClassification implements Evaluation {
         this.metrics = in.readNamedWriteableList(SoftClassificationMetric.class);
     }
 
+    @Override
+    public String getName() {
+        return NAME.getPreferredName();
+    }
+
+    @Override
+    public String getActualField() {
+        return actualField;
+    }
+
+    @Override
+    public String getPredictedField() {
+        return predictedProbabilityField;
+    }
+
+    @Override
+    public List<SoftClassificationMetric> getMetrics() {
+        return metrics;
+    }
+
     @Override
     public String getWriteableName() {
         return NAME.getPreferredName();
@@ -142,60 +146,4 @@ public class BinarySoftClassification implements Evaluation {
     public int hashCode() {
         return Objects.hash(actualField, predictedProbabilityField, metrics);
     }
-
-    @Override
-    public String getName() {
-        return NAME.getPreferredName();
-    }
-
-    @Override
-    public List<SoftClassificationMetric> getMetrics() {
-        return metrics;
-    }
-
-    @Override
-    public SearchSourceBuilder buildSearch(QueryBuilder userProvidedQueryBuilder) {
-        ExceptionsHelper.requireNonNull(userProvidedQueryBuilder, "userProvidedQueryBuilder");
-        SearchSourceBuilder searchSourceBuilder =
-            newSearchSourceBuilder(List.of(actualField, predictedProbabilityField), userProvidedQueryBuilder);
-        BinaryClassInfo binaryClassInfo = new BinaryClassInfo();
-        for (SoftClassificationMetric metric : metrics) {
-            List<AggregationBuilder> aggs = metric.aggs(actualField, Collections.singletonList(binaryClassInfo));
-            aggs.forEach(searchSourceBuilder::aggregation);
-        }
-        return searchSourceBuilder;
-    }
-
-    @Override
-    public void process(SearchResponse searchResponse) {
-        ExceptionsHelper.requireNonNull(searchResponse, "searchResponse");
-        if (searchResponse.getHits().getTotalHits().value == 0) {
-            throw ExceptionsHelper.badRequestException(
-                "No documents found containing both [{}, {}] fields", actualField, predictedProbabilityField);
-        }
-        BinaryClassInfo binaryClassInfo = new BinaryClassInfo();
-        for (SoftClassificationMetric metric : metrics) {
-            metric.process(binaryClassInfo, searchResponse.getAggregations());
-        }
-    }
-
-    private class BinaryClassInfo implements SoftClassificationMetric.ClassInfo {
-
-        private QueryBuilder matchingQuery = QueryBuilders.queryStringQuery(actualField + ": (1 OR true)");
-
-        @Override
-        public String getName() {
-            return String.valueOf(true);
-        }
-
-        @Override
-        public QueryBuilder matchingQuery() {
-            return matchingQuery;
-        }
-
-        @Override
-        public String getProbabilityField() {
-            return predictedProbabilityField;
-        }
-    }
 }

+ 13 - 12
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ConfusionMatrix.java

@@ -37,7 +37,7 @@ public class ConfusionMatrix extends AbstractConfusionMatrixMetric {
     }
 
     public ConfusionMatrix(List<Double> at) {
-        super(at.stream().mapToDouble(Double::doubleValue).toArray());
+        super(at);
     }
 
     public ConfusionMatrix(StreamInput in) throws IOException {
@@ -68,28 +68,29 @@ public class ConfusionMatrix extends AbstractConfusionMatrixMetric {
     }
 
     @Override
-    protected List<AggregationBuilder> aggsAt(String labelField, List<ClassInfo> classInfos, double threshold) {
+    protected List<AggregationBuilder> aggsAt(String actualField, String predictedProbabilityField) {
         List<AggregationBuilder> aggs = new ArrayList<>();
-        for (ClassInfo classInfo : classInfos) {
-            aggs.add(buildAgg(classInfo, threshold, Condition.TP));
-            aggs.add(buildAgg(classInfo, threshold, Condition.FP));
-            aggs.add(buildAgg(classInfo, threshold, Condition.TN));
-            aggs.add(buildAgg(classInfo, threshold, Condition.FN));
+        for (int i = 0; i < thresholds.length; i++) {
+            double threshold = thresholds[i];
+            aggs.add(buildAgg(actualField, predictedProbabilityField, threshold, Condition.TP));
+            aggs.add(buildAgg(actualField, predictedProbabilityField, threshold, Condition.FP));
+            aggs.add(buildAgg(actualField, predictedProbabilityField, threshold, Condition.TN));
+            aggs.add(buildAgg(actualField, predictedProbabilityField, threshold, Condition.FN));
         }
         return aggs;
     }
 
     @Override
-    public EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs) {
+    public EvaluationMetricResult evaluate(Aggregations aggs) {
         long[] tp = new long[thresholds.length];
         long[] fp = new long[thresholds.length];
         long[] tn = new long[thresholds.length];
         long[] fn = new long[thresholds.length];
         for (int i = 0; i < thresholds.length; i++) {
-            Filter tpAgg = aggs.get(aggName(classInfo, thresholds[i], Condition.TP));
-            Filter fpAgg = aggs.get(aggName(classInfo, thresholds[i], Condition.FP));
-            Filter tnAgg = aggs.get(aggName(classInfo, thresholds[i], Condition.TN));
-            Filter fnAgg = aggs.get(aggName(classInfo, thresholds[i], Condition.FN));
+            Filter tpAgg = aggs.get(aggName(thresholds[i], Condition.TP));
+            Filter fpAgg = aggs.get(aggName(thresholds[i], Condition.FP));
+            Filter tnAgg = aggs.get(aggName(thresholds[i], Condition.TN));
+            Filter fnAgg = aggs.get(aggName(thresholds[i], Condition.FN));
             tp[i] = tpAgg.getDocCount();
             fp[i] = fpAgg.getDocCount();
             tn[i] = tnAgg.getDocCount();

+ 10 - 9
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Precision.java

@@ -35,7 +35,7 @@ public class Precision extends AbstractConfusionMatrixMetric {
     }
 
     public Precision(List<Double> at) {
-        super(at.stream().mapToDouble(Double::doubleValue).toArray());
+        super(at);
     }
 
     public Precision(StreamInput in) throws IOException {
@@ -66,22 +66,23 @@ public class Precision extends AbstractConfusionMatrixMetric {
     }
 
     @Override
-    protected List<AggregationBuilder> aggsAt(String labelField, List<ClassInfo> classInfos, double threshold) {
+    protected List<AggregationBuilder> aggsAt(String actualField, String predictedProbabilityField) {
         List<AggregationBuilder> aggs = new ArrayList<>();
-        for (ClassInfo classInfo : classInfos) {
-            aggs.add(buildAgg(classInfo, threshold, Condition.TP));
-            aggs.add(buildAgg(classInfo, threshold, Condition.FP));
+        for (int i = 0; i < thresholds.length; i++) {
+            double threshold = thresholds[i];
+            aggs.add(buildAgg(actualField, predictedProbabilityField, threshold, Condition.TP));
+            aggs.add(buildAgg(actualField, predictedProbabilityField, threshold, Condition.FP));
         }
         return aggs;
     }
 
     @Override
-    public EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs) {
+    public EvaluationMetricResult evaluate(Aggregations aggs) {
         double[] precisions = new double[thresholds.length];
-        for (int i = 0; i < precisions.length; i++) {
+        for (int i = 0; i < thresholds.length; i++) {
             double threshold = thresholds[i];
-            Filter tpAgg = aggs.get(aggName(classInfo, threshold, Condition.TP));
-            Filter fpAgg = aggs.get(aggName(classInfo, threshold, Condition.FP));
+            Filter tpAgg = aggs.get(aggName(threshold, Condition.TP));
+            Filter fpAgg = aggs.get(aggName(threshold, Condition.FP));
             long tp = tpAgg.getDocCount();
             long fp = fpAgg.getDocCount();
             precisions[i] = tp + fp == 0 ? 0.0 : (double) tp / (tp + fp);

+ 10 - 9
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Recall.java

@@ -35,7 +35,7 @@ public class Recall extends AbstractConfusionMatrixMetric {
     }
 
     public Recall(List<Double> at) {
-        super(at.stream().mapToDouble(Double::doubleValue).toArray());
+        super(at);
     }
 
     public Recall(StreamInput in) throws IOException {
@@ -66,22 +66,23 @@ public class Recall extends AbstractConfusionMatrixMetric {
     }
 
     @Override
-    protected List<AggregationBuilder> aggsAt(String actualField, List<ClassInfo> classInfos, double threshold) {
+    protected List<AggregationBuilder> aggsAt(String actualField, String predictedProbabilityField) {
         List<AggregationBuilder> aggs = new ArrayList<>();
-        for (ClassInfo classInfo : classInfos) {
-            aggs.add(buildAgg(classInfo, threshold, Condition.TP));
-            aggs.add(buildAgg(classInfo, threshold, Condition.FN));
+        for (int i = 0; i < thresholds.length; i++) {
+            double threshold = thresholds[i];
+            aggs.add(buildAgg(actualField, predictedProbabilityField, threshold, Condition.TP));
+            aggs.add(buildAgg(actualField, predictedProbabilityField, threshold, Condition.FN));
         }
         return aggs;
     }
 
     @Override
-    public EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs) {
+    public EvaluationMetricResult evaluate(Aggregations aggs) {
         double[] recalls = new double[thresholds.length];
-        for (int i = 0; i < recalls.length; i++) {
+        for (int i = 0; i < thresholds.length; i++) {
             double threshold = thresholds[i];
-            Filter tpAgg = aggs.get(aggName(classInfo, threshold, Condition.TP));
-            Filter fnAgg = aggs.get(aggName(classInfo, threshold, Condition.FN));
+            Filter tpAgg = aggs.get(aggName(threshold, Condition.TP));
+            Filter fnAgg = aggs.get(aggName(threshold, Condition.FN));
             long tp = tpAgg.getDocCount();
             long fn = fnAgg.getDocCount();
             recalls[i] = tp + fn == 0 ? 0.0 : (double) tp / (tp + fn);

+ 3 - 39
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/SoftClassificationMetric.java

@@ -5,49 +5,13 @@
  */
 package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification;
 
-import org.elasticsearch.action.search.SearchResponse;
 import org.elasticsearch.index.query.QueryBuilder;
-import org.elasticsearch.search.aggregations.AggregationBuilder;
-import org.elasticsearch.search.aggregations.Aggregations;
+import org.elasticsearch.index.query.QueryBuilders;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
 
-import java.util.List;
-
 public interface SoftClassificationMetric extends EvaluationMetric {
 
-    /**
-     * The information of a specific class
-     */
-    interface ClassInfo {
-
-        /**
-         * Returns the class name
-         */
-        String getName();
-
-        /**
-         * Returns a query that matches documents of the class
-         */
-        QueryBuilder matchingQuery();
-
-        /**
-         * Returns the field that has the probability to be of the class
-         */
-        String getProbabilityField();
+    static QueryBuilder actualIsTrueQuery(String actualField) {
+        return QueryBuilders.queryStringQuery(actualField + ": (1 OR true)");
     }
-
-    /**
-     * Builds the aggregation that collect required data to compute the metric
-     * @param actualField the field that stores the actual class
-     * @param classInfos the information of each class to compute the metric for
-     * @return the aggregations required to compute the metric
-     */
-    List<AggregationBuilder> aggs(String actualField, List<ClassInfo> classInfos);
-
-    /**
-     * Processes given aggregations as a step towards computing result
-     * @param classInfo the class to calculate the metric for
-     * @param aggs aggregations from {@link SearchResponse}
-     */
-    void process(ClassInfo classInfo, Aggregations aggs);
 }

+ 9 - 12
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ConfusionMatrixTests.java

@@ -49,22 +49,19 @@ public class ConfusionMatrixTests extends AbstractSerializingTestCase<ConfusionM
     }
 
     public void testEvaluate() {
-        SoftClassificationMetric.ClassInfo classInfo = mock(SoftClassificationMetric.ClassInfo.class);
-        when(classInfo.getName()).thenReturn("foo");
-
         Aggregations aggs = new Aggregations(Arrays.asList(
-            createFilterAgg("confusion_matrix_foo_at_0.25_TP", 1L),
-            createFilterAgg("confusion_matrix_foo_at_0.25_FP", 2L),
-            createFilterAgg("confusion_matrix_foo_at_0.25_TN", 3L),
-            createFilterAgg("confusion_matrix_foo_at_0.25_FN", 4L),
-            createFilterAgg("confusion_matrix_foo_at_0.5_TP", 5L),
-            createFilterAgg("confusion_matrix_foo_at_0.5_FP", 6L),
-            createFilterAgg("confusion_matrix_foo_at_0.5_TN", 7L),
-            createFilterAgg("confusion_matrix_foo_at_0.5_FN", 8L)
+            createFilterAgg("confusion_matrix_at_0.25_TP", 1L),
+            createFilterAgg("confusion_matrix_at_0.25_FP", 2L),
+            createFilterAgg("confusion_matrix_at_0.25_TN", 3L),
+            createFilterAgg("confusion_matrix_at_0.25_FN", 4L),
+            createFilterAgg("confusion_matrix_at_0.5_TP", 5L),
+            createFilterAgg("confusion_matrix_at_0.5_FP", 6L),
+            createFilterAgg("confusion_matrix_at_0.5_TN", 7L),
+            createFilterAgg("confusion_matrix_at_0.5_FN", 8L)
         ));
 
         ConfusionMatrix confusionMatrix = new ConfusionMatrix(Arrays.asList(0.25, 0.5));
-        EvaluationMetricResult result = confusionMatrix.evaluate(classInfo, aggs);
+        EvaluationMetricResult result = confusionMatrix.evaluate(aggs);
 
         String expected = "{\"0.25\":{\"tp\":1,\"fp\":2,\"tn\":3,\"fn\":4},\"0.5\":{\"tp\":5,\"fp\":6,\"tn\":7,\"fn\":8}}";
         assertThat(Strings.toString(result), equalTo(expected));

+ 10 - 16
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/PrecisionTests.java

@@ -49,36 +49,30 @@ public class PrecisionTests extends AbstractSerializingTestCase<Precision> {
     }
 
     public void testEvaluate() {
-        SoftClassificationMetric.ClassInfo classInfo = mock(SoftClassificationMetric.ClassInfo.class);
-        when(classInfo.getName()).thenReturn("foo");
-
         Aggregations aggs = new Aggregations(Arrays.asList(
-            createFilterAgg("precision_foo_at_0.25_TP", 1L),
-            createFilterAgg("precision_foo_at_0.25_FP", 4L),
-            createFilterAgg("precision_foo_at_0.5_TP", 3L),
-            createFilterAgg("precision_foo_at_0.5_FP", 1L),
-            createFilterAgg("precision_foo_at_0.75_TP", 5L),
-            createFilterAgg("precision_foo_at_0.75_FP", 0L)
+            createFilterAgg("precision_at_0.25_TP", 1L),
+            createFilterAgg("precision_at_0.25_FP", 4L),
+            createFilterAgg("precision_at_0.5_TP", 3L),
+            createFilterAgg("precision_at_0.5_FP", 1L),
+            createFilterAgg("precision_at_0.75_TP", 5L),
+            createFilterAgg("precision_at_0.75_FP", 0L)
         ));
 
         Precision precision = new Precision(Arrays.asList(0.25, 0.5, 0.75));
-        EvaluationMetricResult result = precision.evaluate(classInfo, aggs);
+        EvaluationMetricResult result = precision.evaluate(aggs);
 
         String expected = "{\"0.25\":0.2,\"0.5\":0.75,\"0.75\":1.0}";
         assertThat(Strings.toString(result), equalTo(expected));
     }
 
     public void testEvaluate_GivenZeroTpAndFp() {
-        SoftClassificationMetric.ClassInfo classInfo = mock(SoftClassificationMetric.ClassInfo.class);
-        when(classInfo.getName()).thenReturn("foo");
-
         Aggregations aggs = new Aggregations(Arrays.asList(
-            createFilterAgg("precision_foo_at_1.0_TP", 0L),
-            createFilterAgg("precision_foo_at_1.0_FP", 0L)
+            createFilterAgg("precision_at_1.0_TP", 0L),
+            createFilterAgg("precision_at_1.0_FP", 0L)
         ));
 
         Precision precision = new Precision(Arrays.asList(1.0));
-        EvaluationMetricResult result = precision.evaluate(classInfo, aggs);
+        EvaluationMetricResult result = precision.evaluate(aggs);
 
         String expected = "{\"1.0\":0.0}";
         assertThat(Strings.toString(result), equalTo(expected));

+ 10 - 16
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/RecallTests.java

@@ -49,36 +49,30 @@ public class RecallTests extends AbstractSerializingTestCase<Recall> {
     }
 
     public void testEvaluate() {
-        SoftClassificationMetric.ClassInfo classInfo = mock(SoftClassificationMetric.ClassInfo.class);
-        when(classInfo.getName()).thenReturn("foo");
-
         Aggregations aggs = new Aggregations(Arrays.asList(
-            createFilterAgg("recall_foo_at_0.25_TP", 1L),
-            createFilterAgg("recall_foo_at_0.25_FN", 4L),
-            createFilterAgg("recall_foo_at_0.5_TP", 3L),
-            createFilterAgg("recall_foo_at_0.5_FN", 1L),
-            createFilterAgg("recall_foo_at_0.75_TP", 5L),
-            createFilterAgg("recall_foo_at_0.75_FN", 0L)
+            createFilterAgg("recall_at_0.25_TP", 1L),
+            createFilterAgg("recall_at_0.25_FN", 4L),
+            createFilterAgg("recall_at_0.5_TP", 3L),
+            createFilterAgg("recall_at_0.5_FN", 1L),
+            createFilterAgg("recall_at_0.75_TP", 5L),
+            createFilterAgg("recall_at_0.75_FN", 0L)
         ));
 
         Recall recall = new Recall(Arrays.asList(0.25, 0.5, 0.75));
-        EvaluationMetricResult result = recall.evaluate(classInfo, aggs);
+        EvaluationMetricResult result = recall.evaluate(aggs);
 
         String expected = "{\"0.25\":0.2,\"0.5\":0.75,\"0.75\":1.0}";
         assertThat(Strings.toString(result), equalTo(expected));
     }
 
     public void testEvaluate_GivenZeroTpAndFp() {
-        SoftClassificationMetric.ClassInfo classInfo = mock(SoftClassificationMetric.ClassInfo.class);
-        when(classInfo.getName()).thenReturn("foo");
-
         Aggregations aggs = new Aggregations(Arrays.asList(
-            createFilterAgg("recall_foo_at_1.0_TP", 0L),
-            createFilterAgg("recall_foo_at_1.0_FN", 0L)
+            createFilterAgg("recall_at_1.0_TP", 0L),
+            createFilterAgg("recall_at_1.0_FN", 0L)
         ));
 
         Recall recall = new Recall(Arrays.asList(1.0));
-        EvaluationMetricResult result = recall.evaluate(classInfo, aggs);
+        EvaluationMetricResult result = recall.evaluate(aggs);
 
         String expected = "{\"1.0\":0.0}";
         assertThat(Strings.toString(result), equalTo(expected));