Jelajahi Sumber

Implement evaluation API for multiclass classification problem (#47126)

Przemysław Witek 6 tahun lalu
induk
melakukan
6c6b4bfedb
18 mengubah file dengan 1853 tambahan dan 41 penghapusan
  1. 13 2
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java
  2. 132 0
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/Classification.java
  3. 164 0
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetric.java
  4. 126 34
      client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java
  5. 9 5
      client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java
  6. 64 0
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/ClassificationTests.java
  7. 74 0
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetricResultTests.java
  8. 50 0
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetricTests.java
  9. 16 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java
  10. 172 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Classification.java
  11. 30 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationMetric.java
  12. 276 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java
  13. 222 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationTests.java
  14. 60 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixResultTests.java
  15. 187 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java
  16. 3 0
      x-pack/plugin/ml/qa/ml-with-security/build.gradle
  17. 137 0
      x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java
  18. 118 0
      x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml

+ 13 - 2
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java

@@ -18,7 +18,9 @@
  */
 package org.elasticsearch.client.ml.dataframe.evaluation;
 
+import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
 import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
+import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
 import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
@@ -41,6 +43,7 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
             // Evaluations
             new NamedXContentRegistry.Entry(
                 Evaluation.class, new ParseField(BinarySoftClassification.NAME), BinarySoftClassification::fromXContent),
+            new NamedXContentRegistry.Entry(Evaluation.class, new ParseField(Classification.NAME), Classification::fromXContent),
             new NamedXContentRegistry.Entry(Evaluation.class, new ParseField(Regression.NAME), Regression::fromXContent),
             // Evaluation metrics
             new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(AucRocMetric.NAME), AucRocMetric::fromXContent),
@@ -48,6 +51,10 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
             new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(RecallMetric.NAME), RecallMetric::fromXContent),
             new NamedXContentRegistry.Entry(
                 EvaluationMetric.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric::fromXContent),
+            new NamedXContentRegistry.Entry(
+                EvaluationMetric.class,
+                new ParseField(MulticlassConfusionMatrixMetric.NAME),
+                MulticlassConfusionMatrixMetric::fromXContent),
             new NamedXContentRegistry.Entry(
                 EvaluationMetric.class, new ParseField(MeanSquaredErrorMetric.NAME), MeanSquaredErrorMetric::fromXContent),
             new NamedXContentRegistry.Entry(
@@ -60,10 +67,14 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
             new NamedXContentRegistry.Entry(
                 EvaluationMetric.Result.class, new ParseField(RecallMetric.NAME), RecallMetric.Result::fromXContent),
             new NamedXContentRegistry.Entry(
-                EvaluationMetric.Result.class, new ParseField(RSquaredMetric.NAME), RSquaredMetric.Result::fromXContent),
+                EvaluationMetric.Result.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric.Result::fromXContent),
+            new NamedXContentRegistry.Entry(
+                EvaluationMetric.Result.class,
+                new ParseField(MulticlassConfusionMatrixMetric.NAME),
+                MulticlassConfusionMatrixMetric.Result::fromXContent),
             new NamedXContentRegistry.Entry(
                 EvaluationMetric.Result.class, new ParseField(MeanSquaredErrorMetric.NAME), MeanSquaredErrorMetric.Result::fromXContent),
             new NamedXContentRegistry.Entry(
-                EvaluationMetric.Result.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric.Result::fromXContent));
+                EvaluationMetric.Result.class, new ParseField(RSquaredMetric.NAME), RSquaredMetric.Result::fromXContent));
     }
 }

+ 132 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/Classification.java

@@ -0,0 +1,132 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.dataframe.evaluation.classification;
+
+import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation;
+import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
+import org.elasticsearch.common.Nullable;
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.List;
+import java.util.Objects;
+
+/**
+ * Evaluation of classification results.
+ */
+public class Classification implements Evaluation {
+
+    public static final String NAME = "classification";
+
+    private static final ParseField ACTUAL_FIELD = new ParseField("actual_field");
+    private static final ParseField PREDICTED_FIELD = new ParseField("predicted_field");
+    private static final ParseField METRICS = new ParseField("metrics");
+
+    @SuppressWarnings("unchecked")
+    public static final ConstructingObjectParser<Classification, Void> PARSER = new ConstructingObjectParser<>(
+        NAME, true, a -> new Classification((String) a[0], (String) a[1], (List<EvaluationMetric>) a[2]));
+
+    static {
+        PARSER.declareString(ConstructingObjectParser.constructorArg(), ACTUAL_FIELD);
+        PARSER.declareString(ConstructingObjectParser.constructorArg(), PREDICTED_FIELD);
+        PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(),
+            (p, c, n) -> p.namedObject(EvaluationMetric.class, n, c), METRICS);
+    }
+
+    public static Classification fromXContent(XContentParser parser) {
+        return PARSER.apply(parser, null);
+    }
+
+    /**
+     * The field containing the actual value
+     * The value of this field is assumed to be numeric
+     */
+    private final String actualField;
+
+    /**
+     * The field containing the predicted value
+     * The value of this field is assumed to be numeric
+     */
+    private final String predictedField;
+
+    /**
+     * The list of metrics to calculate
+     */
+    private final List<EvaluationMetric> metrics;
+
+    public Classification(String actualField, String predictedField) {
+        this(actualField, predictedField, (List<EvaluationMetric>)null);
+    }
+
+    public Classification(String actualField, String predictedField, EvaluationMetric... metrics) {
+        this(actualField, predictedField, Arrays.asList(metrics));
+    }
+
+    public Classification(String actualField, String predictedField, @Nullable List<EvaluationMetric> metrics) {
+        this.actualField = Objects.requireNonNull(actualField);
+        this.predictedField = Objects.requireNonNull(predictedField);
+        if (metrics != null) {
+            metrics.sort(Comparator.comparing(EvaluationMetric::getName));
+        }
+        this.metrics = metrics;
+    }
+
+    @Override
+    public String getName() {
+        return NAME;
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        builder.field(ACTUAL_FIELD.getPreferredName(), actualField);
+        builder.field(PREDICTED_FIELD.getPreferredName(), predictedField);
+
+        if (metrics != null) {
+           builder.startObject(METRICS.getPreferredName());
+           for (EvaluationMetric metric : metrics) {
+               builder.field(metric.getName(), metric);
+           }
+           builder.endObject();
+        }
+
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        Classification that = (Classification) o;
+        return Objects.equals(that.actualField, this.actualField)
+            && Objects.equals(that.predictedField, this.predictedField)
+            && Objects.equals(that.metrics, this.metrics);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(actualField, predictedField, metrics);
+    }
+}

+ 164 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetric.java

@@ -0,0 +1,164 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.dataframe.evaluation.classification;
+
+import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
+import org.elasticsearch.common.Nullable;
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.Map;
+import java.util.Objects;
+import java.util.TreeMap;
+
+import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
+import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
+
+/**
+ * Calculates the multiclass confusion matrix.
+ */
+public class MulticlassConfusionMatrixMetric implements EvaluationMetric {
+
+    public static final String NAME = "multiclass_confusion_matrix";
+
+    public static final ParseField SIZE = new ParseField("size");
+
+    private static final ConstructingObjectParser<MulticlassConfusionMatrixMetric, Void> PARSER = createParser();
+
+    private static ConstructingObjectParser<MulticlassConfusionMatrixMetric, Void> createParser() {
+        ConstructingObjectParser<MulticlassConfusionMatrixMetric, Void>  parser =
+            new ConstructingObjectParser<>(NAME, true, args -> new MulticlassConfusionMatrixMetric((Integer) args[0]));
+        parser.declareInt(optionalConstructorArg(), SIZE);
+        return parser;
+    }
+
+    public static MulticlassConfusionMatrixMetric fromXContent(XContentParser parser) {
+        return PARSER.apply(parser, null);
+    }
+
+    private final Integer size;
+
+    public MulticlassConfusionMatrixMetric() {
+        this(null);
+    }
+
+    public MulticlassConfusionMatrixMetric(@Nullable Integer size) {
+        this.size = size;
+    }
+
+    @Override
+    public String getName() {
+        return NAME;
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        if (size != null) {
+            builder.field(SIZE.getPreferredName(), size);
+        }
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        MulticlassConfusionMatrixMetric that = (MulticlassConfusionMatrixMetric) o;
+        return Objects.equals(this.size, that.size);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(size);
+    }
+
+    public static class Result implements EvaluationMetric.Result {
+
+        private static final ParseField CONFUSION_MATRIX = new ParseField("confusion_matrix");
+        private static final ParseField OTHER_CLASSES_COUNT = new ParseField("_other_");
+
+        @SuppressWarnings("unchecked")
+        private static final ConstructingObjectParser<Result, Void> PARSER =
+            new ConstructingObjectParser<>(
+                "multiclass_confusion_matrix_result", true, a -> new Result((Map<String, Map<String, Long>>) a[0], (long) a[1]));
+
+        static {
+            PARSER.declareObject(
+                constructorArg(),
+                (p, c) -> p.map(TreeMap::new, p2 -> p2.map(TreeMap::new, XContentParser::longValue)),
+                CONFUSION_MATRIX);
+            PARSER.declareLong(constructorArg(), OTHER_CLASSES_COUNT);
+        }
+
+        public static Result fromXContent(XContentParser parser) {
+            return PARSER.apply(parser, null);
+        }
+
+        // Immutable
+        private final Map<String, Map<String, Long>> confusionMatrix;
+        private final long otherClassesCount;
+
+        public Result(Map<String, Map<String, Long>> confusionMatrix, long otherClassesCount) {
+            this.confusionMatrix = Collections.unmodifiableMap(Objects.requireNonNull(confusionMatrix));
+            this.otherClassesCount = otherClassesCount;
+        }
+
+        @Override
+        public String getMetricName() {
+            return NAME;
+        }
+
+        public Map<String, Map<String, Long>> getConfusionMatrix() {
+            return confusionMatrix;
+        }
+
+        public long getOtherClassesCount() {
+            return otherClassesCount;
+        }
+
+        @Override
+        public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+            builder.startObject();
+            builder.field(CONFUSION_MATRIX.getPreferredName(), confusionMatrix);
+            builder.field(OTHER_CLASSES_COUNT.getPreferredName(), otherClassesCount);
+            builder.endObject();
+            return builder;
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            Result that = (Result) o;
+            return Objects.equals(this.confusionMatrix, that.confusionMatrix)
+                && this.otherClassesCount == that.otherClassesCount;
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(confusionMatrix, otherClassesCount);
+        }
+    }
+}

+ 126 - 34
client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java

@@ -125,7 +125,9 @@ import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsStats;
 import org.elasticsearch.client.ml.dataframe.OutlierDetection;
 import org.elasticsearch.client.ml.dataframe.PhaseProgress;
 import org.elasticsearch.client.ml.dataframe.QueryConfig;
+import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
 import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
+import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
 import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
@@ -1573,19 +1575,19 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
 
     public void testEvaluateDataFrame_BinarySoftClassification() throws IOException {
         String indexName = "evaluate-test-index";
-        createIndex(indexName, mappingForClassification());
+        createIndex(indexName, mappingForSoftClassification());
         BulkRequest bulk = new BulkRequest()
             .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
-            .add(docForClassification(indexName, "blue", false, 0.1))  // #0
-            .add(docForClassification(indexName, "blue", false, 0.2))  // #1
-            .add(docForClassification(indexName, "blue", false, 0.3))  // #2
-            .add(docForClassification(indexName, "blue", false, 0.4))  // #3
-            .add(docForClassification(indexName, "blue", false, 0.7))  // #4
-            .add(docForClassification(indexName, "blue", true, 0.2))  // #5
-            .add(docForClassification(indexName, "green", true, 0.3))  // #6
-            .add(docForClassification(indexName, "green", true, 0.4))  // #7
-            .add(docForClassification(indexName, "green", true, 0.8))  // #8
-            .add(docForClassification(indexName, "green", true, 0.9));  // #9
+            .add(docForSoftClassification(indexName, "blue", false, 0.1))  // #0
+            .add(docForSoftClassification(indexName, "blue", false, 0.2))  // #1
+            .add(docForSoftClassification(indexName, "blue", false, 0.3))  // #2
+            .add(docForSoftClassification(indexName, "blue", false, 0.4))  // #3
+            .add(docForSoftClassification(indexName, "blue", false, 0.7))  // #4
+            .add(docForSoftClassification(indexName, "blue", true, 0.2))  // #5
+            .add(docForSoftClassification(indexName, "green", true, 0.3))  // #6
+            .add(docForSoftClassification(indexName, "green", true, 0.4))  // #7
+            .add(docForSoftClassification(indexName, "green", true, 0.8))  // #8
+            .add(docForSoftClassification(indexName, "green", true, 0.9));  // #9
         highLevelClient().bulk(bulk, RequestOptions.DEFAULT);
 
         MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
@@ -1647,19 +1649,19 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
 
     public void testEvaluateDataFrame_BinarySoftClassification_WithQuery() throws IOException {
         String indexName = "evaluate-with-query-test-index";
-        createIndex(indexName, mappingForClassification());
+        createIndex(indexName, mappingForSoftClassification());
         BulkRequest bulk = new BulkRequest()
             .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
-            .add(docForClassification(indexName, "blue", true, 1.0))  // #0
-            .add(docForClassification(indexName, "blue", true, 1.0))  // #1
-            .add(docForClassification(indexName, "blue", true, 1.0))  // #2
-            .add(docForClassification(indexName, "blue", true, 1.0))  // #3
-            .add(docForClassification(indexName, "blue", true, 0.0))  // #4
-            .add(docForClassification(indexName, "blue", true, 0.0))  // #5
-            .add(docForClassification(indexName, "green", true, 0.0))  // #6
-            .add(docForClassification(indexName, "green", true, 0.0))  // #7
-            .add(docForClassification(indexName, "green", true, 0.0))  // #8
-            .add(docForClassification(indexName, "green", true, 1.0));  // #9
+            .add(docForSoftClassification(indexName, "blue", true, 1.0))  // #0
+            .add(docForSoftClassification(indexName, "blue", true, 1.0))  // #1
+            .add(docForSoftClassification(indexName, "blue", true, 1.0))  // #2
+            .add(docForSoftClassification(indexName, "blue", true, 1.0))  // #3
+            .add(docForSoftClassification(indexName, "blue", true, 0.0))  // #4
+            .add(docForSoftClassification(indexName, "blue", true, 0.0))  // #5
+            .add(docForSoftClassification(indexName, "green", true, 0.0))  // #6
+            .add(docForSoftClassification(indexName, "green", true, 0.0))  // #7
+            .add(docForSoftClassification(indexName, "green", true, 0.0))  // #8
+            .add(docForSoftClassification(indexName, "green", true, 1.0));  // #9
         highLevelClient().bulk(bulk, RequestOptions.DEFAULT);
 
         MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
@@ -1722,6 +1724,74 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
         assertThat(rSquaredResult.getValue(), closeTo(-5.1000000000000005, 1e-9));
     }
 
+    public void testEvaluateDataFrame_Classification() throws IOException {
+        String indexName = "evaluate-classification-test-index";
+        createIndex(indexName, mappingForClassification());
+        BulkRequest regressionBulk = new BulkRequest()
+            .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
+            .add(docForClassification(indexName, "cat", "cat"))
+            .add(docForClassification(indexName, "cat", "cat"))
+            .add(docForClassification(indexName, "cat", "cat"))
+            .add(docForClassification(indexName, "cat", "dog"))
+            .add(docForClassification(indexName, "cat", "fish"))
+            .add(docForClassification(indexName, "dog", "cat"))
+            .add(docForClassification(indexName, "dog", "dog"))
+            .add(docForClassification(indexName, "dog", "dog"))
+            .add(docForClassification(indexName, "dog", "dog"))
+            .add(docForClassification(indexName, "horse", "cat"));
+        highLevelClient().bulk(regressionBulk, RequestOptions.DEFAULT);
+
+        MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
+
+        {  // No size provided for MulticlassConfusionMatrixMetric, default used instead
+            EvaluateDataFrameRequest evaluateDataFrameRequest =
+                new EvaluateDataFrameRequest(
+                    indexName,
+                    null,
+                    new Classification(actualClassField, predictedClassField, new MulticlassConfusionMatrixMetric()));
+
+            EvaluateDataFrameResponse evaluateDataFrameResponse =
+                execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
+            assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME));
+            assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
+
+            MulticlassConfusionMatrixMetric.Result mcmResult =
+                evaluateDataFrameResponse.getMetricByName(MulticlassConfusionMatrixMetric.NAME);
+            assertThat(mcmResult.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME));
+            assertThat(
+                mcmResult.getConfusionMatrix(),
+                equalTo(
+                    Map.of(
+                        "cat", Map.of("cat", 3L, "dog", 1L, "horse", 0L, "_other_", 1L),
+                        "dog", Map.of("cat", 1L, "dog", 3L, "horse", 0L),
+                        "horse", Map.of("cat", 1L, "dog", 0L, "horse", 0L))));
+            assertThat(mcmResult.getOtherClassesCount(), equalTo(0L));
+        }
+        {  // Explicit size provided for MulticlassConfusionMatrixMetric metric
+            EvaluateDataFrameRequest evaluateDataFrameRequest =
+                new EvaluateDataFrameRequest(
+                    indexName,
+                    null,
+                    new Classification(actualClassField, predictedClassField, new MulticlassConfusionMatrixMetric(2)));
+
+            EvaluateDataFrameResponse evaluateDataFrameResponse =
+                execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
+            assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME));
+            assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
+
+            MulticlassConfusionMatrixMetric.Result mcmResult =
+                evaluateDataFrameResponse.getMetricByName(MulticlassConfusionMatrixMetric.NAME);
+            assertThat(mcmResult.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME));
+            assertThat(
+                mcmResult.getConfusionMatrix(),
+                equalTo(
+                    Map.of(
+                        "cat", Map.of("cat", 3L, "dog", 1L, "_other_", 1L),
+                        "dog", Map.of("cat", 1L, "dog", 3L))));
+            assertThat(mcmResult.getOtherClassesCount(), equalTo(1L));
+        }
+    }
+
     private static XContentBuilder defaultMappingForTest() throws IOException {
         return XContentFactory.jsonBuilder().startObject()
             .startObject("properties")
@@ -1739,7 +1809,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
     private static final String actualField = "label";
     private static final String probabilityField = "p";
 
-    private static XContentBuilder mappingForClassification() throws IOException {
+    private static XContentBuilder mappingForSoftClassification() throws IOException {
         return XContentFactory.jsonBuilder().startObject()
             .startObject("properties")
                 .startObject(datasetField)
@@ -1755,26 +1825,48 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
         .endObject();
     }
 
-    private static IndexRequest docForClassification(String indexName, String dataset, boolean isTrue, double p) {
+    private static IndexRequest docForSoftClassification(String indexName, String dataset, boolean isTrue, double p) {
         return new IndexRequest()
             .index(indexName)
             .source(XContentType.JSON, datasetField, dataset, actualField, Boolean.toString(isTrue), probabilityField, p);
     }
 
+    private static final String actualClassField = "actual_class";
+    private static final String predictedClassField = "predicted_class";
+
+    private static XContentBuilder mappingForClassification() throws IOException {
+        return XContentFactory.jsonBuilder().startObject()
+            .startObject("properties")
+                .startObject(actualClassField)
+                    .field("type", "keyword")
+                .endObject()
+                .startObject(predictedClassField)
+                    .field("type", "keyword")
+                .endObject()
+            .endObject()
+        .endObject();
+    }
+
+    private static IndexRequest docForClassification(String indexName, String actualClass, String predictedClass) {
+        return new IndexRequest()
+            .index(indexName)
+            .source(XContentType.JSON, actualClassField, actualClass, predictedClassField, predictedClass);
+    }
+
     private static final String actualRegression = "regression_actual";
     private static final String probabilityRegression = "regression_prob";
 
     private static XContentBuilder mappingForRegression() throws IOException {
         return XContentFactory.jsonBuilder().startObject()
             .startObject("properties")
-            .startObject(actualRegression)
-            .field("type", "double")
-            .endObject()
-            .startObject(probabilityRegression)
-            .field("type", "double")
-            .endObject()
+                .startObject(actualRegression)
+                    .field("type", "double")
+                .endObject()
+                .startObject(probabilityRegression)
+                    .field("type", "double")
+                .endObject()
             .endObject()
-            .endObject();
+        .endObject();
     }
 
     private static IndexRequest docForRegression(String indexName, double act, double p) {
@@ -1789,11 +1881,11 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
 
     public void testEstimateMemoryUsage() throws IOException {
         String indexName = "estimate-test-index";
-        createIndex(indexName, mappingForClassification());
+        createIndex(indexName, mappingForSoftClassification());
         BulkRequest bulk1 = new BulkRequest()
             .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
         for (int i = 0; i < 10; ++i) {
-            bulk1.add(docForClassification(indexName, randomAlphaOfLength(10), randomBoolean(), randomDoubleBetween(0.0, 1.0, true)));
+            bulk1.add(docForSoftClassification(indexName, randomAlphaOfLength(10), randomBoolean(), randomDoubleBetween(0.0, 1.0, true)));
         }
         highLevelClient().bulk(bulk1, RequestOptions.DEFAULT);
 
@@ -1819,7 +1911,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
         BulkRequest bulk2 = new BulkRequest()
             .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
         for (int i = 10; i < 100; ++i) {
-            bulk2.add(docForClassification(indexName, randomAlphaOfLength(10), randomBoolean(), randomDoubleBetween(0.0, 1.0, true)));
+            bulk2.add(docForSoftClassification(indexName, randomAlphaOfLength(10), randomBoolean(), randomDoubleBetween(0.0, 1.0, true)));
         }
         highLevelClient().bulk(bulk2, RequestOptions.DEFAULT);
 

+ 9 - 5
client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java

@@ -57,7 +57,9 @@ import org.elasticsearch.client.ilm.ShrinkAction;
 import org.elasticsearch.client.ilm.UnfollowAction;
 import org.elasticsearch.client.ml.dataframe.DataFrameAnalysis;
 import org.elasticsearch.client.ml.dataframe.OutlierDetection;
+import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
 import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
+import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
 import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
@@ -681,7 +683,7 @@ public class RestHighLevelClientTests extends ESTestCase {
 
     public void testProvidedNamedXContents() {
         List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
-        assertEquals(41, namedXContents.size());
+        assertEquals(44, namedXContents.size());
         Map<Class<?>, Integer> categories = new HashMap<>();
         List<String> names = new ArrayList<>();
         for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
@@ -720,22 +722,24 @@ public class RestHighLevelClientTests extends ESTestCase {
         assertTrue(names.contains(org.elasticsearch.client.ml.dataframe.Regression.NAME.getPreferredName()));
         assertEquals(Integer.valueOf(1), categories.get(SyncConfig.class));
         assertTrue(names.contains(TimeSyncConfig.NAME));
-        assertEquals(Integer.valueOf(2), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class));
-        assertThat(names, hasItems(BinarySoftClassification.NAME, Regression.NAME));
-        assertEquals(Integer.valueOf(6), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class));
+        assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class));
+        assertThat(names, hasItems(BinarySoftClassification.NAME, Classification.NAME, Regression.NAME));
+        assertEquals(Integer.valueOf(7), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class));
         assertThat(names,
             hasItems(AucRocMetric.NAME,
                 PrecisionMetric.NAME,
                 RecallMetric.NAME,
                 ConfusionMatrixMetric.NAME,
+                MulticlassConfusionMatrixMetric.NAME,
                 MeanSquaredErrorMetric.NAME,
                 RSquaredMetric.NAME));
-        assertEquals(Integer.valueOf(6), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class));
+        assertEquals(Integer.valueOf(7), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class));
         assertThat(names,
             hasItems(AucRocMetric.NAME,
                 PrecisionMetric.NAME,
                 RecallMetric.NAME,
                 ConfusionMatrixMetric.NAME,
+                MulticlassConfusionMatrixMetric.NAME,
                 MeanSquaredErrorMetric.NAME,
                 RSquaredMetric.NAME));
         assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.inference.preprocessing.PreProcessor.class));

+ 64 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/ClassificationTests.java

@@ -0,0 +1,64 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.dataframe.evaluation.classification;
+
+import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.test.AbstractXContentTestCase;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.function.Predicate;
+
+public class ClassificationTests extends AbstractXContentTestCase<Classification> {
+
+    @Override
+    protected NamedXContentRegistry xContentRegistry() {
+        return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
+    }
+
+    public static Classification createRandom() {
+        return new Classification(
+            randomAlphaOfLength(10),
+            randomAlphaOfLength(10),
+            randomBoolean() ? null : Arrays.asList(new MulticlassConfusionMatrixMetric()));
+    }
+
+    @Override
+    protected Classification createTestInstance() {
+        return createRandom();
+    }
+
+    @Override
+    protected Classification doParseInstance(XContentParser parser) throws IOException {
+        return Classification.fromXContent(parser);
+    }
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return true;
+    }
+
+    @Override
+    protected Predicate<String> getRandomFieldsExcludeFilter() {
+        // allow unknown fields in the root of the object only
+        return field -> !field.isEmpty();
+    }
+}

+ 74 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetricResultTests.java

@@ -0,0 +1,74 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.dataframe.evaluation.classification;
+
+import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.test.AbstractXContentTestCase;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+import java.util.TreeMap;
+import java.util.function.Predicate;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+public class MulticlassConfusionMatrixMetricResultTests extends AbstractXContentTestCase<MulticlassConfusionMatrixMetric.Result> {
+
+    @Override
+    protected NamedXContentRegistry xContentRegistry() {
+        return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
+    }
+
+    @Override
+    protected MulticlassConfusionMatrixMetric.Result createTestInstance() {
+        int numClasses = randomIntBetween(2, 100);
+        List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
+        Map<String, Map<String, Long>> confusionMatrix = new TreeMap<>();
+        for (int i = 0; i < numClasses; i++) {
+            Map<String, Long> row = new TreeMap<>();
+            confusionMatrix.put(classNames.get(i), row);
+            for (int j = 0; j < numClasses; j++) {
+                if (randomBoolean()) {
+                    row.put(classNames.get(i), randomNonNegativeLong());
+                }
+            }
+        }
+        long otherClassesCount = randomNonNegativeLong();
+        return new MulticlassConfusionMatrixMetric.Result(confusionMatrix, otherClassesCount);
+    }
+
+    @Override
+    protected MulticlassConfusionMatrixMetric.Result doParseInstance(XContentParser parser) throws IOException {
+        return MulticlassConfusionMatrixMetric.Result.fromXContent(parser);
+    }
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return true;
+    }
+
+    @Override
+    protected Predicate<String> getRandomFieldsExcludeFilter() {
+        // allow unknown fields in the root of the object only
+        return field -> !field.isEmpty();
+    }
+}

+ 50 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetricTests.java

@@ -0,0 +1,50 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.dataframe.evaluation.classification;
+
+import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.test.AbstractXContentTestCase;
+
+import java.io.IOException;
+
+public class MulticlassConfusionMatrixMetricTests extends AbstractXContentTestCase<MulticlassConfusionMatrixMetric> {
+
+    @Override
+    protected NamedXContentRegistry xContentRegistry() {
+        return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
+    }
+
+    @Override
+    protected MulticlassConfusionMatrixMetric createTestInstance() {
+        Integer size = randomBoolean() ? randomIntBetween(1, 1000) : null;
+        return new MulticlassConfusionMatrixMetric(size);
+    }
+
+    @Override
+    protected MulticlassConfusionMatrixMetric doParseInstance(XContentParser parser) throws IOException {
+        return MulticlassConfusionMatrixMetric.fromXContent(parser);
+    }
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return true;
+    }
+}

+ 16 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java

@@ -8,7 +8,10 @@ package org.elasticsearch.xpack.core.ml.dataframe.evaluation;
 import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.xcontent.NamedXContentRegistry;
 import org.elasticsearch.plugins.spi.NamedXContentProvider;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.ClassificationMetric;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Regression;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RegressionMetric;
@@ -32,6 +35,7 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
         // Evaluations
         namedXContent.add(new NamedXContentRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME,
             BinarySoftClassification::fromXContent));
+        namedXContent.add(new NamedXContentRegistry.Entry(Evaluation.class, Classification.NAME, Classification::fromXContent));
         namedXContent.add(new NamedXContentRegistry.Entry(Evaluation.class, Regression.NAME, Regression::fromXContent));
 
         // Soft classification metrics
@@ -41,6 +45,10 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
         namedXContent.add(new NamedXContentRegistry.Entry(SoftClassificationMetric.class, ConfusionMatrix.NAME,
             ConfusionMatrix::fromXContent));
 
+        // Classification metrics
+        namedXContent.add(new NamedXContentRegistry.Entry(ClassificationMetric.class, MulticlassConfusionMatrix.NAME,
+            MulticlassConfusionMatrix::fromXContent));
+
         // Regression metrics
         namedXContent.add(new NamedXContentRegistry.Entry(RegressionMetric.class, MeanSquaredError.NAME, MeanSquaredError::fromXContent));
         namedXContent.add(new NamedXContentRegistry.Entry(RegressionMetric.class, RSquared.NAME, RSquared::fromXContent));
@@ -54,6 +62,8 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
         // Evaluations
         namedWriteables.add(new NamedWriteableRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME.getPreferredName(),
             BinarySoftClassification::new));
+        namedWriteables.add(new NamedWriteableRegistry.Entry(Evaluation.class, Classification.NAME.getPreferredName(),
+            Classification::new));
         namedWriteables.add(new NamedWriteableRegistry.Entry(Evaluation.class, Regression.NAME.getPreferredName(), Regression::new));
 
         // Evaluation Metrics
@@ -65,6 +75,9 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
             Recall::new));
         namedWriteables.add(new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, ConfusionMatrix.NAME.getPreferredName(),
             ConfusionMatrix::new));
+        namedWriteables.add(new NamedWriteableRegistry.Entry(ClassificationMetric.class,
+            MulticlassConfusionMatrix.NAME.getPreferredName(),
+            MulticlassConfusionMatrix::new));
         namedWriteables.add(new NamedWriteableRegistry.Entry(RegressionMetric.class,
             MeanSquaredError.NAME.getPreferredName(),
             MeanSquaredError::new));
@@ -79,6 +92,9 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
             ScoreByThresholdResult::new));
         namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, ConfusionMatrix.NAME.getPreferredName(),
             ConfusionMatrix.Result::new));
+        namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
+            MulticlassConfusionMatrix.NAME.getPreferredName(),
+            MulticlassConfusionMatrix.Result::new));
         namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
             MeanSquaredError.NAME.getPreferredName(),
             MeanSquaredError.Result::new));

+ 172 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Classification.java

@@ -0,0 +1,172 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
+
+import org.elasticsearch.action.search.SearchResponse;
+import org.elasticsearch.common.Nullable;
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.index.query.QueryBuilder;
+import org.elasticsearch.search.aggregations.AggregationBuilder;
+import org.elasticsearch.search.builder.SearchSourceBuilder;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.List;
+import java.util.Objects;
+
+/**
+ * Evaluation of classification results.
+ */
+public class Classification implements Evaluation {
+
+    public static final ParseField NAME = new ParseField("classification");
+
+    private static final ParseField ACTUAL_FIELD = new ParseField("actual_field");
+    private static final ParseField PREDICTED_FIELD = new ParseField("predicted_field");
+    private static final ParseField METRICS = new ParseField("metrics");
+
+    @SuppressWarnings("unchecked")
+    public static final ConstructingObjectParser<Classification, Void> PARSER = new ConstructingObjectParser<>(
+        NAME.getPreferredName(), a -> new Classification((String) a[0], (String) a[1], (List<ClassificationMetric>) a[2]));
+
+    static {
+        PARSER.declareString(ConstructingObjectParser.constructorArg(), ACTUAL_FIELD);
+        PARSER.declareString(ConstructingObjectParser.constructorArg(), PREDICTED_FIELD);
+        PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(),
+            (p, c, n) -> p.namedObject(ClassificationMetric.class, n, c), METRICS);
+    }
+
+    public static Classification fromXContent(XContentParser parser) {
+        return PARSER.apply(parser, null);
+    }
+
+    /**
+     * The field containing the actual value
+     * The value of this field is assumed to be numeric
+     */
+    private final String actualField;
+
+    /**
+     * The field containing the predicted value
+     * The value of this field is assumed to be numeric
+     */
+    private final String predictedField;
+
+    /**
+     * The list of metrics to calculate
+     */
+    private final List<ClassificationMetric> metrics;
+
+    public Classification(String actualField, String predictedField, @Nullable List<ClassificationMetric> metrics) {
+        this.actualField = ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD);
+        this.predictedField = ExceptionsHelper.requireNonNull(predictedField, PREDICTED_FIELD);
+        this.metrics = initMetrics(metrics);
+    }
+
+    public Classification(StreamInput in) throws IOException {
+        this.actualField = in.readString();
+        this.predictedField = in.readString();
+        this.metrics = in.readNamedWriteableList(ClassificationMetric.class);
+    }
+
+    private static List<ClassificationMetric> initMetrics(@Nullable List<ClassificationMetric> parsedMetrics) {
+        List<ClassificationMetric> metrics = parsedMetrics == null ? defaultMetrics() : new ArrayList<>(parsedMetrics);
+        if (metrics.isEmpty()) {
+            throw ExceptionsHelper.badRequestException("[{}] must have one or more metrics", NAME.getPreferredName());
+        }
+        Collections.sort(metrics, Comparator.comparing(ClassificationMetric::getName));
+        return metrics;
+    }
+
+    private static List<ClassificationMetric> defaultMetrics() {
+        return Arrays.asList(new MulticlassConfusionMatrix());
+    }
+
+    @Override
+    public String getName() {
+        return NAME.getPreferredName();
+    }
+
+    @Override
+    public List<ClassificationMetric> getMetrics() {
+        return metrics;
+    }
+
+    @Override
+    public SearchSourceBuilder buildSearch(QueryBuilder userProvidedQueryBuilder) {
+        ExceptionsHelper.requireNonNull(userProvidedQueryBuilder, "userProvidedQueryBuilder");
+        SearchSourceBuilder searchSourceBuilder = newSearchSourceBuilder(List.of(actualField, predictedField), userProvidedQueryBuilder);
+        for (ClassificationMetric metric : metrics) {
+            List<AggregationBuilder> aggs = metric.aggs(actualField, predictedField);
+            aggs.forEach(searchSourceBuilder::aggregation);
+        }
+        return searchSourceBuilder;
+    }
+
+    @Override
+    public void process(SearchResponse searchResponse) {
+        ExceptionsHelper.requireNonNull(searchResponse, "searchResponse");
+        if (searchResponse.getHits().getTotalHits().value == 0) {
+            throw ExceptionsHelper.badRequestException("No documents found containing both [{}, {}] fields", actualField, predictedField);
+        }
+        for (ClassificationMetric metric : metrics) {
+            metric.process(searchResponse.getAggregations());
+        }
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME.getPreferredName();
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeString(actualField);
+        out.writeString(predictedField);
+        out.writeNamedWriteableList(metrics);
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        builder.field(ACTUAL_FIELD.getPreferredName(), actualField);
+        builder.field(PREDICTED_FIELD.getPreferredName(), predictedField);
+
+        builder.startObject(METRICS.getPreferredName());
+        for (ClassificationMetric metric : metrics) {
+            builder.field(metric.getWriteableName(), metric);
+        }
+        builder.endObject();
+
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        Classification that = (Classification) o;
+        return Objects.equals(that.actualField, this.actualField)
+            && Objects.equals(that.predictedField, this.predictedField)
+            && Objects.equals(that.metrics, this.metrics);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(actualField, predictedField, metrics);
+    }
+}

+ 30 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationMetric.java

@@ -0,0 +1,30 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
+
+import org.elasticsearch.action.search.SearchResponse;
+import org.elasticsearch.search.aggregations.AggregationBuilder;
+import org.elasticsearch.search.aggregations.Aggregations;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
+
+import java.util.List;
+
+public interface ClassificationMetric extends EvaluationMetric {
+
+    /**
+     * Builds the aggregation that collect required data to compute the metric
+     * @param actualField the field that stores the actual value
+     * @param predictedField the field that stores the predicted value
+     * @return the aggregations required to compute the metric
+     */
+    List<AggregationBuilder> aggs(String actualField, String predictedField);
+
+    /**
+     * Processes given aggregations as a step towards computing result
+     * @param aggs aggregations from {@link SearchResponse}
+     */
+    void process(Aggregations aggs);
+}

+ 276 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java

@@ -0,0 +1,276 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
+
+import org.elasticsearch.common.Nullable;
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.index.query.QueryBuilders;
+import org.elasticsearch.search.aggregations.AggregationBuilder;
+import org.elasticsearch.search.aggregations.AggregationBuilders;
+import org.elasticsearch.search.aggregations.Aggregations;
+import org.elasticsearch.search.aggregations.BucketOrder;
+import org.elasticsearch.search.aggregations.bucket.filter.Filters;
+import org.elasticsearch.search.aggregations.bucket.filter.FiltersAggregator.KeyedFilter;
+import org.elasticsearch.search.aggregations.bucket.terms.Terms;
+import org.elasticsearch.search.aggregations.metrics.Cardinality;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Optional;
+import java.util.TreeMap;
+import java.util.stream.Collectors;
+
+import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
+import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
+
+/**
+ * {@link MulticlassConfusionMatrix} is a metric that answers the question:
+ *   "How many examples belonging to class X were classified as Y by the classifier?"
+ * for all the possible class pairs {X, Y}.
+ */
+public class MulticlassConfusionMatrix implements ClassificationMetric {
+
+    public static final ParseField NAME = new ParseField("multiclass_confusion_matrix");
+
+    public static final ParseField SIZE = new ParseField("size");
+
+    private static final ConstructingObjectParser<MulticlassConfusionMatrix, Void> PARSER = createParser();
+
+    private static ConstructingObjectParser<MulticlassConfusionMatrix, Void> createParser() {
+        ConstructingObjectParser<MulticlassConfusionMatrix, Void>  parser =
+            new ConstructingObjectParser<>(NAME.getPreferredName(), true, args -> new MulticlassConfusionMatrix((Integer) args[0]));
+        parser.declareInt(optionalConstructorArg(), SIZE);
+        return parser;
+    }
+
+    public static MulticlassConfusionMatrix fromXContent(XContentParser parser) {
+        return PARSER.apply(parser, null);
+    }
+
+    private static final String STEP_1_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_1_by_actual_class";
+    private static final String STEP_2_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_by_actual_class";
+    private static final String STEP_2_AGGREGATE_BY_PREDICTED_CLASS = NAME.getPreferredName() + "_step_2_by_predicted_class";
+    private static final String STEP_2_CARDINALITY_OF_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_cardinality_of_actual_class";
+    private static final String OTHER_BUCKET_KEY = "_other_";
+    private static final int DEFAULT_SIZE = 10;
+    private static final int MAX_SIZE = 1000;
+
+    private final int size;
+    private List<String> topActualClassNames;
+    private Result result;
+
+    public MulticlassConfusionMatrix() {
+        this((Integer) null);
+    }
+
+    public MulticlassConfusionMatrix(@Nullable Integer size) {
+        if (size != null && (size <= 0 || size > MAX_SIZE)) {
+            throw ExceptionsHelper.badRequestException("[{}] must be an integer in [1, {}]", SIZE.getPreferredName(), MAX_SIZE);
+        }
+        this.size = size != null ? size : DEFAULT_SIZE;
+    }
+
+    public MulticlassConfusionMatrix(StreamInput in) throws IOException {
+        this.size = in.readVInt();
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME.getPreferredName();
+    }
+
+    @Override
+    public String getName() {
+        return NAME.getPreferredName();
+    }
+
+    public int getSize() {
+        return size;
+    }
+
+    @Override
+    public final List<AggregationBuilder> aggs(String actualField, String predictedField) {
+        if (topActualClassNames == null) {  // This is step 1
+            return List.of(
+                AggregationBuilders.terms(STEP_1_AGGREGATE_BY_ACTUAL_CLASS)
+                    .field(actualField)
+                    .order(List.of(BucketOrder.count(false), BucketOrder.key(true)))
+                    .size(size));
+        }
+        if (result == null) {  // This is step 2
+            KeyedFilter[] keyedFilters =
+                topActualClassNames.stream()
+                    .map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className)))
+                    .toArray(KeyedFilter[]::new);
+            return List.of(
+                AggregationBuilders.cardinality(STEP_2_CARDINALITY_OF_ACTUAL_CLASS)
+                    .field(actualField),
+                AggregationBuilders.terms(STEP_2_AGGREGATE_BY_ACTUAL_CLASS)
+                    .field(actualField)
+                    .order(List.of(BucketOrder.count(false), BucketOrder.key(true)))
+                    .size(size)
+                    .subAggregation(AggregationBuilders.filters(STEP_2_AGGREGATE_BY_PREDICTED_CLASS, keyedFilters)
+                        .otherBucket(true)
+                        .otherBucketKey(OTHER_BUCKET_KEY)));
+        }
+        return List.of();
+    }
+
+    @Override
+    public void process(Aggregations aggs) {
+        if (topActualClassNames == null && aggs.get(STEP_1_AGGREGATE_BY_ACTUAL_CLASS) != null) {
+            Terms termsAgg = aggs.get(STEP_1_AGGREGATE_BY_ACTUAL_CLASS);
+            topActualClassNames = termsAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).collect(Collectors.toList());
+        }
+        if (result == null && aggs.get(STEP_2_AGGREGATE_BY_ACTUAL_CLASS) != null) {
+            Cardinality cardinalityAgg = aggs.get(STEP_2_CARDINALITY_OF_ACTUAL_CLASS);
+            Terms termsAgg = aggs.get(STEP_2_AGGREGATE_BY_ACTUAL_CLASS);
+            Map<String, Map<String, Long>> counts = new TreeMap<>();
+            for (Terms.Bucket bucket : termsAgg.getBuckets()) {
+                String actualClass = bucket.getKeyAsString();
+                Map<String, Long> subCounts = new TreeMap<>();
+                counts.put(actualClass, subCounts);
+                Filters subAgg = bucket.getAggregations().get(STEP_2_AGGREGATE_BY_PREDICTED_CLASS);
+                for (Filters.Bucket subBucket : subAgg.getBuckets()) {
+                    String predictedClass = subBucket.getKeyAsString();
+                    Long docCount = subBucket.getDocCount();
+                    if ((OTHER_BUCKET_KEY.equals(predictedClass) && docCount == 0L) == false) {
+                        subCounts.put(predictedClass, docCount);
+                    }
+                }
+            }
+            result = new Result(counts, termsAgg.getSumOfOtherDocCounts() == 0 ? 0 : cardinalityAgg.getValue() - size);
+        }
+    }
+
+    @Override
+    public Optional<EvaluationMetricResult> getResult() {
+        return Optional.ofNullable(result);
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeVInt(size);
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        builder.field(SIZE.getPreferredName(), size);
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        MulticlassConfusionMatrix that = (MulticlassConfusionMatrix) o;
+        return Objects.equals(this.size, that.size);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(size);
+    }
+
+    public static class Result implements EvaluationMetricResult {
+
+        private static final ParseField CONFUSION_MATRIX = new ParseField("confusion_matrix");
+        private static final ParseField OTHER_CLASSES_COUNT = new ParseField("_other_");
+
+        private static final ConstructingObjectParser<Result, Void> PARSER =
+            new ConstructingObjectParser<>(
+                "multiclass_confusion_matrix_result", true, a -> new Result((Map<String, Map<String, Long>>) a[0], (long) a[1]));
+
+        static {
+            PARSER.declareObject(
+                constructorArg(),
+                (p, c) -> p.map(TreeMap::new, p2 -> p2.map(TreeMap::new, XContentParser::longValue)),
+                CONFUSION_MATRIX);
+            PARSER.declareLong(constructorArg(), OTHER_CLASSES_COUNT);
+        }
+
+        public static Result fromXContent(XContentParser parser) {
+            return PARSER.apply(parser, null);
+        }
+
+        // Immutable
+        private final Map<String, Map<String, Long>> confusionMatrix;
+        private final long otherClassesCount;
+
+        public Result(Map<String, Map<String, Long>> confusionMatrix, long otherClassesCount) {
+            this.confusionMatrix = Collections.unmodifiableMap(Objects.requireNonNull(confusionMatrix));
+            this.otherClassesCount = otherClassesCount;
+        }
+
+        public Result(StreamInput in) throws IOException {
+            this.confusionMatrix = Collections.unmodifiableMap(
+                in.readMap(StreamInput::readString, in2 -> in2.readMap(StreamInput::readString, StreamInput::readLong)));
+            this.otherClassesCount = in.readLong();
+        }
+
+        @Override
+        public String getWriteableName() {
+            return NAME.getPreferredName();
+        }
+
+        @Override
+        public String getMetricName() {
+            return NAME.getPreferredName();
+        }
+
+        public Map<String, Map<String, Long>> getConfusionMatrix() {
+            return confusionMatrix;
+        }
+
+        public long getOtherClassesCount() {
+            return otherClassesCount;
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            out.writeMap(
+                confusionMatrix,
+                StreamOutput::writeString,
+                (out2, row) -> out2.writeMap(row, StreamOutput::writeString, StreamOutput::writeLong));
+            out.writeLong(otherClassesCount);
+        }
+
+        @Override
+        public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+            builder.startObject();
+            builder.field(CONFUSION_MATRIX.getPreferredName(), confusionMatrix);
+            builder.field(OTHER_CLASSES_COUNT.getPreferredName(), otherClassesCount);
+            builder.endObject();
+            return builder;
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            Result that = (Result) o;
+            return Objects.equals(this.confusionMatrix, that.confusionMatrix)
+                && this.otherClassesCount == that.otherClassesCount;
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(confusionMatrix, otherClassesCount);
+        }
+    }
+}

+ 222 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationTests.java

@@ -0,0 +1,222 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
+
+import org.apache.lucene.search.TotalHits;
+import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.action.search.SearchResponse;
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.index.query.QueryBuilder;
+import org.elasticsearch.index.query.QueryBuilders;
+import org.elasticsearch.search.SearchHits;
+import org.elasticsearch.search.aggregations.AggregationBuilder;
+import org.elasticsearch.search.aggregations.Aggregations;
+import org.elasticsearch.search.builder.SearchSourceBuilder;
+import org.elasticsearch.test.AbstractSerializingTestCase;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.Optional;
+
+import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty;
+import static org.elasticsearch.test.hamcrest.OptionalMatchers.isPresent;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.greaterThan;
+import static org.hamcrest.Matchers.is;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+public class ClassificationTests extends AbstractSerializingTestCase<Classification> {
+
+    @Override
+    protected NamedWriteableRegistry getNamedWriteableRegistry() {
+        return new NamedWriteableRegistry(new MlEvaluationNamedXContentProvider().getNamedWriteables());
+    }
+
+    @Override
+    protected NamedXContentRegistry xContentRegistry() {
+        return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
+    }
+
+    public static Classification createRandom() {
+        return new Classification(
+            randomAlphaOfLength(10),
+            randomAlphaOfLength(10),
+            randomBoolean() ? null : Arrays.asList(MulticlassConfusionMatrixTests.createRandom()));
+    }
+
+    @Override
+    protected Classification doParseInstance(XContentParser parser) throws IOException {
+        return Classification.fromXContent(parser);
+    }
+
+    @Override
+    protected Classification createTestInstance() {
+        return createRandom();
+    }
+
+    @Override
+    protected Writeable.Reader<Classification> instanceReader() {
+        return Classification::new;
+    }
+
+    public void testConstructor_GivenEmptyMetrics() {
+        ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
+            () -> new Classification("foo", "bar", Collections.emptyList()));
+        assertThat(e.getMessage(), equalTo("[classification] must have one or more metrics"));
+    }
+
+    public void testBuildSearch() {
+        QueryBuilder userProvidedQuery =
+            QueryBuilders.boolQuery()
+                .filter(QueryBuilders.termQuery("field_A", "some-value"))
+                .filter(QueryBuilders.termQuery("field_B", "some-other-value"));
+        QueryBuilder expectedSearchQuery =
+            QueryBuilders.boolQuery()
+                .filter(QueryBuilders.existsQuery("act"))
+                .filter(QueryBuilders.existsQuery("pred"))
+                .filter(QueryBuilders.boolQuery()
+                    .filter(QueryBuilders.termQuery("field_A", "some-value"))
+                    .filter(QueryBuilders.termQuery("field_B", "some-other-value")));
+
+        Classification evaluation = new Classification("act", "pred", Arrays.asList(new MulticlassConfusionMatrix()));
+
+        SearchSourceBuilder searchSourceBuilder = evaluation.buildSearch(userProvidedQuery);
+        assertThat(searchSourceBuilder.query(), equalTo(expectedSearchQuery));
+        assertThat(searchSourceBuilder.aggregations().count(), greaterThan(0));
+    }
+
+    public void testProcess_MultipleMetricsWithDifferentNumberOfSteps() {
+        ClassificationMetric metric1 = new FakeClassificationMetric("fake_metric_1", 2);
+        ClassificationMetric metric2 = new FakeClassificationMetric("fake_metric_2", 3);
+        ClassificationMetric metric3 = new FakeClassificationMetric("fake_metric_3", 4);
+        ClassificationMetric metric4 = new FakeClassificationMetric("fake_metric_4", 5);
+
+        Classification evaluation = new Classification("act", "pred", Arrays.asList(metric1, metric2, metric3, metric4));
+        assertThat(metric1.getResult(), isEmpty());
+        assertThat(metric2.getResult(), isEmpty());
+        assertThat(metric3.getResult(), isEmpty());
+        assertThat(metric4.getResult(), isEmpty());
+        assertThat(evaluation.hasAllResults(), is(false));
+
+        evaluation.process(mockSearchResponseWithNonZeroTotalHits());
+        assertThat(metric1.getResult(), isEmpty());
+        assertThat(metric2.getResult(), isEmpty());
+        assertThat(metric3.getResult(), isEmpty());
+        assertThat(metric4.getResult(), isEmpty());
+        assertThat(evaluation.hasAllResults(), is(false));
+
+        evaluation.process(mockSearchResponseWithNonZeroTotalHits());
+        assertThat(metric1.getResult(), isPresent());
+        assertThat(metric2.getResult(), isEmpty());
+        assertThat(metric3.getResult(), isEmpty());
+        assertThat(metric4.getResult(), isEmpty());
+        assertThat(evaluation.hasAllResults(), is(false));
+
+        evaluation.process(mockSearchResponseWithNonZeroTotalHits());
+        assertThat(metric1.getResult(), isPresent());
+        assertThat(metric2.getResult(), isPresent());
+        assertThat(metric3.getResult(), isEmpty());
+        assertThat(metric4.getResult(), isEmpty());
+        assertThat(evaluation.hasAllResults(), is(false));
+
+        evaluation.process(mockSearchResponseWithNonZeroTotalHits());
+        assertThat(metric1.getResult(), isPresent());
+        assertThat(metric2.getResult(), isPresent());
+        assertThat(metric3.getResult(), isPresent());
+        assertThat(metric4.getResult(), isEmpty());
+        assertThat(evaluation.hasAllResults(), is(false));
+
+        evaluation.process(mockSearchResponseWithNonZeroTotalHits());
+        assertThat(metric1.getResult(), isPresent());
+        assertThat(metric2.getResult(), isPresent());
+        assertThat(metric3.getResult(), isPresent());
+        assertThat(metric4.getResult(), isPresent());
+        assertThat(evaluation.hasAllResults(), is(true));
+
+        evaluation.process(mockSearchResponseWithNonZeroTotalHits());
+        assertThat(metric1.getResult(), isPresent());
+        assertThat(metric2.getResult(), isPresent());
+        assertThat(metric3.getResult(), isPresent());
+        assertThat(metric4.getResult(), isPresent());
+        assertThat(evaluation.hasAllResults(), is(true));
+    }
+
+    private static SearchResponse mockSearchResponseWithNonZeroTotalHits() {
+        SearchResponse searchResponse = mock(SearchResponse.class);
+        SearchHits hits = new SearchHits(SearchHits.EMPTY, new TotalHits(10, TotalHits.Relation.EQUAL_TO), 0);
+        when(searchResponse.getHits()).thenReturn(hits);
+        return searchResponse;
+    }
+
+    /**
+     * Metric which iterates through its steps in {@link #process} method.
+     * Number of steps is configurable.
+     * Upon reaching the last step, the result is produced.
+     */
+    private static class FakeClassificationMetric implements ClassificationMetric {
+
+        private final String name;
+        private final int numSteps;
+        private int currentStepIndex;
+        private EvaluationMetricResult result;
+
+        FakeClassificationMetric(String name, int numSteps) {
+            this.name = name;
+            this.numSteps = numSteps;
+        }
+
+        @Override
+        public String getName() {
+            return name;
+        }
+
+        @Override
+        public String getWriteableName() {
+            return name;
+        }
+
+        @Override
+        public List<AggregationBuilder> aggs(String actualField, String predictedField) {
+            return List.of();
+        }
+
+        @Override
+        public void process(Aggregations aggs) {
+            if (result != null) {
+                return;
+            }
+            currentStepIndex++;
+            if (currentStepIndex == numSteps) {
+                // This is the last step, time to write evaluation result
+                result = mock(EvaluationMetricResult.class);
+            }
+        }
+
+        @Override
+        public Optional<EvaluationMetricResult> getResult() {
+            return Optional.ofNullable(result);
+        }
+
+        @Override
+        public XContentBuilder toXContent(XContentBuilder builder, Params params) {
+            return builder;
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) {
+        }
+    }
+}

+ 60 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixResultTests.java

@@ -0,0 +1,60 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
+
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.test.AbstractSerializingTestCase;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+import java.util.TreeMap;
+import java.util.function.Predicate;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+public class MulticlassConfusionMatrixResultTests extends AbstractSerializingTestCase<MulticlassConfusionMatrix.Result> {
+
+    @Override
+    protected MulticlassConfusionMatrix.Result doParseInstance(XContentParser parser) throws IOException {
+        return MulticlassConfusionMatrix.Result.fromXContent(parser);
+    }
+
+    @Override
+    protected MulticlassConfusionMatrix.Result createTestInstance() {
+        int numClasses = randomIntBetween(2, 100);
+        List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
+        Map<String, Map<String, Long>> confusionMatrix = new TreeMap<>();
+        for (int i = 0; i < numClasses; i++) {
+            Map<String, Long> row = new TreeMap<>();
+            confusionMatrix.put(classNames.get(i), row);
+            for (int j = 0; j < numClasses; j++) {
+                if (randomBoolean()) {
+                    row.put(classNames.get(i), randomNonNegativeLong());
+                }
+            }
+        }
+        long otherClassesCount = randomNonNegativeLong();
+        return new MulticlassConfusionMatrix.Result(confusionMatrix, otherClassesCount);
+    }
+
+    @Override
+    protected Writeable.Reader<MulticlassConfusionMatrix.Result> instanceReader() {
+        return MulticlassConfusionMatrix.Result::new;
+    }
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return true;
+    }
+
+    @Override
+    protected Predicate<String> getRandomFieldsExcludeFilter() {
+        // allow unknown fields in the root of the object only
+        return field -> !field.isEmpty();
+    }
+}

+ 187 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java

@@ -0,0 +1,187 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
+
+import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.search.aggregations.AggregationBuilder;
+import org.elasticsearch.search.aggregations.Aggregations;
+import org.elasticsearch.search.aggregations.bucket.filter.Filters;
+import org.elasticsearch.search.aggregations.bucket.terms.Terms;
+import org.elasticsearch.search.aggregations.metrics.Cardinality;
+import org.elasticsearch.test.AbstractSerializingTestCase;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+
+import static org.hamcrest.Matchers.empty;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.not;
+import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<MulticlassConfusionMatrix> {
+
+    @Override
+    protected MulticlassConfusionMatrix doParseInstance(XContentParser parser) throws IOException {
+        return MulticlassConfusionMatrix.fromXContent(parser);
+    }
+
+    @Override
+    protected MulticlassConfusionMatrix createTestInstance() {
+        return createRandom();
+    }
+
+    @Override
+    protected Writeable.Reader<MulticlassConfusionMatrix> instanceReader() {
+        return MulticlassConfusionMatrix::new;
+    }
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return true;
+    }
+
+    public static MulticlassConfusionMatrix createRandom() {
+        Integer size = randomBoolean() ? null : randomIntBetween(1, 1000);
+        return new MulticlassConfusionMatrix(size);
+    }
+
+    public void testConstructor_SizeValidationFailures() {
+        {
+            ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(-1));
+            assertThat(e.getMessage(), equalTo("[size] must be an integer in [1, 1000]"));
+        }
+        {
+            ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(0));
+            assertThat(e.getMessage(), equalTo("[size] must be an integer in [1, 1000]"));
+        }
+        {
+            ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> new MulticlassConfusionMatrix(1001));
+            assertThat(e.getMessage(), equalTo("[size] must be an integer in [1, 1000]"));
+        }
+    }
+
+    public void testAggs() {
+        MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix();
+        List<AggregationBuilder> aggs = confusionMatrix.aggs("act", "pred");
+        assertThat(aggs, is(not(empty())));
+        assertThat(confusionMatrix.getResult(), equalTo(Optional.empty()));
+    }
+
+    public void testEvaluate() {
+        Aggregations aggs = new Aggregations(List.of(
+            mockTerms(
+                "multiclass_confusion_matrix_step_1_by_actual_class",
+                List.of(
+                    mockTermsBucket("dog", new Aggregations(List.of())),
+                    mockTermsBucket("cat", new Aggregations(List.of()))),
+                0L),
+            mockTerms(
+                "multiclass_confusion_matrix_step_2_by_actual_class",
+                List.of(
+                    mockTermsBucket(
+                        "dog",
+                        new Aggregations(List.of(mockFilters(
+                            "multiclass_confusion_matrix_step_2_by_predicted_class",
+                            List.of(mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))),
+                    mockTermsBucket(
+                        "cat",
+                        new Aggregations(List.of(mockFilters(
+                            "multiclass_confusion_matrix_step_2_by_predicted_class",
+                            List.of(mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 0L))))))),
+                0L),
+            mockCardinality("multiclass_confusion_matrix_step_2_cardinality_of_actual_class", 2L)));
+
+        MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2);
+        confusionMatrix.process(aggs);
+
+        assertThat(confusionMatrix.aggs("act", "pred"), is(empty()));
+        MulticlassConfusionMatrix.Result result = (MulticlassConfusionMatrix.Result) confusionMatrix.getResult().get();
+        assertThat(result.getMetricName(), equalTo("multiclass_confusion_matrix"));
+        assertThat(
+            result.getConfusionMatrix(),
+            equalTo(Map.of("dog", Map.of("cat", 10L, "dog", 20L), "cat", Map.of("cat", 30L, "dog", 40L))));
+        assertThat(result.getOtherClassesCount(), equalTo(0L));
+    }
+
+    public void testEvaluate_OtherClassesCountGreaterThanZero() {
+        Aggregations aggs = new Aggregations(List.of(
+            mockTerms(
+                "multiclass_confusion_matrix_step_1_by_actual_class",
+                List.of(
+                    mockTermsBucket("dog", new Aggregations(List.of())),
+                    mockTermsBucket("cat", new Aggregations(List.of()))),
+                100L),
+            mockTerms(
+                "multiclass_confusion_matrix_step_2_by_actual_class",
+                List.of(
+                    mockTermsBucket(
+                        "dog",
+                        new Aggregations(List.of(mockFilters(
+                            "multiclass_confusion_matrix_step_2_by_predicted_class",
+                            List.of(mockFiltersBucket("cat", 10L), mockFiltersBucket("dog", 20L), mockFiltersBucket("_other_", 0L)))))),
+                    mockTermsBucket(
+                        "cat",
+                        new Aggregations(List.of(mockFilters(
+                            "multiclass_confusion_matrix_step_2_by_predicted_class",
+                            List.of(mockFiltersBucket("cat", 30L), mockFiltersBucket("dog", 40L), mockFiltersBucket("_other_", 15L))))))),
+                100L),
+            mockCardinality("multiclass_confusion_matrix_step_2_cardinality_of_actual_class", 5L)));
+
+        MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2);
+        confusionMatrix.process(aggs);
+
+        assertThat(confusionMatrix.aggs("act", "pred"), is(empty()));
+        MulticlassConfusionMatrix.Result result = (MulticlassConfusionMatrix.Result) confusionMatrix.getResult().get();
+        assertThat(result.getMetricName(), equalTo("multiclass_confusion_matrix"));
+        assertThat(
+            result.getConfusionMatrix(),
+            equalTo(Map.of("dog", Map.of("cat", 10L, "dog", 20L), "cat", Map.of("cat", 30L, "dog", 40L, "_other_", 15L))));
+        assertThat(result.getOtherClassesCount(), equalTo(3L));
+    }
+
+    private static Terms mockTerms(String name, List<Terms.Bucket> buckets, long sumOfOtherDocCounts) {
+        Terms aggregation = mock(Terms.class);
+        when(aggregation.getName()).thenReturn(name);
+        doReturn(buckets).when(aggregation).getBuckets();
+        when(aggregation.getSumOfOtherDocCounts()).thenReturn(sumOfOtherDocCounts);
+        return aggregation;
+    }
+
+    private static Terms.Bucket mockTermsBucket(String actualClass, Aggregations subAggs) {
+        Terms.Bucket bucket = mock(Terms.Bucket.class);
+        when(bucket.getKeyAsString()).thenReturn(actualClass);
+        when(bucket.getAggregations()).thenReturn(subAggs);
+        return bucket;
+    }
+
+    private static Filters mockFilters(String name, List<Filters.Bucket> buckets) {
+        Filters aggregation = mock(Filters.class);
+        when(aggregation.getName()).thenReturn(name);
+        doReturn(buckets).when(aggregation).getBuckets();
+        return aggregation;
+    }
+
+    private static Filters.Bucket mockFiltersBucket(String predictedClass, long docCount) {
+        Filters.Bucket bucket = mock(Filters.Bucket.class);
+        when(bucket.getKeyAsString()).thenReturn(predictedClass);
+        when(bucket.getDocCount()).thenReturn(docCount);
+        return bucket;
+    }
+
+    private static Cardinality mockCardinality(String name, long value) {
+        Cardinality aggregation = mock(Cardinality.class);
+        when(aggregation.getName()).thenReturn(name);
+        when(aggregation.getValue()).thenReturn(value);
+        return aggregation;
+    }
+}

+ 3 - 0
x-pack/plugin/ml/qa/ml-with-security/build.gradle

@@ -90,6 +90,9 @@ integTest.runner  {
     'ml/evaluate_data_frame/Test binary_soft_classification given precision with empty thresholds',
     'ml/evaluate_data_frame/Test binary_soft_classification given recall with empty thresholds',
     'ml/evaluate_data_frame/Test binary_soft_classification given confusion_matrix with empty thresholds',
+    'ml/evaluate_data_frame/Test classification given evaluation with empty metrics',
+    'ml/evaluate_data_frame/Test classification given missing actual_field',
+    'ml/evaluate_data_frame/Test classification given missing predicted_field',
     'ml/evaluate_data_frame/Test regression given evaluation with empty metrics',
     'ml/evaluate_data_frame/Test regression given missing actual_field',
     'ml/evaluate_data_frame/Test regression given missing predicted_field',

+ 137 - 0
x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java

@@ -0,0 +1,137 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.ml.integration;
+
+import org.elasticsearch.action.bulk.BulkRequestBuilder;
+import org.elasticsearch.action.bulk.BulkResponse;
+import org.elasticsearch.action.index.IndexRequest;
+import org.elasticsearch.action.support.WriteRequest;
+import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
+import org.junit.After;
+import org.junit.Before;
+
+import java.util.List;
+import java.util.Map;
+
+import static org.hamcrest.Matchers.equalTo;
+
+public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
+
+    private static final String ANIMALS_DATA_INDEX = "test-evaluate-animals-index";
+
+    private static final String ACTUAL_CLASS_FIELD = "actual_class_field";
+    private static final String PREDICTED_CLASS_FIELD = "predicted_class_field";
+
+    @Before
+    public void setup() {
+        indexAnimalsData(ANIMALS_DATA_INDEX);
+    }
+
+    @After
+    public void cleanup() {
+        cleanUp();
+    }
+
+    public void testEvaluate_MulticlassClassification_DefaultMetrics() {
+        EvaluateDataFrameAction.Request evaluateDataFrameRequest =
+            new EvaluateDataFrameAction.Request()
+                .setIndices(List.of(ANIMALS_DATA_INDEX))
+                .setEvaluation(new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, null));
+
+        EvaluateDataFrameAction.Response evaluateDataFrameResponse =
+            client().execute(EvaluateDataFrameAction.INSTANCE, evaluateDataFrameRequest).actionGet();
+
+        assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
+        assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
+        MulticlassConfusionMatrix.Result confusionMatrixResult =
+            (MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(0);
+        assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
+        assertThat(
+            confusionMatrixResult.getConfusionMatrix(),
+            equalTo(Map.of(
+                "ant", Map.of("ant", 1L, "cat", 4L, "dog", 3L, "fox", 2L, "mouse", 5L),
+                "cat", Map.of("ant", 3L, "cat", 1L, "dog", 5L, "fox", 4L, "mouse", 2L),
+                "dog", Map.of("ant", 4L, "cat", 2L, "dog", 1L, "fox", 5L, "mouse", 3L),
+                "fox", Map.of("ant", 5L, "cat", 3L, "dog", 2L, "fox", 1L, "mouse", 4L),
+                "mouse", Map.of("ant", 2L, "cat", 5L, "dog", 4L, "fox", 3L, "mouse", 1L))));
+        assertThat(confusionMatrixResult.getOtherClassesCount(), equalTo(0L));
+    }
+
+    public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithDefaultSize() {
+        EvaluateDataFrameAction.Request evaluateDataFrameRequest =
+            new EvaluateDataFrameAction.Request()
+                .setIndices(List.of(ANIMALS_DATA_INDEX))
+                .setEvaluation(new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, List.of(new MulticlassConfusionMatrix())));
+
+        EvaluateDataFrameAction.Response evaluateDataFrameResponse =
+            client().execute(EvaluateDataFrameAction.INSTANCE, evaluateDataFrameRequest).actionGet();
+
+        assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
+        assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
+        MulticlassConfusionMatrix.Result confusionMatrixResult =
+            (MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(0);
+        assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
+        assertThat(
+            confusionMatrixResult.getConfusionMatrix(),
+            equalTo(Map.of(
+                "ant", Map.of("ant", 1L, "cat", 4L, "dog", 3L, "fox", 2L, "mouse", 5L),
+                "cat", Map.of("ant", 3L, "cat", 1L, "dog", 5L, "fox", 4L, "mouse", 2L),
+                "dog", Map.of("ant", 4L, "cat", 2L, "dog", 1L, "fox", 5L, "mouse", 3L),
+                "fox", Map.of("ant", 5L, "cat", 3L, "dog", 2L, "fox", 1L, "mouse", 4L),
+                "mouse", Map.of("ant", 2L, "cat", 5L, "dog", 4L, "fox", 3L, "mouse", 1L))));
+        assertThat(confusionMatrixResult.getOtherClassesCount(), equalTo(0L));
+    }
+
+    public void testEvaluate_MulticlassClassification_ConfusionMatrixMetricWithUserProvidedSize() {
+        EvaluateDataFrameAction.Request evaluateDataFrameRequest =
+            new EvaluateDataFrameAction.Request()
+                .setIndices(List.of(ANIMALS_DATA_INDEX))
+                .setEvaluation(new Classification(ACTUAL_CLASS_FIELD, PREDICTED_CLASS_FIELD, List.of(new MulticlassConfusionMatrix(3))));
+
+        EvaluateDataFrameAction.Response evaluateDataFrameResponse =
+            client().execute(EvaluateDataFrameAction.INSTANCE, evaluateDataFrameRequest).actionGet();
+
+        assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
+        assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
+        MulticlassConfusionMatrix.Result confusionMatrixResult =
+            (MulticlassConfusionMatrix.Result) evaluateDataFrameResponse.getMetrics().get(0);
+        assertThat(confusionMatrixResult.getMetricName(), equalTo(MulticlassConfusionMatrix.NAME.getPreferredName()));
+        assertThat(
+            confusionMatrixResult.getConfusionMatrix(),
+            equalTo(Map.of(
+                "ant", Map.of("ant", 1L, "cat", 4L, "dog", 3L, "_other_", 7L),
+                "cat", Map.of("ant", 3L, "cat", 1L, "dog", 5L, "_other_", 6L),
+                "dog", Map.of("ant", 4L, "cat", 2L, "dog", 1L, "_other_", 8L))));
+        assertThat(confusionMatrixResult.getOtherClassesCount(), equalTo(2L));
+    }
+
+    private static void indexAnimalsData(String indexName) {
+        client().admin().indices().prepareCreate(indexName)
+            .addMapping("_doc", ACTUAL_CLASS_FIELD, "type=keyword", PREDICTED_CLASS_FIELD, "type=keyword")
+            .get();
+
+        List<String> classNames = List.of("dog", "cat", "mouse", "ant", "fox");
+        BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
+            .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
+        for (int i = 0; i < classNames.size(); i++) {
+            for (int j = 0; j < classNames.size(); j++) {
+                for (int k = 0; k < j + 1; k++) {
+                    bulkRequestBuilder.add(
+                        new IndexRequest(indexName)
+                            .source(
+                                ACTUAL_CLASS_FIELD, classNames.get(i),
+                                PREDICTED_CLASS_FIELD, classNames.get((i + j) % classNames.size())));
+                }
+            }
+        }
+        BulkResponse bulkResponse = bulkRequestBuilder.get();
+        if (bulkResponse.hasFailures()) {
+            fail("Failed to index data: " + bulkResponse.buildFailureMessage());
+        }
+    }
+}

+ 118 - 0
x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml

@@ -11,6 +11,8 @@ setup:
             "outlier_score": 0.0,
             "regression_field_act": 10.9,
             "regression_field_pred": 10.9,
+            "classification_field_act": "dog",
+            "classification_field_pred": "dog",
             "all_true_field": true,
             "all_false_field": false
           }
@@ -26,6 +28,8 @@ setup:
             "outlier_score": 0.2,
             "regression_field_act": 12.0,
             "regression_field_pred": 9.9,
+            "classification_field_act": "cat",
+            "classification_field_pred": "cat",
             "all_true_field": true,
             "all_false_field": false
           }
@@ -41,6 +45,8 @@ setup:
             "outlier_score": 0.3,
             "regression_field_act": 20.9,
             "regression_field_pred": 5.9,
+            "classification_field_act": "mouse",
+            "classification_field_pred": "mouse",
             "all_true_field": true,
             "all_false_field": false
           }
@@ -56,6 +62,8 @@ setup:
             "outlier_score": 0.3,
             "regression_field_act": 11.9,
             "regression_field_pred": 11.9,
+            "classification_field_act": "dog",
+            "classification_field_pred": "cat",
             "all_true_field": true,
             "all_false_field": false
           }
@@ -71,6 +79,8 @@ setup:
             "outlier_score": 0.4,
             "regression_field_act": 42.9,
             "regression_field_pred": 42.9,
+            "classification_field_act": "cat",
+            "classification_field_pred": "dog",
             "all_true_field": true,
             "all_false_field": false
           }
@@ -86,6 +96,8 @@ setup:
             "outlier_score": 0.5,
             "regression_field_act": 0.42,
             "regression_field_pred": 0.42,
+            "classification_field_act": "dog",
+            "classification_field_pred": "dog",
             "all_true_field": true,
             "all_false_field": false
           }
@@ -101,6 +113,8 @@ setup:
             "outlier_score": 0.9,
             "regression_field_act": 1.1235813,
             "regression_field_pred": 1.12358,
+            "classification_field_act": "cat",
+            "classification_field_pred": "cat",
             "all_true_field": true,
             "all_false_field": false
           }
@@ -116,6 +130,8 @@ setup:
             "outlier_score": 0.95,
             "regression_field_act": -5.20,
             "regression_field_pred": -5.1,
+            "classification_field_act": "mouse",
+            "classification_field_pred": "cat",
             "all_true_field": true,
             "all_false_field": false
           }
@@ -569,6 +585,108 @@ setup:
               }
             }
           }
+
+---
+"Test classification given evaluation with empty metrics":
+  - do:
+      catch: /\[classification\] must have one or more metrics/
+      ml.evaluate_data_frame:
+        body: >
+          {
+            "index": "utopia",
+            "evaluation": {
+              "classification": {
+                "actual_field": "classification_field_act.keyword",
+                "predicted_field": "classification_field_pred.keyword",
+                "metrics": { }
+              }
+            }
+          }
+---
+"Test classification multiclass_confusion_matrix":
+  - do:
+      ml.evaluate_data_frame:
+        body: >
+          {
+            "index": "utopia",
+            "evaluation": {
+              "classification": {
+                "actual_field": "classification_field_act.keyword",
+                "predicted_field": "classification_field_pred.keyword",
+                "metrics": { "multiclass_confusion_matrix": {} }
+              }
+            }
+          }
+
+  - match: { classification.multiclass_confusion_matrix.confusion_matrix: {cat: {cat: 2, dog: 1, mouse: 0}, dog: {cat: 1, dog: 2, mouse: 0}, mouse: {cat: 1, dog: 0, mouse: 1} } }
+  - match: { classification.multiclass_confusion_matrix._other_: 0 }
+---
+"Test classification multiclass_confusion_matrix with explicit size":
+  - do:
+      ml.evaluate_data_frame:
+        body: >
+          {
+            "index": "utopia",
+            "evaluation": {
+              "classification": {
+                "actual_field": "classification_field_act.keyword",
+                "predicted_field": "classification_field_pred.keyword",
+                "metrics": { "multiclass_confusion_matrix": { "size": 2 } }
+              }
+            }
+          }
+
+  - match: { classification.multiclass_confusion_matrix.confusion_matrix: {cat: {cat: 2, dog: 1}, dog: {cat: 1, dog: 2} } }
+  - match: { classification.multiclass_confusion_matrix._other_: 1 }
+---
+"Test classification with null metrics":
+  - do:
+      ml.evaluate_data_frame:
+        body: >
+          {
+            "index": "utopia",
+            "evaluation": {
+              "classification": {
+                "actual_field": "classification_field_act.keyword",
+                "predicted_field": "classification_field_pred.keyword"
+              }
+            }
+          }
+
+  - match: { classification.multiclass_confusion_matrix.confusion_matrix: {cat: {cat: 2, dog: 1, mouse: 0}, dog: {cat: 1, dog: 2, mouse: 0}, mouse: {cat: 1, dog: 0, mouse: 1} } }
+  - match: { classification.multiclass_confusion_matrix._other_: 0 }
+---
+"Test classification given missing actual_field":
+  - do:
+      catch: /No documents found containing both \[missing, classification_field_pred.keyword\] fields/
+      ml.evaluate_data_frame:
+        body:  >
+          {
+            "index": "utopia",
+            "evaluation": {
+              "classification": {
+                "actual_field": "missing",
+                "predicted_field": "classification_field_pred.keyword"
+              }
+            }
+          }
+
+---
+"Test classification given missing predicted_field":
+  - do:
+      catch: /No documents found containing both \[classification_field_act.keyword, missing\] fields/
+      ml.evaluate_data_frame:
+        body:  >
+          {
+            "index": "utopia",
+            "evaluation": {
+              "classification": {
+                "actual_field": "classification_field_act.keyword",
+                "predicted_field": "missing"
+              }
+            }
+          }
+
 ---
 "Test regression given evaluation with empty metrics":
   - do: