Browse Source

Implement accuracy metric for multiclass classification (#47772)

Przemysław Witek 5 years ago
parent
commit
94ee36d61e
17 changed files with 908 additions and 77 deletions
  1. 5 0
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java
  2. 211 0
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AccuracyMetric.java
  3. 22 0
      client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java
  4. 6 3
      client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java
  5. 12 4
      client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java
  6. 63 0
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AccuracyMetricResultTests.java
  7. 53 0
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AccuracyMetricTests.java
  8. 6 5
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/ClassificationTests.java
  9. 6 2
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetricTests.java
  10. 7 4
      docs/java-rest/high-level/ml/evaluate-data-frame.asciidoc
  11. 6 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java
  12. 293 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java
  13. 44 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyResultTests.java
  14. 112 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyTests.java
  15. 3 4
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationTests.java
  16. 30 55
      x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java
  17. 29 0
      x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml

+ 5 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java

@@ -18,6 +18,7 @@
  */
 package org.elasticsearch.client.ml.dataframe.evaluation;
 
+import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
 import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
@@ -51,6 +52,8 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
             new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(RecallMetric.NAME), RecallMetric::fromXContent),
             new NamedXContentRegistry.Entry(
                 EvaluationMetric.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric::fromXContent),
+            new NamedXContentRegistry.Entry(
+                EvaluationMetric.class, new ParseField(AccuracyMetric.NAME), AccuracyMetric::fromXContent),
             new NamedXContentRegistry.Entry(
                 EvaluationMetric.class,
                 new ParseField(MulticlassConfusionMatrixMetric.NAME),
@@ -68,6 +71,8 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
                 EvaluationMetric.Result.class, new ParseField(RecallMetric.NAME), RecallMetric.Result::fromXContent),
             new NamedXContentRegistry.Entry(
                 EvaluationMetric.Result.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric.Result::fromXContent),
+            new NamedXContentRegistry.Entry(
+                EvaluationMetric.Result.class, new ParseField(AccuracyMetric.NAME), AccuracyMetric.Result::fromXContent),
             new NamedXContentRegistry.Entry(
                 EvaluationMetric.Result.class,
                 new ParseField(MulticlassConfusionMatrixMetric.NAME),

+ 211 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AccuracyMetric.java

@@ -0,0 +1,211 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.dataframe.evaluation.classification;
+
+import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ObjectParser;
+import org.elasticsearch.common.xcontent.ToXContent;
+import org.elasticsearch.common.xcontent.ToXContentObject;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
+import java.util.Objects;
+
+import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
+
+/**
+ * {@link AccuracyMetric} is a metric that answers the question:
+ *   "What fraction of examples have been classified correctly by the classifier?"
+ *
+ * equation: accuracy = 1/n * Σ(y == y´)
+ */
+public class AccuracyMetric implements EvaluationMetric {
+
+    public static final String NAME = "accuracy";
+
+    private static final ObjectParser<AccuracyMetric, Void> PARSER = new ObjectParser<>(NAME, true, AccuracyMetric::new);
+
+    public static AccuracyMetric fromXContent(XContentParser parser) {
+        return PARSER.apply(parser, null);
+    }
+
+    public AccuracyMetric() {}
+
+    @Override
+    public String getName() {
+        return NAME;
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
+        builder.startObject();
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        return true;
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hashCode(NAME);
+    }
+
+    public static class Result implements EvaluationMetric.Result {
+
+        private static final ParseField ACTUAL_CLASSES = new ParseField("actual_classes");
+        private static final ParseField OVERALL_ACCURACY = new ParseField("overall_accuracy");
+
+        @SuppressWarnings("unchecked")
+        private static final ConstructingObjectParser<Result, Void> PARSER =
+            new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List<ActualClass>) a[0], (double) a[1]));
+
+        static {
+            PARSER.declareObjectArray(constructorArg(), ActualClass.PARSER, ACTUAL_CLASSES);
+            PARSER.declareDouble(constructorArg(), OVERALL_ACCURACY);
+        }
+
+        public static Result fromXContent(XContentParser parser) {
+            return PARSER.apply(parser, null);
+        }
+
+        /** List of actual classes. */
+        private final List<ActualClass> actualClasses;
+        /** Fraction of documents predicted correctly. */
+        private final double overallAccuracy;
+
+        public Result(List<ActualClass> actualClasses, double overallAccuracy) {
+            this.actualClasses = Collections.unmodifiableList(Objects.requireNonNull(actualClasses));
+            this.overallAccuracy = overallAccuracy;
+        }
+
+        @Override
+        public String getMetricName() {
+            return NAME;
+        }
+
+        public List<ActualClass> getActualClasses() {
+            return actualClasses;
+        }
+
+        public double getOverallAccuracy() {
+            return overallAccuracy;
+        }
+
+        @Override
+        public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+            builder.startObject();
+            builder.field(ACTUAL_CLASSES.getPreferredName(), actualClasses);
+            builder.field(OVERALL_ACCURACY.getPreferredName(), overallAccuracy);
+            builder.endObject();
+            return builder;
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            Result that = (Result) o;
+            return Objects.equals(this.actualClasses, that.actualClasses)
+                && this.overallAccuracy == that.overallAccuracy;
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(actualClasses, overallAccuracy);
+        }
+    }
+
+    public static class ActualClass implements ToXContentObject {
+
+        private static final ParseField ACTUAL_CLASS = new ParseField("actual_class");
+        private static final ParseField ACTUAL_CLASS_DOC_COUNT = new ParseField("actual_class_doc_count");
+        private static final ParseField ACCURACY = new ParseField("accuracy");
+
+        @SuppressWarnings("unchecked")
+        private static final ConstructingObjectParser<ActualClass, Void> PARSER =
+            new ConstructingObjectParser<>("accuracy_actual_class", true, a -> new ActualClass((String) a[0], (long) a[1], (double) a[2]));
+
+        static {
+            PARSER.declareString(constructorArg(), ACTUAL_CLASS);
+            PARSER.declareLong(constructorArg(), ACTUAL_CLASS_DOC_COUNT);
+            PARSER.declareDouble(constructorArg(), ACCURACY);
+        }
+
+        /** Name of the actual class. */
+        private final String actualClass;
+        /** Number of documents (examples) belonging to the {code actualClass} class. */
+        private final long actualClassDocCount;
+        /** Fraction of documents belonging to the {code actualClass} class predicted correctly. */
+        private final double accuracy;
+
+        public ActualClass(
+            String actualClass, long actualClassDocCount, double accuracy) {
+            this.actualClass = Objects.requireNonNull(actualClass);
+            this.actualClassDocCount = actualClassDocCount;
+            this.accuracy = accuracy;
+        }
+
+        public String getActualClass() {
+            return actualClass;
+        }
+
+        public long getActualClassDocCount() {
+            return actualClassDocCount;
+        }
+
+        public double getAccuracy() {
+            return accuracy;
+        }
+
+        @Override
+        public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+            builder.startObject();
+            builder.field(ACTUAL_CLASS.getPreferredName(), actualClass);
+            builder.field(ACTUAL_CLASS_DOC_COUNT.getPreferredName(), actualClassDocCount);
+            builder.field(ACCURACY.getPreferredName(), accuracy);
+            builder.endObject();
+            return builder;
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            ActualClass that = (ActualClass) o;
+            return Objects.equals(this.actualClass, that.actualClass)
+                && this.actualClassDocCount == that.actualClassDocCount
+                && this.accuracy == that.accuracy;
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(actualClass, actualClassDocCount, accuracy);
+        }
+    }
+}

+ 22 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java

@@ -125,6 +125,7 @@ import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsStats;
 import org.elasticsearch.client.ml.dataframe.OutlierDetection;
 import org.elasticsearch.client.ml.dataframe.PhaseProgress;
 import org.elasticsearch.client.ml.dataframe.QueryConfig;
+import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
 import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.ActualClass;
@@ -1783,6 +1784,27 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
 
         MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
 
+        {  // Accuracy
+            EvaluateDataFrameRequest evaluateDataFrameRequest =
+                new EvaluateDataFrameRequest(
+                    indexName, null, new Classification(actualClassField, predictedClassField, new AccuracyMetric()));
+
+            EvaluateDataFrameResponse evaluateDataFrameResponse =
+                execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
+            assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME));
+            assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
+
+            AccuracyMetric.Result accuracyResult = evaluateDataFrameResponse.getMetricByName(AccuracyMetric.NAME);
+            assertThat(accuracyResult.getMetricName(), equalTo(AccuracyMetric.NAME));
+            assertThat(
+                accuracyResult.getActualClasses(),
+                equalTo(
+                    List.of(
+                        new AccuracyMetric.ActualClass("cat", 5, 0.6),  // 3 out of 5 examples labeled as "cat" were classified correctly
+                        new AccuracyMetric.ActualClass("dog", 4, 0.75),  // 3 out of 4 examples labeled as "dog" were classified correctly
+                        new AccuracyMetric.ActualClass("ant", 1, 0.0))));  // no examples labeled as "ant" were classified correctly
+            assertThat(accuracyResult.getOverallAccuracy(), equalTo(0.6));  // 6 out of 10 examples were classified correctly
+        }
         {  // No size provided for MulticlassConfusionMatrixMetric, default used instead
             EvaluateDataFrameRequest evaluateDataFrameRequest =
                 new EvaluateDataFrameRequest(

+ 6 - 3
client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java

@@ -57,6 +57,7 @@ import org.elasticsearch.client.ilm.ShrinkAction;
 import org.elasticsearch.client.ilm.UnfollowAction;
 import org.elasticsearch.client.ml.dataframe.DataFrameAnalysis;
 import org.elasticsearch.client.ml.dataframe.OutlierDetection;
+import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
 import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
@@ -687,7 +688,7 @@ public class RestHighLevelClientTests extends ESTestCase {
 
     public void testProvidedNamedXContents() {
         List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
-        assertEquals(49, namedXContents.size());
+        assertEquals(51, namedXContents.size());
         Map<Class<?>, Integer> categories = new HashMap<>();
         List<String> names = new ArrayList<>();
         for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
@@ -729,21 +730,23 @@ public class RestHighLevelClientTests extends ESTestCase {
         assertTrue(names.contains(TimeSyncConfig.NAME));
         assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class));
         assertThat(names, hasItems(BinarySoftClassification.NAME, Classification.NAME, Regression.NAME));
-        assertEquals(Integer.valueOf(7), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class));
+        assertEquals(Integer.valueOf(8), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class));
         assertThat(names,
             hasItems(AucRocMetric.NAME,
                 PrecisionMetric.NAME,
                 RecallMetric.NAME,
                 ConfusionMatrixMetric.NAME,
+                AccuracyMetric.NAME,
                 MulticlassConfusionMatrixMetric.NAME,
                 MeanSquaredErrorMetric.NAME,
                 RSquaredMetric.NAME));
-        assertEquals(Integer.valueOf(7), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class));
+        assertEquals(Integer.valueOf(8), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class));
         assertThat(names,
             hasItems(AucRocMetric.NAME,
                 PrecisionMetric.NAME,
                 RecallMetric.NAME,
                 ConfusionMatrixMetric.NAME,
+                AccuracyMetric.NAME,
                 MulticlassConfusionMatrixMetric.NAME,
                 MeanSquaredErrorMetric.NAME,
                 RSquaredMetric.NAME));

+ 12 - 4
client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java

@@ -141,6 +141,7 @@ import org.elasticsearch.client.ml.dataframe.OutlierDetection;
 import org.elasticsearch.client.ml.dataframe.QueryConfig;
 import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation;
 import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
+import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.ActualClass;
 import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.PredictedClass;
@@ -3347,20 +3348,27 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
                     "actual_class", // <2>
                     "predicted_class", // <3>
                     // Evaluation metrics // <4>
-                    new MulticlassConfusionMatrixMetric(3)); // <5>
+                    new AccuracyMetric(), // <5>
+                    new MulticlassConfusionMatrixMetric(3)); // <6>
             // end::evaluate-data-frame-evaluation-classification
 
             EvaluateDataFrameRequest request = new EvaluateDataFrameRequest(indexName, null, evaluation);
             EvaluateDataFrameResponse response = client.machineLearning().evaluateDataFrame(request, RequestOptions.DEFAULT);
 
             // tag::evaluate-data-frame-results-classification
+            AccuracyMetric.Result accuracyResult = response.getMetricByName(AccuracyMetric.NAME); // <1>
+            double accuracy = accuracyResult.getOverallAccuracy(); // <2>
+
             MulticlassConfusionMatrixMetric.Result multiclassConfusionMatrix =
-                response.getMetricByName(MulticlassConfusionMatrixMetric.NAME); // <1>
+                response.getMetricByName(MulticlassConfusionMatrixMetric.NAME); // <3>
 
-            List<ActualClass> confusionMatrix = multiclassConfusionMatrix.getConfusionMatrix(); // <2>
-            long otherClassesCount = multiclassConfusionMatrix.getOtherActualClassCount(); // <3>
+            List<ActualClass> confusionMatrix = multiclassConfusionMatrix.getConfusionMatrix(); // <4>
+            long otherClassesCount = multiclassConfusionMatrix.getOtherActualClassCount(); // <5>
             // end::evaluate-data-frame-results-classification
 
+            assertThat(accuracyResult.getMetricName(), equalTo(AccuracyMetric.NAME));
+            assertThat(accuracy, equalTo(0.6));
+
             assertThat(multiclassConfusionMatrix.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME));
             assertThat(
                 confusionMatrix,

+ 63 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AccuracyMetricResultTests.java

@@ -0,0 +1,63 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.dataframe.evaluation.classification;
+
+import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
+import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric.ActualClass;
+import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric.Result;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.test.AbstractXContentTestCase;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+public class AccuracyMetricResultTests extends AbstractXContentTestCase<AccuracyMetric.Result> {
+
+    @Override
+    protected NamedXContentRegistry xContentRegistry() {
+        return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
+    }
+
+    @Override
+    protected AccuracyMetric.Result createTestInstance() {
+        int numClasses = randomIntBetween(2, 100);
+        List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
+        List<ActualClass> actualClasses = new ArrayList<>(numClasses);
+        for (int i = 0; i < numClasses; i++) {
+            double accuracy = randomDoubleBetween(0.0, 1.0, true);
+            actualClasses.add(new ActualClass(classNames.get(i), randomNonNegativeLong(), accuracy));
+        }
+        double overallAccuracy = randomDoubleBetween(0.0, 1.0, true);
+        return new Result(actualClasses, overallAccuracy);
+    }
+
+    @Override
+    protected AccuracyMetric.Result doParseInstance(XContentParser parser) throws IOException {
+        return AccuracyMetric.Result.fromXContent(parser);
+    }
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return true;
+    }
+}

+ 53 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AccuracyMetricTests.java

@@ -0,0 +1,53 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.dataframe.evaluation.classification;
+
+import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.test.AbstractXContentTestCase;
+
+import java.io.IOException;
+
+public class AccuracyMetricTests extends AbstractXContentTestCase<AccuracyMetric> {
+
+    @Override
+    protected NamedXContentRegistry xContentRegistry() {
+        return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
+    }
+
+    static AccuracyMetric createRandom() {
+        return new AccuracyMetric();
+    }
+
+    @Override
+    protected AccuracyMetric createTestInstance() {
+        return createRandom();
+    }
+
+    @Override
+    protected AccuracyMetric doParseInstance(XContentParser parser) throws IOException {
+        return AccuracyMetric.fromXContent(parser);
+    }
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return true;
+    }
+}

+ 6 - 5
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/ClassificationTests.java

@@ -18,6 +18,7 @@
  */
 package org.elasticsearch.client.ml.dataframe.evaluation.classification;
 
+import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
 import org.elasticsearch.common.xcontent.NamedXContentRegistry;
 import org.elasticsearch.common.xcontent.XContentParser;
@@ -25,6 +26,7 @@ import org.elasticsearch.test.AbstractXContentTestCase;
 
 import java.io.IOException;
 import java.util.Arrays;
+import java.util.List;
 import java.util.function.Predicate;
 
 public class ClassificationTests extends AbstractXContentTestCase<Classification> {
@@ -34,11 +36,10 @@ public class ClassificationTests extends AbstractXContentTestCase<Classification
         return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
     }
 
-    public static Classification createRandom() {
-        return new Classification(
-            randomAlphaOfLength(10),
-            randomAlphaOfLength(10),
-            randomBoolean() ? null : Arrays.asList(new MulticlassConfusionMatrixMetric()));
+    static Classification createRandom() {
+        List<EvaluationMetric> metrics =
+            randomSubsetOf(Arrays.asList(AccuracyMetricTests.createRandom(), MulticlassConfusionMatrixMetricTests.createRandom()));
+        return new Classification(randomAlphaOfLength(10), randomAlphaOfLength(10), metrics.isEmpty() ? null : metrics);
     }
 
     @Override

+ 6 - 2
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetricTests.java

@@ -32,12 +32,16 @@ public class MulticlassConfusionMatrixMetricTests extends AbstractXContentTestCa
         return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
     }
 
-    @Override
-    protected MulticlassConfusionMatrixMetric createTestInstance() {
+    static MulticlassConfusionMatrixMetric createRandom() {
         Integer size = randomBoolean() ? randomIntBetween(1, 1000) : null;
         return new MulticlassConfusionMatrixMetric(size);
     }
 
+    @Override
+    protected MulticlassConfusionMatrixMetric createTestInstance() {
+        return createRandom();
+    }
+
     @Override
     protected MulticlassConfusionMatrixMetric doParseInstance(XContentParser parser) throws IOException {
         return MulticlassConfusionMatrixMetric.fromXContent(parser);

+ 7 - 4
docs/java-rest/high-level/ml/evaluate-data-frame.asciidoc

@@ -52,7 +52,8 @@ include-tagged::{doc-tests-file}[{api}-evaluation-classification]
 <2> Name of the field in the index. Its value denotes the actual (i.e. ground truth) class the example belongs to.
 <3> Name of the field in the index. Its value denotes the predicted (as per some ML algorithm) class of the example.
 <4> The remaining parameters are the metrics to be calculated based on the two fields described above
-<5> Multiclass confusion matrix of size 3
+<5> Accuracy
+<6> Multiclass confusion matrix of size 3
 
 ===== Regression
 
@@ -101,9 +102,11 @@ include-tagged::{doc-tests-file}[{api}-results-softclassification]
 include-tagged::{doc-tests-file}[{api}-results-classification]
 --------------------------------------------------
 
-<1> Fetching multiclass confusion matrix metric by name
-<2> Fetching the contents of the confusion matrix
-<3> Fetching the number of classes that were not included in the matrix
+<1> Fetching accuracy metric by name
+<2> Fetching the actual accuracy value
+<3> Fetching multiclass confusion matrix metric by name
+<4> Fetching the contents of the confusion matrix
+<5> Fetching the number of classes that were not included in the matrix
 
 ===== Regression
 

+ 6 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java

@@ -8,6 +8,7 @@ package org.elasticsearch.xpack.core.ml.dataframe.evaluation;
 import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.xcontent.NamedXContentRegistry;
 import org.elasticsearch.plugins.spi.NamedXContentProvider;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.ClassificationMetric;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError;
@@ -48,6 +49,7 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
         // Classification metrics
         namedXContent.add(new NamedXContentRegistry.Entry(ClassificationMetric.class, MulticlassConfusionMatrix.NAME,
             MulticlassConfusionMatrix::fromXContent));
+        namedXContent.add(new NamedXContentRegistry.Entry(ClassificationMetric.class, Accuracy.NAME, Accuracy::fromXContent));
 
         // Regression metrics
         namedXContent.add(new NamedXContentRegistry.Entry(RegressionMetric.class, MeanSquaredError.NAME, MeanSquaredError::fromXContent));
@@ -78,6 +80,7 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
         namedWriteables.add(new NamedWriteableRegistry.Entry(ClassificationMetric.class,
             MulticlassConfusionMatrix.NAME.getPreferredName(),
             MulticlassConfusionMatrix::new));
+        namedWriteables.add(new NamedWriteableRegistry.Entry(ClassificationMetric.class, Accuracy.NAME.getPreferredName(), Accuracy::new));
         namedWriteables.add(new NamedWriteableRegistry.Entry(RegressionMetric.class,
             MeanSquaredError.NAME.getPreferredName(),
             MeanSquaredError::new));
@@ -95,6 +98,9 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
         namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
             MulticlassConfusionMatrix.NAME.getPreferredName(),
             MulticlassConfusionMatrix.Result::new));
+        namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
+            Accuracy.NAME.getPreferredName(),
+            Accuracy.Result::new));
         namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
             MeanSquaredError.NAME.getPreferredName(),
             MeanSquaredError.Result::new));

+ 293 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java

@@ -0,0 +1,293 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
+
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ObjectParser;
+import org.elasticsearch.common.xcontent.ToXContentObject;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.script.Script;
+import org.elasticsearch.search.aggregations.AggregationBuilder;
+import org.elasticsearch.search.aggregations.AggregationBuilders;
+import org.elasticsearch.search.aggregations.Aggregations;
+import org.elasticsearch.search.aggregations.bucket.terms.Terms;
+import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+
+import java.io.IOException;
+import java.text.MessageFormat;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Locale;
+import java.util.Objects;
+import java.util.Optional;
+
+import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
+
+/**
+ * {@link Accuracy} is a metric that answers the question:
+ *   "What fraction of examples have been classified correctly by the classifier?"
+ *
+ * equation: accuracy = 1/n * Σ(y == y´)
+ */
+public class Accuracy implements ClassificationMetric {
+
+    public static final ParseField NAME = new ParseField("accuracy");
+
+    private static final String PAINLESS_TEMPLATE = "doc[''{0}''].value == doc[''{1}''].value";
+    private static final String CLASSES_AGG_NAME = "classification_classes";
+    private static final String PER_CLASS_ACCURACY_AGG_NAME = "classification_per_class_accuracy";
+    private static final String OVERALL_ACCURACY_AGG_NAME = "classification_overall_accuracy";
+
+    private static String buildScript(Object...args) {
+        return new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args);
+    }
+
+    private static final ObjectParser<Accuracy, Void> PARSER = new ObjectParser<>(NAME.getPreferredName(), true, Accuracy::new);
+
+    public static Accuracy fromXContent(XContentParser parser) {
+        return PARSER.apply(parser, null);
+    }
+
+    private EvaluationMetricResult result;
+
+    public Accuracy() {}
+
+    public Accuracy(StreamInput in) throws IOException {}
+
+    @Override
+    public String getWriteableName() {
+        return NAME.getPreferredName();
+    }
+
+    @Override
+    public String getName() {
+        return NAME.getPreferredName();
+    }
+
+    @Override
+    public final List<AggregationBuilder> aggs(String actualField, String predictedField) {
+        if (result != null) {
+            return List.of();
+        }
+        Script accuracyScript = new Script(buildScript(actualField, predictedField));
+        return List.of(
+            AggregationBuilders.terms(CLASSES_AGG_NAME)
+                .field(actualField)
+                .subAggregation(AggregationBuilders.avg(PER_CLASS_ACCURACY_AGG_NAME).script(accuracyScript)),
+            AggregationBuilders.avg(OVERALL_ACCURACY_AGG_NAME).script(accuracyScript));
+    }
+
+    @Override
+    public void process(Aggregations aggs) {
+        if (result != null) {
+            return;
+        }
+        Terms classesAgg = aggs.get(CLASSES_AGG_NAME);
+        NumericMetricsAggregation.SingleValue overallAccuracyAgg = aggs.get(OVERALL_ACCURACY_AGG_NAME);
+        List<ActualClass> actualClasses = new ArrayList<>(classesAgg.getBuckets().size());
+        for (Terms.Bucket bucket : classesAgg.getBuckets()) {
+            String actualClass = bucket.getKeyAsString();
+            long actualClassDocCount = bucket.getDocCount();
+            NumericMetricsAggregation.SingleValue accuracyAgg = bucket.getAggregations().get(PER_CLASS_ACCURACY_AGG_NAME);
+            actualClasses.add(new ActualClass(actualClass, actualClassDocCount, accuracyAgg.value()));
+        }
+        result = new Result(actualClasses, overallAccuracyAgg.value());
+    }
+
+    @Override
+    public Optional<EvaluationMetricResult> getResult() {
+        return Optional.ofNullable(result);
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        return true;
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hashCode(NAME.getPreferredName());
+    }
+
+    public static class Result implements EvaluationMetricResult {
+
+        private static final ParseField ACTUAL_CLASSES = new ParseField("actual_classes");
+        private static final ParseField OVERALL_ACCURACY = new ParseField("overall_accuracy");
+
+        @SuppressWarnings("unchecked")
+        private static final ConstructingObjectParser<Result, Void> PARSER =
+            new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List<ActualClass>) a[0], (double) a[1]));
+
+        static {
+            PARSER.declareObjectArray(constructorArg(), ActualClass.PARSER, ACTUAL_CLASSES);
+            PARSER.declareDouble(constructorArg(), OVERALL_ACCURACY);
+        }
+
+        public static Result fromXContent(XContentParser parser) {
+            return PARSER.apply(parser, null);
+        }
+
+        /** List of actual classes. */
+        private final List<ActualClass> actualClasses;
+        /** Fraction of documents predicted correctly. */
+        private final double overallAccuracy;
+
+        public Result(List<ActualClass> actualClasses, double overallAccuracy) {
+            this.actualClasses = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(actualClasses, ACTUAL_CLASSES));
+            this.overallAccuracy = overallAccuracy;
+        }
+
+        public Result(StreamInput in) throws IOException {
+            this.actualClasses = Collections.unmodifiableList(in.readList(ActualClass::new));
+            this.overallAccuracy = in.readDouble();
+        }
+
+        @Override
+        public String getWriteableName() {
+            return NAME.getPreferredName();
+        }
+
+        @Override
+        public String getMetricName() {
+            return NAME.getPreferredName();
+        }
+
+        public List<ActualClass> getActualClasses() {
+            return actualClasses;
+        }
+
+        public double getOverallAccuracy() {
+            return overallAccuracy;
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            out.writeList(actualClasses);
+            out.writeDouble(overallAccuracy);
+        }
+
+        @Override
+        public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+            builder.startObject();
+            builder.field(ACTUAL_CLASSES.getPreferredName(), actualClasses);
+            builder.field(OVERALL_ACCURACY.getPreferredName(), overallAccuracy);
+            builder.endObject();
+            return builder;
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            Result that = (Result) o;
+            return Objects.equals(this.actualClasses, that.actualClasses)
+                && this.overallAccuracy == that.overallAccuracy;
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(actualClasses, overallAccuracy);
+        }
+    }
+
+    public static class ActualClass implements ToXContentObject, Writeable {
+
+        private static final ParseField ACTUAL_CLASS = new ParseField("actual_class");
+        private static final ParseField ACTUAL_CLASS_DOC_COUNT = new ParseField("actual_class_doc_count");
+        private static final ParseField ACCURACY = new ParseField("accuracy");
+
+        @SuppressWarnings("unchecked")
+        private static final ConstructingObjectParser<ActualClass, Void> PARSER =
+            new ConstructingObjectParser<>("accuracy_actual_class", true, a -> new ActualClass((String) a[0], (long) a[1], (double) a[2]));
+
+        static {
+            PARSER.declareString(constructorArg(), ACTUAL_CLASS);
+            PARSER.declareLong(constructorArg(), ACTUAL_CLASS_DOC_COUNT);
+            PARSER.declareDouble(constructorArg(), ACCURACY);
+        }
+
+        /** Name of the actual class. */
+        private final String actualClass;
+        /** Number of documents (examples) belonging to the {code actualClass} class. */
+        private final long actualClassDocCount;
+        /** Fraction of documents belonging to the {code actualClass} class predicted correctly. */
+        private final double accuracy;
+
+        public ActualClass(
+            String actualClass, long actualClassDocCount, double accuracy) {
+            this.actualClass = ExceptionsHelper.requireNonNull(actualClass, ACTUAL_CLASS);
+            this.actualClassDocCount = actualClassDocCount;
+            this.accuracy = accuracy;
+        }
+
+        public ActualClass(StreamInput in) throws IOException {
+            this.actualClass = in.readString();
+            this.actualClassDocCount = in.readVLong();
+            this.accuracy = in.readDouble();
+        }
+
+        public String getActualClass() {
+            return actualClass;
+        }
+
+        public double getAccuracy() {
+            return accuracy;
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            out.writeString(actualClass);
+            out.writeVLong(actualClassDocCount);
+            out.writeDouble(accuracy);
+        }
+
+        @Override
+        public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+            builder.startObject();
+            builder.field(ACTUAL_CLASS.getPreferredName(), actualClass);
+            builder.field(ACTUAL_CLASS_DOC_COUNT.getPreferredName(), actualClassDocCount);
+            builder.field(ACCURACY.getPreferredName(), accuracy);
+            builder.endObject();
+            return builder;
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            ActualClass that = (ActualClass) o;
+            return Objects.equals(this.actualClass, that.actualClass)
+                && this.actualClassDocCount == that.actualClassDocCount
+                && this.accuracy == that.accuracy;
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(actualClass, actualClassDocCount, accuracy);
+        }
+    }
+}

+ 44 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyResultTests.java

@@ -0,0 +1,44 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
+
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.ActualClass;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy.Result;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+public class AccuracyResultTests extends AbstractWireSerializingTestCase<Accuracy.Result> {
+
+    @Override
+    protected NamedWriteableRegistry getNamedWriteableRegistry() {
+        return new NamedWriteableRegistry(new MlEvaluationNamedXContentProvider().getNamedWriteables());
+    }
+
+    @Override
+    protected Accuracy.Result createTestInstance() {
+        int numClasses = randomIntBetween(2, 100);
+        List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
+        List<ActualClass> actualClasses = new ArrayList<>(numClasses);
+        for (int i = 0; i < numClasses; i++) {
+            double accuracy = randomDoubleBetween(0.0, 1.0, true);
+            actualClasses.add(new ActualClass(classNames.get(i), randomNonNegativeLong(), accuracy));
+        }
+        double overallAccuracy = randomDoubleBetween(0.0, 1.0, true);
+        return new Result(actualClasses, overallAccuracy);
+    }
+
+    @Override
+    protected Writeable.Reader<Accuracy.Result> instanceReader() {
+        return Accuracy.Result::new;
+    }
+}

+ 112 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyTests.java

@@ -0,0 +1,112 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
+
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.search.aggregations.Aggregations;
+import org.elasticsearch.search.aggregations.bucket.terms.Terms;
+import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
+import org.elasticsearch.test.AbstractSerializingTestCase;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.List;
+
+import static org.hamcrest.Matchers.equalTo;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+public class AccuracyTests extends AbstractSerializingTestCase<Accuracy> {
+
+    @Override
+    protected Accuracy doParseInstance(XContentParser parser) throws IOException {
+        return Accuracy.fromXContent(parser);
+    }
+
+    @Override
+    protected Accuracy createTestInstance() {
+        return createRandom();
+    }
+
+    @Override
+    protected Writeable.Reader<Accuracy> instanceReader() {
+        return Accuracy::new;
+    }
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return true;
+    }
+
+    public static Accuracy createRandom() {
+        return new Accuracy();
+    }
+
+    public void testProcess() {
+        Aggregations aggs = new Aggregations(Arrays.asList(
+            createTermsAgg("classification_classes"),
+            createSingleMetricAgg("classification_overall_accuracy", 0.8123),
+            createSingleMetricAgg("some_other_single_metric_agg", 0.2377)
+        ));
+
+        Accuracy accuracy = new Accuracy();
+        accuracy.process(aggs);
+
+        assertThat(accuracy.getResult().get(), equalTo(new Accuracy.Result(List.of(), 0.8123)));
+    }
+
+    public void testProcess_GivenMissingAgg() {
+        {
+            Aggregations aggs = new Aggregations(Arrays.asList(
+                createTermsAgg("classification_classes"),
+                createSingleMetricAgg("some_other_single_metric_agg", 0.2377)
+            ));
+            Accuracy accuracy = new Accuracy();
+            expectThrows(NullPointerException.class, () -> accuracy.process(aggs));
+        }
+        {
+            Aggregations aggs = new Aggregations(Arrays.asList(
+                createSingleMetricAgg("classification_overall_accuracy", 0.8123),
+                createSingleMetricAgg("some_other_single_metric_agg", 0.2377)
+            ));
+            Accuracy accuracy = new Accuracy();
+            expectThrows(NullPointerException.class, () -> accuracy.process(aggs));
+        }
+    }
+
+    public void testProcess_GivenAggOfWrongType() {
+        {
+            Aggregations aggs = new Aggregations(Arrays.asList(
+                createTermsAgg("classification_classes"),
+                createTermsAgg("classification_overall_accuracy")
+            ));
+            Accuracy accuracy = new Accuracy();
+            expectThrows(ClassCastException.class, () -> accuracy.process(aggs));
+        }
+        {
+            Aggregations aggs = new Aggregations(Arrays.asList(
+                createSingleMetricAgg("classification_classes", 1.0),
+                createSingleMetricAgg("classification_overall_accuracy", 0.8123)
+            ));
+            Accuracy accuracy = new Accuracy();
+            expectThrows(ClassCastException.class, () -> accuracy.process(aggs));
+        }
+    }
+
+    private static NumericMetricsAggregation.SingleValue createSingleMetricAgg(String name, double value) {
+        NumericMetricsAggregation.SingleValue agg = mock(NumericMetricsAggregation.SingleValue.class);
+        when(agg.getName()).thenReturn(name);
+        when(agg.value()).thenReturn(value);
+        return agg;
+    }
+
+    private static Terms createTermsAgg(String name) {
+        Terms agg = mock(Terms.class);
+        when(agg.getName()).thenReturn(name);
+        return agg;
+    }
+}

+ 3 - 4
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationTests.java

@@ -51,10 +51,9 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
     }
 
     public static Classification createRandom() {
-        return new Classification(
-            randomAlphaOfLength(10),
-            randomAlphaOfLength(10),
-            randomBoolean() ? null : Arrays.asList(MulticlassConfusionMatrixTests.createRandom()));
+        List<ClassificationMetric> metrics =
+            randomSubsetOf(Arrays.asList(AccuracyTests.createRandom(), MulticlassConfusionMatrixTests.createRandom()));
+        return new Classification(randomAlphaOfLength(10), randomAlphaOfLength(10), metrics.isEmpty() ? null : metrics);
     }
 
     @Override

+ 30 - 55
x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java

@@ -10,6 +10,7 @@ import org.elasticsearch.action.bulk.BulkResponse;
 import org.elasticsearch.action.index.IndexRequest;
 import org.elasticsearch.action.support.WriteRequest;
 import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass;
@@ -20,6 +21,7 @@ import org.junit.Before;
 import java.util.List;
 
 import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.hasSize;
 
 public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
 
@@ -43,69 +45,42 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
             evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, null));
 
         assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
-        assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
-        MulticlassConfusionMatrix.Result confusionMatrixResult =
-            (MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(0);
-        assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
+        assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
         assertThat(
-            confusionMatrixResult.getConfusionMatrix(),
-            equalTo(List.of(
-                new ActualClass("ant",
-                    15,
-                    List.of(
-                        new PredictedClass("ant", 1L),
-                        new PredictedClass("cat", 4L),
-                        new PredictedClass("dog", 3L),
-                        new PredictedClass("fox", 2L),
-                        new PredictedClass("mouse", 5L)),
-                    0),
-                new ActualClass("cat",
-                    15,
-                    List.of(
-                        new PredictedClass("ant", 3L),
-                        new PredictedClass("cat", 1L),
-                        new PredictedClass("dog", 5L),
-                        new PredictedClass("fox", 4L),
-                        new PredictedClass("mouse", 2L)),
-                    0),
-                new ActualClass("dog",
-                    15,
-                    List.of(
-                        new PredictedClass("ant", 4L),
-                        new PredictedClass("cat", 2L),
-                        new PredictedClass("dog", 1L),
-                        new PredictedClass("fox", 5L),
-                        new PredictedClass("mouse", 3L)),
-                    0),
-                new ActualClass("fox",
-                    15,
-                    List.of(
-                        new PredictedClass("ant", 5L),
-                        new PredictedClass("cat", 3L),
-                        new PredictedClass("dog", 2L),
-                        new PredictedClass("fox", 1L),
-                        new PredictedClass("mouse", 4L)),
-                    0),
-                new ActualClass("mouse",
-                    15,
-                    List.of(
-                        new PredictedClass("ant", 2L),
-                        new PredictedClass("cat", 5L),
-                        new PredictedClass("dog", 4L),
-                        new PredictedClass("fox", 3L),
-                        new PredictedClass("mouse", 1L)),
-                    0))));
-        assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(0L));
+            evaluateDataFrameResponse.getMetrics().get(0).getMetricName(),
+            equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
+    }
+
+    public void testEvaluate_MulticlassClassification_Accuracy() {
+        EvaluateDataFrameAction.Response evaluateDataFrameResponse =
+            evaluateDataFrame(ANIMALS_DATA_INDEX, new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, List.of(new Accuracy())));
+
+        assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
+        assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
+
+        Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0);
+        assertThat(accuracyResult.getMetricName(), equalTo(Accuracy.NAME.getPreferredName()));
+        assertThat(
+            accuracyResult.getActualClasses(),
+            equalTo(
+                List.of(
+                    new Accuracy.ActualClass("ant", 15, 1.0 / 15),
+                    new Accuracy.ActualClass("cat", 15, 1.0 / 15),
+                    new Accuracy.ActualClass("dog", 15, 1.0 / 15),
+                    new Accuracy.ActualClass("fox", 15, 1.0 / 15),
+                    new Accuracy.ActualClass("mouse", 15, 1.0 / 15))));
+        assertThat(accuracyResult.getOverallAccuracy(), equalTo(5.0 / 75));
     }
 
-    public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithDefaultSize() {
+    public void testEvaluate_MulticlassClassification_AccuracyAndConfusionMatrixMetricWithDefaultSize() {
         EvaluateDataFrameAction.Response evaluateDataFrameResponse =
             evaluateDataFrame(
                 ANIMALS_DATA_INDEX,
                 new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, List.of(new MulticlassConfusionMatrix())));
 
         assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
-        assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
+        assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
+
         MulticlassConfusionMatrix.Result confusionMatrixResult =
             (MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(0);
         assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
@@ -167,7 +142,7 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
                 new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, List.of(new MulticlassConfusionMatrix(3))));
 
         assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
-        assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
+        assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
         MulticlassConfusionMatrix.Result confusionMatrixResult =
             (MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(0);
         assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));

+ 29 - 0
x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml

@@ -603,6 +603,35 @@ setup:
             }
           }
 ---
+"Test classification accuracy":
+  - do:
+      ml.evaluate_data_frame:
+        body: >
+          {
+            "index": "utopia",
+            "evaluation": {
+              "classification": {
+                "actual_field": "classification_field_act.keyword",
+                "predicted_field": "classification_field_pred.keyword",
+                "metrics": { "accuracy": {} }
+              }
+            }
+          }
+
+  - match:
+      classification.accuracy:
+        actual_classes:
+          - actual_class: "cat"
+            actual_class_doc_count: 3
+            accuracy: 0.6666666666666666  # 2 out of 3
+          - actual_class: "dog"
+            actual_class_doc_count: 3
+            accuracy: 0.6666666666666666  # 2 out of 3
+          - actual_class: "mouse"
+            actual_class_doc_count: 2
+            accuracy: 0.5  # 1 out of 2
+        overall_accuracy: 0.625  # 5 out of 8
+---
 "Test classification multiclass_confusion_matrix":
   - do:
       ml.evaluate_data_frame: