Browse Source

[ML] Implement AucRoc metric for classification - HLRC (#62304)

Przemysław Witek 5 years ago
parent
commit
a9e54a2d9e
15 changed files with 622 additions and 308 deletions
  1. 43 27
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java
  2. 264 0
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AucRocMetric.java
  3. 35 13
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/Classification.java
  4. 10 159
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/outlierdetection/AucRocMetric.java
  5. 71 53
      client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java
  6. 11 9
      client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java
  7. 59 37
      client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java
  8. 1 1
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/EvaluateDataFrameResponseTests.java
  9. 1 1
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AucRocMetricAucRocPointTests.java
  10. 2 1
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AucRocMetricResultTests.java
  11. 55 0
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AucRocMetricTests.java
  12. 6 1
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/ClassificationTests.java
  13. 53 0
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/outlierdetection/AucRocMetricTests.java
  14. 1 1
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/outlierdetection/OutlierDetectionTests.java
  15. 10 5
      docs/java-rest/high-level/ml/evaluate-data-frame.asciidoc

+ 43 - 27
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java

@@ -19,13 +19,13 @@
 package org.elasticsearch.client.ml.dataframe.evaluation;
 
 import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric;
+import org.elasticsearch.client.ml.dataframe.evaluation.classification.AucRocMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
 import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
-import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric;
+import org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric;
+import org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.ConfusionMatrixMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.OutlierDetection;
-import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric;
-import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.regression.HuberMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
@@ -63,34 +63,42 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
             // Evaluation metrics
             new NamedXContentRegistry.Entry(
                 EvaluationMetric.class,
-                new ParseField(registeredMetricName(OutlierDetection.NAME, AucRocMetric.NAME)),
-                AucRocMetric::fromXContent),
+                new ParseField(
+                    registeredMetricName(
+                        OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.NAME)),
+                org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric::fromXContent),
             new NamedXContentRegistry.Entry(
                 EvaluationMetric.class,
-                new ParseField(registeredMetricName(OutlierDetection.NAME, PrecisionMetric.NAME)),
-                PrecisionMetric::fromXContent),
+                new ParseField(
+                    registeredMetricName(
+                        OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.NAME)),
+                org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric::fromXContent),
             new NamedXContentRegistry.Entry(
                 EvaluationMetric.class,
-                new ParseField(registeredMetricName(OutlierDetection.NAME, RecallMetric.NAME)),
-                RecallMetric::fromXContent),
+                new ParseField(
+                    registeredMetricName(
+                        OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.NAME)),
+                org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric::fromXContent),
             new NamedXContentRegistry.Entry(
                 EvaluationMetric.class,
                 new ParseField(registeredMetricName(OutlierDetection.NAME, ConfusionMatrixMetric.NAME)),
                 ConfusionMatrixMetric::fromXContent),
+            new NamedXContentRegistry.Entry(
+                EvaluationMetric.class,
+                new ParseField(registeredMetricName(Classification.NAME, AucRocMetric.NAME)),
+                AucRocMetric::fromXContent),
             new NamedXContentRegistry.Entry(
                 EvaluationMetric.class,
                 new ParseField(registeredMetricName(Classification.NAME, AccuracyMetric.NAME)),
                 AccuracyMetric::fromXContent),
             new NamedXContentRegistry.Entry(
                 EvaluationMetric.class,
-                new ParseField(registeredMetricName(
-                    Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME)),
-                org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric::fromXContent),
+                new ParseField(registeredMetricName(Classification.NAME, PrecisionMetric.NAME)),
+                PrecisionMetric::fromXContent),
             new NamedXContentRegistry.Entry(
                 EvaluationMetric.class,
-                new ParseField(registeredMetricName(
-                    Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME)),
-                org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric::fromXContent),
+                new ParseField(registeredMetricName(Classification.NAME, RecallMetric.NAME)),
+                RecallMetric::fromXContent),
             new NamedXContentRegistry.Entry(
                 EvaluationMetric.class,
                 new ParseField(registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME)),
@@ -114,34 +122,42 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
             // Evaluation metrics results
             new NamedXContentRegistry.Entry(
                 EvaluationMetric.Result.class,
-                new ParseField(registeredMetricName(OutlierDetection.NAME, AucRocMetric.NAME)),
-                AucRocMetric.Result::fromXContent),
+                new ParseField(
+                    registeredMetricName(
+                        OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.NAME)),
+                org.elasticsearch.client.ml.dataframe.evaluation.classification.AucRocMetric.Result::fromXContent),
             new NamedXContentRegistry.Entry(
                 EvaluationMetric.Result.class,
-                new ParseField(registeredMetricName(OutlierDetection.NAME, PrecisionMetric.NAME)),
-                PrecisionMetric.Result::fromXContent),
+                new ParseField(
+                    registeredMetricName(
+                        OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.NAME)),
+                org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.Result::fromXContent),
             new NamedXContentRegistry.Entry(
                 EvaluationMetric.Result.class,
-                new ParseField(registeredMetricName(OutlierDetection.NAME, RecallMetric.NAME)),
-                RecallMetric.Result::fromXContent),
+                new ParseField(
+                    registeredMetricName(
+                        OutlierDetection.NAME, org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.NAME)),
+                org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.Result::fromXContent),
             new NamedXContentRegistry.Entry(
                 EvaluationMetric.Result.class,
                 new ParseField(registeredMetricName(OutlierDetection.NAME, ConfusionMatrixMetric.NAME)),
                 ConfusionMatrixMetric.Result::fromXContent),
+            new NamedXContentRegistry.Entry(
+                EvaluationMetric.Result.class,
+                new ParseField(registeredMetricName(Classification.NAME, AucRocMetric.NAME)),
+                AucRocMetric.Result::fromXContent),
             new NamedXContentRegistry.Entry(
                 EvaluationMetric.Result.class,
                 new ParseField(registeredMetricName(Classification.NAME, AccuracyMetric.NAME)),
                 AccuracyMetric.Result::fromXContent),
             new NamedXContentRegistry.Entry(
                 EvaluationMetric.Result.class,
-                new ParseField(registeredMetricName(
-                    Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME)),
-                org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.Result::fromXContent),
+                new ParseField(registeredMetricName(Classification.NAME, PrecisionMetric.NAME)),
+                PrecisionMetric.Result::fromXContent),
             new NamedXContentRegistry.Entry(
                 EvaluationMetric.Result.class,
-                new ParseField(registeredMetricName(
-                    Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME)),
-                org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.Result::fromXContent),
+                new ParseField(registeredMetricName(Classification.NAME, RecallMetric.NAME)),
+                RecallMetric.Result::fromXContent),
             new NamedXContentRegistry.Entry(
                 EvaluationMetric.Result.class,
                 new ParseField(registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME)),

+ 264 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AucRocMetric.java

@@ -0,0 +1,264 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.dataframe.evaluation.classification;
+
+import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
+import org.elasticsearch.common.Nullable;
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ToXContent;
+import org.elasticsearch.common.xcontent.ToXContentObject;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
+import java.util.Objects;
+
+import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
+import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
+
+/**
+ * Area under the curve (AUC) of the receiver operating characteristic (ROC).
+ * The ROC curve is a plot of the TPR (true positive rate) against
+ * the FPR (false positive rate) over a varying threshold.
+ */
+public class AucRocMetric implements EvaluationMetric {
+
+    public static final String NAME = "auc_roc";
+
+    public static final ParseField CLASS_NAME = new ParseField("class_name");
+    public static final ParseField INCLUDE_CURVE = new ParseField("include_curve");
+
+    @SuppressWarnings("unchecked")
+    public static final ConstructingObjectParser<AucRocMetric, Void> PARSER =
+        new ConstructingObjectParser<>(NAME, true, args -> new AucRocMetric((String) args[0], (Boolean) args[1]));
+
+    static {
+        PARSER.declareString(constructorArg(), CLASS_NAME);
+        PARSER.declareBoolean(optionalConstructorArg(), INCLUDE_CURVE);
+    }
+
+    public static AucRocMetric fromXContent(XContentParser parser) {
+        return PARSER.apply(parser, null);
+    }
+
+    public static AucRocMetric forClass(String className) {
+        return new AucRocMetric(className, false);
+    }
+
+    public static AucRocMetric forClassWithCurve(String className) {
+        return new AucRocMetric(className, true);
+    }
+
+    private final String className;
+    private final Boolean includeCurve;
+
+    public AucRocMetric(String className, Boolean includeCurve) {
+        this.className = Objects.requireNonNull(className);
+        this.includeCurve = includeCurve;
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
+        builder.startObject();
+        builder.field(CLASS_NAME.getPreferredName(), className);
+        if (includeCurve != null) {
+            builder.field(INCLUDE_CURVE.getPreferredName(), includeCurve);
+        }
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public String getName() {
+        return NAME;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        AucRocMetric that = (AucRocMetric) o;
+        return Objects.equals(className, that.className)
+            && Objects.equals(includeCurve, that.includeCurve);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(className, includeCurve);
+    }
+
+    public static class Result implements EvaluationMetric.Result {
+
+        public static Result fromXContent(XContentParser parser) {
+            return PARSER.apply(parser, null);
+        }
+
+        private static final ParseField SCORE = new ParseField("score");
+        private static final ParseField DOC_COUNT = new ParseField("doc_count");
+        private static final ParseField CURVE = new ParseField("curve");
+
+        @SuppressWarnings("unchecked")
+        private static final ConstructingObjectParser<Result, Void> PARSER =
+            new ConstructingObjectParser<>(
+                "auc_roc_result", true, args -> new Result((double) args[0], (long) args[1], (List<AucRocPoint>) args[2]));
+
+        static {
+            PARSER.declareDouble(constructorArg(), SCORE);
+            PARSER.declareLong(constructorArg(), DOC_COUNT);
+            PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> AucRocPoint.fromXContent(p), CURVE);
+        }
+
+        private final double score;
+        private final long docCount;
+        private final List<AucRocPoint> curve;
+
+        public Result(double score, long docCount, @Nullable List<AucRocPoint> curve) {
+            this.score = score;
+            this.docCount = docCount;
+            this.curve = curve;
+        }
+
+        @Override
+        public String getMetricName() {
+            return NAME;
+        }
+
+        public double getScore() {
+            return score;
+        }
+
+        public long getDocCount() {
+            return docCount;
+        }
+
+        public List<AucRocPoint> getCurve() {
+            return curve == null ? null : Collections.unmodifiableList(curve);
+        }
+
+        @Override
+        public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
+            builder.startObject();
+            builder.field(SCORE.getPreferredName(), score);
+            builder.field(DOC_COUNT.getPreferredName(), docCount);
+            if (curve != null && curve.isEmpty() == false) {
+                builder.field(CURVE.getPreferredName(), curve);
+            }
+            builder.endObject();
+            return builder;
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            Result that = (Result) o;
+            return score == that.score
+                && docCount == that.docCount
+                && Objects.equals(curve, that.curve);
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(score, docCount, curve);
+        }
+
+        @Override
+        public String toString() {
+            return Strings.toString(this);
+        }
+    }
+
+    public static final class AucRocPoint implements ToXContentObject {
+
+        public static AucRocPoint fromXContent(XContentParser parser) {
+            return PARSER.apply(parser, null);
+        }
+
+        private static final ParseField TPR = new ParseField("tpr");
+        private static final ParseField FPR = new ParseField("fpr");
+        private static final ParseField THRESHOLD = new ParseField("threshold");
+
+        @SuppressWarnings("unchecked")
+        private static final ConstructingObjectParser<AucRocPoint, Void> PARSER =
+            new ConstructingObjectParser<>(
+                "auc_roc_point",
+                true,
+                args -> new AucRocPoint((double) args[0], (double) args[1], (double) args[2]));
+
+        static {
+            PARSER.declareDouble(constructorArg(), TPR);
+            PARSER.declareDouble(constructorArg(), FPR);
+            PARSER.declareDouble(constructorArg(), THRESHOLD);
+        }
+
+        private final double tpr;
+        private final double fpr;
+        private final double threshold;
+
+        public AucRocPoint(double tpr, double fpr, double threshold) {
+            this.tpr = tpr;
+            this.fpr = fpr;
+            this.threshold = threshold;
+        }
+
+        public double getTruePositiveRate() {
+            return tpr;
+        }
+
+        public double getFalsePositiveRate() {
+            return fpr;
+        }
+
+        public double getThreshold() {
+            return threshold;
+        }
+
+        @Override
+        public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+            return builder
+                .startObject()
+                .field(TPR.getPreferredName(), tpr)
+                .field(FPR.getPreferredName(), fpr)
+                .field(THRESHOLD.getPreferredName(), threshold)
+                .endObject();
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            AucRocPoint that = (AucRocPoint) o;
+            return tpr == that.tpr && fpr == that.fpr && threshold == that.threshold;
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(tpr, fpr, threshold);
+        }
+
+        @Override
+        public String toString() {
+            return Strings.toString(this);
+        }
+    }
+}

+ 35 - 13
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/Classification.java

@@ -45,15 +45,20 @@ public class Classification implements Evaluation {
 
     private static final ParseField ACTUAL_FIELD = new ParseField("actual_field");
     private static final ParseField PREDICTED_FIELD = new ParseField("predicted_field");
+    private static final ParseField TOP_CLASSES_FIELD = new ParseField("top_classes_field");
+
     private static final ParseField METRICS = new ParseField("metrics");
 
     @SuppressWarnings("unchecked")
     public static final ConstructingObjectParser<Classification, Void> PARSER = new ConstructingObjectParser<>(
-        NAME, true, a -> new Classification((String) a[0], (String) a[1], (List<EvaluationMetric>) a[2]));
+        NAME,
+        true,
+        a -> new Classification((String) a[0], (String) a[1], (String) a[2], (List<EvaluationMetric>) a[3]));
 
     static {
         PARSER.declareString(constructorArg(), ACTUAL_FIELD);
-        PARSER.declareString(constructorArg(), PREDICTED_FIELD);
+        PARSER.declareString(optionalConstructorArg(), PREDICTED_FIELD);
+        PARSER.declareString(optionalConstructorArg(), TOP_CLASSES_FIELD);
         PARSER.declareNamedObjects(
             optionalConstructorArg(), (p, c, n) -> p.namedObject(EvaluationMetric.class, registeredMetricName(NAME, n), c), METRICS);
     }
@@ -64,32 +69,44 @@ public class Classification implements Evaluation {
 
     /**
      * The field containing the actual value
-     * The value of this field is assumed to be numeric
      */
     private final String actualField;
 
     /**
      * The field containing the predicted value
-     * The value of this field is assumed to be numeric
      */
     private final String predictedField;
 
+    /**
+     * The field containing the array of top classes
+     */
+    private final String topClassesField;
+
     /**
      * The list of metrics to calculate
      */
     private final List<EvaluationMetric> metrics;
 
-    public Classification(String actualField, String predictedField) {
-        this(actualField, predictedField, (List<EvaluationMetric>)null);
+    public Classification(String actualField,
+                          String predictedField,
+                          String topClassesField) {
+        this(actualField, predictedField, topClassesField, (List<EvaluationMetric>)null);
     }
 
-    public Classification(String actualField, String predictedField, EvaluationMetric... metrics) {
-        this(actualField, predictedField, Arrays.asList(metrics));
+    public Classification(String actualField,
+                          String predictedField,
+                          String topClassesField,
+                          EvaluationMetric... metrics) {
+        this(actualField, predictedField, topClassesField, Arrays.asList(metrics));
     }
 
-    public Classification(String actualField, String predictedField, @Nullable List<EvaluationMetric> metrics) {
+    public Classification(String actualField,
+                          @Nullable String predictedField,
+                          @Nullable String topClassesField,
+                          @Nullable List<EvaluationMetric> metrics) {
         this.actualField = Objects.requireNonNull(actualField);
-        this.predictedField = Objects.requireNonNull(predictedField);
+        this.predictedField = predictedField;
+        this.topClassesField = topClassesField;
         if (metrics != null) {
             metrics.sort(Comparator.comparing(EvaluationMetric::getName));
         }
@@ -105,8 +122,12 @@ public class Classification implements Evaluation {
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
         builder.startObject();
         builder.field(ACTUAL_FIELD.getPreferredName(), actualField);
-        builder.field(PREDICTED_FIELD.getPreferredName(), predictedField);
-
+        if (predictedField != null) {
+            builder.field(PREDICTED_FIELD.getPreferredName(), predictedField);
+        }
+        if (topClassesField != null) {
+            builder.field(TOP_CLASSES_FIELD.getPreferredName(), topClassesField);
+        }
         if (metrics != null) {
            builder.startObject(METRICS.getPreferredName());
            for (EvaluationMetric metric : metrics) {
@@ -126,11 +147,12 @@ public class Classification implements Evaluation {
         Classification that = (Classification) o;
         return Objects.equals(that.actualField, this.actualField)
             && Objects.equals(that.predictedField, this.predictedField)
+            && Objects.equals(that.topClassesField, this.topClassesField)
             && Objects.equals(that.metrics, this.metrics);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(actualField, predictedField, metrics);
+        return Objects.hash(actualField, predictedField, topClassesField, metrics);
     }
 }

+ 10 - 159
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/outlierdetection/AucRocMetric.java

@@ -19,21 +19,14 @@
 package org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection;
 
 import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
-import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.ParseField;
-import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.xcontent.ConstructingObjectParser;
-import org.elasticsearch.common.xcontent.ToXContent;
-import org.elasticsearch.common.xcontent.ToXContentObject;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentParser;
 
 import java.io.IOException;
-import java.util.Collections;
-import java.util.List;
 import java.util.Objects;
 
-import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
 import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
 
 /**
@@ -49,7 +42,7 @@ public class AucRocMetric implements EvaluationMetric {
 
     @SuppressWarnings("unchecked")
     public static final ConstructingObjectParser<AucRocMetric, Void> PARSER =
-        new ConstructingObjectParser<>(NAME, args -> new AucRocMetric((Boolean) args[0]));
+        new ConstructingObjectParser<>(NAME, true, args -> new AucRocMetric((Boolean) args[0]));
 
     static {
         PARSER.declareBoolean(optionalConstructorArg(), INCLUDE_CURVE);
@@ -63,18 +56,20 @@ public class AucRocMetric implements EvaluationMetric {
         return new AucRocMetric(true);
     }
 
-    private final boolean includeCurve;
+    private final Boolean includeCurve;
 
     public AucRocMetric(Boolean includeCurve) {
-        this.includeCurve = includeCurve == null ? false : includeCurve;
+        this.includeCurve = includeCurve;
     }
 
     @Override
-    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
-        return builder
-            .startObject()
-            .field(INCLUDE_CURVE.getPreferredName(), includeCurve)
-            .endObject();
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        if (includeCurve != null) {
+            builder.field(INCLUDE_CURVE.getPreferredName(), includeCurve);
+        }
+        builder.endObject();
+        return builder;
     }
 
     @Override
@@ -94,148 +89,4 @@ public class AucRocMetric implements EvaluationMetric {
     public int hashCode() {
         return Objects.hash(includeCurve);
     }
-
-    public static class Result implements EvaluationMetric.Result {
-
-        public static Result fromXContent(XContentParser parser) {
-            return PARSER.apply(parser, null);
-        }
-
-        private static final ParseField SCORE = new ParseField("score");
-        private static final ParseField CURVE = new ParseField("curve");
-
-        @SuppressWarnings("unchecked")
-        private static final ConstructingObjectParser<Result, Void> PARSER =
-            new ConstructingObjectParser<>("auc_roc_result", true, args -> new Result((double) args[0], (List<AucRocPoint>) args[1]));
-
-        static {
-            PARSER.declareDouble(constructorArg(), SCORE);
-            PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> AucRocPoint.fromXContent(p), CURVE);
-        }
-
-        private final double score;
-        private final List<AucRocPoint> curve;
-
-        public Result(double score, @Nullable List<AucRocPoint> curve) {
-            this.score = score;
-            this.curve = curve;
-        }
-
-        @Override
-        public String getMetricName() {
-            return NAME;
-        }
-
-        public double getScore() {
-            return score;
-        }
-
-        public List<AucRocPoint> getCurve() {
-            return curve == null ? null : Collections.unmodifiableList(curve);
-        }
-
-        @Override
-        public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
-            builder.startObject();
-            builder.field(SCORE.getPreferredName(), score);
-            if (curve != null && curve.isEmpty() == false) {
-                builder.field(CURVE.getPreferredName(), curve);
-            }
-            builder.endObject();
-            return builder;
-        }
-
-        @Override
-        public boolean equals(Object o) {
-            if (this == o) return true;
-            if (o == null || getClass() != o.getClass()) return false;
-            Result that = (Result) o;
-            return Objects.equals(score, that.score)
-                && Objects.equals(curve, that.curve);
-        }
-
-        @Override
-        public int hashCode() {
-            return Objects.hash(score, curve);
-        }
-
-        @Override
-        public String toString() {
-            return Strings.toString(this);
-        }
-    }
-
-    public static final class AucRocPoint implements ToXContentObject {
-
-        public static AucRocPoint fromXContent(XContentParser parser) {
-            return PARSER.apply(parser, null);
-        }
-
-        private static final ParseField TPR = new ParseField("tpr");
-        private static final ParseField FPR = new ParseField("fpr");
-        private static final ParseField THRESHOLD = new ParseField("threshold");
-
-        @SuppressWarnings("unchecked")
-        private static final ConstructingObjectParser<AucRocPoint, Void> PARSER =
-            new ConstructingObjectParser<>(
-                "auc_roc_point",
-                true,
-                args -> new AucRocPoint((double) args[0], (double) args[1], (double) args[2]));
-
-        static {
-            PARSER.declareDouble(constructorArg(), TPR);
-            PARSER.declareDouble(constructorArg(), FPR);
-            PARSER.declareDouble(constructorArg(), THRESHOLD);
-        }
-
-        private final double tpr;
-        private final double fpr;
-        private final double threshold;
-
-        public AucRocPoint(double tpr, double fpr, double threshold) {
-            this.tpr = tpr;
-            this.fpr = fpr;
-            this.threshold = threshold;
-        }
-
-        public double getTruePositiveRate() {
-            return tpr;
-        }
-
-        public double getFalsePositiveRate() {
-            return fpr;
-        }
-
-        public double getThreshold() {
-            return threshold;
-        }
-
-        @Override
-        public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
-            return builder
-                .startObject()
-                .field(TPR.getPreferredName(), tpr)
-                .field(FPR.getPreferredName(), fpr)
-                .field(THRESHOLD.getPreferredName(), threshold)
-                .endObject();
-        }
-
-        @Override
-        public boolean equals(Object o) {
-            if (this == o) return true;
-            if (o == null || getClass() != o.getClass()) return false;
-            AucRocPoint that = (AucRocPoint) o;
-            return tpr == that.tpr && fpr == that.fpr && threshold == that.threshold;
-        }
-
-        @Override
-        public int hashCode() {
-            return Objects.hash(tpr, fpr, threshold);
-        }
-
-        @Override
-        public String toString() {
-            return Strings.toString(this);
-        }
-    }
 }

+ 71 - 53
client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java

@@ -138,13 +138,13 @@ import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsStats;
 import org.elasticsearch.client.ml.dataframe.PhaseProgress;
 import org.elasticsearch.client.ml.dataframe.QueryConfig;
 import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric;
+import org.elasticsearch.client.ml.dataframe.evaluation.classification.AucRocMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
 import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
-import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric;
+import org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric;
+import org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.ConfusionMatrixMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.OutlierDetection;
-import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric;
-import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.regression.HuberMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
@@ -1744,15 +1744,22 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
                 new OutlierDetection(
                     actualField,
                     probabilityField,
-                    PrecisionMetric.at(0.4, 0.5, 0.6), RecallMetric.at(0.5, 0.7), ConfusionMatrixMetric.at(0.5), AucRocMetric.withCurve()));
+                    org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.at(0.4, 0.5, 0.6),
+                    org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.at(0.5, 0.7),
+                    ConfusionMatrixMetric.at(0.5),
+                    org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.withCurve()));
 
         EvaluateDataFrameResponse evaluateDataFrameResponse =
             execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
         assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(OutlierDetection.NAME));
         assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(4));
 
-        PrecisionMetric.Result precisionResult = evaluateDataFrameResponse.getMetricByName(PrecisionMetric.NAME);
-        assertThat(precisionResult.getMetricName(), equalTo(PrecisionMetric.NAME));
+        org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.Result precisionResult =
+            evaluateDataFrameResponse.getMetricByName(
+                org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.NAME);
+        assertThat(
+            precisionResult.getMetricName(),
+            equalTo(org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.NAME));
         // Precision is 3/5=0.6 as there were 3 true examples (#7, #8, #9) among the 5 positive examples (#3, #4, #7, #8, #9)
         assertThat(precisionResult.getScoreByThreshold("0.4"), closeTo(0.6, 1e-9));
         // Precision is 2/3=0.(6) as there were 2 true examples (#8, #9) among the 3 positive examples (#4, #8, #9)
@@ -1761,8 +1768,11 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
         assertThat(precisionResult.getScoreByThreshold("0.6"), closeTo(0.666666666, 1e-9));
         assertNull(precisionResult.getScoreByThreshold("0.1"));
 
-        RecallMetric.Result recallResult = evaluateDataFrameResponse.getMetricByName(RecallMetric.NAME);
-        assertThat(recallResult.getMetricName(), equalTo(RecallMetric.NAME));
+        org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.Result recallResult =
+            evaluateDataFrameResponse.getMetricByName(org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.NAME);
+        assertThat(
+            recallResult.getMetricName(),
+            equalTo(org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.NAME));
         // Recall is 2/5=0.4 as there were 2 true positive examples (#8, #9) among the 5 true examples (#5, #6, #7, #8, #9)
         assertThat(recallResult.getScoreByThreshold("0.5"), closeTo(0.4, 1e-9));
         // Recall is 2/5=0.4 as there were 2 true positive examples (#8, #9) among the 5 true examples (#5, #6, #7, #8, #9)
@@ -1778,7 +1788,8 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
         assertThat(confusionMatrix.getFalseNegatives(), equalTo(3L));  // docs #5, #6 and #7
         assertNull(confusionMatrixResult.getScoreByThreshold("0.1"));
 
-        AucRocMetric.Result aucRocResult = evaluateDataFrameResponse.getMetricByName(AucRocMetric.NAME);
+        AucRocMetric.Result aucRocResult =
+            evaluateDataFrameResponse.getMetricByName(org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.NAME);
         assertThat(aucRocResult.getMetricName(), equalTo(AucRocMetric.NAME));
         assertThat(aucRocResult.getScore(), closeTo(0.70025, 1e-9));
         assertNotNull(aucRocResult.getCurve());
@@ -1890,24 +1901,40 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
         createIndex(indexName, mappingForClassification());
         BulkRequest regressionBulk = new BulkRequest()
             .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
-            .add(docForClassification(indexName, "cat", "cat"))
-            .add(docForClassification(indexName, "cat", "cat"))
-            .add(docForClassification(indexName, "cat", "cat"))
-            .add(docForClassification(indexName, "cat", "dog"))
-            .add(docForClassification(indexName, "cat", "fish"))
-            .add(docForClassification(indexName, "dog", "cat"))
-            .add(docForClassification(indexName, "dog", "dog"))
-            .add(docForClassification(indexName, "dog", "dog"))
-            .add(docForClassification(indexName, "dog", "dog"))
-            .add(docForClassification(indexName, "ant", "cat"));
+            .add(docForClassification(indexName, "cat", "cat", 0.9))
+            .add(docForClassification(indexName, "cat", "cat", 0.85))
+            .add(docForClassification(indexName, "cat", "cat", 0.95))
+            .add(docForClassification(indexName, "cat", "dog", 0.4))
+            .add(docForClassification(indexName, "cat", "fish", 0.35))
+            .add(docForClassification(indexName, "dog", "cat", 0.5))
+            .add(docForClassification(indexName, "dog", "dog", 0.4))
+            .add(docForClassification(indexName, "dog", "dog", 0.35))
+            .add(docForClassification(indexName, "dog", "dog", 0.6))
+            .add(docForClassification(indexName, "ant", "cat", 0.1));
         highLevelClient().bulk(regressionBulk, RequestOptions.DEFAULT);
 
         MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
 
+        {  // AucRoc
+            EvaluateDataFrameRequest evaluateDataFrameRequest =
+                new EvaluateDataFrameRequest(
+                    indexName, null, new Classification(actualClassField, null, topClassesField, AucRocMetric.forClassWithCurve("cat")));
+
+            EvaluateDataFrameResponse evaluateDataFrameResponse =
+                execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
+            assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME));
+            assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
+
+            AucRocMetric.Result aucRocResult = evaluateDataFrameResponse.getMetricByName(AucRocMetric.NAME);
+            assertThat(aucRocResult.getMetricName(), equalTo(AucRocMetric.NAME));
+            assertThat(aucRocResult.getScore(), closeTo(0.99995, 1e-9));
+            assertThat(aucRocResult.getDocCount(), equalTo(5L));
+            assertNotNull(aucRocResult.getCurve());
+        }
         {  // Accuracy
             EvaluateDataFrameRequest evaluateDataFrameRequest =
                 new EvaluateDataFrameRequest(
-                    indexName, null, new Classification(actualClassField, predictedClassField, new AccuracyMetric()));
+                    indexName, null, new Classification(actualClassField, predictedClassField, null, new AccuracyMetric()));
 
             EvaluateDataFrameResponse evaluateDataFrameResponse =
                 execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
@@ -1931,65 +1958,47 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
         {  // Precision
             EvaluateDataFrameRequest evaluateDataFrameRequest =
                 new EvaluateDataFrameRequest(
-                    indexName,
-                    null,
-                    new Classification(
-                        actualClassField,
-                        predictedClassField,
-                        new org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric()));
+                    indexName, null, new Classification(actualClassField, predictedClassField, null, new PrecisionMetric()));
 
             EvaluateDataFrameResponse evaluateDataFrameResponse =
                 execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
             assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME));
             assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
 
-            org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.Result precisionResult =
-                evaluateDataFrameResponse.getMetricByName(
-                    org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME);
-            assertThat(
-                precisionResult.getMetricName(),
-                equalTo(org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME));
+            PrecisionMetric.Result precisionResult = evaluateDataFrameResponse.getMetricByName(PrecisionMetric.NAME);
+            assertThat(precisionResult.getMetricName(), equalTo(PrecisionMetric.NAME));
             assertThat(
                 precisionResult.getClasses(),
                 equalTo(
                     List.of(
                         // 3 out of 5 examples labeled as "cat" were classified correctly
-                        new org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.PerClassResult("cat", 0.6),
+                        new PrecisionMetric.PerClassResult("cat", 0.6),
                         // 3 out of 4 examples labeled as "dog" were classified correctly
-                        new org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.PerClassResult("dog", 0.75))));
+                        new PrecisionMetric.PerClassResult("dog", 0.75))));
             assertThat(precisionResult.getAvgPrecision(), equalTo(0.675));
         }
         {  // Recall
             EvaluateDataFrameRequest evaluateDataFrameRequest =
                 new EvaluateDataFrameRequest(
-                    indexName,
-                    null,
-                    new Classification(
-                        actualClassField,
-                        predictedClassField,
-                        new org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric()));
+                    indexName, null, new Classification(actualClassField, predictedClassField, null, new RecallMetric()));
 
             EvaluateDataFrameResponse evaluateDataFrameResponse =
                 execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
             assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME));
             assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
 
-            org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.Result recallResult =
-                evaluateDataFrameResponse.getMetricByName(
-                    org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME);
-            assertThat(
-                recallResult.getMetricName(),
-                equalTo(org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME));
+            RecallMetric.Result recallResult = evaluateDataFrameResponse.getMetricByName(RecallMetric.NAME);
+            assertThat(recallResult.getMetricName(), equalTo(RecallMetric.NAME));
             assertThat(
                 recallResult.getClasses(),
                 equalTo(
                     List.of(
                         // 3 out of 5 examples labeled as "cat" were classified correctly
-                        new org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.PerClassResult("cat", 0.6),
+                        new RecallMetric.PerClassResult("cat", 0.6),
                         // 3 out of 4 examples labeled as "dog" were classified correctly
-                        new org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.PerClassResult("dog", 0.75),
+                        new RecallMetric.PerClassResult("dog", 0.75),
                         // no examples labeled as "ant" were classified correctly
-                        new org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.PerClassResult("ant", 0.0))));
+                        new RecallMetric.PerClassResult("ant", 0.0))));
             assertThat(recallResult.getAvgRecall(), equalTo(0.45));
         }
         {  // No size provided for MulticlassConfusionMatrixMetric, default used instead
@@ -1997,7 +2006,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
                 new EvaluateDataFrameRequest(
                     indexName,
                     null,
-                    new Classification(actualClassField, predictedClassField, new MulticlassConfusionMatrixMetric()));
+                    new Classification(actualClassField, predictedClassField, null, new MulticlassConfusionMatrixMetric()));
 
             EvaluateDataFrameResponse evaluateDataFrameResponse =
                 execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
@@ -2042,7 +2051,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
                 new EvaluateDataFrameRequest(
                     indexName,
                     null,
-                    new Classification(actualClassField, predictedClassField, new MulticlassConfusionMatrixMetric(2)));
+                    new Classification(actualClassField, predictedClassField, null, new MulticlassConfusionMatrixMetric(2)));
 
             EvaluateDataFrameResponse evaluateDataFrameResponse =
                 execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
@@ -2116,6 +2125,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
 
     private static final String actualClassField = "actual_class";
     private static final String predictedClassField = "predicted_class";
+    private static final String topClassesField = "top_classes";
 
     private static XContentBuilder mappingForClassification() throws IOException {
         return XContentFactory.jsonBuilder().startObject()
@@ -2126,14 +2136,22 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
                 .startObject(predictedClassField)
                     .field("type", "keyword")
                 .endObject()
+                .startObject(topClassesField)
+                    .field("type", "nested")
+                .endObject()
             .endObject()
         .endObject();
     }
 
-    private static IndexRequest docForClassification(String indexName, String actualClass, String predictedClass) {
+    private static IndexRequest docForClassification(String indexName, String actualClass, String predictedClass, double p) {
         return new IndexRequest()
             .index(indexName)
-            .source(XContentType.JSON, actualClassField, actualClass, predictedClassField, predictedClass);
+            .source(XContentType.JSON,
+                actualClassField, actualClass,
+                predictedClassField, predictedClass,
+                topClassesField, List.of(
+                    Map.of("class_name", predictedClass, "class_probability", p),
+                    Map.of("class_name", "other", "class_probability", 1 - p)));
     }
 
     private static final String actualRegression = "regression_actual";

+ 11 - 9
client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java

@@ -59,18 +59,18 @@ import org.elasticsearch.client.ilm.UnfollowAction;
 import org.elasticsearch.client.ilm.WaitForSnapshotAction;
 import org.elasticsearch.client.ml.dataframe.DataFrameAnalysis;
 import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric;
+import org.elasticsearch.client.ml.dataframe.evaluation.classification.AucRocMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
 import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
+import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.ConfusionMatrixMetric;
+import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.OutlierDetection;
+import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric;
+import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric;
+import org.elasticsearch.client.ml.dataframe.evaluation.regression.HuberMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
-import org.elasticsearch.client.ml.dataframe.evaluation.regression.HuberMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
-import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric;
-import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.OutlierDetection;
-import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.ConfusionMatrixMetric;
-import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric;
-import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric;
 import org.elasticsearch.client.ml.dataframe.stats.classification.ClassificationStats;
 import org.elasticsearch.client.ml.dataframe.stats.outlierdetection.OutlierDetectionStats;
 import org.elasticsearch.client.ml.dataframe.stats.regression.RegressionStats;
@@ -707,7 +707,7 @@ public class RestHighLevelClientTests extends ESTestCase {
 
     public void testProvidedNamedXContents() {
         List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
-        assertEquals(73, namedXContents.size());
+        assertEquals(75, namedXContents.size());
         Map<Class<?>, Integer> categories = new HashMap<>();
         List<String> names = new ArrayList<>();
         for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
@@ -756,13 +756,14 @@ public class RestHighLevelClientTests extends ESTestCase {
         assertTrue(names.contains(TimeSyncConfig.NAME));
         assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class));
         assertThat(names, hasItems(OutlierDetection.NAME, Classification.NAME, Regression.NAME));
-        assertEquals(Integer.valueOf(12), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class));
+        assertEquals(Integer.valueOf(13), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class));
         assertThat(names,
             hasItems(
                 registeredMetricName(OutlierDetection.NAME, AucRocMetric.NAME),
                 registeredMetricName(OutlierDetection.NAME, PrecisionMetric.NAME),
                 registeredMetricName(OutlierDetection.NAME, RecallMetric.NAME),
                 registeredMetricName(OutlierDetection.NAME, ConfusionMatrixMetric.NAME),
+                registeredMetricName(Classification.NAME, AucRocMetric.NAME),
                 registeredMetricName(Classification.NAME, AccuracyMetric.NAME),
                 registeredMetricName(
                     Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME),
@@ -773,13 +774,14 @@ public class RestHighLevelClientTests extends ESTestCase {
                 registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME),
                 registeredMetricName(Regression.NAME, HuberMetric.NAME),
                 registeredMetricName(Regression.NAME, RSquaredMetric.NAME)));
-        assertEquals(Integer.valueOf(12), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class));
+        assertEquals(Integer.valueOf(13), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class));
         assertThat(names,
             hasItems(
                 registeredMetricName(OutlierDetection.NAME, AucRocMetric.NAME),
                 registeredMetricName(OutlierDetection.NAME, PrecisionMetric.NAME),
                 registeredMetricName(OutlierDetection.NAME, RecallMetric.NAME),
                 registeredMetricName(OutlierDetection.NAME, ConfusionMatrixMetric.NAME),
+                registeredMetricName(Classification.NAME, AucRocMetric.NAME),
                 registeredMetricName(Classification.NAME, AccuracyMetric.NAME),
                 registeredMetricName(
                     Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME),

+ 59 - 37
client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java

@@ -156,15 +156,15 @@ import org.elasticsearch.client.ml.dataframe.Regression;
 import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation;
 import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric;
+import org.elasticsearch.client.ml.dataframe.evaluation.classification.AucRocMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.ActualClass;
 import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.PredictedClass;
-import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric;
+import org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric;
+import org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.ConfusionMatrixMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.ConfusionMatrixMetric.ConfusionMatrix;
 import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.OutlierDetection;
-import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric;
-import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.regression.HuberMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
@@ -201,6 +201,7 @@ import org.elasticsearch.client.ml.job.results.CategoryDefinition;
 import org.elasticsearch.client.ml.job.results.Influencer;
 import org.elasticsearch.client.ml.job.results.OverallBucket;
 import org.elasticsearch.client.ml.job.stats.JobStats;
+import org.elasticsearch.common.TriFunction;
 import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.common.unit.ByteSizeUnit;
 import org.elasticsearch.common.unit.ByteSizeValue;
@@ -3326,7 +3327,7 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
             30, TimeUnit.SECONDS);
     }
 
-    public void testEvaluateDataFrame() throws Exception {
+    public void testEvaluateDataFrame_OutlierDetection() throws Exception {
         String indexName = "evaluate-test-index";
         CreateIndexRequest createIndexRequest =
             new CreateIndexRequest(indexName)
@@ -3363,10 +3364,10 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
                     "label", // <2>
                     "p", // <3>
                     // Evaluation metrics // <4>
-                    PrecisionMetric.at(0.4, 0.5, 0.6), // <5>
-                    RecallMetric.at(0.5, 0.7), // <6>
+                    org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.at(0.4, 0.5, 0.6), // <5>
+                    org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.at(0.5, 0.7), // <6>
                     ConfusionMatrixMetric.at(0.5), // <7>
-                    AucRocMetric.withCurve()); // <8>
+                    org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.withCurve()); // <8>
             // end::evaluate-data-frame-evaluation-outlierdetection
 
             // tag::evaluate-data-frame-request
@@ -3386,7 +3387,8 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
             // end::evaluate-data-frame-response
 
             // tag::evaluate-data-frame-results-outlierdetection
-            PrecisionMetric.Result precisionResult = response.getMetricByName(PrecisionMetric.NAME); // <1>
+            org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.Result precisionResult =
+                response.getMetricByName(org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.NAME); // <1>
             double precision = precisionResult.getScoreByThreshold("0.4"); // <2>
 
             ConfusionMatrixMetric.Result confusionMatrixResult = response.getMetricByName(ConfusionMatrixMetric.NAME); // <3>
@@ -3395,7 +3397,11 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
 
             assertThat(
                 metrics.stream().map(EvaluationMetric.Result::getMetricName).collect(Collectors.toList()),
-                containsInAnyOrder(PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME, AucRocMetric.NAME));
+                containsInAnyOrder(
+                    org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.NAME,
+                    org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.NAME,
+                    ConfusionMatrixMetric.NAME,
+                    org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.NAME));
             assertThat(precision, closeTo(0.6, 1e-9));
             assertThat(confusionMatrix.getTruePositives(), equalTo(2L));  // docs #8 and #9
             assertThat(confusionMatrix.getFalsePositives(), equalTo(1L));  // doc #4
@@ -3409,10 +3415,10 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
                 new OutlierDetection(
                     "label",
                     "p",
-                    PrecisionMetric.at(0.4, 0.5, 0.6),
-                    RecallMetric.at(0.5, 0.7),
+                    org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.at(0.4, 0.5, 0.6),
+                    org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.at(0.5, 0.7),
                     ConfusionMatrixMetric.at(0.5),
-                    AucRocMetric.withCurve()));
+                    org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.withCurve()));
 
             // tag::evaluate-data-frame-execute-listener
             ActionListener<EvaluateDataFrameResponse> listener = new ActionListener<>() {
@@ -3452,21 +3458,33 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
                         .startObject("predicted_class")
                             .field("type", "keyword")
                         .endObject()
+                        .startObject("ml.top_classes")
+                            .field("type", "nested")
+                        .endObject()
                     .endObject()
                     .endObject());
+        TriFunction<String, String, Double, IndexRequest> indexRequest = (actualClass, predictedClass, p) -> {
+            return new IndexRequest()
+                .source(XContentType.JSON,
+                    "actual_class", actualClass,
+                    "predicted_class", predictedClass,
+                    "ml.top_classes", List.of(
+                        Map.of("class_name", predictedClass, "class_probability", p),
+                        Map.of("class_name", "other", "class_probability", 1 - p)));
+        };
         BulkRequest bulkRequest =
             new BulkRequest(indexName)
                 .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
-                .add(new IndexRequest().source(XContentType.JSON, "actual_class", "cat", "predicted_class", "cat")) // #0
-                .add(new IndexRequest().source(XContentType.JSON, "actual_class", "cat", "predicted_class", "cat")) // #1
-                .add(new IndexRequest().source(XContentType.JSON, "actual_class", "cat", "predicted_class", "cat")) // #2
-                .add(new IndexRequest().source(XContentType.JSON, "actual_class", "cat", "predicted_class", "dog")) // #3
-                .add(new IndexRequest().source(XContentType.JSON, "actual_class", "cat", "predicted_class", "fox")) // #4
-                .add(new IndexRequest().source(XContentType.JSON, "actual_class", "dog", "predicted_class", "cat")) // #5
-                .add(new IndexRequest().source(XContentType.JSON, "actual_class", "dog", "predicted_class", "dog")) // #6
-                .add(new IndexRequest().source(XContentType.JSON, "actual_class", "dog", "predicted_class", "dog")) // #7
-                .add(new IndexRequest().source(XContentType.JSON, "actual_class", "dog", "predicted_class", "dog")) // #8
-                .add(new IndexRequest().source(XContentType.JSON, "actual_class", "ant", "predicted_class", "cat")); // #9
+                .add(indexRequest.apply("cat", "cat", 0.9)) // #0
+                .add(indexRequest.apply("cat", "cat", 0.9)) // #1
+                .add(indexRequest.apply("cat", "cat", 0.9)) // #2
+                .add(indexRequest.apply("cat", "dog", 0.9)) // #3
+                .add(indexRequest.apply("cat", "fox", 0.9)) // #4
+                .add(indexRequest.apply("dog", "cat", 0.9)) // #5
+                .add(indexRequest.apply("dog", "dog", 0.9)) // #6
+                .add(indexRequest.apply("dog", "dog", 0.9)) // #7
+                .add(indexRequest.apply("dog", "dog", 0.9)) // #8
+                .add(indexRequest.apply("ant", "cat", 0.9)); // #9
         RestHighLevelClient client = highLevelClient();
         client.indices().create(createIndexRequest, RequestOptions.DEFAULT);
         client.bulk(bulkRequest, RequestOptions.DEFAULT);
@@ -3476,11 +3494,13 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
                 new org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification( // <1>
                     "actual_class", // <2>
                     "predicted_class", // <3>
-                    // Evaluation metrics // <4>
-                    new AccuracyMetric(), // <5>
-                    new org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric(), // <6>
-                    new org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric(), // <7>
-                    new MulticlassConfusionMatrixMetric(3)); // <8>
+                    "ml.top_classes", // <4>
+                    // Evaluation metrics // <5>
+                    new AccuracyMetric(), // <6>
+                    new PrecisionMetric(), // <7>
+                    new RecallMetric(), // <8>
+                    new MulticlassConfusionMatrixMetric(3), // <9>
+                    AucRocMetric.forClass("cat")); // <10>
             // end::evaluate-data-frame-evaluation-classification
 
             EvaluateDataFrameRequest request = new EvaluateDataFrameRequest(indexName, null, evaluation);
@@ -3490,12 +3510,10 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
             AccuracyMetric.Result accuracyResult = response.getMetricByName(AccuracyMetric.NAME); // <1>
             double accuracy = accuracyResult.getOverallAccuracy(); // <2>
 
-            org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.Result precisionResult =
-                response.getMetricByName(org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME); // <3>
+            PrecisionMetric.Result precisionResult = response.getMetricByName(PrecisionMetric.NAME); // <3>
             double precision = precisionResult.getAvgPrecision(); // <4>
 
-            org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.Result recallResult =
-                response.getMetricByName(org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME); // <5>
+            RecallMetric.Result recallResult = response.getMetricByName(RecallMetric.NAME); // <5>
             double recall = recallResult.getAvgRecall(); // <6>
 
             MulticlassConfusionMatrixMetric.Result multiclassConfusionMatrix =
@@ -3503,19 +3521,19 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
 
             List<ActualClass> confusionMatrix = multiclassConfusionMatrix.getConfusionMatrix(); // <8>
             long otherClassesCount = multiclassConfusionMatrix.getOtherActualClassCount(); // <9>
+
+            AucRocMetric.Result aucRocResult = response.getMetricByName(AucRocMetric.NAME); // <10>
+            double aucRocScore = aucRocResult.getScore(); // <11>
+            Long aucRocDocCount = aucRocResult.getDocCount(); // <12>
             // end::evaluate-data-frame-results-classification
 
             assertThat(accuracyResult.getMetricName(), equalTo(AccuracyMetric.NAME));
             assertThat(accuracy, equalTo(0.6));
 
-            assertThat(
-                precisionResult.getMetricName(),
-                equalTo(org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME));
+            assertThat(precisionResult.getMetricName(), equalTo(PrecisionMetric.NAME));
             assertThat(precision, equalTo(0.675));
 
-            assertThat(
-                recallResult.getMetricName(),
-                equalTo(org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME));
+            assertThat(recallResult.getMetricName(), equalTo(RecallMetric.NAME));
             assertThat(recall, equalTo(0.45));
 
             assertThat(multiclassConfusionMatrix.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME));
@@ -3539,6 +3557,10 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
                             List.of(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)),
                             0L))));
             assertThat(otherClassesCount, equalTo(0L));
+
+            assertThat(aucRocResult.getMetricName(), equalTo(AucRocMetric.NAME));
+            assertThat(aucRocScore, equalTo(0.2625));
+            assertThat(aucRocDocCount, equalTo(5L));
         }
     }
 

+ 1 - 1
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/EvaluateDataFrameResponseTests.java

@@ -26,7 +26,7 @@ import org.elasticsearch.client.ml.dataframe.evaluation.classification.Multiclas
 import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetricResultTests;
 import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetricResultTests;
 import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
-import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetricResultTests;
+import org.elasticsearch.client.ml.dataframe.evaluation.classification.AucRocMetricResultTests;
 import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.OutlierDetection;
 import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.ConfusionMatrixMetricResultTests;
 import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetricResultTests;

+ 1 - 1
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/outlierdetection/AucRocMetricAucRocPointTests.java → client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AucRocMetricAucRocPointTests.java

@@ -16,7 +16,7 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection;
+package org.elasticsearch.client.ml.dataframe.evaluation.classification;
 
 import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.test.AbstractXContentTestCase;

+ 2 - 1
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/outlierdetection/AucRocMetricResultTests.java → client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AucRocMetricResultTests.java

@@ -16,7 +16,7 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection;
+package org.elasticsearch.client.ml.dataframe.evaluation.classification;
 
 import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.test.AbstractXContentTestCase;
@@ -31,6 +31,7 @@ public class AucRocMetricResultTests extends AbstractXContentTestCase<AucRocMetr
     public static AucRocMetric.Result randomResult() {
         return new AucRocMetric.Result(
             randomDouble(),
+            randomLong(),
             Stream
                 .generate(AucRocMetricAucRocPointTests::randomPoint)
                 .limit(randomIntBetween(1, 10))

+ 55 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/AucRocMetricTests.java

@@ -0,0 +1,55 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.dataframe.evaluation.classification;
+
+import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.test.AbstractXContentTestCase;
+
+import java.io.IOException;
+
+public class AucRocMetricTests extends AbstractXContentTestCase<AucRocMetric> {
+
+    @Override
+    protected NamedXContentRegistry xContentRegistry() {
+        return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
+    }
+
+    public static AucRocMetric createRandom() {
+        return new AucRocMetric(
+            randomAlphaOfLengthBetween(1, 10),
+            randomBoolean() ? randomBoolean() : null);
+    }
+
+    @Override
+    protected AucRocMetric createTestInstance() {
+        return createRandom();
+    }
+
+    @Override
+    protected AucRocMetric doParseInstance(XContentParser parser) throws IOException {
+        return AucRocMetric.fromXContent(parser);
+    }
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return true;
+    }
+}

+ 6 - 1
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/ClassificationTests.java

@@ -40,11 +40,16 @@ public class ClassificationTests extends AbstractXContentTestCase<Classification
         List<EvaluationMetric> metrics =
             randomSubsetOf(
                 Arrays.asList(
+                    AucRocMetricTests.createRandom(),
                     AccuracyMetricTests.createRandom(),
                     PrecisionMetricTests.createRandom(),
                     RecallMetricTests.createRandom(),
                     MulticlassConfusionMatrixMetricTests.createRandom()));
-        return new Classification(randomAlphaOfLength(10), randomAlphaOfLength(10), metrics.isEmpty() ? null : metrics);
+        return new Classification(
+            randomAlphaOfLength(10),
+            randomBoolean() ? randomAlphaOfLength(10) : null,
+            randomBoolean() ? randomAlphaOfLength(10) : null,
+            metrics.isEmpty() ? null : metrics);
     }
 
     @Override

+ 53 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/outlierdetection/AucRocMetricTests.java

@@ -0,0 +1,53 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection;
+
+import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.test.AbstractXContentTestCase;
+
+import java.io.IOException;
+
+public class AucRocMetricTests extends AbstractXContentTestCase<AucRocMetric> {
+
+    @Override
+    protected NamedXContentRegistry xContentRegistry() {
+        return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
+    }
+
+    public static AucRocMetric createRandom() {
+        return new AucRocMetric(randomBoolean() ? randomBoolean() : null);
+    }
+
+    @Override
+    protected AucRocMetric createTestInstance() {
+        return createRandom();
+    }
+
+    @Override
+    protected AucRocMetric doParseInstance(XContentParser parser) throws IOException {
+        return AucRocMetric.fromXContent(parser);
+    }
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return true;
+    }
+}

+ 1 - 1
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/outlierdetection/OutlierDetectionTests.java

@@ -40,7 +40,7 @@ public class OutlierDetectionTests extends AbstractXContentTestCase<OutlierDetec
     public static OutlierDetection createRandom() {
         List<EvaluationMetric> metrics = new ArrayList<>();
         if (randomBoolean()) {
-            metrics.add(new AucRocMetric(randomBoolean()));
+            metrics.add(AucRocMetricTests.createRandom());
         }
         if (randomBoolean()) {
             metrics.add(new PrecisionMetric(Arrays.asList(randomArray(1,

+ 10 - 5
docs/java-rest/high-level/ml/evaluate-data-frame.asciidoc

@@ -51,11 +51,13 @@ include-tagged::{doc-tests-file}[{api}-evaluation-classification]
 <1> Constructing a new evaluation
 <2> Name of the field in the index. Its value denotes the actual (i.e. ground truth) class the example belongs to.
 <3> Name of the field in the index. Its value denotes the predicted (as per some ML algorithm) class of the example.
-<4> The remaining parameters are the metrics to be calculated based on the two fields described above
-<5> Accuracy
-<6> Precision
-<7> Recall
-<8> Multiclass confusion matrix of size 3
+<4> Name of the field in the index. Its value denotes the array of top classes. Must be nested.
+<5> The remaining parameters are the metrics to be calculated based on the two fields described above
+<6> Accuracy
+<7> Precision
+<8> Recall
+<9> Multiclass confusion matrix of size 3
+<10> {wikipedia}/Receiver_operating_characteristic#Area_under_the_curve[AuC ROC] calculated for class "cat" treated as positive and the rest as negative
 
 ===== Regression
 
@@ -115,6 +117,9 @@ include-tagged::{doc-tests-file}[{api}-results-classification]
 <7> Fetching multiclass confusion matrix metric by name
 <8> Fetching the contents of the confusion matrix
 <9> Fetching the number of classes that were not included in the matrix
+<10> Fetching AucRoc metric by name
+<11> Fetching the actual AucRoc score
+<12> Fetching the number of documents that were used in order to calculate AucRoc score
 
 ===== Regression