瀏覽代碼

Fix accuracy metric (#50310)

Przemysław Witek 5 年之前
父節點
當前提交
e901f90afb
共有 14 個文件被更改,包括 472 次插入295 次删除
  1. 53 44
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AccuracyMetric.java
  2. 7 7
      client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java
  3. 4 4
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AccuracyMetricResultTests.java
  4. 1 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetric.java
  5. 131 81
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java
  6. 58 29
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java
  7. 17 16
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java
  8. 11 10
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java
  9. 5 5
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyResultTests.java
  10. 106 41
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyTests.java
  11. 24 20
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java
  12. 44 21
      x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java
  13. 4 6
      x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java
  14. 7 10
      x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml

+ 53 - 44
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AccuracyMetric.java

@@ -20,6 +20,7 @@ package org.elasticsearch.client.ml.dataframe.evaluation.classification;
 
 import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
 import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.xcontent.ConstructingObjectParser;
 import org.elasticsearch.common.xcontent.ObjectParser;
 import org.elasticsearch.common.xcontent.ToXContent;
@@ -35,10 +36,25 @@ import java.util.Objects;
 import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
 
 /**
- * {@link AccuracyMetric} is a metric that answers the question:
- *   "What fraction of examples have been classified correctly by the classifier?"
+ * {@link AccuracyMetric} is a metric that answers the following two questions:
  *
- * equation: accuracy = 1/n * Σ(y == y´)
+ *   1. What is the fraction of documents for which predicted class equals the actual class?
+ *
+ *      equation: overall_accuracy = 1/n * Σ(y == y')
+ *      where: n  = total number of documents
+ *             y  = document's actual class
+ *             y' = document's predicted class
+ *
+ *   2. For any given class X, what is the fraction of documents for which either
+ *       a) both actual and predicted class are equal to X (true positives)
+ *      or
+ *       b) both actual and predicted class are not equal to X (true negatives)
+ *
+ *      equation: accuracy(X) = 1/n * (TP(X) + TN(X))
+ *      where: X     = class being examined
+ *             n     = total number of documents
+ *             TP(X) = number of true positives wrt X
+ *             TN(X) = number of true negatives wrt X
  */
 public class AccuracyMetric implements EvaluationMetric {
 
@@ -78,15 +94,15 @@ public class AccuracyMetric implements EvaluationMetric {
 
     public static class Result implements EvaluationMetric.Result {
 
-        private static final ParseField ACTUAL_CLASSES = new ParseField("actual_classes");
+        private static final ParseField CLASSES = new ParseField("classes");
         private static final ParseField OVERALL_ACCURACY = new ParseField("overall_accuracy");
 
         @SuppressWarnings("unchecked")
         private static final ConstructingObjectParser<Result, Void> PARSER =
-            new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List<ActualClass>) a[0], (double) a[1]));
+            new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List<PerClassResult>) a[0], (double) a[1]));
 
         static {
-            PARSER.declareObjectArray(constructorArg(), ActualClass.PARSER, ACTUAL_CLASSES);
+            PARSER.declareObjectArray(constructorArg(), PerClassResult.PARSER, CLASSES);
             PARSER.declareDouble(constructorArg(), OVERALL_ACCURACY);
         }
 
@@ -94,13 +110,13 @@ public class AccuracyMetric implements EvaluationMetric {
             return PARSER.apply(parser, null);
         }
 
-        /** List of actual classes. */
-        private final List<ActualClass> actualClasses;
-        /** Fraction of documents predicted correctly. */
+        /** List of per-class results. */
+        private final List<PerClassResult> classes;
+        /** Fraction of documents for which predicted class equals the actual class. */
         private final double overallAccuracy;
 
-        public Result(List<ActualClass> actualClasses, double overallAccuracy) {
-            this.actualClasses = Collections.unmodifiableList(Objects.requireNonNull(actualClasses));
+        public Result(List<PerClassResult> classes, double overallAccuracy) {
+            this.classes = Collections.unmodifiableList(Objects.requireNonNull(classes));
             this.overallAccuracy = overallAccuracy;
         }
 
@@ -109,8 +125,8 @@ public class AccuracyMetric implements EvaluationMetric {
             return NAME;
         }
 
-        public List<ActualClass> getActualClasses() {
-            return actualClasses;
+        public List<PerClassResult> getClasses() {
+            return classes;
         }
 
         public double getOverallAccuracy() {
@@ -120,7 +136,7 @@ public class AccuracyMetric implements EvaluationMetric {
         @Override
         public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
             builder.startObject();
-            builder.field(ACTUAL_CLASSES.getPreferredName(), actualClasses);
+            builder.field(CLASSES.getPreferredName(), classes);
             builder.field(OVERALL_ACCURACY.getPreferredName(), overallAccuracy);
             builder.endObject();
             return builder;
@@ -131,52 +147,42 @@ public class AccuracyMetric implements EvaluationMetric {
             if (this == o) return true;
             if (o == null || getClass() != o.getClass()) return false;
             Result that = (Result) o;
-            return Objects.equals(this.actualClasses, that.actualClasses)
+            return Objects.equals(this.classes, that.classes)
                 && this.overallAccuracy == that.overallAccuracy;
         }
 
         @Override
         public int hashCode() {
-            return Objects.hash(actualClasses, overallAccuracy);
+            return Objects.hash(classes, overallAccuracy);
         }
     }
 
-    public static class ActualClass implements ToXContentObject {
+    public static class PerClassResult implements ToXContentObject {
 
-        private static final ParseField ACTUAL_CLASS = new ParseField("actual_class");
-        private static final ParseField ACTUAL_CLASS_DOC_COUNT = new ParseField("actual_class_doc_count");
+        private static final ParseField CLASS_NAME = new ParseField("class_name");
         private static final ParseField ACCURACY = new ParseField("accuracy");
 
         @SuppressWarnings("unchecked")
-        private static final ConstructingObjectParser<ActualClass, Void> PARSER =
-            new ConstructingObjectParser<>("accuracy_actual_class", true, a -> new ActualClass((String) a[0], (long) a[1], (double) a[2]));
+        private static final ConstructingObjectParser<PerClassResult, Void> PARSER =
+            new ConstructingObjectParser<>("accuracy_per_class_result", true, a -> new PerClassResult((String) a[0], (double) a[1]));
 
         static {
-            PARSER.declareString(constructorArg(), ACTUAL_CLASS);
-            PARSER.declareLong(constructorArg(), ACTUAL_CLASS_DOC_COUNT);
+            PARSER.declareString(constructorArg(), CLASS_NAME);
             PARSER.declareDouble(constructorArg(), ACCURACY);
         }
 
-        /** Name of the actual class. */
-        private final String actualClass;
-        /** Number of documents (examples) belonging to the {code actualClass} class. */
-        private final long actualClassDocCount;
-        /** Fraction of documents belonging to the {code actualClass} class predicted correctly. */
+        /** Name of the class. */
+        private final String className;
+        /** Fraction of documents that are either true positives or true negatives wrt {@code className}. */
         private final double accuracy;
 
-        public ActualClass(
-            String actualClass, long actualClassDocCount, double accuracy) {
-            this.actualClass = Objects.requireNonNull(actualClass);
-            this.actualClassDocCount = actualClassDocCount;
+        public PerClassResult(String className, double accuracy) {
+            this.className = Objects.requireNonNull(className);
             this.accuracy = accuracy;
         }
 
-        public String getActualClass() {
-            return actualClass;
-        }
-
-        public long getActualClassDocCount() {
-            return actualClassDocCount;
+        public String getClassName() {
+            return className;
         }
 
         public double getAccuracy() {
@@ -186,8 +192,7 @@ public class AccuracyMetric implements EvaluationMetric {
         @Override
         public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
             builder.startObject();
-            builder.field(ACTUAL_CLASS.getPreferredName(), actualClass);
-            builder.field(ACTUAL_CLASS_DOC_COUNT.getPreferredName(), actualClassDocCount);
+            builder.field(CLASS_NAME.getPreferredName(), className);
             builder.field(ACCURACY.getPreferredName(), accuracy);
             builder.endObject();
             return builder;
@@ -197,15 +202,19 @@ public class AccuracyMetric implements EvaluationMetric {
         public boolean equals(Object o) {
             if (this == o) return true;
             if (o == null || getClass() != o.getClass()) return false;
-            ActualClass that = (ActualClass) o;
-            return Objects.equals(this.actualClass, that.actualClass)
-                && this.actualClassDocCount == that.actualClassDocCount
+            PerClassResult that = (PerClassResult) o;
+            return Objects.equals(this.className, that.className)
                 && this.accuracy == that.accuracy;
         }
 
         @Override
         public int hashCode() {
-            return Objects.hash(actualClass, actualClassDocCount, accuracy);
+            return Objects.hash(className, accuracy);
+        }
+
+        @Override
+        public String toString() {
+            return Strings.toString(this);
         }
     }
 }

+ 7 - 7
client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java

@@ -1819,15 +1819,15 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
             AccuracyMetric.Result accuracyResult = evaluateDataFrameResponse.getMetricByName(AccuracyMetric.NAME);
             assertThat(accuracyResult.getMetricName(), equalTo(AccuracyMetric.NAME));
             assertThat(
-                accuracyResult.getActualClasses(),
+                accuracyResult.getClasses(),
                 equalTo(
                     List.of(
-                        // 3 out of 5 examples labeled as "cat" were classified correctly
-                        new AccuracyMetric.ActualClass("cat", 5, 0.6),
-                        // 3 out of 4 examples labeled as "dog" were classified correctly
-                        new AccuracyMetric.ActualClass("dog", 4, 0.75),
-                        // no examples labeled as "ant" were classified correctly
-                        new AccuracyMetric.ActualClass("ant", 1, 0.0))));
+                        // 9 out of 10 examples were classified correctly
+                        new AccuracyMetric.PerClassResult("ant", 0.9),
+                        // 6 out of 10 examples were classified correctly
+                        new AccuracyMetric.PerClassResult("cat", 0.6),
+                        // 8 out of 10 examples were classified correctly
+                        new AccuracyMetric.PerClassResult("dog", 0.8))));
             assertThat(accuracyResult.getOverallAccuracy(), equalTo(0.6));  // 6 out of 10 examples were classified correctly
         }
         {  // Precision

+ 4 - 4
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AccuracyMetricResultTests.java

@@ -19,7 +19,7 @@
 package org.elasticsearch.client.ml.dataframe.evaluation.classification;
 
 import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
-import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric.ActualClass;
+import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric.PerClassResult;
 import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric.Result;
 import org.elasticsearch.common.xcontent.NamedXContentRegistry;
 import org.elasticsearch.common.xcontent.XContentParser;
@@ -41,13 +41,13 @@ public class AccuracyMetricResultTests extends AbstractXContentTestCase<Result>
     public static Result randomResult() {
         int numClasses = randomIntBetween(2, 100);
         List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
-        List<ActualClass> actualClasses = new ArrayList<>(numClasses);
+        List<PerClassResult> classes = new ArrayList<>(numClasses);
         for (int i = 0; i < numClasses; i++) {
             double accuracy = randomDoubleBetween(0.0, 1.0, true);
-            actualClasses.add(new ActualClass(classNames.get(i), randomNonNegativeLong(), accuracy));
+            classes.add(new PerClassResult(classNames.get(i), accuracy));
         }
         double overallAccuracy = randomDoubleBetween(0.0, 1.0, true);
-        return new Result(actualClasses, overallAccuracy);
+        return new Result(classes, overallAccuracy);
     }
 
     @Override

+ 1 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetric.java

@@ -44,5 +44,5 @@ public interface EvaluationMetric extends ToXContentObject, NamedWriteable {
      * Gets the evaluation result for this metric.
      * @return {@code Optional.empty()} if the result is not available yet, {@code Optional.of(result)} otherwise
      */
-    Optional<EvaluationMetricResult> getResult();
+    Optional<? extends EvaluationMetricResult> getResult();
 }

+ 131 - 81
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java

@@ -5,6 +5,7 @@
  */
 package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
 
+import org.apache.lucene.util.SetOnce;
 import org.elasticsearch.common.ParseField;
 import org.elasticsearch.common.collect.Tuple;
 import org.elasticsearch.common.io.stream.StreamInput;
@@ -20,7 +21,6 @@ import org.elasticsearch.search.aggregations.AggregationBuilder;
 import org.elasticsearch.search.aggregations.AggregationBuilders;
 import org.elasticsearch.search.aggregations.Aggregations;
 import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
-import org.elasticsearch.search.aggregations.bucket.terms.Terms;
 import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
@@ -39,22 +39,36 @@ import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constru
 import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
 
 /**
- * {@link Accuracy} is a metric that answers the question:
- *   "What fraction of examples have been classified correctly by the classifier?"
+ * {@link Accuracy} is a metric that answers the following two questions:
  *
- * equation: accuracy = 1/n * Σ(y == y´)
+ *   1. What is the fraction of documents for which predicted class equals the actual class?
+ *
+ *      equation: overall_accuracy = 1/n * Σ(y == y')
+ *      where: n  = total number of documents
+ *             y  = document's actual class
+ *             y' = document's predicted class
+ *
+ *   2. For any given class X, what is the fraction of documents for which either
+ *       a) both actual and predicted class are equal to X (true positives)
+ *      or
+ *       b) both actual and predicted class are not equal to X (true negatives)
+ *
+ *      equation: accuracy(X) = 1/n * (TP(X) + TN(X))
+ *      where: X     = class being examined
+ *             n     = total number of documents
+ *             TP(X) = number of true positives wrt X
+ *             TN(X) = number of true negatives wrt X
  */
 public class Accuracy implements EvaluationMetric {
 
     public static final ParseField NAME = new ParseField("accuracy");
 
+    static final String OVERALL_ACCURACY_AGG_NAME = "classification_overall_accuracy";
+
     private static final String PAINLESS_TEMPLATE = "doc[''{0}''].value == doc[''{1}''].value";
-    private static final String CLASSES_AGG_NAME = "classification_classes";
-    private static final String PER_CLASS_ACCURACY_AGG_NAME = "classification_per_class_accuracy";
-    private static final String OVERALL_ACCURACY_AGG_NAME = "classification_overall_accuracy";
 
-    private static String buildScript(Object...args) {
-        return new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args);
+    private static Script buildScript(Object...args) {
+        return new Script(new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args));
     }
 
     private static final ObjectParser<Accuracy, Void> PARSER = new ObjectParser<>(NAME.getPreferredName(), true, Accuracy::new);
@@ -63,11 +77,20 @@ public class Accuracy implements EvaluationMetric {
         return PARSER.apply(parser, null);
     }
 
-    private EvaluationMetricResult result;
+    private static final int MAX_CLASSES_CARDINALITY = 1000;
 
-    public Accuracy() {}
+    private final MulticlassConfusionMatrix matrix;
+    private final SetOnce<String> actualField = new SetOnce<>();
+    private final SetOnce<Double> overallAccuracy = new SetOnce<>();
+    private final SetOnce<Result> result = new SetOnce<>();
 
-    public Accuracy(StreamInput in) throws IOException {}
+    public Accuracy() {
+        this.matrix = new MulticlassConfusionMatrix(MAX_CLASSES_CARDINALITY, NAME.getPreferredName() + "_");
+    }
+
+    public Accuracy(StreamInput in) throws IOException {
+        this.matrix = new MulticlassConfusionMatrix(in);
+    }
 
     @Override
     public String getWriteableName() {
@@ -81,43 +104,79 @@ public class Accuracy implements EvaluationMetric {
 
     @Override
     public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
-        if (result != null) {
-            return Tuple.tuple(List.of(), List.of());
+        // Store given {@code actualField} for the purpose of generating error message in {@code process}.
+        this.actualField.trySet(actualField);
+        List<AggregationBuilder> aggs = new ArrayList<>();
+        List<PipelineAggregationBuilder> pipelineAggs = new ArrayList<>();
+        if (overallAccuracy.get() == null) {
+            aggs.add(AggregationBuilders.avg(OVERALL_ACCURACY_AGG_NAME).script(buildScript(actualField, predictedField)));
+        }
+        if (result.get() == null) {
+            Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> matrixAggs = matrix.aggs(actualField, predictedField);
+            aggs.addAll(matrixAggs.v1());
+            pipelineAggs.addAll(matrixAggs.v2());
         }
-        Script accuracyScript = new Script(buildScript(actualField, predictedField));
-        return Tuple.tuple(
-            List.of(
-                AggregationBuilders.terms(CLASSES_AGG_NAME)
-                    .field(actualField)
-                    .subAggregation(AggregationBuilders.avg(PER_CLASS_ACCURACY_AGG_NAME).script(accuracyScript)),
-                AggregationBuilders.avg(OVERALL_ACCURACY_AGG_NAME).script(accuracyScript)),
-            List.of());
+        return Tuple.tuple(aggs, pipelineAggs);
     }
 
     @Override
     public void process(Aggregations aggs) {
-        if (result != null) {
-            return;
+        if (overallAccuracy.get() == null && aggs.get(OVERALL_ACCURACY_AGG_NAME) instanceof NumericMetricsAggregation.SingleValue) {
+            NumericMetricsAggregation.SingleValue overallAccuracyAgg = aggs.get(OVERALL_ACCURACY_AGG_NAME);
+            overallAccuracy.set(overallAccuracyAgg.value());
         }
-        Terms classesAgg = aggs.get(CLASSES_AGG_NAME);
-        NumericMetricsAggregation.SingleValue overallAccuracyAgg = aggs.get(OVERALL_ACCURACY_AGG_NAME);
-        List<ActualClass> actualClasses = new ArrayList<>(classesAgg.getBuckets().size());
-        for (Terms.Bucket bucket : classesAgg.getBuckets()) {
-            String actualClass = bucket.getKeyAsString();
-            long actualClassDocCount = bucket.getDocCount();
-            NumericMetricsAggregation.SingleValue accuracyAgg = bucket.getAggregations().get(PER_CLASS_ACCURACY_AGG_NAME);
-            actualClasses.add(new ActualClass(actualClass, actualClassDocCount, accuracyAgg.value()));
+        matrix.process(aggs);
+        if (result.get() == null && matrix.getResult().isPresent()) {
+            if (matrix.getResult().get().getOtherActualClassCount() > 0) {
+                // This means there were more than {@code maxClassesCardinality} buckets.
+                // We cannot calculate per-class accuracy accurately, so we fail.
+                throw ExceptionsHelper.badRequestException(
+                    "Cannot calculate per-class accuracy. Cardinality of field [{}] is too high", actualField.get());
+            }
+            result.set(new Result(computePerClassAccuracy(matrix.getResult().get()), overallAccuracy.get()));
         }
-        result = new Result(actualClasses, overallAccuracyAgg.value());
     }
 
     @Override
-    public Optional<EvaluationMetricResult> getResult() {
-        return Optional.ofNullable(result);
+    public Optional<Result> getResult() {
+        return Optional.ofNullable(result.get());
+    }
+
+    /**
+     * Computes the per-class accuracy results based on multiclass confusion matrix's result.
+     * Time complexity of this method is linear wrt multiclass confusion matrix size, so O(n^2) where n is the matrix dimension.
+     * This method is visible for testing only.
+     */
+    static List<PerClassResult> computePerClassAccuracy(MulticlassConfusionMatrix.Result matrixResult) {
+        assert matrixResult.getOtherActualClassCount() == 0;
+        // Number of actual classes taken into account
+        int n = matrixResult.getConfusionMatrix().size();
+        // Total number of documents taken into account
+        long totalDocCount =
+            matrixResult.getConfusionMatrix().stream().mapToLong(MulticlassConfusionMatrix.ActualClass::getActualClassDocCount).sum();
+        List<PerClassResult> classes = new ArrayList<>(n);
+        for (int i = 0; i < n; ++i) {
+            String className = matrixResult.getConfusionMatrix().get(i).getActualClass();
+            // Start with the assumption that all the docs were predicted correctly.
+            long correctDocCount = totalDocCount;
+            for (int j = 0; j < n; ++j) {
+                if (i != j) {
+                    // Subtract errors (false negatives)
+                    correctDocCount -= matrixResult.getConfusionMatrix().get(i).getPredictedClasses().get(j).getCount();
+                    // Subtract errors (false positives)
+                    correctDocCount -= matrixResult.getConfusionMatrix().get(j).getPredictedClasses().get(i).getCount();
+                }
+            }
+            // Subtract errors (false negatives) for classes other than explicitly listed in confusion matrix
+            correctDocCount -= matrixResult.getConfusionMatrix().get(i).getOtherPredictedClassDocCount();
+            classes.add(new PerClassResult(className, ((double)correctDocCount) / totalDocCount));
+        }
+        return classes;
     }
 
     @Override
     public void writeTo(StreamOutput out) throws IOException {
+        matrix.writeTo(out);
     }
 
     @Override
@@ -131,25 +190,26 @@ public class Accuracy implements EvaluationMetric {
     public boolean equals(Object o) {
         if (this == o) return true;
         if (o == null || getClass() != o.getClass()) return false;
-        return true;
+        Accuracy that = (Accuracy) o;
+        return Objects.equals(this.matrix, that.matrix);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hashCode(NAME.getPreferredName());
+        return Objects.hash(matrix);
     }
 
     public static class Result implements EvaluationMetricResult {
 
-        private static final ParseField ACTUAL_CLASSES = new ParseField("actual_classes");
+        private static final ParseField CLASSES = new ParseField("classes");
         private static final ParseField OVERALL_ACCURACY = new ParseField("overall_accuracy");
 
         @SuppressWarnings("unchecked")
         private static final ConstructingObjectParser<Result, Void> PARSER =
-            new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List<ActualClass>) a[0], (double) a[1]));
+            new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List<PerClassResult>) a[0], (double) a[1]));
 
         static {
-            PARSER.declareObjectArray(constructorArg(), ActualClass.PARSER, ACTUAL_CLASSES);
+            PARSER.declareObjectArray(constructorArg(), PerClassResult.PARSER, CLASSES);
             PARSER.declareDouble(constructorArg(), OVERALL_ACCURACY);
         }
 
@@ -157,18 +217,18 @@ public class Accuracy implements EvaluationMetric {
             return PARSER.apply(parser, null);
         }
 
-        /** List of actual classes. */
-        private final List<ActualClass> actualClasses;
-        /** Fraction of documents predicted correctly. */
+        /** List of per-class results. */
+        private final List<PerClassResult> classes;
+        /** Fraction of documents for which predicted class equals the actual class. */
         private final double overallAccuracy;
 
-        public Result(List<ActualClass> actualClasses, double overallAccuracy) {
-            this.actualClasses = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(actualClasses, ACTUAL_CLASSES));
+        public Result(List<PerClassResult> classes, double overallAccuracy) {
+            this.classes = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(classes, CLASSES));
             this.overallAccuracy = overallAccuracy;
         }
 
         public Result(StreamInput in) throws IOException {
-            this.actualClasses = Collections.unmodifiableList(in.readList(ActualClass::new));
+            this.classes = Collections.unmodifiableList(in.readList(PerClassResult::new));
             this.overallAccuracy = in.readDouble();
         }
 
@@ -182,8 +242,8 @@ public class Accuracy implements EvaluationMetric {
             return NAME.getPreferredName();
         }
 
-        public List<ActualClass> getActualClasses() {
-            return actualClasses;
+        public List<PerClassResult> getClasses() {
+            return classes;
         }
 
         public double getOverallAccuracy() {
@@ -192,14 +252,14 @@ public class Accuracy implements EvaluationMetric {
 
         @Override
         public void writeTo(StreamOutput out) throws IOException {
-            out.writeList(actualClasses);
+            out.writeList(classes);
             out.writeDouble(overallAccuracy);
         }
 
         @Override
         public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
             builder.startObject();
-            builder.field(ACTUAL_CLASSES.getPreferredName(), actualClasses);
+            builder.field(CLASSES.getPreferredName(), classes);
             builder.field(OVERALL_ACCURACY.getPreferredName(), overallAccuracy);
             builder.endObject();
             return builder;
@@ -210,54 +270,47 @@ public class Accuracy implements EvaluationMetric {
             if (this == o) return true;
             if (o == null || getClass() != o.getClass()) return false;
             Result that = (Result) o;
-            return Objects.equals(this.actualClasses, that.actualClasses)
+            return Objects.equals(this.classes, that.classes)
                 && this.overallAccuracy == that.overallAccuracy;
         }
 
         @Override
         public int hashCode() {
-            return Objects.hash(actualClasses, overallAccuracy);
+            return Objects.hash(classes, overallAccuracy);
         }
     }
 
-    public static class ActualClass implements ToXContentObject, Writeable {
+    public static class PerClassResult implements ToXContentObject, Writeable {
 
-        private static final ParseField ACTUAL_CLASS = new ParseField("actual_class");
-        private static final ParseField ACTUAL_CLASS_DOC_COUNT = new ParseField("actual_class_doc_count");
+        private static final ParseField CLASS_NAME = new ParseField("class_name");
         private static final ParseField ACCURACY = new ParseField("accuracy");
 
         @SuppressWarnings("unchecked")
-        private static final ConstructingObjectParser<ActualClass, Void> PARSER =
-            new ConstructingObjectParser<>("accuracy_actual_class", true, a -> new ActualClass((String) a[0], (long) a[1], (double) a[2]));
+        private static final ConstructingObjectParser<PerClassResult, Void> PARSER =
+            new ConstructingObjectParser<>("accuracy_per_class_result", true, a -> new PerClassResult((String) a[0], (double) a[1]));
 
         static {
-            PARSER.declareString(constructorArg(), ACTUAL_CLASS);
-            PARSER.declareLong(constructorArg(), ACTUAL_CLASS_DOC_COUNT);
+            PARSER.declareString(constructorArg(), CLASS_NAME);
             PARSER.declareDouble(constructorArg(), ACCURACY);
         }
 
-        /** Name of the actual class. */
-        private final String actualClass;
-        /** Number of documents (examples) belonging to the {code actualClass} class. */
-        private final long actualClassDocCount;
-        /** Fraction of documents belonging to the {code actualClass} class predicted correctly. */
+        /** Name of the class. */
+        private final String className;
+        /** Fraction of documents that are either true positives or true negatives wrt {@code className}. */
         private final double accuracy;
 
-        public ActualClass(
-            String actualClass, long actualClassDocCount, double accuracy) {
-            this.actualClass = ExceptionsHelper.requireNonNull(actualClass, ACTUAL_CLASS);
-            this.actualClassDocCount = actualClassDocCount;
+        public PerClassResult(String className, double accuracy) {
+            this.className = ExceptionsHelper.requireNonNull(className, CLASS_NAME);
             this.accuracy = accuracy;
         }
 
-        public ActualClass(StreamInput in) throws IOException {
-            this.actualClass = in.readString();
-            this.actualClassDocCount = in.readVLong();
+        public PerClassResult(StreamInput in) throws IOException {
+            this.className = in.readString();
             this.accuracy = in.readDouble();
         }
 
-        public String getActualClass() {
-            return actualClass;
+        public String getClassName() {
+            return className;
         }
 
         public double getAccuracy() {
@@ -266,16 +319,14 @@ public class Accuracy implements EvaluationMetric {
 
         @Override
         public void writeTo(StreamOutput out) throws IOException {
-            out.writeString(actualClass);
-            out.writeVLong(actualClassDocCount);
+            out.writeString(className);
             out.writeDouble(accuracy);
         }
 
         @Override
         public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
             builder.startObject();
-            builder.field(ACTUAL_CLASS.getPreferredName(), actualClass);
-            builder.field(ACTUAL_CLASS_DOC_COUNT.getPreferredName(), actualClassDocCount);
+            builder.field(CLASS_NAME.getPreferredName(), className);
             builder.field(ACCURACY.getPreferredName(), accuracy);
             builder.endObject();
             return builder;
@@ -285,15 +336,14 @@ public class Accuracy implements EvaluationMetric {
         public boolean equals(Object o) {
             if (this == o) return true;
             if (o == null || getClass() != o.getClass()) return false;
-            ActualClass that = (ActualClass) o;
-            return Objects.equals(this.actualClass, that.actualClass)
-                && this.actualClassDocCount == that.actualClassDocCount
+            PerClassResult that = (PerClassResult) o;
+            return Objects.equals(this.className, that.className)
                 && this.accuracy == that.accuracy;
         }
 
         @Override
         public int hashCode() {
-            return Objects.hash(actualClass, actualClassDocCount, accuracy);
+            return Objects.hash(className, accuracy);
         }
     }
 }

+ 58 - 29
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java

@@ -5,6 +5,8 @@
  */
 package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
 
+import org.apache.lucene.util.SetOnce;
+import org.elasticsearch.Version;
 import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.ParseField;
 import org.elasticsearch.common.collect.Tuple;
@@ -52,13 +54,16 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
     public static final ParseField NAME = new ParseField("multiclass_confusion_matrix");
 
     public static final ParseField SIZE = new ParseField("size");
+    public static final ParseField AGG_NAME_PREFIX = new ParseField("agg_name_prefix");
 
     private static final ConstructingObjectParser<MulticlassConfusionMatrix, Void> PARSER = createParser();
 
     private static ConstructingObjectParser<MulticlassConfusionMatrix, Void> createParser() {
         ConstructingObjectParser<MulticlassConfusionMatrix, Void>  parser =
-            new ConstructingObjectParser<>(NAME.getPreferredName(), true, args -> new MulticlassConfusionMatrix((Integer) args[0]));
+            new ConstructingObjectParser<>(
+                NAME.getPreferredName(), true, args -> new MulticlassConfusionMatrix((Integer) args[0], (String) args[1]));
         parser.declareInt(optionalConstructorArg(), SIZE);
+        parser.declareString(optionalConstructorArg(), AGG_NAME_PREFIX);
         return parser;
     }
 
@@ -66,31 +71,39 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
         return PARSER.apply(parser, null);
     }
 
-    private static final String STEP_1_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_1_by_actual_class";
-    private static final String STEP_2_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_by_actual_class";
-    private static final String STEP_2_AGGREGATE_BY_PREDICTED_CLASS = NAME.getPreferredName() + "_step_2_by_predicted_class";
-    private static final String STEP_2_CARDINALITY_OF_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_cardinality_of_actual_class";
+    static final String STEP_1_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_1_by_actual_class";
+    static final String STEP_2_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_by_actual_class";
+    static final String STEP_2_AGGREGATE_BY_PREDICTED_CLASS = NAME.getPreferredName() + "_step_2_by_predicted_class";
+    static final String STEP_2_CARDINALITY_OF_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_cardinality_of_actual_class";
     private static final String OTHER_BUCKET_KEY = "_other_";
+    private static final String DEFAULT_AGG_NAME_PREFIX = "";
     private static final int DEFAULT_SIZE = 10;
     private static final int MAX_SIZE = 1000;
 
     private final int size;
-    private List<String> topActualClassNames;
-    private Result result;
+    private final String aggNamePrefix;
+    private final SetOnce<List<String>> topActualClassNames = new SetOnce<>();
+    private final SetOnce<Result> result = new SetOnce<>();
 
     public MulticlassConfusionMatrix() {
-        this((Integer) null);
+        this(null, null);
     }
 
-    public MulticlassConfusionMatrix(@Nullable Integer size) {
+    public MulticlassConfusionMatrix(@Nullable Integer size, @Nullable String aggNamePrefix) {
         if (size != null && (size <= 0 || size > MAX_SIZE)) {
             throw ExceptionsHelper.badRequestException("[{}] must be an integer in [1, {}]", SIZE.getPreferredName(), MAX_SIZE);
         }
         this.size = size != null ? size : DEFAULT_SIZE;
+        this.aggNamePrefix = aggNamePrefix != null ? aggNamePrefix : DEFAULT_AGG_NAME_PREFIX;
     }
 
     public MulticlassConfusionMatrix(StreamInput in) throws IOException {
         this.size = in.readVInt();
+        if (in.getVersion().onOrAfter(Version.V_8_0_0)) {
+            this.aggNamePrefix = in.readString();
+        } else {
+            this.aggNamePrefix = DEFAULT_AGG_NAME_PREFIX;
+        }
     }
 
     @Override
@@ -109,30 +122,30 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
 
     @Override
     public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
-        if (topActualClassNames == null) {  // This is step 1
+        if (topActualClassNames.get() == null) {  // This is step 1
             return Tuple.tuple(
                 List.of(
-                    AggregationBuilders.terms(STEP_1_AGGREGATE_BY_ACTUAL_CLASS)
+                    AggregationBuilders.terms(aggName(STEP_1_AGGREGATE_BY_ACTUAL_CLASS))
                         .field(actualField)
                         .order(List.of(BucketOrder.count(false), BucketOrder.key(true)))
                         .size(size)),
                 List.of());
         }
-        if (result == null) {  // This is step 2
+        if (result.get() == null) {  // This is step 2
             KeyedFilter[] keyedFiltersActual =
-                topActualClassNames.stream()
+                topActualClassNames.get().stream()
                     .map(className -> new KeyedFilter(className, QueryBuilders.termQuery(actualField, className)))
                     .toArray(KeyedFilter[]::new);
             KeyedFilter[] keyedFiltersPredicted =
-                topActualClassNames.stream()
+                topActualClassNames.get().stream()
                     .map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className)))
                     .toArray(KeyedFilter[]::new);
             return Tuple.tuple(
                 List.of(
-                    AggregationBuilders.cardinality(STEP_2_CARDINALITY_OF_ACTUAL_CLASS)
+                    AggregationBuilders.cardinality(aggName(STEP_2_CARDINALITY_OF_ACTUAL_CLASS))
                         .field(actualField),
-                    AggregationBuilders.filters(STEP_2_AGGREGATE_BY_ACTUAL_CLASS, keyedFiltersActual)
-                        .subAggregation(AggregationBuilders.filters(STEP_2_AGGREGATE_BY_PREDICTED_CLASS, keyedFiltersPredicted)
+                    AggregationBuilders.filters(aggName(STEP_2_AGGREGATE_BY_ACTUAL_CLASS), keyedFiltersActual)
+                        .subAggregation(AggregationBuilders.filters(aggName(STEP_2_AGGREGATE_BY_PREDICTED_CLASS), keyedFiltersPredicted)
                             .otherBucket(true)
                             .otherBucketKey(OTHER_BUCKET_KEY))),
                 List.of());
@@ -142,18 +155,18 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
 
     @Override
     public void process(Aggregations aggs) {
-        if (topActualClassNames == null && aggs.get(STEP_1_AGGREGATE_BY_ACTUAL_CLASS) != null) {
-            Terms termsAgg = aggs.get(STEP_1_AGGREGATE_BY_ACTUAL_CLASS);
-            topActualClassNames = termsAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).sorted().collect(Collectors.toList());
+        if (topActualClassNames.get() == null && aggs.get(aggName(STEP_1_AGGREGATE_BY_ACTUAL_CLASS)) != null) {
+            Terms termsAgg = aggs.get(aggName(STEP_1_AGGREGATE_BY_ACTUAL_CLASS));
+            topActualClassNames.set(termsAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).sorted().collect(Collectors.toList()));
         }
-        if (result == null && aggs.get(STEP_2_AGGREGATE_BY_ACTUAL_CLASS) != null) {
-            Cardinality cardinalityAgg = aggs.get(STEP_2_CARDINALITY_OF_ACTUAL_CLASS);
-            Filters filtersAgg = aggs.get(STEP_2_AGGREGATE_BY_ACTUAL_CLASS);
+        if (result.get() == null && aggs.get(aggName(STEP_2_AGGREGATE_BY_ACTUAL_CLASS)) != null) {
+            Cardinality cardinalityAgg = aggs.get(aggName(STEP_2_CARDINALITY_OF_ACTUAL_CLASS));
+            Filters filtersAgg = aggs.get(aggName(STEP_2_AGGREGATE_BY_ACTUAL_CLASS));
             List<ActualClass> actualClasses = new ArrayList<>(filtersAgg.getBuckets().size());
             for (Filters.Bucket bucket : filtersAgg.getBuckets()) {
                 String actualClass = bucket.getKeyAsString();
                 long actualClassDocCount = bucket.getDocCount();
-                Filters subAgg = bucket.getAggregations().get(STEP_2_AGGREGATE_BY_PREDICTED_CLASS);
+                Filters subAgg = bucket.getAggregations().get(aggName(STEP_2_AGGREGATE_BY_PREDICTED_CLASS));
                 List<PredictedClass> predictedClasses = new ArrayList<>();
                 long otherPredictedClassDocCount = 0;
                 for (Filters.Bucket subBucket : subAgg.getBuckets()) {
@@ -168,18 +181,25 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
                 predictedClasses.sort(comparing(PredictedClass::getPredictedClass));
                 actualClasses.add(new ActualClass(actualClass, actualClassDocCount, predictedClasses, otherPredictedClassDocCount));
             }
-            result = new Result(actualClasses, Math.max(cardinalityAgg.getValue() - size, 0));
+            result.set(new Result(actualClasses, Math.max(cardinalityAgg.getValue() - size, 0)));
         }
     }
 
+    private String aggName(String aggNameWithoutPrefix) {
+        return aggNamePrefix + aggNameWithoutPrefix;
+    }
+
     @Override
-    public Optional<EvaluationMetricResult> getResult() {
-        return Optional.ofNullable(result);
+    public Optional<Result> getResult() {
+        return Optional.ofNullable(result.get());
     }
 
     @Override
     public void writeTo(StreamOutput out) throws IOException {
         out.writeVInt(size);
+        if (out.getVersion().onOrAfter(Version.V_8_0_0)) {
+            out.writeString(aggNamePrefix);
+        }
     }
 
     @Override
@@ -195,12 +215,13 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
         if (this == o) return true;
         if (o == null || getClass() != o.getClass()) return false;
         MulticlassConfusionMatrix that = (MulticlassConfusionMatrix) o;
-        return Objects.equals(this.size, that.size);
+        return this.size == that.size
+            && Objects.equals(this.aggNamePrefix, that.aggNamePrefix);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(size);
+        return Objects.hash(size, aggNamePrefix);
     }
 
     public static class Result implements EvaluationMetricResult {
@@ -334,6 +355,10 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
             return actualClass;
         }
 
+        public long getActualClassDocCount() {
+            return actualClassDocCount;
+        }
+
         public List<PredictedClass> getPredictedClasses() {
             return predictedClasses;
         }
@@ -410,6 +435,10 @@ public class MulticlassConfusionMatrix implements EvaluationMetric {
             return predictedClass;
         }
 
+        public long getCount() {
+            return count;
+        }
+
         @Override
         public void writeTo(StreamOutput out) throws IOException {
             out.writeString(predictedClass);

+ 17 - 16
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java

@@ -5,6 +5,7 @@
  */
 package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
 
+import org.apache.lucene.util.SetOnce;
 import org.elasticsearch.common.ParseField;
 import org.elasticsearch.common.collect.Tuple;
 import org.elasticsearch.common.io.stream.StreamInput;
@@ -76,9 +77,9 @@ public class Precision implements EvaluationMetric {
 
     private static final int MAX_CLASSES_CARDINALITY = 1000;
 
-    private String actualField;
-    private List<String> topActualClassNames;
-    private EvaluationMetricResult result;
+    private final SetOnce<String> actualField = new SetOnce<>();
+    private final SetOnce<List<String>> topActualClassNames = new SetOnce<>();
+    private final SetOnce<Result> result = new SetOnce<>();
 
     public Precision() {}
 
@@ -97,8 +98,8 @@ public class Precision implements EvaluationMetric {
     @Override
     public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
         // Store given {@code actualField} for the purpose of generating error message in {@code process}.
-        this.actualField = actualField;
-        if (topActualClassNames == null) {  // This is step 1
+        this.actualField.trySet(actualField);
+        if (topActualClassNames.get() == null) {  // This is step 1
             return Tuple.tuple(
                 List.of(
                     AggregationBuilders.terms(ACTUAL_CLASSES_NAMES_AGG_NAME)
@@ -107,9 +108,9 @@ public class Precision implements EvaluationMetric {
                         .size(MAX_CLASSES_CARDINALITY)),
                 List.of());
         }
-        if (result == null) {  // This is step 2
+        if (result.get() == null) {  // This is step 2
             KeyedFilter[] keyedFiltersPredicted =
-                topActualClassNames.stream()
+                topActualClassNames.get().stream()
                     .map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className)))
                     .toArray(KeyedFilter[]::new);
             Script script = buildScript(actualField, predictedField);
@@ -127,18 +128,18 @@ public class Precision implements EvaluationMetric {
 
     @Override
     public void process(Aggregations aggs) {
-        if (topActualClassNames == null && aggs.get(ACTUAL_CLASSES_NAMES_AGG_NAME) instanceof Terms) {
+        if (topActualClassNames.get() == null && aggs.get(ACTUAL_CLASSES_NAMES_AGG_NAME) instanceof Terms) {
             Terms topActualClassesAgg = aggs.get(ACTUAL_CLASSES_NAMES_AGG_NAME);
             if (topActualClassesAgg.getSumOfOtherDocCounts() > 0) {
-                // This means there were more than {@code maxClassesCardinality} buckets.
+                // This means there were more than {@code MAX_CLASSES_CARDINALITY} buckets.
                 // We cannot calculate average precision accurately, so we fail.
                 throw ExceptionsHelper.badRequestException(
-                    "Cannot calculate average precision. Cardinality of field [{}] is too high", actualField);
+                    "Cannot calculate average precision. Cardinality of field [{}] is too high", actualField.get());
             }
-            topActualClassNames =
-                topActualClassesAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).sorted().collect(Collectors.toList());
+            topActualClassNames.set(
+                topActualClassesAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).sorted().collect(Collectors.toList()));
         }
-        if (result == null &&
+        if (result.get() == null &&
                 aggs.get(BY_PREDICTED_CLASS_AGG_NAME) instanceof Filters &&
                 aggs.get(AVG_PRECISION_AGG_NAME) instanceof NumericMetricsAggregation.SingleValue) {
             Filters byPredictedClassAgg = aggs.get(BY_PREDICTED_CLASS_AGG_NAME);
@@ -152,13 +153,13 @@ public class Precision implements EvaluationMetric {
                     classes.add(new PerClassResult(className, precision));
                 }
             }
-            result = new Result(classes, avgPrecisionAgg.value());
+            result.set(new Result(classes, avgPrecisionAgg.value()));
         }
     }
 
     @Override
-    public Optional<EvaluationMetricResult> getResult() {
-        return Optional.ofNullable(result);
+    public Optional<Result> getResult() {
+        return Optional.ofNullable(result.get());
     }
 
     @Override

+ 11 - 10
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java

@@ -5,6 +5,7 @@
  */
 package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
 
+import org.apache.lucene.util.SetOnce;
 import org.elasticsearch.common.ParseField;
 import org.elasticsearch.common.collect.Tuple;
 import org.elasticsearch.common.io.stream.StreamInput;
@@ -70,8 +71,8 @@ public class Recall implements EvaluationMetric {
 
     private static final int MAX_CLASSES_CARDINALITY = 1000;
 
-    private String actualField;
-    private EvaluationMetricResult result;
+    private final SetOnce<String> actualField = new SetOnce<>();
+    private final SetOnce<Result> result = new SetOnce<>();
 
     public Recall() {}
 
@@ -90,8 +91,8 @@ public class Recall implements EvaluationMetric {
     @Override
     public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
         // Store given {@code actualField} for the purpose of generating error message in {@code process}.
-        this.actualField = actualField;
-        if (result != null) {
+        this.actualField.trySet(actualField);
+        if (result.get() != null) {
             return Tuple.tuple(List.of(), List.of());
         }
         Script script = buildScript(actualField, predictedField);
@@ -109,15 +110,15 @@ public class Recall implements EvaluationMetric {
 
     @Override
     public void process(Aggregations aggs) {
-        if (result == null &&
+        if (result.get() == null &&
                 aggs.get(BY_ACTUAL_CLASS_AGG_NAME) instanceof Terms &&
                 aggs.get(AVG_RECALL_AGG_NAME) instanceof NumericMetricsAggregation.SingleValue) {
             Terms byActualClassAgg = aggs.get(BY_ACTUAL_CLASS_AGG_NAME);
             if (byActualClassAgg.getSumOfOtherDocCounts() > 0) {
-                // This means there were more than {@code maxClassesCardinality} buckets.
+                // This means there were more than {@code MAX_CLASSES_CARDINALITY} buckets.
                 // We cannot calculate average recall accurately, so we fail.
                 throw ExceptionsHelper.badRequestException(
-                    "Cannot calculate average recall. Cardinality of field [{}] is too high", actualField);
+                    "Cannot calculate average recall. Cardinality of field [{}] is too high", actualField.get());
             }
             NumericMetricsAggregation.SingleValue avgRecallAgg = aggs.get(AVG_RECALL_AGG_NAME);
             List<PerClassResult> classes = new ArrayList<>(byActualClassAgg.getBuckets().size());
@@ -126,13 +127,13 @@ public class Recall implements EvaluationMetric {
                 NumericMetricsAggregation.SingleValue recallAgg = bucket.getAggregations().get(PER_ACTUAL_CLASS_RECALL_AGG_NAME);
                 classes.add(new PerClassResult(className, recallAgg.value()));
             }
-            result = new Result(classes, avgRecallAgg.value());
+            result.set(new Result(classes, avgRecallAgg.value()));
         }
     }
 
     @Override
-    public Optional<EvaluationMetricResult> getResult() {
-        return Optional.ofNullable(result);
+    public Optional<Result> getResult() {
+        return Optional.ofNullable(result.get());
     }
 
     @Override

+ 5 - 5
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyResultTests.java

@@ -8,9 +8,9 @@ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
 import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.test.AbstractWireSerializingTestCase;
-import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.ActualClass;
-import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.Result;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.PerClassResult;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.Result;
 
 import java.util.ArrayList;
 import java.util.List;
@@ -22,13 +22,13 @@ public class AccuracyResultTests extends AbstractWireSerializingTestCase<Result>
     public static Result createRandom() {
         int numClasses = randomIntBetween(2, 100);
         List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
-        List<ActualClass> actualClasses = new ArrayList<>(numClasses);
+        List<PerClassResult> classes = new ArrayList<>(numClasses);
         for (int i = 0; i < numClasses; i++) {
             double accuracy = randomDoubleBetween(0.0, 1.0, true);
-            actualClasses.add(new ActualClass(classNames.get(i), randomNonNegativeLong(), accuracy));
+            classes.add(new PerClassResult(classNames.get(i), accuracy));
         }
         double overallAccuracy = randomDoubleBetween(0.0, 1.0, true);
-        return new Result(actualClasses, overallAccuracy);
+        return new Result(classes, overallAccuracy);
     }
 
     @Override

+ 106 - 41
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyTests.java

@@ -5,17 +5,26 @@
  */
 package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
 
+import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.search.aggregations.Aggregations;
 import org.elasticsearch.test.AbstractSerializingTestCase;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.PerClassResult;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.Result;
 
 import java.io.IOException;
-import java.util.Arrays;
 import java.util.List;
 
+import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockCardinality;
+import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFilters;
+import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFiltersBucket;
 import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue;
 import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms;
+import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTermsBucket;
+import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.TupleMatchers.isTuple;
+import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.empty;
 import static org.hamcrest.Matchers.equalTo;
 
 public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> {
@@ -45,53 +54,109 @@ public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> {
     }
 
     public void testProcess() {
-        Aggregations aggs = new Aggregations(Arrays.asList(
-            mockTerms("classification_classes"),
-            mockSingleValue("classification_overall_accuracy", 0.8123),
-            mockSingleValue("some_other_single_metric_agg", 0.2377)
-        ));
+        Aggregations aggs = new Aggregations(List.of(
+            mockTerms(
+                "accuracy_" + MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS,
+                List.of(
+                    mockTermsBucket("dog", new Aggregations(List.of())),
+                    mockTermsBucket("cat", new Aggregations(List.of()))),
+                100L),
+            mockFilters(
+                "accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS,
+                List.of(
+                    mockFiltersBucket(
+                        "dog",
+                        30,
+                        new Aggregations(List.of(mockFilters(
+                            "accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
+                            List.of(mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))),
+                    mockFiltersBucket(
+                        "cat",
+                        70,
+                        new Aggregations(List.of(mockFilters(
+                            "accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
+                            List.of(mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))),
+            mockCardinality("accuracy_" + MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 1000L),
+            mockSingleValue(Accuracy.OVERALL_ACCURACY_AGG_NAME, 0.5)));
 
         Accuracy accuracy = new Accuracy();
         accuracy.process(aggs);
 
-        assertThat(accuracy.getResult().get(), equalTo(new Accuracy.Result(List.of(), 0.8123)));
+        assertThat(accuracy.aggs("act", "pred"), isTuple(empty(), empty()));
+
+        Result result = accuracy.getResult().get();
+        assertThat(result.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
+        assertThat(
+            result.getClasses(),
+            equalTo(
+                List.of(
+                    new PerClassResult("dog", 0.5),
+                    new PerClassResult("cat", 0.5))));
+        assertThat(result.getOverallAccuracy(), equalTo(0.5));
+    }
+
+    public void testProcess_GivenCardinalityTooHigh() {
+        Aggregations aggs = new Aggregations(List.of(
+            mockTerms(
+                "accuracy_" + MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS,
+                List.of(
+                    mockTermsBucket("dog", new Aggregations(List.of())),
+                    mockTermsBucket("cat", new Aggregations(List.of()))),
+                100L),
+            mockFilters(
+                "accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS,
+                List.of(
+                    mockFiltersBucket(
+                        "dog",
+                        30,
+                        new Aggregations(List.of(mockFilters(
+                            "accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
+                            List.of(mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))),
+                    mockFiltersBucket(
+                        "cat",
+                        70,
+                        new Aggregations(List.of(mockFilters(
+                            "accuracy_" + MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
+                            List.of(mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))),
+            mockCardinality("accuracy_" + MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 1001L),
+            mockSingleValue(Accuracy.OVERALL_ACCURACY_AGG_NAME, 0.5)));
+
+        Accuracy accuracy = new Accuracy();
+        accuracy.aggs("foo", "bar");
+        ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> accuracy.process(aggs));
+        assertThat(e.getMessage(), containsString("Cardinality of field [foo] is too high"));
     }
 
-    public void testProcess_GivenMissingAgg() {
-        {
-            Aggregations aggs = new Aggregations(Arrays.asList(
-                mockTerms("classification_classes"),
-                mockSingleValue("some_other_single_metric_agg", 0.2377)
-            ));
-            Accuracy accuracy = new Accuracy();
-            expectThrows(NullPointerException.class, () -> accuracy.process(aggs));
-        }
-        {
-            Aggregations aggs = new Aggregations(Arrays.asList(
-                mockSingleValue("classification_overall_accuracy", 0.8123),
-                mockSingleValue("some_other_single_metric_agg", 0.2377)
-            ));
-            Accuracy accuracy = new Accuracy();
-            expectThrows(NullPointerException.class, () -> accuracy.process(aggs));
-        }
+    public void testComputePerClassAccuracy() {
+        assertThat(
+            Accuracy.computePerClassAccuracy(
+                new MulticlassConfusionMatrix.Result(
+                    List.of(
+                        new MulticlassConfusionMatrix.ActualClass("A", 14, List.of(
+                            new MulticlassConfusionMatrix.PredictedClass("A", 1),
+                            new MulticlassConfusionMatrix.PredictedClass("B", 6),
+                            new MulticlassConfusionMatrix.PredictedClass("C", 4)
+                        ), 3L),
+                        new MulticlassConfusionMatrix.ActualClass("B", 20, List.of(
+                            new MulticlassConfusionMatrix.PredictedClass("A", 5),
+                            new MulticlassConfusionMatrix.PredictedClass("B", 3),
+                            new MulticlassConfusionMatrix.PredictedClass("C", 9)
+                        ), 3L),
+                        new MulticlassConfusionMatrix.ActualClass("C", 17, List.of(
+                            new MulticlassConfusionMatrix.PredictedClass("A", 8),
+                            new MulticlassConfusionMatrix.PredictedClass("B", 2),
+                            new MulticlassConfusionMatrix.PredictedClass("C", 7)
+                        ), 0L)),
+                    0)),
+            equalTo(
+                List.of(
+                    new Accuracy.PerClassResult("A", 25.0 / 51),  // 13 false positives, 13 false negatives
+                    new Accuracy.PerClassResult("B", 26.0 / 51),  //  8 false positives, 17 false negatives
+                    new Accuracy.PerClassResult("C", 28.0 / 51))) // 13 false positives, 10 false negatives
+        );
     }
 
-    public void testProcess_GivenAggOfWrongType() {
-        {
-            Aggregations aggs = new Aggregations(Arrays.asList(
-                mockTerms("classification_classes"),
-                mockTerms("classification_overall_accuracy")
-            ));
-            Accuracy accuracy = new Accuracy();
-            expectThrows(ClassCastException.class, () -> accuracy.process(aggs));
-        }
-        {
-            Aggregations aggs = new Aggregations(Arrays.asList(
-                mockSingleValue("classification_classes", 1.0),
-                mockSingleValue("classification_overall_accuracy", 0.8123)
-            ));
-            Accuracy accuracy = new Accuracy();
-            expectThrows(ClassCastException.class, () -> accuracy.process(aggs));
-        }
+    public void testComputePerClassAccuracy_OtherActualClassCountIsNonZero() {
+        expectThrows(AssertionError.class, () -> Accuracy.computePerClassAccuracy(new MulticlassConfusionMatrix.Result(List.of(), 1)));
     }
 }

+ 24 - 20
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java

@@ -15,6 +15,7 @@ import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
 import org.elasticsearch.test.AbstractSerializingTestCase;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.Result;
 
 import java.io.IOException;
 import java.util.List;
@@ -54,20 +55,23 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
 
     public static MulticlassConfusionMatrix createRandom() {
         Integer size = randomBoolean() ? null : randomIntBetween(1, 1000);
-        return new MulticlassConfusionMatrix(size);
+        return new MulticlassConfusionMatrix(size, null);
     }
 
     public void testConstructor_SizeValidationFailures() {
         {
-            ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(-1));
+            ElasticsearchStatusException e =
+                expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(-1, null));
             assertThat(e.getMessage(), equalTo("[size] must be an integer in [1, 1000]"));
         }
         {
-            ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(0));
+            ElasticsearchStatusException e =
+                expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(0, null));
             assertThat(e.getMessage(), equalTo("[size] must be an integer in [1, 1000]"));
         }
         {
-            ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(1001));
+            ElasticsearchStatusException e =
+                expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(1001, null));
             assertThat(e.getMessage(), equalTo("[size] must be an integer in [1, 1000]"));
         }
     }
@@ -82,34 +86,34 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
     public void testEvaluate() {
         Aggregations aggs = new Aggregations(List.of(
             mockTerms(
-                "multiclass_confusion_matrix_step_1_by_actual_class",
+                MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS,
                 List.of(
                     mockTermsBucket("dog", new Aggregations(List.of())),
                     mockTermsBucket("cat", new Aggregations(List.of()))),
                 0L),
             mockFilters(
-                "multiclass_confusion_matrix_step_2_by_actual_class",
+                MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS,
                 List.of(
                     mockFiltersBucket(
                         "dog",
                         30,
                         new Aggregations(List.of(mockFilters(
-                            "multiclass_confusion_matrix_step_2_by_predicted_class",
+                            MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
                             List.of(mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))),
                     mockFiltersBucket(
                         "cat",
                         70,
                         new Aggregations(List.of(mockFilters(
-                            "multiclass_confusion_matrix_step_2_by_predicted_class",
+                            MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
                             List.of(mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L)))))))),
-            mockCardinality("multiclass_confusion_matrix_step_2_cardinality_of_actual_class", 2L)));
+            mockCardinality(MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 2L)));
 
-        MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2);
+        MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2, null);
         confusionMatrix.process(aggs);
 
         assertThat(confusionMatrix.aggs("act", "pred"), isTuple(empty(), empty()));
-        MulticlassConfusionMatrix.Result result = (MulticlassConfusionMatrix.Result) confusionMatrix.getResult().get();
-        assertThat(result.getMetricName(), equalTo("multiclass_confusion_matrix"));
+        Result result = confusionMatrix.getResult().get();
+        assertThat(result.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
         assertThat(
             result.getConfusionMatrix(),
             equalTo(
@@ -122,34 +126,34 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
     public void testEvaluate_OtherClassesCountGreaterThanZero() {
         Aggregations aggs = new Aggregations(List.of(
             mockTerms(
-                "multiclass_confusion_matrix_step_1_by_actual_class",
+                MulticlassConfusionMatrix.STEP_1_AGGREGATE_BY_ACTUAL_CLASS,
                 List.of(
                     mockTermsBucket("dog", new Aggregations(List.of())),
                     mockTermsBucket("cat", new Aggregations(List.of()))),
                 100L),
             mockFilters(
-                "multiclass_confusion_matrix_step_2_by_actual_class",
+                MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_ACTUAL_CLASS,
                 List.of(
                     mockFiltersBucket(
                         "dog",
                         30,
                         new Aggregations(List.of(mockFilters(
-                            "multiclass_confusion_matrix_step_2_by_predicted_class",
+                            MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
                             List.of(mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))),
                     mockFiltersBucket(
                         "cat",
                         85,
                         new Aggregations(List.of(mockFilters(
-                            "multiclass_confusion_matrix_step_2_by_predicted_class",
+                            MulticlassConfusionMatrix.STEP_2_AGGREGATE_BY_PREDICTED_CLASS,
                             List.of(mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 15L)))))))),
-            mockCardinality("multiclass_confusion_matrix_step_2_cardinality_of_actual_class", 5L)));
+            mockCardinality(MulticlassConfusionMatrix.STEP_2_CARDINALITY_OF_ACTUAL_CLASS, 5L)));
 
-        MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2);
+        MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2, null);
         confusionMatrix.process(aggs);
 
         assertThat(confusionMatrix.aggs("act", "pred"), isTuple(empty(), empty()));
-        MulticlassConfusionMatrix.Result result = (MulticlassConfusionMatrix.Result) confusionMatrix.getResult().get();
-        assertThat(result.getMetricName(), equalTo("multiclass_confusion_matrix"));
+        Result result = confusionMatrix.getResult().get();
+        assertThat(result.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
         assertThat(
             result.getConfusionMatrix(),
             equalTo(

+ 44 - 21
x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java

@@ -11,6 +11,7 @@ import org.elasticsearch.action.bulk.BulkResponse;
 import org.elasticsearch.action.index.IndexRequest;
 import org.elasticsearch.action.support.WriteRequest;
 import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall;
@@ -21,6 +22,8 @@ import org.junit.Before;
 
 import java.util.List;
 
+import static java.util.stream.Collectors.toList;
+import static org.hamcrest.Matchers.contains;
 import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.hasSize;
@@ -52,10 +55,28 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
             evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, null));
 
         assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
-        assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
         assertThat(
-            evaluateDataFrameResponse.getMetrics().get(0).getMetricName(),
-            equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
+            evaluateDataFrameResponse.getMetrics().stream().map(EvaluationMetricResult::getMetricName).collect(toList()),
+            contains(MulticlassConfusionMatrix.NAME.getPreferredName()));
+    }
+
+    public void testEvaluate_AllMetrics() {
+        EvaluateDataFrameAction.Response evaluateDataFrameResponse =
+            evaluateDataFrame(
+                ANIMALS_DATA_INDEX,
+                new Classification(
+                    ANIMAL_NAME_FIELD,
+                    ANIMAL_NAME_PREDICTION_FIELD,
+                    List.of(new Accuracy(), new MulticlassConfusionMatrix(), new Precision(), new Recall())));
+
+        assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
+        assertThat(
+            evaluateDataFrameResponse.getMetrics().stream().map(EvaluationMetricResult::getMetricName).collect(toList()),
+            contains(
+                Accuracy.NAME.getPreferredName(),
+                MulticlassConfusionMatrix.NAME.getPreferredName(),
+                Precision.NAME.getPreferredName(),
+                Recall.NAME.getPreferredName()));
     }
 
     public void testEvaluate_Accuracy_KeywordField() {
@@ -69,14 +90,14 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
         Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0);
         assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
         assertThat(
-            accuracyResult.getActualClasses(),
+            accuracyResult.getClasses(),
             equalTo(
                 List.of(
-                    new Accuracy.ActualClass("ant", 15, 1.0 / 15),
-                    new Accuracy.ActualClass("cat", 15, 1.0 / 15),
-                    new Accuracy.ActualClass("dog", 15, 1.0 / 15),
-                    new Accuracy.ActualClass("fox", 15, 1.0 / 15),
-                    new Accuracy.ActualClass("mouse", 15, 1.0 / 15))));
+                    new Accuracy.PerClassResult("ant", 47.0 / 75),
+                    new Accuracy.PerClassResult("cat", 47.0 / 75),
+                    new Accuracy.PerClassResult("dog", 47.0 / 75),
+                    new Accuracy.PerClassResult("fox", 47.0 / 75),
+                    new Accuracy.PerClassResult("mouse", 47.0 / 75))));
         assertThat(accuracyResult.getOverallAccuracy(), equalTo(5.0 / 75));
     }
 
@@ -91,13 +112,14 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
         Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0);
         assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
         assertThat(
-            accuracyResult.getActualClasses(),
-            equalTo(List.of(
-                new Accuracy.ActualClass("1", 15, 1.0 / 15),
-                new Accuracy.ActualClass("2", 15, 2.0 / 15),
-                new Accuracy.ActualClass("3", 15, 3.0 / 15),
-                new Accuracy.ActualClass("4", 15, 4.0 / 15),
-                new Accuracy.ActualClass("5", 15, 5.0 / 15))));
+            accuracyResult.getClasses(),
+            equalTo(
+                List.of(
+                    new Accuracy.PerClassResult("1", 57.0 / 75),
+                    new Accuracy.PerClassResult("2", 54.0 / 75),
+                    new Accuracy.PerClassResult("3", 51.0 / 75),
+                    new Accuracy.PerClassResult("4", 48.0 / 75),
+                    new Accuracy.PerClassResult("5", 45.0 / 75))));
         assertThat(accuracyResult.getOverallAccuracy(), equalTo(15.0 / 75));
     }
 
@@ -112,10 +134,11 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
         Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0);
         assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
         assertThat(
-            accuracyResult.getActualClasses(),
-            equalTo(List.of(
-                new Accuracy.ActualClass("true", 45, 27.0 / 45),
-                new Accuracy.ActualClass("false", 30, 18.0 / 30))));
+            accuracyResult.getClasses(),
+            equalTo(
+                List.of(
+                    new Accuracy.PerClassResult("false", 18.0 / 30),
+                    new Accuracy.PerClassResult("true", 27.0 / 45))));
         assertThat(accuracyResult.getOverallAccuracy(), equalTo(45.0 / 75));
     }
 
@@ -250,7 +273,7 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
         EvaluateDataFrameAction.Response evaluateDataFrameResponse =
             evaluateDataFrame(
                 ANIMALS_DATA_INDEX,
-                new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, List.of(new MulticlassConfusionMatrix(3))));
+                new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, List.of(new MulticlassConfusionMatrix(3, null))));
 
         assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
         assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));

+ 4 - 6
x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java

@@ -466,12 +466,10 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
         {   // Accuracy
             Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0);
             assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
-            List<Accuracy.ActualClass> actualClasses = accuracyResult.getActualClasses();
-            assertThat(
-                actualClasses.stream().map(Accuracy.ActualClass::getActualClass).collect(toList()),
-                equalTo(dependentVariableValuesAsStrings));
-            actualClasses.forEach(
-                actualClass -> assertThat(actualClass.getAccuracy(), allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0))));
+            for (Accuracy.PerClassResult klass : accuracyResult.getClasses()) {
+                assertThat(klass.getClassName(), is(in(dependentVariableValuesAsStrings)));
+                assertThat(klass.getAccuracy(), allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0)));
+            }
         }
 
         {   // MulticlassConfusionMatrix

+ 7 - 10
x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml

@@ -620,16 +620,13 @@ setup:
 
   - match:
       classification.accuracy:
-        actual_classes:
-          - actual_class: "cat"
-            actual_class_doc_count: 3
-            accuracy: 0.6666666666666666  # 2 out of 3
-          - actual_class: "dog"
-            actual_class_doc_count: 3
-            accuracy: 0.6666666666666666  # 2 out of 3
-          - actual_class: "mouse"
-            actual_class_doc_count: 2
-            accuracy: 0.5  # 1 out of 2
+        classes:
+          - class_name: "cat"
+            accuracy: 0.625  # 5 out of 8
+          - class_name: "dog"
+            accuracy: 0.75  # 6 out of 8
+          - class_name: "mouse"
+            accuracy: 0.875  # 7 out of 8
         overall_accuracy: 0.625  # 5 out of 8
 ---
 "Test classification precision":