Pārlūkot izejas kodu

Implement `precision` and `recall` metrics for classification evaluation (#49671)

Przemysław Witek 5 gadi atpakaļ
vecāks
revīzija
786ead630a
54 mainītis faili ar 2483 papildinājumiem un 368 dzēšanām
  1. 4 3
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/EvaluateDataFrameResponse.java
  2. 85 21
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java
  3. 8 4
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/Classification.java
  4. 201 0
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/PrecisionMetric.java
  5. 201 0
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/RecallMetric.java
  6. 8 4
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/Regression.java
  7. 3 1
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java
  8. 64 0
      client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java
  9. 30 19
      client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java
  10. 24 4
      client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java
  11. 2 0
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/EvaluateDataFrameResponseTests.java
  12. 2 0
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/ClassificationTests.java
  13. 67 0
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/PrecisionMetricResultTests.java
  14. 53 0
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/PrecisionMetricTests.java
  15. 67 0
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/RecallMetricResultTests.java
  16. 53 0
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/RecallMetricTests.java
  17. 11 5
      docs/java-rest/high-level/ml/evaluate-data-frame.asciidoc
  18. 10 24
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java
  19. 5 2
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/Evaluation.java
  20. 3 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetric.java
  21. 146 76
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java
  22. 16 10
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java
  23. 12 9
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Classification.java
  24. 0 11
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationMetric.java
  25. 28 20
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java
  26. 345 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java
  27. 319 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java
  28. 13 6
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredError.java
  29. 15 8
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquared.java
  30. 12 9
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java
  31. 0 11
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionMetric.java
  32. 8 5
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AbstractConfusionMatrixMetric.java
  33. 13 7
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AucRoc.java
  34. 17 8
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java
  35. 4 2
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ConfusionMatrix.java
  36. 3 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Precision.java
  37. 3 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Recall.java
  38. 5 2
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ScoreByThresholdResult.java
  39. 0 17
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/SoftClassificationMetric.java
  40. 5 6
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameActionRequestTests.java
  41. 8 3
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameActionResponseTests.java
  42. 14 10
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyResultTests.java
  43. 14 9
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationTests.java
  44. 7 5
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java
  45. 48 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/PrecisionResultTests.java
  46. 119 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/PrecisionTests.java
  47. 48 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/RecallResultTests.java
  48. 118 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/RecallTests.java
  49. 49 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/TupleMatchers.java
  50. 3 2
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java
  51. 3 2
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassificationTests.java
  52. 111 38
      x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java
  53. 24 2
      x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java
  54. 52 0
      x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml

+ 4 - 3
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/EvaluateDataFrameResponse.java

@@ -34,6 +34,7 @@ import java.util.Map;
 import java.util.Objects;
 import java.util.stream.Collectors;
 
+import static org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
 import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
 
 public class EvaluateDataFrameResponse implements ToXContentObject {
@@ -46,7 +47,7 @@ public class EvaluateDataFrameResponse implements ToXContentObject {
         ensureExpectedToken(XContentParser.Token.FIELD_NAME, parser.nextToken(), parser::getTokenLocation);
         String evaluationName = parser.currentName();
         parser.nextToken();
-        Map<String, EvaluationMetric.Result> metrics = parser.map(LinkedHashMap::new, EvaluateDataFrameResponse::parseMetric);
+        Map<String, EvaluationMetric.Result> metrics = parser.map(LinkedHashMap::new, p -> parseMetric(evaluationName, p));
         List<EvaluationMetric.Result> knownMetrics =
             metrics.values().stream()
                 .filter(Objects::nonNull)  // Filter out null values returned by {@link EvaluateDataFrameResponse::parseMetric}.
@@ -55,10 +56,10 @@ public class EvaluateDataFrameResponse implements ToXContentObject {
         return new EvaluateDataFrameResponse(evaluationName, knownMetrics);
     }
 
-    private static EvaluationMetric.Result parseMetric(XContentParser parser) throws IOException {
+    private static EvaluationMetric.Result parseMetric(String evaluationName, XContentParser parser) throws IOException {
         String metricName = parser.currentName();
         try {
-            return parser.namedObject(EvaluationMetric.Result.class, metricName, null);
+            return parser.namedObject(EvaluationMetric.Result.class, registeredMetricName(evaluationName, metricName), null);
         } catch (NamedObjectNotFoundException e) {
             parser.skipChildren();
             // Metric name not recognized. Return {@code null} value here and filter it out later.

+ 85 - 21
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java

@@ -20,24 +20,36 @@ package org.elasticsearch.client.ml.dataframe.evaluation;
 
 import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
-import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
+import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
-import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
-import org.elasticsearch.common.ParseField;
-import org.elasticsearch.common.xcontent.NamedXContentRegistry;
-import org.elasticsearch.plugins.spi.NamedXContentProvider;
 import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
+import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
 import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetric;
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.plugins.spi.NamedXContentProvider;
 
 import java.util.Arrays;
 import java.util.List;
 
 public class MlEvaluationNamedXContentProvider implements NamedXContentProvider {
 
+    /**
+     * Constructs the name under which a metric (or metric result) is registered.
+     * The name is prefixed with evaluation name so that registered names are unique.
+     *
+     * @param evaluationName name of the evaluation
+     * @param metricName name of the metric
+     * @return name appropriate for registering a metric (or metric result) in {@link NamedXContentRegistry}
+     */
+    public static String registeredMetricName(String evaluationName, String metricName) {
+        return evaluationName + "." +  metricName;
+    }
+
     @Override
     public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
         return Arrays.asList(
@@ -47,39 +59,91 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
             new NamedXContentRegistry.Entry(Evaluation.class, new ParseField(Classification.NAME), Classification::fromXContent),
             new NamedXContentRegistry.Entry(Evaluation.class, new ParseField(Regression.NAME), Regression::fromXContent),
             // Evaluation metrics
-            new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(AucRocMetric.NAME), AucRocMetric::fromXContent),
-            new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(PrecisionMetric.NAME), PrecisionMetric::fromXContent),
-            new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(RecallMetric.NAME), RecallMetric::fromXContent),
             new NamedXContentRegistry.Entry(
-                EvaluationMetric.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric::fromXContent),
+                EvaluationMetric.class,
+                new ParseField(registeredMetricName(BinarySoftClassification.NAME, AucRocMetric.NAME)),
+                AucRocMetric::fromXContent),
             new NamedXContentRegistry.Entry(
-                EvaluationMetric.class, new ParseField(AccuracyMetric.NAME), AccuracyMetric::fromXContent),
+                EvaluationMetric.class,
+                new ParseField(registeredMetricName(BinarySoftClassification.NAME, PrecisionMetric.NAME)),
+                PrecisionMetric::fromXContent),
             new NamedXContentRegistry.Entry(
                 EvaluationMetric.class,
-                new ParseField(MulticlassConfusionMatrixMetric.NAME),
+                new ParseField(registeredMetricName(BinarySoftClassification.NAME, RecallMetric.NAME)),
+                RecallMetric::fromXContent),
+            new NamedXContentRegistry.Entry(
+                EvaluationMetric.class,
+                new ParseField(registeredMetricName(BinarySoftClassification.NAME, ConfusionMatrixMetric.NAME)),
+                ConfusionMatrixMetric::fromXContent),
+            new NamedXContentRegistry.Entry(
+                EvaluationMetric.class,
+                new ParseField(registeredMetricName(Classification.NAME, AccuracyMetric.NAME)),
+                AccuracyMetric::fromXContent),
+            new NamedXContentRegistry.Entry(
+                EvaluationMetric.class,
+                new ParseField(registeredMetricName(
+                    Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME)),
+                org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric::fromXContent),
+            new NamedXContentRegistry.Entry(
+                EvaluationMetric.class,
+                new ParseField(registeredMetricName(
+                    Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME)),
+                org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric::fromXContent),
+            new NamedXContentRegistry.Entry(
+                EvaluationMetric.class,
+                new ParseField(registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME)),
                 MulticlassConfusionMatrixMetric::fromXContent),
             new NamedXContentRegistry.Entry(
-                EvaluationMetric.class, new ParseField(MeanSquaredErrorMetric.NAME), MeanSquaredErrorMetric::fromXContent),
+                EvaluationMetric.class,
+                new ParseField(registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME)),
+                MeanSquaredErrorMetric::fromXContent),
             new NamedXContentRegistry.Entry(
-                EvaluationMetric.class, new ParseField(RSquaredMetric.NAME), RSquaredMetric::fromXContent),
+                EvaluationMetric.class,
+                new ParseField(registeredMetricName(Regression.NAME, RSquaredMetric.NAME)),
+                RSquaredMetric::fromXContent),
             // Evaluation metrics results
             new NamedXContentRegistry.Entry(
-                EvaluationMetric.Result.class, new ParseField(AucRocMetric.NAME), AucRocMetric.Result::fromXContent),
+                EvaluationMetric.Result.class,
+                new ParseField(registeredMetricName(BinarySoftClassification.NAME, AucRocMetric.NAME)),
+                AucRocMetric.Result::fromXContent),
+            new NamedXContentRegistry.Entry(
+                EvaluationMetric.Result.class,
+                new ParseField(registeredMetricName(BinarySoftClassification.NAME, PrecisionMetric.NAME)),
+                PrecisionMetric.Result::fromXContent),
             new NamedXContentRegistry.Entry(
-                EvaluationMetric.Result.class, new ParseField(PrecisionMetric.NAME), PrecisionMetric.Result::fromXContent),
+                EvaluationMetric.Result.class,
+                new ParseField(registeredMetricName(BinarySoftClassification.NAME, RecallMetric.NAME)),
+                RecallMetric.Result::fromXContent),
             new NamedXContentRegistry.Entry(
-                EvaluationMetric.Result.class, new ParseField(RecallMetric.NAME), RecallMetric.Result::fromXContent),
+                EvaluationMetric.Result.class,
+                new ParseField(registeredMetricName(BinarySoftClassification.NAME, ConfusionMatrixMetric.NAME)),
+                ConfusionMatrixMetric.Result::fromXContent),
             new NamedXContentRegistry.Entry(
-                EvaluationMetric.Result.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric.Result::fromXContent),
+                EvaluationMetric.Result.class,
+                new ParseField(registeredMetricName(Classification.NAME, AccuracyMetric.NAME)),
+                AccuracyMetric.Result::fromXContent),
             new NamedXContentRegistry.Entry(
-                EvaluationMetric.Result.class, new ParseField(AccuracyMetric.NAME), AccuracyMetric.Result::fromXContent),
+                EvaluationMetric.Result.class,
+                new ParseField(registeredMetricName(
+                    Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME)),
+                org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.Result::fromXContent),
             new NamedXContentRegistry.Entry(
                 EvaluationMetric.Result.class,
-                new ParseField(MulticlassConfusionMatrixMetric.NAME),
+                new ParseField(registeredMetricName(
+                    Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME)),
+                org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.Result::fromXContent),
+            new NamedXContentRegistry.Entry(
+                EvaluationMetric.Result.class,
+                new ParseField(registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME)),
                 MulticlassConfusionMatrixMetric.Result::fromXContent),
             new NamedXContentRegistry.Entry(
-                EvaluationMetric.Result.class, new ParseField(MeanSquaredErrorMetric.NAME), MeanSquaredErrorMetric.Result::fromXContent),
+                EvaluationMetric.Result.class,
+                new ParseField(registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME)),
+                MeanSquaredErrorMetric.Result::fromXContent),
             new NamedXContentRegistry.Entry(
-                EvaluationMetric.Result.class, new ParseField(RSquaredMetric.NAME), RSquaredMetric.Result::fromXContent));
+                EvaluationMetric.Result.class,
+                new ParseField(registeredMetricName(Regression.NAME, RSquaredMetric.NAME)),
+                RSquaredMetric.Result::fromXContent)
+        );
     }
 }

+ 8 - 4
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/Classification.java

@@ -32,6 +32,10 @@ import java.util.Comparator;
 import java.util.List;
 import java.util.Objects;
 
+import static org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
+import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
+import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
+
 /**
  * Evaluation of classification results.
  */
@@ -48,10 +52,10 @@ public class Classification implements Evaluation {
         NAME, true, a -> new Classification((String) a[0], (String) a[1], (List<EvaluationMetric>) a[2]));
 
     static {
-        PARSER.declareString(ConstructingObjectParser.constructorArg(), ACTUAL_FIELD);
-        PARSER.declareString(ConstructingObjectParser.constructorArg(), PREDICTED_FIELD);
-        PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(),
-            (p, c, n) -> p.namedObject(EvaluationMetric.class, n, c), METRICS);
+        PARSER.declareString(constructorArg(), ACTUAL_FIELD);
+        PARSER.declareString(constructorArg(), PREDICTED_FIELD);
+        PARSER.declareNamedObjects(
+            optionalConstructorArg(), (p, c, n) -> p.namedObject(EvaluationMetric.class, registeredMetricName(NAME, n), c), METRICS);
     }
 
     public static Classification fromXContent(XContentParser parser) {

+ 201 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/PrecisionMetric.java

@@ -0,0 +1,201 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.dataframe.evaluation.classification;
+
+import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ObjectParser;
+import org.elasticsearch.common.xcontent.ToXContentObject;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
+import java.util.Objects;
+
+import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
+
+/**
+ * {@link PrecisionMetric} is a metric that answers the question:
+ *   "What fraction of documents classified as X actually belongs to X?"
+ * for any given class X
+ *
+ * equation: precision(X) = TP(X) / (TP(X) + FP(X))
+ * where: TP(X) - number of true positives wrt X
+ *        FP(X) - number of false positives wrt X
+ */
+public class PrecisionMetric implements EvaluationMetric {
+
+    public static final String NAME = "precision";
+
+    private static final ObjectParser<PrecisionMetric, Void> PARSER = new ObjectParser<>(NAME, true, PrecisionMetric::new);
+
+    public static PrecisionMetric fromXContent(XContentParser parser) {
+        return PARSER.apply(parser, null);
+    }
+
+    public PrecisionMetric() {}
+
+    @Override
+    public String getName() {
+        return NAME;
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        return true;
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hashCode(NAME);
+    }
+
+    public static class Result implements EvaluationMetric.Result {
+
+        private static final ParseField CLASSES = new ParseField("classes");
+        private static final ParseField AVG_PRECISION = new ParseField("avg_precision");
+
+        @SuppressWarnings("unchecked")
+        private static final ConstructingObjectParser<Result, Void> PARSER =
+            new ConstructingObjectParser<>("precision_result", true, a -> new Result((List<PerClassResult>) a[0], (double) a[1]));
+
+        static {
+            PARSER.declareObjectArray(constructorArg(), PerClassResult.PARSER, CLASSES);
+            PARSER.declareDouble(constructorArg(), AVG_PRECISION);
+        }
+
+        public static Result fromXContent(XContentParser parser) {
+            return PARSER.apply(parser, null);
+        }
+
+        /** List of per-class results. */
+        private final List<PerClassResult> classes;
+        /** Average of per-class precisions. */
+        private final double avgPrecision;
+
+        public Result(List<PerClassResult> classes, double avgPrecision) {
+            this.classes = Collections.unmodifiableList(Objects.requireNonNull(classes));
+            this.avgPrecision = avgPrecision;
+        }
+
+        @Override
+        public String getMetricName() {
+            return NAME;
+        }
+
+        public List<PerClassResult> getClasses() {
+            return classes;
+        }
+
+        public double getAvgPrecision() {
+            return avgPrecision;
+        }
+
+        @Override
+        public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+            builder.startObject();
+            builder.field(CLASSES.getPreferredName(), classes);
+            builder.field(AVG_PRECISION.getPreferredName(), avgPrecision);
+            builder.endObject();
+            return builder;
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            Result that = (Result) o;
+            return Objects.equals(this.classes, that.classes)
+                && this.avgPrecision == that.avgPrecision;
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(classes, avgPrecision);
+        }
+    }
+
+    public static class PerClassResult implements ToXContentObject {
+
+        private static final ParseField CLASS_NAME = new ParseField("class_name");
+        private static final ParseField PRECISION = new ParseField("precision");
+
+        @SuppressWarnings("unchecked")
+        private static final ConstructingObjectParser<PerClassResult, Void> PARSER =
+            new ConstructingObjectParser<>("precision_per_class_result", true, a -> new PerClassResult((String) a[0], (double) a[1]));
+
+        static {
+            PARSER.declareString(constructorArg(), CLASS_NAME);
+            PARSER.declareDouble(constructorArg(), PRECISION);
+        }
+
+        /** Name of the class. */
+        private final String className;
+        /** Fraction of documents predicted as belonging to the {@code predictedClass} class predicted correctly. */
+        private final double precision;
+
+        public PerClassResult(String className, double precision) {
+            this.className = Objects.requireNonNull(className);
+            this.precision = precision;
+        }
+
+        public String getClassName() {
+            return className;
+        }
+
+        public double getPrecision() {
+            return precision;
+        }
+
+        @Override
+        public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+            builder.startObject();
+            builder.field(CLASS_NAME.getPreferredName(), className);
+            builder.field(PRECISION.getPreferredName(), precision);
+            builder.endObject();
+            return builder;
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            PerClassResult that = (PerClassResult) o;
+            return Objects.equals(this.className, that.className)
+                && this.precision == that.precision;
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(className, precision);
+        }
+    }
+}

+ 201 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/RecallMetric.java

@@ -0,0 +1,201 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.dataframe.evaluation.classification;
+
+import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ObjectParser;
+import org.elasticsearch.common.xcontent.ToXContentObject;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
+import java.util.Objects;
+
+import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
+
+/**
+ * {@link RecallMetric} is a metric that answers the question:
+ *   "What fraction of documents belonging to X have been predicted as X by the classifier?"
+ * for any given class X
+ *
+ * equation: recall(X) = TP(X) / (TP(X) + FN(X))
+ * where: TP(X) - number of true positives wrt X
+ *        FN(X) - number of false negatives wrt X
+ */
+public class RecallMetric implements EvaluationMetric {
+
+    public static final String NAME = "recall";
+
+    private static final ObjectParser<RecallMetric, Void> PARSER = new ObjectParser<>(NAME, true, RecallMetric::new);
+
+    public static RecallMetric fromXContent(XContentParser parser) {
+        return PARSER.apply(parser, null);
+    }
+
+    public RecallMetric() {}
+
+    @Override
+    public String getName() {
+        return NAME;
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        return true;
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hashCode(NAME);
+    }
+
+    public static class Result implements EvaluationMetric.Result {
+
+        private static final ParseField CLASSES = new ParseField("classes");
+        private static final ParseField AVG_RECALL = new ParseField("avg_recall");
+
+        @SuppressWarnings("unchecked")
+        private static final ConstructingObjectParser<Result, Void> PARSER =
+            new ConstructingObjectParser<>("recall_result", true, a -> new Result((List<PerClassResult>) a[0], (double) a[1]));
+
+        static {
+            PARSER.declareObjectArray(constructorArg(), PerClassResult.PARSER, CLASSES);
+            PARSER.declareDouble(constructorArg(), AVG_RECALL);
+        }
+
+        public static Result fromXContent(XContentParser parser) {
+            return PARSER.apply(parser, null);
+        }
+
+        /** List of per-class results. */
+        private final List<PerClassResult> classes;
+        /** Average of per-class recalls. */
+        private final double avgRecall;
+
+        public Result(List<PerClassResult> classes, double avgRecall) {
+            this.classes = Collections.unmodifiableList(Objects.requireNonNull(classes));
+            this.avgRecall = avgRecall;
+        }
+
+        @Override
+        public String getMetricName() {
+            return NAME;
+        }
+
+        public List<PerClassResult> getClasses() {
+            return classes;
+        }
+
+        public double getAvgRecall() {
+            return avgRecall;
+        }
+
+        @Override
+        public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+            builder.startObject();
+            builder.field(CLASSES.getPreferredName(), classes);
+            builder.field(AVG_RECALL.getPreferredName(), avgRecall);
+            builder.endObject();
+            return builder;
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            Result that = (Result) o;
+            return Objects.equals(this.classes, that.classes)
+                && this.avgRecall == that.avgRecall;
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(classes, avgRecall);
+        }
+    }
+
+    public static class PerClassResult implements ToXContentObject {
+
+        private static final ParseField CLASS_NAME = new ParseField("class_name");
+        private static final ParseField RECALL = new ParseField("recall");
+
+        @SuppressWarnings("unchecked")
+        private static final ConstructingObjectParser<PerClassResult, Void> PARSER =
+            new ConstructingObjectParser<>("recall_per_class_result", true, a -> new PerClassResult((String) a[0], (double) a[1]));
+
+        static {
+            PARSER.declareString(constructorArg(), CLASS_NAME);
+            PARSER.declareDouble(constructorArg(), RECALL);
+        }
+
+        /** Name of the class. */
+        private final String className;
+        /** Fraction of documents actually belonging to the {@code actualClass} class predicted correctly. */
+        private final double recall;
+
+        public PerClassResult(String className, double recall) {
+            this.className = Objects.requireNonNull(className);
+            this.recall = recall;
+        }
+
+        public String getClassName() {
+            return className;
+        }
+
+        public double getRecall() {
+            return recall;
+        }
+
+        @Override
+        public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+            builder.startObject();
+            builder.field(CLASS_NAME.getPreferredName(), className);
+            builder.field(RECALL.getPreferredName(), recall);
+            builder.endObject();
+            return builder;
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            PerClassResult that = (PerClassResult) o;
+            return Objects.equals(this.className, that.className)
+                && this.recall == that.recall;
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(className, recall);
+        }
+    }
+}

+ 8 - 4
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/Regression.java

@@ -33,6 +33,10 @@ import java.util.Comparator;
 import java.util.List;
 import java.util.Objects;
 
+import static org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
+import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
+import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
+
 /**
  * Evaluation of regression results.
  */
@@ -49,10 +53,10 @@ public class Regression implements Evaluation {
         NAME, true, a -> new Regression((String) a[0], (String) a[1], (List<EvaluationMetric>) a[2]));
 
     static {
-        PARSER.declareString(ConstructingObjectParser.constructorArg(), ACTUAL_FIELD);
-        PARSER.declareString(ConstructingObjectParser.constructorArg(), PREDICTED_FIELD);
-        PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(),
-            (p, c, n) -> p.namedObject(EvaluationMetric.class, n, c), METRICS);
+        PARSER.declareString(constructorArg(), ACTUAL_FIELD);
+        PARSER.declareString(constructorArg(), PREDICTED_FIELD);
+        PARSER.declareNamedObjects(
+            optionalConstructorArg(), (p, c, n) -> p.namedObject(EvaluationMetric.class, registeredMetricName(NAME, n), c), METRICS);
     }
 
     public static Regression fromXContent(XContentParser parser) {

+ 3 - 1
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java

@@ -33,6 +33,7 @@ import java.util.Comparator;
 import java.util.List;
 import java.util.Objects;
 
+import static org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
 import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
 import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
 
@@ -59,7 +60,8 @@ public class BinarySoftClassification implements Evaluation {
     static {
         PARSER.declareString(constructorArg(), ACTUAL_FIELD);
         PARSER.declareString(constructorArg(), PREDICTED_PROBABILITY_FIELD);
-        PARSER.declareNamedObjects(optionalConstructorArg(), (p, c, n) -> p.namedObject(EvaluationMetric.class, n, null), METRICS);
+        PARSER.declareNamedObjects(
+            optionalConstructorArg(), (p, c, n) -> p.namedObject(EvaluationMetric.class, registeredMetricName(NAME, n), null), METRICS);
     }
 
     public static BinarySoftClassification fromXContent(XContentParser parser) {

+ 64 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java

@@ -1830,6 +1830,70 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
                         new AccuracyMetric.ActualClass("ant", 1, 0.0))));
             assertThat(accuracyResult.getOverallAccuracy(), equalTo(0.6));  // 6 out of 10 examples were classified correctly
         }
+        {  // Precision
+            EvaluateDataFrameRequest evaluateDataFrameRequest =
+                new EvaluateDataFrameRequest(
+                    indexName,
+                    null,
+                    new Classification(
+                        actualClassField,
+                        predictedClassField,
+                        new org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric()));
+
+            EvaluateDataFrameResponse evaluateDataFrameResponse =
+                execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
+            assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME));
+            assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
+
+            org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.Result precisionResult =
+                evaluateDataFrameResponse.getMetricByName(
+                    org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME);
+            assertThat(
+                precisionResult.getMetricName(),
+                equalTo(org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME));
+            assertThat(
+                precisionResult.getClasses(),
+                equalTo(
+                    List.of(
+                        // 3 out of 5 examples labeled as "cat" were classified correctly
+                        new org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.PerClassResult("cat", 0.6),
+                        // 3 out of 4 examples labeled as "dog" were classified correctly
+                        new org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.PerClassResult("dog", 0.75))));
+            assertThat(precisionResult.getAvgPrecision(), equalTo(0.675));
+        }
+        {  // Recall
+            EvaluateDataFrameRequest evaluateDataFrameRequest =
+                new EvaluateDataFrameRequest(
+                    indexName,
+                    null,
+                    new Classification(
+                        actualClassField,
+                        predictedClassField,
+                        new org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric()));
+
+            EvaluateDataFrameResponse evaluateDataFrameResponse =
+                execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
+            assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME));
+            assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
+
+            org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.Result recallResult =
+                evaluateDataFrameResponse.getMetricByName(
+                    org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME);
+            assertThat(
+                recallResult.getMetricName(),
+                equalTo(org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME));
+            assertThat(
+                recallResult.getClasses(),
+                equalTo(
+                    List.of(
+                        // 3 out of 5 examples labeled as "cat" were classified correctly
+                        new org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.PerClassResult("cat", 0.6),
+                        // 3 out of 4 examples labeled as "dog" were classified correctly
+                        new org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.PerClassResult("dog", 0.75),
+                        // no examples labeled as "ant" were classified correctly
+                        new org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.PerClassResult("ant", 0.0))));
+            assertThat(recallResult.getAvgRecall(), equalTo(0.45));
+        }
         {  // No size provided for MulticlassConfusionMatrixMetric, default used instead
             EvaluateDataFrameRequest evaluateDataFrameRequest =
                 new EvaluateDataFrameRequest(

+ 30 - 19
client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java

@@ -128,6 +128,7 @@ import java.util.concurrent.atomic.AtomicReference;
 import java.util.stream.Collectors;
 import java.util.stream.Stream;
 
+import static org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
 import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
 import static org.hamcrest.CoreMatchers.endsWith;
 import static org.hamcrest.CoreMatchers.equalTo;
@@ -688,7 +689,7 @@ public class RestHighLevelClientTests extends ESTestCase {
 
     public void testProvidedNamedXContents() {
         List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
-        assertEquals(51, namedXContents.size());
+        assertEquals(55, namedXContents.size());
         Map<Class<?>, Integer> categories = new HashMap<>();
         List<String> names = new ArrayList<>();
         for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
@@ -730,26 +731,36 @@ public class RestHighLevelClientTests extends ESTestCase {
         assertTrue(names.contains(TimeSyncConfig.NAME));
         assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class));
         assertThat(names, hasItems(BinarySoftClassification.NAME, Classification.NAME, Regression.NAME));
-        assertEquals(Integer.valueOf(8), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class));
+        assertEquals(Integer.valueOf(10), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class));
         assertThat(names,
-            hasItems(AucRocMetric.NAME,
-                PrecisionMetric.NAME,
-                RecallMetric.NAME,
-                ConfusionMatrixMetric.NAME,
-                AccuracyMetric.NAME,
-                MulticlassConfusionMatrixMetric.NAME,
-                MeanSquaredErrorMetric.NAME,
-                RSquaredMetric.NAME));
-        assertEquals(Integer.valueOf(8), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class));
+            hasItems(
+                registeredMetricName(BinarySoftClassification.NAME, AucRocMetric.NAME),
+                registeredMetricName(BinarySoftClassification.NAME, PrecisionMetric.NAME),
+                registeredMetricName(BinarySoftClassification.NAME, RecallMetric.NAME),
+                registeredMetricName(BinarySoftClassification.NAME, ConfusionMatrixMetric.NAME),
+                registeredMetricName(Classification.NAME, AccuracyMetric.NAME),
+                registeredMetricName(
+                    Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME),
+                registeredMetricName(
+                    Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME),
+                registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME),
+                registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME),
+                registeredMetricName(Regression.NAME, RSquaredMetric.NAME)));
+        assertEquals(Integer.valueOf(10), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class));
         assertThat(names,
-            hasItems(AucRocMetric.NAME,
-                PrecisionMetric.NAME,
-                RecallMetric.NAME,
-                ConfusionMatrixMetric.NAME,
-                AccuracyMetric.NAME,
-                MulticlassConfusionMatrixMetric.NAME,
-                MeanSquaredErrorMetric.NAME,
-                RSquaredMetric.NAME));
+            hasItems(
+                registeredMetricName(BinarySoftClassification.NAME, AucRocMetric.NAME),
+                registeredMetricName(BinarySoftClassification.NAME, PrecisionMetric.NAME),
+                registeredMetricName(BinarySoftClassification.NAME, RecallMetric.NAME),
+                registeredMetricName(BinarySoftClassification.NAME, ConfusionMatrixMetric.NAME),
+                registeredMetricName(Classification.NAME, AccuracyMetric.NAME),
+                registeredMetricName(
+                    Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME),
+                registeredMetricName(
+                    Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME),
+                registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME),
+                registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME),
+                registeredMetricName(Regression.NAME, RSquaredMetric.NAME)));
         assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.inference.preprocessing.PreProcessor.class));
         assertThat(names, hasItems(FrequencyEncoding.NAME, OneHotEncoding.NAME, TargetMeanEncoding.NAME));
         assertEquals(Integer.valueOf(2), categories.get(org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel.class));

+ 24 - 4
client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java

@@ -3372,7 +3372,9 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
                     "predicted_class", // <3>
                     // Evaluation metrics // <4>
                     new AccuracyMetric(), // <5>
-                    new MulticlassConfusionMatrixMetric(3)); // <6>
+                    new org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric(), // <6>
+                    new org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric(), // <7>
+                    new MulticlassConfusionMatrixMetric(3)); // <8>
             // end::evaluate-data-frame-evaluation-classification
 
             EvaluateDataFrameRequest request = new EvaluateDataFrameRequest(indexName, null, evaluation);
@@ -3382,16 +3384,34 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
             AccuracyMetric.Result accuracyResult = response.getMetricByName(AccuracyMetric.NAME); // <1>
             double accuracy = accuracyResult.getOverallAccuracy(); // <2>
 
+            org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.Result precisionResult =
+                response.getMetricByName(org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME); // <3>
+            double precision = precisionResult.getAvgPrecision(); // <4>
+
+            org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.Result recallResult =
+                response.getMetricByName(org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME); // <5>
+            double recall = recallResult.getAvgRecall(); // <6>
+
             MulticlassConfusionMatrixMetric.Result multiclassConfusionMatrix =
-                response.getMetricByName(MulticlassConfusionMatrixMetric.NAME); // <3>
+                response.getMetricByName(MulticlassConfusionMatrixMetric.NAME); // <7>
 
-            List<ActualClass> confusionMatrix = multiclassConfusionMatrix.getConfusionMatrix(); // <4>
-            long otherClassesCount = multiclassConfusionMatrix.getOtherActualClassCount(); // <5>
+            List<ActualClass> confusionMatrix = multiclassConfusionMatrix.getConfusionMatrix(); // <8>
+            long otherClassesCount = multiclassConfusionMatrix.getOtherActualClassCount(); // <9>
             // end::evaluate-data-frame-results-classification
 
             assertThat(accuracyResult.getMetricName(), equalTo(AccuracyMetric.NAME));
             assertThat(accuracy, equalTo(0.6));
 
+            assertThat(
+                precisionResult.getMetricName(),
+                equalTo(org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME));
+            assertThat(precision, equalTo(0.675));
+
+            assertThat(
+                recallResult.getMetricName(),
+                equalTo(org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME));
+            assertThat(recall, equalTo(0.45));
+
             assertThat(multiclassConfusionMatrix.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME));
             assertThat(
                 confusionMatrix,

+ 2 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/EvaluateDataFrameResponseTests.java

@@ -64,6 +64,8 @@ public class EvaluateDataFrameResponseTests extends AbstractXContentTestCase<Eva
                 metrics = randomSubsetOf(
                     Arrays.asList(
                         AccuracyMetricResultTests.randomResult(),
+                        org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetricResultTests.randomResult(),
+                        org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetricResultTests.randomResult(),
                         MulticlassConfusionMatrixMetricResultTests.randomResult()));
                 break;
             default:

+ 2 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/ClassificationTests.java

@@ -41,6 +41,8 @@ public class ClassificationTests extends AbstractXContentTestCase<Classification
             randomSubsetOf(
                 Arrays.asList(
                     AccuracyMetricTests.createRandom(),
+                    PrecisionMetricTests.createRandom(),
+                    RecallMetricTests.createRandom(),
                     MulticlassConfusionMatrixMetricTests.createRandom()));
         return new Classification(randomAlphaOfLength(10), randomAlphaOfLength(10), metrics.isEmpty() ? null : metrics);
     }

+ 67 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/PrecisionMetricResultTests.java

@@ -0,0 +1,67 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.dataframe.evaluation.classification;
+
+import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
+import org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.PerClassResult;
+import org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.Result;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.test.AbstractXContentTestCase;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+public class PrecisionMetricResultTests extends AbstractXContentTestCase<Result> {
+
+    @Override
+    protected NamedXContentRegistry xContentRegistry() {
+        return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
+    }
+
+    public static Result randomResult() {
+        int numClasses = randomIntBetween(2, 100);
+        List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
+        List<PerClassResult> classes = new ArrayList<>(numClasses);
+        for (int i = 0; i < numClasses; i++) {
+            double precision = randomDoubleBetween(0.0, 1.0, true);
+            classes.add(new PerClassResult(classNames.get(i), precision));
+        }
+        double avgPrecision = randomDoubleBetween(0.0, 1.0, true);
+        return new Result(classes, avgPrecision);
+    }
+
+    @Override
+    protected Result createTestInstance() {
+        return randomResult();
+    }
+
+    @Override
+    protected Result doParseInstance(XContentParser parser) throws IOException {
+        return Result.fromXContent(parser);
+    }
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return true;
+    }
+}

+ 53 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/PrecisionMetricTests.java

@@ -0,0 +1,53 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.dataframe.evaluation.classification;
+
+import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.test.AbstractXContentTestCase;
+
+import java.io.IOException;
+
+public class PrecisionMetricTests extends AbstractXContentTestCase<PrecisionMetric> {
+
+    @Override
+    protected NamedXContentRegistry xContentRegistry() {
+        return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
+    }
+
+    static PrecisionMetric createRandom() {
+        return new PrecisionMetric();
+    }
+
+    @Override
+    protected PrecisionMetric createTestInstance() {
+        return createRandom();
+    }
+
+    @Override
+    protected PrecisionMetric doParseInstance(XContentParser parser) throws IOException {
+        return PrecisionMetric.fromXContent(parser);
+    }
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return true;
+    }
+}

+ 67 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/RecallMetricResultTests.java

@@ -0,0 +1,67 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.dataframe.evaluation.classification;
+
+import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
+import org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.PerClassResult;
+import org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.Result;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.test.AbstractXContentTestCase;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+public class RecallMetricResultTests extends AbstractXContentTestCase<Result> {
+
+    @Override
+    protected NamedXContentRegistry xContentRegistry() {
+        return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
+    }
+
+    public static Result randomResult() {
+        int numClasses = randomIntBetween(2, 100);
+        List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
+        List<PerClassResult> classes = new ArrayList<>(numClasses);
+        for (int i = 0; i < numClasses; i++) {
+            double recall = randomDoubleBetween(0.0, 1.0, true);
+            classes.add(new PerClassResult(classNames.get(i), recall));
+        }
+        double avgRecall = randomDoubleBetween(0.0, 1.0, true);
+        return new Result(classes, avgRecall);
+    }
+
+    @Override
+    protected Result createTestInstance() {
+        return randomResult();
+    }
+
+    @Override
+    protected Result doParseInstance(XContentParser parser) throws IOException {
+        return Result.fromXContent(parser);
+    }
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return true;
+    }
+}

+ 53 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/RecallMetricTests.java

@@ -0,0 +1,53 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.dataframe.evaluation.classification;
+
+import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.test.AbstractXContentTestCase;
+
+import java.io.IOException;
+
+public class RecallMetricTests extends AbstractXContentTestCase<RecallMetric> {
+
+    @Override
+    protected NamedXContentRegistry xContentRegistry() {
+        return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
+    }
+
+    static RecallMetric createRandom() {
+        return new RecallMetric();
+    }
+
+    @Override
+    protected RecallMetric createTestInstance() {
+        return createRandom();
+    }
+
+    @Override
+    protected RecallMetric doParseInstance(XContentParser parser) throws IOException {
+        return RecallMetric.fromXContent(parser);
+    }
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return true;
+    }
+}

+ 11 - 5
docs/java-rest/high-level/ml/evaluate-data-frame.asciidoc

@@ -53,7 +53,9 @@ include-tagged::{doc-tests-file}[{api}-evaluation-classification]
 <3> Name of the field in the index. Its value denotes the predicted (as per some ML algorithm) class of the example.
 <4> The remaining parameters are the metrics to be calculated based on the two fields described above
 <5> Accuracy
-<6> Multiclass confusion matrix of size 3
+<6> Precision
+<7> Recall
+<8> Multiclass confusion matrix of size 3
 
 ===== Regression
 
@@ -104,9 +106,13 @@ include-tagged::{doc-tests-file}[{api}-results-classification]
 
 <1> Fetching accuracy metric by name
 <2> Fetching the actual accuracy value
-<3> Fetching multiclass confusion matrix metric by name
-<4> Fetching the contents of the confusion matrix
-<5> Fetching the number of classes that were not included in the matrix
+<3> Fetching precision metric by name
+<4> Fetching the actual precision value
+<5> Fetching recall metric by name
+<6> Fetching the actual recall value
+<7> Fetching multiclass confusion matrix metric by name
+<8> Fetching the contents of the confusion matrix
+<9> Fetching the number of classes that were not included in the matrix
 
 ===== Regression
 
@@ -118,4 +124,4 @@ include-tagged::{doc-tests-file}[{api}-results-regression]
 <1> Fetching mean squared error metric by name
 <2> Fetching the actual mean squared error value
 <3> Fetching R squared metric by name
-<4> Fetching the actual R squared value
+<4> Fetching the actual R squared value

+ 10 - 24
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java

@@ -68,7 +68,6 @@ import org.elasticsearch.xpack.core.ml.MachineLearningFeatureSetUsage;
 import org.elasticsearch.xpack.core.ml.MlMetadata;
 import org.elasticsearch.xpack.core.ml.MlTasks;
 import org.elasticsearch.xpack.core.ml.action.CloseJobAction;
-import org.elasticsearch.xpack.core.ml.action.ExplainDataFrameAnalyticsAction;
 import org.elasticsearch.xpack.core.ml.action.DeleteCalendarAction;
 import org.elasticsearch.xpack.core.ml.action.DeleteCalendarEventAction;
 import org.elasticsearch.xpack.core.ml.action.DeleteDataFrameAnalyticsAction;
@@ -80,6 +79,7 @@ import org.elasticsearch.xpack.core.ml.action.DeleteJobAction;
 import org.elasticsearch.xpack.core.ml.action.DeleteModelSnapshotAction;
 import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction;
 import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
+import org.elasticsearch.xpack.core.ml.action.ExplainDataFrameAnalyticsAction;
 import org.elasticsearch.xpack.core.ml.action.FinalizeJobExecutionAction;
 import org.elasticsearch.xpack.core.ml.action.FindFileStructureAction;
 import org.elasticsearch.xpack.core.ml.action.FlushJobAction;
@@ -134,15 +134,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
 import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
 import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
 import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
-import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
-import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
-import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.AucRoc;
-import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
-import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ConfusionMatrix;
-import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.Precision;
-import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.Recall;
-import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ScoreByThresholdResult;
-import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
 import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding;
 import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
 import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
@@ -245,6 +237,9 @@ import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.List;
 import java.util.Optional;
+import java.util.stream.Stream;
+
+import static java.util.stream.Collectors.toList;
 
 // TODO: merge this into XPackPlugin
 public class XPackClientPlugin extends Plugin implements ActionPlugin, NetworkPlugin {
@@ -426,7 +421,8 @@ public class XPackClientPlugin extends Plugin implements ActionPlugin, NetworkPl
 
     @Override
     public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
-        return Arrays.asList(
+        return Stream.concat(
+            Arrays.asList(
                 // graph
                 new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.GRAPH, GraphFeatureSetUsage::new),
                 // logstash
@@ -454,18 +450,6 @@ public class XPackClientPlugin extends Plugin implements ActionPlugin, NetworkPl
                 new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, OutlierDetection.NAME.getPreferredName(), OutlierDetection::new),
                 new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, Regression.NAME.getPreferredName(), Regression::new),
                 new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, Classification.NAME.getPreferredName(), Classification::new),
-                // ML - Data frame evaluation
-                new NamedWriteableRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME.getPreferredName(),
-                        BinarySoftClassification::new),
-                new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, AucRoc.NAME.getPreferredName(), AucRoc::new),
-                new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, Precision.NAME.getPreferredName(), Precision::new),
-                new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, Recall.NAME.getPreferredName(), Recall::new),
-                new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, ConfusionMatrix.NAME.getPreferredName(),
-                        ConfusionMatrix::new),
-                new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, AucRoc.NAME.getPreferredName(), AucRoc.Result::new),
-                new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, ScoreByThresholdResult.NAME, ScoreByThresholdResult::new),
-                new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, ConfusionMatrix.NAME.getPreferredName(),
-                        ConfusionMatrix.Result::new),
                 // ML - Inference preprocessing
                 new NamedWriteableRegistry.Entry(PreProcessor.class, FrequencyEncoding.NAME.getPreferredName(), FrequencyEncoding::new),
                 new NamedWriteableRegistry.Entry(PreProcessor.class, OneHotEncoding.NAME.getPreferredName(), OneHotEncoding::new),
@@ -568,7 +552,9 @@ public class XPackClientPlugin extends Plugin implements ActionPlugin, NetworkPl
                 new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.SPATIAL, SpatialFeatureSetUsage::new),
                 // data science
                 new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.ANALYTICS, AnalyticsFeatureSetUsage::new)
-        );
+            ).stream(),
+            MlEvaluationNamedXContentProvider.getNamedWriteables().stream()
+        ).collect(toList());
     }
 
     @Override

+ 5 - 2
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/Evaluation.java

@@ -7,12 +7,14 @@ package org.elasticsearch.xpack.core.ml.dataframe.evaluation;
 
 import org.elasticsearch.action.search.SearchResponse;
 import org.elasticsearch.common.Nullable;
+import org.elasticsearch.common.collect.Tuple;
 import org.elasticsearch.common.io.stream.NamedWriteable;
 import org.elasticsearch.common.xcontent.ToXContentObject;
 import org.elasticsearch.index.query.BoolQueryBuilder;
 import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.index.query.QueryBuilders;
 import org.elasticsearch.search.aggregations.AggregationBuilder;
+import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
 import org.elasticsearch.search.builder.SearchSourceBuilder;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 
@@ -76,8 +78,9 @@ public interface Evaluation extends ToXContentObject, NamedWriteable {
         SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).query(boolQuery);
         for (EvaluationMetric metric : getMetrics()) {
             // Fetch aggregations requested by individual metrics
-            List<AggregationBuilder> aggs = metric.aggs(getActualField(), getPredictedField());
-            aggs.forEach(searchSourceBuilder::aggregation);
+            Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs = metric.aggs(getActualField(), getPredictedField());
+            aggs.v1().forEach(searchSourceBuilder::aggregation);
+            aggs.v2().forEach(searchSourceBuilder::aggregation);
         }
         return searchSourceBuilder;
     }

+ 3 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetric.java

@@ -6,10 +6,12 @@
 package org.elasticsearch.xpack.core.ml.dataframe.evaluation;
 
 import org.elasticsearch.action.search.SearchResponse;
+import org.elasticsearch.common.collect.Tuple;
 import org.elasticsearch.common.io.stream.NamedWriteable;
 import org.elasticsearch.common.xcontent.ToXContentObject;
 import org.elasticsearch.search.aggregations.AggregationBuilder;
 import org.elasticsearch.search.aggregations.Aggregations;
+import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
 
 import java.util.List;
 import java.util.Optional;
@@ -30,7 +32,7 @@ public interface EvaluationMetric extends ToXContentObject, NamedWriteable {
      * @param predictedField the field that stores the predicted value (class name or probability)
      * @return the aggregations required to compute the metric
      */
-    List<AggregationBuilder> aggs(String actualField, String predictedField);
+    Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField);
 
     /**
      * Processes given aggregations as a step towards computing result

+ 146 - 76
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java

@@ -5,109 +5,179 @@
  */
 package org.elasticsearch.xpack.core.ml.dataframe.evaluation;
 
+import org.elasticsearch.common.ParseField;
 import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.xcontent.NamedXContentRegistry;
 import org.elasticsearch.plugins.spi.NamedXContentProvider;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification;
-import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.ClassificationMetric;
-import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Regression;
-import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RegressionMetric;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.AucRoc;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ConfusionMatrix;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.Precision;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.Recall;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ScoreByThresholdResult;
-import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric;
 
-import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.List;
 
 public class MlEvaluationNamedXContentProvider implements NamedXContentProvider {
 
-    @Override
-    public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
-        List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
+    /**
+     * Constructs the name under which a metric (or metric result) is registered.
+     * The name is prefixed with evaluation name so that registered names are unique.
+     *
+     * @param evaluationName name of the evaluation
+     * @param metricName name of the metric
+     * @return name appropriate for registering a metric (or metric result) in {@link NamedXContentRegistry}
+     */
+    public static String registeredMetricName(ParseField evaluationName, ParseField metricName) {
+        return registeredMetricName(evaluationName.getPreferredName(), metricName.getPreferredName());
+    }
 
-        // Evaluations
-        namedXContent.add(new NamedXContentRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME,
-            BinarySoftClassification::fromXContent));
-        namedXContent.add(new NamedXContentRegistry.Entry(Evaluation.class, Classification.NAME, Classification::fromXContent));
-        namedXContent.add(new NamedXContentRegistry.Entry(Evaluation.class, Regression.NAME, Regression::fromXContent));
+    /**
+     * Constructs the name under which a metric (or metric result) is registered.
+     * The name is prefixed with evaluation name so that registered names are unique.
+     *
+     * @param evaluationName name of the evaluation
+     * @param metricName name of the metric
+     * @return name appropriate for registering a metric (or metric result) in {@link NamedXContentRegistry}
+     */
+    public static String registeredMetricName(String evaluationName, String metricName) {
+        return evaluationName + "." +  metricName;
+    }
 
-        // Soft classification metrics
-        namedXContent.add(new NamedXContentRegistry.Entry(SoftClassificationMetric.class, AucRoc.NAME, AucRoc::fromXContent));
-        namedXContent.add(new NamedXContentRegistry.Entry(SoftClassificationMetric.class, Precision.NAME, Precision::fromXContent));
-        namedXContent.add(new NamedXContentRegistry.Entry(SoftClassificationMetric.class, Recall.NAME, Recall::fromXContent));
-        namedXContent.add(new NamedXContentRegistry.Entry(SoftClassificationMetric.class, ConfusionMatrix.NAME,
-            ConfusionMatrix::fromXContent));
+    @Override
+    public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
+        return Arrays.asList(
+            // Evaluations
+            new NamedXContentRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME, BinarySoftClassification::fromXContent),
+            new NamedXContentRegistry.Entry(Evaluation.class, Classification.NAME, Classification::fromXContent),
+            new NamedXContentRegistry.Entry(Evaluation.class, Regression.NAME, Regression::fromXContent),
 
-        // Classification metrics
-        namedXContent.add(new NamedXContentRegistry.Entry(ClassificationMetric.class, MulticlassConfusionMatrix.NAME,
-            MulticlassConfusionMatrix::fromXContent));
-        namedXContent.add(new NamedXContentRegistry.Entry(ClassificationMetric.class, Accuracy.NAME, Accuracy::fromXContent));
+            // Soft classification metrics
+            new NamedXContentRegistry.Entry(EvaluationMetric.class,
+                new ParseField(registeredMetricName(BinarySoftClassification.NAME, AucRoc.NAME)),
+                AucRoc::fromXContent),
+            new NamedXContentRegistry.Entry(EvaluationMetric.class,
+                new ParseField(registeredMetricName(BinarySoftClassification.NAME, Precision.NAME)),
+                Precision::fromXContent),
+            new NamedXContentRegistry.Entry(EvaluationMetric.class,
+                new ParseField(registeredMetricName(BinarySoftClassification.NAME, Recall.NAME)),
+                Recall::fromXContent),
+            new NamedXContentRegistry.Entry(EvaluationMetric.class,
+                new ParseField(registeredMetricName(BinarySoftClassification.NAME, ConfusionMatrix.NAME)),
+                ConfusionMatrix::fromXContent),
 
-        // Regression metrics
-        namedXContent.add(new NamedXContentRegistry.Entry(RegressionMetric.class, MeanSquaredError.NAME, MeanSquaredError::fromXContent));
-        namedXContent.add(new NamedXContentRegistry.Entry(RegressionMetric.class, RSquared.NAME, RSquared::fromXContent));
+            // Classification metrics
+            new NamedXContentRegistry.Entry(EvaluationMetric.class,
+                new ParseField(registeredMetricName(Classification.NAME, MulticlassConfusionMatrix.NAME)),
+                MulticlassConfusionMatrix::fromXContent),
+            new NamedXContentRegistry.Entry(EvaluationMetric.class,
+                new ParseField(registeredMetricName(Classification.NAME, Accuracy.NAME)),
+                Accuracy::fromXContent),
+            new NamedXContentRegistry.Entry(EvaluationMetric.class,
+                new ParseField(
+                    registeredMetricName(
+                        Classification.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision.NAME)),
+                org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision::fromXContent),
+            new NamedXContentRegistry.Entry(EvaluationMetric.class,
+                new ParseField(
+                    registeredMetricName(
+                        Classification.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall.NAME)),
+                org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall::fromXContent),
 
-        return namedXContent;
+            // Regression metrics
+            new NamedXContentRegistry.Entry(EvaluationMetric.class,
+                new ParseField(registeredMetricName(Regression.NAME, MeanSquaredError.NAME)),
+                MeanSquaredError::fromXContent),
+            new NamedXContentRegistry.Entry(EvaluationMetric.class,
+                new ParseField(registeredMetricName(Regression.NAME, RSquared.NAME)),
+                RSquared::fromXContent)
+        );
     }
 
-    public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
-        List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
-
-        // Evaluations
-        namedWriteables.add(new NamedWriteableRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME.getPreferredName(),
-            BinarySoftClassification::new));
-        namedWriteables.add(new NamedWriteableRegistry.Entry(Evaluation.class, Classification.NAME.getPreferredName(),
-            Classification::new));
-        namedWriteables.add(new NamedWriteableRegistry.Entry(Evaluation.class, Regression.NAME.getPreferredName(), Regression::new));
-
-        // Evaluation Metrics
-        namedWriteables.add(new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, AucRoc.NAME.getPreferredName(),
-            AucRoc::new));
-        namedWriteables.add(new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, Precision.NAME.getPreferredName(),
-            Precision::new));
-        namedWriteables.add(new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, Recall.NAME.getPreferredName(),
-            Recall::new));
-        namedWriteables.add(new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, ConfusionMatrix.NAME.getPreferredName(),
-            ConfusionMatrix::new));
-        namedWriteables.add(new NamedWriteableRegistry.Entry(ClassificationMetric.class,
-            MulticlassConfusionMatrix.NAME.getPreferredName(),
-            MulticlassConfusionMatrix::new));
-        namedWriteables.add(new NamedWriteableRegistry.Entry(ClassificationMetric.class, Accuracy.NAME.getPreferredName(), Accuracy::new));
-        namedWriteables.add(new NamedWriteableRegistry.Entry(RegressionMetric.class,
-            MeanSquaredError.NAME.getPreferredName(),
-            MeanSquaredError::new));
-        namedWriteables.add(new NamedWriteableRegistry.Entry(RegressionMetric.class,
-            RSquared.NAME.getPreferredName(),
-            RSquared::new));
+    public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
+        return Arrays.asList(
+            // Evaluations
+            new NamedWriteableRegistry.Entry(Evaluation.class,
+                BinarySoftClassification.NAME.getPreferredName(),
+                BinarySoftClassification::new),
+            new NamedWriteableRegistry.Entry(Evaluation.class,
+                Classification.NAME.getPreferredName(),
+                Classification::new),
+            new NamedWriteableRegistry.Entry(Evaluation.class,
+                Regression.NAME.getPreferredName(),
+                Regression::new),
 
-        // Evaluation Metrics Results
-        namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, AucRoc.NAME.getPreferredName(),
-            AucRoc.Result::new));
-        namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, ScoreByThresholdResult.NAME,
-            ScoreByThresholdResult::new));
-        namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, ConfusionMatrix.NAME.getPreferredName(),
-            ConfusionMatrix.Result::new));
-        namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
-            MulticlassConfusionMatrix.NAME.getPreferredName(),
-            MulticlassConfusionMatrix.Result::new));
-        namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
-            Accuracy.NAME.getPreferredName(),
-            Accuracy.Result::new));
-        namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
-            MeanSquaredError.NAME.getPreferredName(),
-            MeanSquaredError.Result::new));
-        namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
-            RSquared.NAME.getPreferredName(),
-            RSquared.Result::new));
+            // Evaluation metrics
+            new NamedWriteableRegistry.Entry(EvaluationMetric.class,
+                registeredMetricName(BinarySoftClassification.NAME, AucRoc.NAME),
+                AucRoc::new),
+            new NamedWriteableRegistry.Entry(EvaluationMetric.class,
+                registeredMetricName(BinarySoftClassification.NAME, Precision.NAME),
+                Precision::new),
+            new NamedWriteableRegistry.Entry(EvaluationMetric.class,
+                registeredMetricName(BinarySoftClassification.NAME, Recall.NAME),
+                Recall::new),
+            new NamedWriteableRegistry.Entry(EvaluationMetric.class,
+                registeredMetricName(BinarySoftClassification.NAME, ConfusionMatrix.NAME),
+                ConfusionMatrix::new),
+            new NamedWriteableRegistry.Entry(EvaluationMetric.class,
+                registeredMetricName(Classification.NAME, MulticlassConfusionMatrix.NAME),
+                MulticlassConfusionMatrix::new),
+            new NamedWriteableRegistry.Entry(EvaluationMetric.class,
+                registeredMetricName(Classification.NAME, Accuracy.NAME),
+                Accuracy::new),
+            new NamedWriteableRegistry.Entry(EvaluationMetric.class,
+                registeredMetricName(
+                    Classification.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision.NAME),
+                org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision::new),
+            new NamedWriteableRegistry.Entry(EvaluationMetric.class,
+                registeredMetricName(
+                    Classification.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall.NAME),
+                org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall::new),
+            new NamedWriteableRegistry.Entry(EvaluationMetric.class,
+                registeredMetricName(Regression.NAME, MeanSquaredError.NAME),
+                MeanSquaredError::new),
+            new NamedWriteableRegistry.Entry(EvaluationMetric.class,
+                registeredMetricName(Regression.NAME, RSquared.NAME),
+                RSquared::new),
 
-        return namedWriteables;
+            // Evaluation metrics results
+            new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
+                registeredMetricName(BinarySoftClassification.NAME, AucRoc.NAME),
+                AucRoc.Result::new),
+            new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
+                registeredMetricName(BinarySoftClassification.NAME, ScoreByThresholdResult.NAME),
+                ScoreByThresholdResult::new),
+            new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
+                registeredMetricName(BinarySoftClassification.NAME, ConfusionMatrix.NAME),
+                ConfusionMatrix.Result::new),
+            new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
+                registeredMetricName(Classification.NAME, MulticlassConfusionMatrix.NAME),
+                MulticlassConfusionMatrix.Result::new),
+            new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
+                registeredMetricName(Classification.NAME, Accuracy.NAME),
+                Accuracy.Result::new),
+            new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
+                registeredMetricName(
+                    Classification.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision.NAME),
+                org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision.Result::new),
+            new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
+                registeredMetricName(
+                    Classification.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall.NAME),
+                org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall.Result::new),
+            new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
+                registeredMetricName(Regression.NAME, MeanSquaredError.NAME),
+                MeanSquaredError.Result::new),
+            new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
+                registeredMetricName(Regression.NAME, RSquared.NAME),
+                RSquared.Result::new)
+        );
     }
 }

+ 16 - 10
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Accuracy.java

@@ -6,6 +6,7 @@
 package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
 
 import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.collect.Tuple;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.io.stream.Writeable;
@@ -18,8 +19,10 @@ import org.elasticsearch.script.Script;
 import org.elasticsearch.search.aggregations.AggregationBuilder;
 import org.elasticsearch.search.aggregations.AggregationBuilders;
 import org.elasticsearch.search.aggregations.Aggregations;
+import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
 import org.elasticsearch.search.aggregations.bucket.terms.Terms;
 import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 
@@ -33,6 +36,7 @@ import java.util.Objects;
 import java.util.Optional;
 
 import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
+import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
 
 /**
  * {@link Accuracy} is a metric that answers the question:
@@ -40,7 +44,7 @@ import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constru
  *
  * equation: accuracy = 1/n * Σ(y == y´)
  */
-public class Accuracy implements ClassificationMetric {
+public class Accuracy implements EvaluationMetric {
 
     public static final ParseField NAME = new ParseField("accuracy");
 
@@ -67,7 +71,7 @@ public class Accuracy implements ClassificationMetric {
 
     @Override
     public String getWriteableName() {
-        return NAME.getPreferredName();
+        return registeredMetricName(Classification.NAME, NAME);
     }
 
     @Override
@@ -76,16 +80,18 @@ public class Accuracy implements ClassificationMetric {
     }
 
     @Override
-    public final List<AggregationBuilder> aggs(String actualField, String predictedField) {
+    public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
         if (result != null) {
-            return List.of();
+            return Tuple.tuple(List.of(), List.of());
         }
         Script accuracyScript = new Script(buildScript(actualField, predictedField));
-        return List.of(
-            AggregationBuilders.terms(CLASSES_AGG_NAME)
-                .field(actualField)
-                .subAggregation(AggregationBuilders.avg(PER_CLASS_ACCURACY_AGG_NAME).script(accuracyScript)),
-            AggregationBuilders.avg(OVERALL_ACCURACY_AGG_NAME).script(accuracyScript));
+        return Tuple.tuple(
+            List.of(
+                AggregationBuilders.terms(CLASSES_AGG_NAME)
+                    .field(actualField)
+                    .subAggregation(AggregationBuilders.avg(PER_CLASS_ACCURACY_AGG_NAME).script(accuracyScript)),
+                AggregationBuilders.avg(OVERALL_ACCURACY_AGG_NAME).script(accuracyScript)),
+            List.of());
     }
 
     @Override
@@ -168,7 +174,7 @@ public class Accuracy implements ClassificationMetric {
 
         @Override
         public String getWriteableName() {
-            return NAME.getPreferredName();
+            return registeredMetricName(Classification.NAME, NAME);
         }
 
         @Override

+ 12 - 9
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Classification.java

@@ -13,6 +13,7 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 
 import java.io.IOException;
@@ -20,6 +21,8 @@ import java.util.Arrays;
 import java.util.List;
 import java.util.Objects;
 
+import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
+
 /**
  * Evaluation of classification results.
  */
@@ -33,13 +36,13 @@ public class Classification implements Evaluation {
 
     @SuppressWarnings("unchecked")
     public static final ConstructingObjectParser<Classification, Void> PARSER = new ConstructingObjectParser<>(
-        NAME.getPreferredName(), a -> new Classification((String) a[0], (String) a[1], (List<ClassificationMetric>) a[2]));
+        NAME.getPreferredName(), a -> new Classification((String) a[0], (String) a[1], (List<EvaluationMetric>) a[2]));
 
     static {
         PARSER.declareString(ConstructingObjectParser.constructorArg(), ACTUAL_FIELD);
         PARSER.declareString(ConstructingObjectParser.constructorArg(), PREDICTED_FIELD);
         PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(),
-            (p, c, n) -> p.namedObject(ClassificationMetric.class, n, c), METRICS);
+            (p, c, n) -> p.namedObject(EvaluationMetric.class, registeredMetricName(NAME.getPreferredName(), n), c), METRICS);
     }
 
     public static Classification fromXContent(XContentParser parser) {
@@ -61,22 +64,22 @@ public class Classification implements Evaluation {
     /**
      * The list of metrics to calculate
      */
-    private final List<ClassificationMetric> metrics;
+    private final List<EvaluationMetric> metrics;
 
-    public Classification(String actualField, String predictedField, @Nullable List<ClassificationMetric> metrics) {
+    public Classification(String actualField, String predictedField, @Nullable List<EvaluationMetric> metrics) {
         this.actualField = ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD);
         this.predictedField = ExceptionsHelper.requireNonNull(predictedField, PREDICTED_FIELD);
         this.metrics = initMetrics(metrics, Classification::defaultMetrics);
     }
 
-    private static List<ClassificationMetric> defaultMetrics() {
+    private static List<EvaluationMetric> defaultMetrics() {
         return Arrays.asList(new MulticlassConfusionMatrix());
     }
 
     public Classification(StreamInput in) throws IOException {
         this.actualField = in.readString();
         this.predictedField = in.readString();
-        this.metrics = in.readNamedWriteableList(ClassificationMetric.class);
+        this.metrics = in.readNamedWriteableList(EvaluationMetric.class);
     }
 
     @Override
@@ -95,7 +98,7 @@ public class Classification implements Evaluation {
     }
 
     @Override
-    public List<ClassificationMetric> getMetrics() {
+    public List<EvaluationMetric> getMetrics() {
         return metrics;
     }
 
@@ -118,8 +121,8 @@ public class Classification implements Evaluation {
         builder.field(PREDICTED_FIELD.getPreferredName(), predictedField);
 
         builder.startObject(METRICS.getPreferredName());
-        for (ClassificationMetric metric : metrics) {
-            builder.field(metric.getWriteableName(), metric);
+        for (EvaluationMetric metric : metrics) {
+            builder.field(metric.getName(), metric);
         }
         builder.endObject();
 

+ 0 - 11
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationMetric.java

@@ -1,11 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License;
- * you may not use this file except in compliance with the Elastic License.
- */
-package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
-
-import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
-
-public interface ClassificationMetric extends EvaluationMetric {
-}

+ 28 - 20
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java

@@ -7,6 +7,7 @@ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
 
 import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.collect.Tuple;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.io.stream.Writeable;
@@ -19,10 +20,12 @@ import org.elasticsearch.search.aggregations.AggregationBuilder;
 import org.elasticsearch.search.aggregations.AggregationBuilders;
 import org.elasticsearch.search.aggregations.Aggregations;
 import org.elasticsearch.search.aggregations.BucketOrder;
+import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
 import org.elasticsearch.search.aggregations.bucket.filter.Filters;
 import org.elasticsearch.search.aggregations.bucket.filter.FiltersAggregator.KeyedFilter;
 import org.elasticsearch.search.aggregations.bucket.terms.Terms;
 import org.elasticsearch.search.aggregations.metrics.Cardinality;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 
@@ -37,13 +40,14 @@ import java.util.stream.Collectors;
 import static java.util.Comparator.comparing;
 import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
 import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
+import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
 
 /**
  * {@link MulticlassConfusionMatrix} is a metric that answers the question:
- *   "How many examples belonging to class X were classified as Y by the classifier?"
+ *   "How many documents belonging to class X were classified as Y by the classifier?"
  * for all the possible class pairs {X, Y}.
  */
-public class MulticlassConfusionMatrix implements ClassificationMetric {
+public class MulticlassConfusionMatrix implements EvaluationMetric {
 
     public static final ParseField NAME = new ParseField("multiclass_confusion_matrix");
 
@@ -91,7 +95,7 @@ public class MulticlassConfusionMatrix implements ClassificationMetric {
 
     @Override
     public String getWriteableName() {
-        return NAME.getPreferredName();
+        return registeredMetricName(Classification.NAME, NAME);
     }
 
     @Override
@@ -104,13 +108,15 @@ public class MulticlassConfusionMatrix implements ClassificationMetric {
     }
 
     @Override
-    public final List<AggregationBuilder> aggs(String actualField, String predictedField) {
+    public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
         if (topActualClassNames == null) {  // This is step 1
-            return List.of(
-                AggregationBuilders.terms(STEP_1_AGGREGATE_BY_ACTUAL_CLASS)
-                    .field(actualField)
-                    .order(List.of(BucketOrder.count(false), BucketOrder.key(true)))
-                    .size(size));
+            return Tuple.tuple(
+                List.of(
+                    AggregationBuilders.terms(STEP_1_AGGREGATE_BY_ACTUAL_CLASS)
+                        .field(actualField)
+                        .order(List.of(BucketOrder.count(false), BucketOrder.key(true)))
+                        .size(size)),
+                List.of());
         }
         if (result == null) {  // This is step 2
             KeyedFilter[] keyedFiltersActual =
@@ -121,15 +127,17 @@ public class MulticlassConfusionMatrix implements ClassificationMetric {
                 topActualClassNames.stream()
                     .map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className)))
                     .toArray(KeyedFilter[]::new);
-            return List.of(
-                AggregationBuilders.cardinality(STEP_2_CARDINALITY_OF_ACTUAL_CLASS)
-                    .field(actualField),
-                AggregationBuilders.filters(STEP_2_AGGREGATE_BY_ACTUAL_CLASS, keyedFiltersActual)
-                    .subAggregation(AggregationBuilders.filters(STEP_2_AGGREGATE_BY_PREDICTED_CLASS, keyedFiltersPredicted)
-                        .otherBucket(true)
-                        .otherBucketKey(OTHER_BUCKET_KEY)));
-        }
-        return List.of();
+            return Tuple.tuple(
+                List.of(
+                    AggregationBuilders.cardinality(STEP_2_CARDINALITY_OF_ACTUAL_CLASS)
+                        .field(actualField),
+                    AggregationBuilders.filters(STEP_2_AGGREGATE_BY_ACTUAL_CLASS, keyedFiltersActual)
+                        .subAggregation(AggregationBuilders.filters(STEP_2_AGGREGATE_BY_PREDICTED_CLASS, keyedFiltersPredicted)
+                            .otherBucket(true)
+                            .otherBucketKey(OTHER_BUCKET_KEY))),
+                List.of());
+        }
+        return Tuple.tuple(List.of(), List.of());
     }
 
     @Override
@@ -231,7 +239,7 @@ public class MulticlassConfusionMatrix implements ClassificationMetric {
 
         @Override
         public String getWriteableName() {
-            return NAME.getPreferredName();
+            return registeredMetricName(Classification.NAME, NAME);
         }
 
         @Override
@@ -300,7 +308,7 @@ public class MulticlassConfusionMatrix implements ClassificationMetric {
 
         /** Name of the actual class. */
         private final String actualClass;
-        /** Number of documents (examples) belonging to the {code actualClass} class. */
+        /** Number of documents belonging to the {code actualClass} class. */
         private final long actualClassDocCount;
         /** List of predicted classes. */
         private final List<PredictedClass> predictedClasses;

+ 345 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Precision.java

@@ -0,0 +1,345 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
+
+import org.elasticsearch.common.Nullable;
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.collect.Tuple;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ObjectParser;
+import org.elasticsearch.common.xcontent.ToXContentObject;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.index.query.QueryBuilders;
+import org.elasticsearch.script.Script;
+import org.elasticsearch.search.aggregations.AggregationBuilder;
+import org.elasticsearch.search.aggregations.AggregationBuilders;
+import org.elasticsearch.search.aggregations.Aggregations;
+import org.elasticsearch.search.aggregations.BucketOrder;
+import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
+import org.elasticsearch.search.aggregations.PipelineAggregatorBuilders;
+import org.elasticsearch.search.aggregations.bucket.filter.Filters;
+import org.elasticsearch.search.aggregations.bucket.filter.FiltersAggregator.KeyedFilter;
+import org.elasticsearch.search.aggregations.bucket.terms.Terms;
+import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+
+import java.io.IOException;
+import java.text.MessageFormat;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Locale;
+import java.util.Objects;
+import java.util.Optional;
+import java.util.stream.Collectors;
+
+import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
+import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
+
+/**
+ * {@link Precision} is a metric that answers the question:
+ *   "What fraction of documents classified as X actually belongs to X?"
+ * for any given class X
+ *
+ * equation: precision(X) = TP(X) / (TP(X) + FP(X))
+ * where: TP(X) - number of true positives wrt X
+ *        FP(X) - number of false positives wrt X
+ */
+public class Precision implements EvaluationMetric {
+
+    public static final ParseField NAME = new ParseField("precision");
+
+    private static final String PAINLESS_TEMPLATE = "doc[''{0}''].value == doc[''{1}''].value";
+    private static final String AGG_NAME_PREFIX = "classification_precision_";
+    static final String ACTUAL_CLASSES_NAMES_AGG_NAME = AGG_NAME_PREFIX + "by_actual_class";
+    static final String BY_PREDICTED_CLASS_AGG_NAME = AGG_NAME_PREFIX + "by_predicted_class";
+    static final String PER_PREDICTED_CLASS_PRECISION_AGG_NAME = AGG_NAME_PREFIX + "per_predicted_class_precision";
+    static final String AVG_PRECISION_AGG_NAME = AGG_NAME_PREFIX + "avg_precision";
+
+    private static Script buildScript(Object...args) {
+        return new Script(new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args));
+    }
+
+    private static final ObjectParser<Precision, Void> PARSER = new ObjectParser<>(NAME.getPreferredName(), true, Precision::new);
+
+    public static Precision fromXContent(XContentParser parser) {
+        return PARSER.apply(parser, null);
+    }
+
+    private static final int DEFAULT_MAX_CLASSES_CARDINALITY = 1000;
+
+    private final int maxClassesCardinality;
+    private String actualField;
+    private List<String> topActualClassNames;
+    private EvaluationMetricResult result;
+
+    public Precision() {
+        this((Integer) null);
+    }
+
+    // Visible for testing
+    public Precision(@Nullable Integer maxClassesCardinality) {
+        this.maxClassesCardinality = maxClassesCardinality != null ? maxClassesCardinality : DEFAULT_MAX_CLASSES_CARDINALITY;
+    }
+
+    public Precision(StreamInput in) throws IOException {
+        this.maxClassesCardinality = DEFAULT_MAX_CLASSES_CARDINALITY;
+    }
+
+    @Override
+    public String getWriteableName() {
+        return registeredMetricName(Classification.NAME, NAME);
+    }
+
+    @Override
+    public String getName() {
+        return NAME.getPreferredName();
+    }
+
+    @Override
+    public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
+        // Store given {@code actualField} for the purpose of generating error message in {@code process}.
+        this.actualField = actualField;
+        if (topActualClassNames == null) {  // This is step 1
+            return Tuple.tuple(
+                List.of(
+                    AggregationBuilders.terms(ACTUAL_CLASSES_NAMES_AGG_NAME)
+                        .field(actualField)
+                        .order(List.of(BucketOrder.count(false), BucketOrder.key(true)))
+                        .size(maxClassesCardinality)),
+                List.of());
+        }
+        if (result == null) {  // This is step 2
+            KeyedFilter[] keyedFiltersPredicted =
+                topActualClassNames.stream()
+                    .map(className -> new KeyedFilter(className, QueryBuilders.termQuery(predictedField, className)))
+                    .toArray(KeyedFilter[]::new);
+            Script script = buildScript(actualField, predictedField);
+            return Tuple.tuple(
+                List.of(
+                    AggregationBuilders.filters(BY_PREDICTED_CLASS_AGG_NAME, keyedFiltersPredicted)
+                        .subAggregation(AggregationBuilders.avg(PER_PREDICTED_CLASS_PRECISION_AGG_NAME).script(script))),
+                List.of(
+                    PipelineAggregatorBuilders.avgBucket(
+                        AVG_PRECISION_AGG_NAME,
+                        BY_PREDICTED_CLASS_AGG_NAME + ">" + PER_PREDICTED_CLASS_PRECISION_AGG_NAME)));
+        }
+        return Tuple.tuple(List.of(), List.of());
+    }
+
+    @Override
+    public void process(Aggregations aggs) {
+        if (topActualClassNames == null && aggs.get(ACTUAL_CLASSES_NAMES_AGG_NAME) instanceof Terms) {
+            Terms topActualClassesAgg = aggs.get(ACTUAL_CLASSES_NAMES_AGG_NAME);
+            if (topActualClassesAgg.getSumOfOtherDocCounts() > 0) {
+                // This means there were more than {@code maxClassesCardinality} buckets.
+                // We cannot calculate average precision accurately, so we fail.
+                throw ExceptionsHelper.badRequestException(
+                    "Cannot calculate average precision. Cardinality of field [{}] is too high", actualField);
+            }
+            topActualClassNames =
+                topActualClassesAgg.getBuckets().stream().map(Terms.Bucket::getKeyAsString).sorted().collect(Collectors.toList());
+        }
+        if (result == null &&
+                aggs.get(BY_PREDICTED_CLASS_AGG_NAME) instanceof Filters &&
+                aggs.get(AVG_PRECISION_AGG_NAME) instanceof NumericMetricsAggregation.SingleValue) {
+            Filters byPredictedClassAgg = aggs.get(BY_PREDICTED_CLASS_AGG_NAME);
+            NumericMetricsAggregation.SingleValue avgPrecisionAgg = aggs.get(AVG_PRECISION_AGG_NAME);
+            List<PerClassResult> classes = new ArrayList<>(byPredictedClassAgg.getBuckets().size());
+            for (Filters.Bucket bucket : byPredictedClassAgg.getBuckets()) {
+                String className = bucket.getKeyAsString();
+                NumericMetricsAggregation.SingleValue precisionAgg = bucket.getAggregations().get(PER_PREDICTED_CLASS_PRECISION_AGG_NAME);
+                double precision = precisionAgg.value();
+                if (Double.isFinite(precision)) {
+                    classes.add(new PerClassResult(className, precision));
+                }
+            }
+            result = new Result(classes, avgPrecisionAgg.value());
+        }
+    }
+
+    @Override
+    public Optional<EvaluationMetricResult> getResult() {
+        return Optional.ofNullable(result);
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        return true;
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hashCode(NAME.getPreferredName());
+    }
+
+    public static class Result implements EvaluationMetricResult {
+
+        private static final ParseField CLASSES = new ParseField("classes");
+        private static final ParseField AVG_PRECISION = new ParseField("avg_precision");
+
+        @SuppressWarnings("unchecked")
+        private static final ConstructingObjectParser<Result, Void> PARSER =
+            new ConstructingObjectParser<>("precision_result", true, a -> new Result((List<PerClassResult>) a[0], (double) a[1]));
+
+        static {
+            PARSER.declareObjectArray(constructorArg(), PerClassResult.PARSER, CLASSES);
+            PARSER.declareDouble(constructorArg(), AVG_PRECISION);
+        }
+
+        public static Result fromXContent(XContentParser parser) {
+            return PARSER.apply(parser, null);
+        }
+
+        /** List of per-class results. */
+        private final List<PerClassResult> classes;
+        /** Average of per-class precisions. */
+        private final double avgPrecision;
+
+        public Result(List<PerClassResult> classes, double avgPrecision) {
+            this.classes = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(classes, CLASSES));
+            this.avgPrecision = avgPrecision;
+        }
+
+        public Result(StreamInput in) throws IOException {
+            this.classes = Collections.unmodifiableList(in.readList(PerClassResult::new));
+            this.avgPrecision = in.readDouble();
+        }
+
+        @Override
+        public String getWriteableName() {
+            return registeredMetricName(Classification.NAME, NAME);
+        }
+
+        @Override
+        public String getMetricName() {
+            return NAME.getPreferredName();
+        }
+
+        public List<PerClassResult> getClasses() {
+            return classes;
+        }
+
+        public double getAvgPrecision() {
+            return avgPrecision;
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            out.writeList(classes);
+            out.writeDouble(avgPrecision);
+        }
+
+        @Override
+        public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+            builder.startObject();
+            builder.field(CLASSES.getPreferredName(), classes);
+            builder.field(AVG_PRECISION.getPreferredName(), avgPrecision);
+            builder.endObject();
+            return builder;
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            Result that = (Result) o;
+            return Objects.equals(this.classes, that.classes)
+                && this.avgPrecision == that.avgPrecision;
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(classes, avgPrecision);
+        }
+    }
+
+    public static class PerClassResult implements ToXContentObject, Writeable {
+
+        private static final ParseField CLASS_NAME = new ParseField("class_name");
+        private static final ParseField PRECISION = new ParseField("precision");
+
+        @SuppressWarnings("unchecked")
+        private static final ConstructingObjectParser<PerClassResult, Void> PARSER =
+            new ConstructingObjectParser<>("precision_per_class_result", true, a -> new PerClassResult((String) a[0], (double) a[1]));
+
+        static {
+            PARSER.declareString(constructorArg(), CLASS_NAME);
+            PARSER.declareDouble(constructorArg(), PRECISION);
+        }
+
+        /** Name of the class. */
+        private final String className;
+        /** Fraction of documents predicted as belonging to the {@code predictedClass} class predicted correctly. */
+        private final double precision;
+
+        public PerClassResult(String className, double precision) {
+            this.className = ExceptionsHelper.requireNonNull(className, CLASS_NAME);
+            this.precision = precision;
+        }
+
+        public PerClassResult(StreamInput in) throws IOException {
+            this.className = in.readString();
+            this.precision = in.readDouble();
+        }
+
+        public String getClassName() {
+            return className;
+        }
+
+        public double getPrecision() {
+            return precision;
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            out.writeString(className);
+            out.writeDouble(precision);
+        }
+
+        @Override
+        public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+            builder.startObject();
+            builder.field(CLASS_NAME.getPreferredName(), className);
+            builder.field(PRECISION.getPreferredName(), precision);
+            builder.endObject();
+            return builder;
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            PerClassResult that = (PerClassResult) o;
+            return Objects.equals(this.className, that.className)
+                && this.precision == that.precision;
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(className, precision);
+        }
+    }
+}

+ 319 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/Recall.java

@@ -0,0 +1,319 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
+
+import org.elasticsearch.common.Nullable;
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.collect.Tuple;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ObjectParser;
+import org.elasticsearch.common.xcontent.ToXContentObject;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.script.Script;
+import org.elasticsearch.search.aggregations.AggregationBuilder;
+import org.elasticsearch.search.aggregations.AggregationBuilders;
+import org.elasticsearch.search.aggregations.Aggregations;
+import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
+import org.elasticsearch.search.aggregations.PipelineAggregatorBuilders;
+import org.elasticsearch.search.aggregations.bucket.terms.Terms;
+import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+
+import java.io.IOException;
+import java.text.MessageFormat;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Locale;
+import java.util.Objects;
+import java.util.Optional;
+
+import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
+import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
+
+/**
+ * {@link Recall} is a metric that answers the question:
+ *   "What fraction of documents belonging to X have been predicted as X by the classifier?"
+ * for any given class X
+ *
+ * equation: recall(X) = TP(X) / (TP(X) + FN(X))
+ * where: TP(X) - number of true positives wrt X
+ *        FN(X) - number of false negatives wrt X
+ */
+public class Recall implements EvaluationMetric {
+
+    public static final ParseField NAME = new ParseField("recall");
+
+    private static final String PAINLESS_TEMPLATE = "doc[''{0}''].value == doc[''{1}''].value";
+    private static final String AGG_NAME_PREFIX = "classification_recall_";
+    static final String BY_ACTUAL_CLASS_AGG_NAME = AGG_NAME_PREFIX + "by_actual_class";
+    static final String PER_ACTUAL_CLASS_RECALL_AGG_NAME = AGG_NAME_PREFIX + "per_actual_class_recall";
+    static final String AVG_RECALL_AGG_NAME = AGG_NAME_PREFIX + "avg_recall";
+
+    private static Script buildScript(Object...args) {
+        return new Script(new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args));
+    }
+
+    private static final ObjectParser<Recall, Void> PARSER = new ObjectParser<>(NAME.getPreferredName(), true, Recall::new);
+
+    public static Recall fromXContent(XContentParser parser) {
+        return PARSER.apply(parser, null);
+    }
+
+    private static final int DEFAULT_MAX_CLASSES_CARDINALITY = 1000;
+
+    private final int maxClassesCardinality;
+    private String actualField;
+    private EvaluationMetricResult result;
+
+    public Recall() {
+        this((Integer) null);
+    }
+
+    // Visible for testing
+    public Recall(@Nullable Integer maxClassesCardinality) {
+        this.maxClassesCardinality = maxClassesCardinality != null ? maxClassesCardinality : DEFAULT_MAX_CLASSES_CARDINALITY;
+    }
+
+    public Recall(StreamInput in) throws IOException {
+        this.maxClassesCardinality = DEFAULT_MAX_CLASSES_CARDINALITY;
+    }
+
+    @Override
+    public String getWriteableName() {
+        return registeredMetricName(Classification.NAME, NAME);
+    }
+
+    @Override
+    public String getName() {
+        return NAME.getPreferredName();
+    }
+
+    @Override
+    public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
+        // Store given {@code actualField} for the purpose of generating error message in {@code process}.
+        this.actualField = actualField;
+        if (result != null) {
+            return Tuple.tuple(List.of(), List.of());
+        }
+        Script script = buildScript(actualField, predictedField);
+        return Tuple.tuple(
+            List.of(
+                AggregationBuilders.terms(BY_ACTUAL_CLASS_AGG_NAME)
+                    .field(actualField)
+                    .size(maxClassesCardinality)
+                    .subAggregation(AggregationBuilders.avg(PER_ACTUAL_CLASS_RECALL_AGG_NAME).script(script))),
+            List.of(
+                PipelineAggregatorBuilders.avgBucket(
+                    AVG_RECALL_AGG_NAME,
+                    BY_ACTUAL_CLASS_AGG_NAME + ">" + PER_ACTUAL_CLASS_RECALL_AGG_NAME)));
+    }
+
+    @Override
+    public void process(Aggregations aggs) {
+        if (result == null &&
+                aggs.get(BY_ACTUAL_CLASS_AGG_NAME) instanceof Terms &&
+                aggs.get(AVG_RECALL_AGG_NAME) instanceof NumericMetricsAggregation.SingleValue) {
+            Terms byActualClassAgg = aggs.get(BY_ACTUAL_CLASS_AGG_NAME);
+            if (byActualClassAgg.getSumOfOtherDocCounts() > 0) {
+                // This means there were more than {@code maxClassesCardinality} buckets.
+                // We cannot calculate average recall accurately, so we fail.
+                throw ExceptionsHelper.badRequestException(
+                    "Cannot calculate average recall. Cardinality of field [{}] is too high", actualField);
+            }
+            NumericMetricsAggregation.SingleValue avgRecallAgg = aggs.get(AVG_RECALL_AGG_NAME);
+            List<PerClassResult> classes = new ArrayList<>(byActualClassAgg.getBuckets().size());
+            for (Terms.Bucket bucket : byActualClassAgg.getBuckets()) {
+                String className = bucket.getKeyAsString();
+                NumericMetricsAggregation.SingleValue recallAgg = bucket.getAggregations().get(PER_ACTUAL_CLASS_RECALL_AGG_NAME);
+                classes.add(new PerClassResult(className, recallAgg.value()));
+            }
+            result = new Result(classes, avgRecallAgg.value());
+        }
+    }
+
+    @Override
+    public Optional<EvaluationMetricResult> getResult() {
+        return Optional.ofNullable(result);
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        return true;
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hashCode(NAME.getPreferredName());
+    }
+
+    public static class Result implements EvaluationMetricResult {
+
+        private static final ParseField CLASSES = new ParseField("classes");
+        private static final ParseField AVG_RECALL = new ParseField("avg_recall");
+
+        @SuppressWarnings("unchecked")
+        private static final ConstructingObjectParser<Result, Void> PARSER =
+            new ConstructingObjectParser<>("recall_result", true, a -> new Result((List<PerClassResult>) a[0], (double) a[1]));
+
+        static {
+            PARSER.declareObjectArray(constructorArg(), PerClassResult.PARSER, CLASSES);
+            PARSER.declareDouble(constructorArg(), AVG_RECALL);
+        }
+
+        public static Result fromXContent(XContentParser parser) {
+            return PARSER.apply(parser, null);
+        }
+
+        /** List of per-class results. */
+        private final List<PerClassResult> classes;
+        /** Average of per-class recalls. */
+        private final double avgRecall;
+
+        public Result(List<PerClassResult> classes, double avgRecall) {
+            this.classes = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(classes, CLASSES));
+            this.avgRecall = avgRecall;
+        }
+
+        public Result(StreamInput in) throws IOException {
+            this.classes = Collections.unmodifiableList(in.readList(PerClassResult::new));
+            this.avgRecall = in.readDouble();
+        }
+
+        @Override
+        public String getWriteableName() {
+            return registeredMetricName(Classification.NAME, NAME);
+        }
+
+        @Override
+        public String getMetricName() {
+            return NAME.getPreferredName();
+        }
+
+        public List<PerClassResult> getClasses() {
+            return classes;
+        }
+
+        public double getAvgRecall() {
+            return avgRecall;
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            out.writeList(classes);
+            out.writeDouble(avgRecall);
+        }
+
+        @Override
+        public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+            builder.startObject();
+            builder.field(CLASSES.getPreferredName(), classes);
+            builder.field(AVG_RECALL.getPreferredName(), avgRecall);
+            builder.endObject();
+            return builder;
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            Result that = (Result) o;
+            return Objects.equals(this.classes, that.classes)
+                && this.avgRecall == that.avgRecall;
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(classes, avgRecall);
+        }
+    }
+
+    public static class PerClassResult implements ToXContentObject, Writeable {
+
+        private static final ParseField CLASS_NAME = new ParseField("class_name");
+        private static final ParseField RECALL = new ParseField("recall");
+
+        @SuppressWarnings("unchecked")
+        private static final ConstructingObjectParser<PerClassResult, Void> PARSER =
+            new ConstructingObjectParser<>("recall_per_class_result", true, a -> new PerClassResult((String) a[0], (double) a[1]));
+
+        static {
+            PARSER.declareString(constructorArg(), CLASS_NAME);
+            PARSER.declareDouble(constructorArg(), RECALL);
+        }
+
+        /** Name of the class. */
+        private final String className;
+        /** Fraction of documents actually belonging to the {@code actualClass} class predicted correctly. */
+        private final double recall;
+
+        public PerClassResult(String className, double recall) {
+            this.className = ExceptionsHelper.requireNonNull(className, CLASS_NAME);
+            this.recall = recall;
+        }
+
+        public PerClassResult(StreamInput in) throws IOException {
+            this.className = in.readString();
+            this.recall = in.readDouble();
+        }
+
+        public String getClassName() {
+            return className;
+        }
+
+        public double getRecall() {
+            return recall;
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            out.writeString(className);
+            out.writeDouble(recall);
+        }
+
+        @Override
+        public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+            builder.startObject();
+            builder.field(CLASS_NAME.getPreferredName(), className);
+            builder.field(RECALL.getPreferredName(), recall);
+            builder.endObject();
+            return builder;
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            PerClassResult that = (PerClassResult) o;
+            return Objects.equals(this.className, that.className)
+                && this.recall == that.recall;
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(className, recall);
+        }
+    }
+}

+ 13 - 6
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredError.java

@@ -6,6 +6,7 @@
 package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression;
 
 import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.collect.Tuple;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.xcontent.ObjectParser;
@@ -15,7 +16,9 @@ import org.elasticsearch.script.Script;
 import org.elasticsearch.search.aggregations.AggregationBuilder;
 import org.elasticsearch.search.aggregations.AggregationBuilders;
 import org.elasticsearch.search.aggregations.Aggregations;
+import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
 import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
 
 import java.io.IOException;
@@ -25,12 +28,14 @@ import java.util.Locale;
 import java.util.Objects;
 import java.util.Optional;
 
+import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
+
 /**
  * Calculates the mean squared error between two known numerical fields.
  *
  * equation: mse = 1/n * Σ(y - y´)^2
  */
-public class MeanSquaredError implements RegressionMetric {
+public class MeanSquaredError implements EvaluationMetric {
 
     public static final ParseField NAME = new ParseField("mean_squared_error");
 
@@ -60,11 +65,13 @@ public class MeanSquaredError implements RegressionMetric {
     }
 
     @Override
-    public List<AggregationBuilder> aggs(String actualField, String predictedField) {
+    public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
         if (result != null) {
-            return List.of();
+            return Tuple.tuple(List.of(), List.of());
         }
-        return List.of(AggregationBuilders.avg(AGG_NAME).script(new Script(buildScript(actualField, predictedField))));
+        return Tuple.tuple(
+            List.of(AggregationBuilders.avg(AGG_NAME).script(new Script(buildScript(actualField, predictedField)))),
+            List.of());
     }
 
     @Override
@@ -80,7 +87,7 @@ public class MeanSquaredError implements RegressionMetric {
 
     @Override
     public String getWriteableName() {
-        return NAME.getPreferredName();
+        return registeredMetricName(Regression.NAME, NAME);
     }
 
     @Override
@@ -123,7 +130,7 @@ public class MeanSquaredError implements RegressionMetric {
 
         @Override
         public String getWriteableName() {
-            return NAME.getPreferredName();
+            return registeredMetricName(Regression.NAME, NAME);
         }
 
         @Override

+ 15 - 8
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquared.java

@@ -6,6 +6,7 @@
 package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression;
 
 import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.collect.Tuple;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.xcontent.ObjectParser;
@@ -15,9 +16,11 @@ import org.elasticsearch.script.Script;
 import org.elasticsearch.search.aggregations.AggregationBuilder;
 import org.elasticsearch.search.aggregations.AggregationBuilders;
 import org.elasticsearch.search.aggregations.Aggregations;
+import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
 import org.elasticsearch.search.aggregations.metrics.ExtendedStats;
 import org.elasticsearch.search.aggregations.metrics.ExtendedStatsAggregationBuilder;
 import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
 
 import java.io.IOException;
@@ -27,6 +30,8 @@ import java.util.Locale;
 import java.util.Objects;
 import java.util.Optional;
 
+import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
+
 /**
  * Calculates R-Squared between two known numerical fields.
  *
@@ -35,7 +40,7 @@ import java.util.Optional;
  * SSres = Σ(y - y´)^2, The residual sum of squares
  * SStot =  Σ(y - y_mean)^2, The total sum of squares
  */
-public class RSquared implements RegressionMetric {
+public class RSquared implements EvaluationMetric {
 
     public static final ParseField NAME = new ParseField("r_squared");
 
@@ -65,13 +70,15 @@ public class RSquared implements RegressionMetric {
     }
 
     @Override
-    public List<AggregationBuilder> aggs(String actualField, String predictedField) {
+    public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
         if (result != null) {
-            return List.of();
+            return Tuple.tuple(List.of(), List.of());
         }
-        return List.of(
-            AggregationBuilders.sum(SS_RES).script(new Script(buildScript(actualField, predictedField))),
-            AggregationBuilders.extendedStats(ExtendedStatsAggregationBuilder.NAME + "_actual").field(actualField));
+        return Tuple.tuple(
+            List.of(
+                AggregationBuilders.sum(SS_RES).script(new Script(buildScript(actualField, predictedField))),
+                AggregationBuilders.extendedStats(ExtendedStatsAggregationBuilder.NAME + "_actual").field(actualField)),
+            List.of());
     }
 
     @Override
@@ -95,7 +102,7 @@ public class RSquared implements RegressionMetric {
 
     @Override
     public String getWriteableName() {
-        return NAME.getPreferredName();
+        return registeredMetricName(Regression.NAME, NAME);
     }
 
     @Override
@@ -138,7 +145,7 @@ public class RSquared implements RegressionMetric {
 
         @Override
         public String getWriteableName() {
-            return NAME.getPreferredName();
+            return registeredMetricName(Regression.NAME, NAME);
         }
 
         @Override

+ 12 - 9
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java

@@ -13,6 +13,7 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 
 import java.io.IOException;
@@ -20,6 +21,8 @@ import java.util.Arrays;
 import java.util.List;
 import java.util.Objects;
 
+import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
+
 /**
  * Evaluation of regression results.
  */
@@ -33,13 +36,13 @@ public class Regression implements Evaluation {
 
     @SuppressWarnings("unchecked")
     public static final ConstructingObjectParser<Regression, Void> PARSER = new ConstructingObjectParser<>(
-        NAME.getPreferredName(), a -> new Regression((String) a[0], (String) a[1], (List<RegressionMetric>) a[2]));
+        NAME.getPreferredName(), a -> new Regression((String) a[0], (String) a[1], (List<EvaluationMetric>) a[2]));
 
     static {
         PARSER.declareString(ConstructingObjectParser.constructorArg(), ACTUAL_FIELD);
         PARSER.declareString(ConstructingObjectParser.constructorArg(), PREDICTED_FIELD);
         PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(),
-            (p, c, n) -> p.namedObject(RegressionMetric.class, n, c), METRICS);
+            (p, c, n) -> p.namedObject(EvaluationMetric.class, registeredMetricName(NAME.getPreferredName(), n), c), METRICS);
     }
 
     public static Regression fromXContent(XContentParser parser) {
@@ -61,22 +64,22 @@ public class Regression implements Evaluation {
     /**
      * The list of metrics to calculate
      */
-    private final List<RegressionMetric> metrics;
+    private final List<EvaluationMetric> metrics;
 
-    public Regression(String actualField, String predictedField, @Nullable List<RegressionMetric> metrics) {
+    public Regression(String actualField, String predictedField, @Nullable List<EvaluationMetric> metrics) {
         this.actualField = ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD);
         this.predictedField = ExceptionsHelper.requireNonNull(predictedField, PREDICTED_FIELD);
         this.metrics = initMetrics(metrics, Regression::defaultMetrics);
     }
 
-    private static List<RegressionMetric> defaultMetrics() {
+    private static List<EvaluationMetric> defaultMetrics() {
         return Arrays.asList(new MeanSquaredError(), new RSquared());
     }
 
     public Regression(StreamInput in) throws IOException {
         this.actualField = in.readString();
         this.predictedField = in.readString();
-        this.metrics = in.readNamedWriteableList(RegressionMetric.class);
+        this.metrics = in.readNamedWriteableList(EvaluationMetric.class);
     }
 
     @Override
@@ -95,7 +98,7 @@ public class Regression implements Evaluation {
     }
 
     @Override
-    public List<RegressionMetric> getMetrics() {
+    public List<EvaluationMetric> getMetrics() {
         return metrics;
     }
 
@@ -118,8 +121,8 @@ public class Regression implements Evaluation {
         builder.field(PREDICTED_FIELD.getPreferredName(), predictedField);
 
         builder.startObject(METRICS.getPreferredName());
-        for (RegressionMetric metric : metrics) {
-            builder.field(metric.getWriteableName(), metric);
+        for (EvaluationMetric metric : metrics) {
+            builder.field(metric.getName(), metric);
         }
         builder.endObject();
 

+ 0 - 11
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionMetric.java

@@ -1,11 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License;
- * you may not use this file except in compliance with the Elastic License.
- */
-package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression;
-
-import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
-
-public interface RegressionMetric extends EvaluationMetric {
-}

+ 8 - 5
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AbstractConfusionMatrixMetric.java

@@ -6,6 +6,7 @@
 package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification;
 
 import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.collect.Tuple;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.xcontent.XContentBuilder;
@@ -15,6 +16,8 @@ import org.elasticsearch.index.query.QueryBuilders;
 import org.elasticsearch.search.aggregations.AggregationBuilder;
 import org.elasticsearch.search.aggregations.AggregationBuilders;
 import org.elasticsearch.search.aggregations.Aggregations;
+import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 
@@ -22,9 +25,9 @@ import java.io.IOException;
 import java.util.List;
 import java.util.Optional;
 
-import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric.actualIsTrueQuery;
+import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.BinarySoftClassification.actualIsTrueQuery;
 
-abstract class AbstractConfusionMatrixMetric implements SoftClassificationMetric {
+abstract class AbstractConfusionMatrixMetric implements EvaluationMetric {
 
     public static final ParseField AT = new ParseField("at");
 
@@ -62,11 +65,11 @@ abstract class AbstractConfusionMatrixMetric implements SoftClassificationMetric
     }
 
     @Override
-    public final List<AggregationBuilder> aggs(String actualField, String predictedProbabilityField) {
+    public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedProbabilityField) {
         if (result != null) {
-            return List.of();
+            return Tuple.tuple(List.of(), List.of());
         }
-        return aggsAt(actualField, predictedProbabilityField);
+        return Tuple.tuple(aggsAt(actualField, predictedProbabilityField), List.of());
     }
 
     @Override

+ 13 - 7
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AucRoc.java

@@ -7,6 +7,7 @@ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification;
 
 import org.elasticsearch.common.ParseField;
 import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.collect.Tuple;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.io.stream.Writeable;
@@ -18,8 +19,10 @@ import org.elasticsearch.index.query.QueryBuilders;
 import org.elasticsearch.search.aggregations.AggregationBuilder;
 import org.elasticsearch.search.aggregations.AggregationBuilders;
 import org.elasticsearch.search.aggregations.Aggregations;
+import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
 import org.elasticsearch.search.aggregations.bucket.filter.Filter;
 import org.elasticsearch.search.aggregations.metrics.Percentiles;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 
@@ -33,7 +36,8 @@ import java.util.Objects;
 import java.util.Optional;
 import java.util.stream.IntStream;
 
-import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric.actualIsTrueQuery;
+import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
+import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.BinarySoftClassification.actualIsTrueQuery;
 
 /**
  * Area under the curve (AUC) of the receiver operating characteristic (ROC).
@@ -53,7 +57,7 @@ import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassific
  * When this is used for multi-class classification, it will calculate the ROC
  * curve of each class versus the rest.
  */
-public class AucRoc implements SoftClassificationMetric {
+public class AucRoc implements EvaluationMetric {
 
     public static final ParseField NAME = new ParseField("auc_roc");
 
@@ -88,7 +92,7 @@ public class AucRoc implements SoftClassificationMetric {
 
     @Override
     public String getWriteableName() {
-        return NAME.getPreferredName();
+        return registeredMetricName(BinarySoftClassification.NAME, NAME);
     }
 
     @Override
@@ -123,9 +127,9 @@ public class AucRoc implements SoftClassificationMetric {
     }
 
     @Override
-    public List<AggregationBuilder> aggs(String actualField, String predictedProbabilityField) {
+    public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedProbabilityField) {
         if (result != null) {
-            return List.of();
+            return Tuple.tuple(List.of(), List.of());
         }
         double[] percentiles = IntStream.range(1, 100).mapToDouble(v -> (double) v).toArray();
         AggregationBuilder percentilesForClassValueAgg =
@@ -138,7 +142,9 @@ public class AucRoc implements SoftClassificationMetric {
                 .filter(NON_TRUE_AGG_NAME, QueryBuilders.boolQuery().mustNot(actualIsTrueQuery(actualField)))
                 .subAggregation(
                     AggregationBuilders.percentiles(PERCENTILES).field(predictedProbabilityField).percentiles(percentiles));
-        return List.of(percentilesForClassValueAgg, percentilesForRestAgg);
+        return Tuple.tuple(
+            List.of(percentilesForClassValueAgg, percentilesForRestAgg),
+            List.of());
     }
 
     @Override
@@ -330,7 +336,7 @@ public class AucRoc implements SoftClassificationMetric {
 
         @Override
         public String getWriteableName() {
-            return NAME.getPreferredName();
+            return registeredMetricName(BinarySoftClassification.NAME, NAME);
         }
 
         @Override

+ 17 - 8
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java

@@ -12,7 +12,10 @@ import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.xcontent.ConstructingObjectParser;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.index.query.QueryBuilder;
+import org.elasticsearch.index.query.QueryBuilders;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 
 import java.io.IOException;
@@ -20,6 +23,8 @@ import java.util.Arrays;
 import java.util.List;
 import java.util.Objects;
 
+import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
+
 /**
  * Evaluation of binary soft classification methods, e.g. outlier detection.
  * This is useful to evaluate problems where a model outputs a probability of whether
@@ -34,19 +39,23 @@ public class BinarySoftClassification implements Evaluation {
     private static final ParseField METRICS = new ParseField("metrics");
 
     public static final ConstructingObjectParser<BinarySoftClassification, Void> PARSER = new ConstructingObjectParser<>(
-        NAME.getPreferredName(), a -> new BinarySoftClassification((String) a[0], (String) a[1], (List<SoftClassificationMetric>) a[2]));
+        NAME.getPreferredName(), a -> new BinarySoftClassification((String) a[0], (String) a[1], (List<EvaluationMetric>) a[2]));
 
     static {
         PARSER.declareString(ConstructingObjectParser.constructorArg(), ACTUAL_FIELD);
         PARSER.declareString(ConstructingObjectParser.constructorArg(), PREDICTED_PROBABILITY_FIELD);
         PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(),
-            (p, c, n) -> p.namedObject(SoftClassificationMetric.class, n, null), METRICS);
+            (p, c, n) -> p.namedObject(EvaluationMetric.class, registeredMetricName(NAME.getPreferredName(), n), c), METRICS);
     }
 
     public static BinarySoftClassification fromXContent(XContentParser parser) {
         return PARSER.apply(parser, null);
     }
 
+    static QueryBuilder actualIsTrueQuery(String actualField) {
+        return QueryBuilders.queryStringQuery(actualField + ": (1 OR true)");
+    }
+
     /**
      * The field where the actual class is marked up.
      * The value of this field is assumed to either be 1 or 0, or true or false.
@@ -61,16 +70,16 @@ public class BinarySoftClassification implements Evaluation {
     /**
      * The list of metrics to calculate
      */
-    private final List<SoftClassificationMetric> metrics;
+    private final List<EvaluationMetric> metrics;
 
     public BinarySoftClassification(String actualField, String predictedProbabilityField,
-                                    @Nullable List<SoftClassificationMetric> metrics) {
+                                    @Nullable List<EvaluationMetric> metrics) {
         this.actualField = ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD);
         this.predictedProbabilityField = ExceptionsHelper.requireNonNull(predictedProbabilityField, PREDICTED_PROBABILITY_FIELD);
         this.metrics = initMetrics(metrics, BinarySoftClassification::defaultMetrics);
     }
 
-    private static List<SoftClassificationMetric> defaultMetrics() {
+    private static List<EvaluationMetric> defaultMetrics() {
         return Arrays.asList(
             new AucRoc(false),
             new Precision(Arrays.asList(0.25, 0.5, 0.75)),
@@ -81,7 +90,7 @@ public class BinarySoftClassification implements Evaluation {
     public BinarySoftClassification(StreamInput in) throws IOException {
         this.actualField = in.readString();
         this.predictedProbabilityField = in.readString();
-        this.metrics = in.readNamedWriteableList(SoftClassificationMetric.class);
+        this.metrics = in.readNamedWriteableList(EvaluationMetric.class);
     }
 
     @Override
@@ -100,7 +109,7 @@ public class BinarySoftClassification implements Evaluation {
     }
 
     @Override
-    public List<SoftClassificationMetric> getMetrics() {
+    public List<EvaluationMetric> getMetrics() {
         return metrics;
     }
 
@@ -123,7 +132,7 @@ public class BinarySoftClassification implements Evaluation {
         builder.field(PREDICTED_PROBABILITY_FIELD.getPreferredName(), predictedProbabilityField);
 
         builder.startObject(METRICS.getPreferredName());
-        for (SoftClassificationMetric metric : metrics) {
+        for (EvaluationMetric metric : metrics) {
             builder.field(metric.getName(), metric);
         }
         builder.endObject();

+ 4 - 2
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ConfusionMatrix.java

@@ -21,6 +21,8 @@ import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.List;
 
+import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
+
 public class ConfusionMatrix extends AbstractConfusionMatrixMetric {
 
     public static final ParseField NAME = new ParseField("confusion_matrix");
@@ -46,7 +48,7 @@ public class ConfusionMatrix extends AbstractConfusionMatrixMetric {
 
     @Override
     public String getWriteableName() {
-        return NAME.getPreferredName();
+        return registeredMetricName(BinarySoftClassification.NAME, NAME);
     }
 
     @Override
@@ -129,7 +131,7 @@ public class ConfusionMatrix extends AbstractConfusionMatrixMetric {
 
         @Override
         public String getWriteableName() {
-            return NAME.getPreferredName();
+            return registeredMetricName(BinarySoftClassification.NAME, NAME);
         }
 
         @Override

+ 3 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Precision.java

@@ -19,6 +19,8 @@ import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.List;
 
+import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
+
 public class Precision extends AbstractConfusionMatrixMetric {
 
     public static final ParseField NAME = new ParseField("precision");
@@ -44,7 +46,7 @@ public class Precision extends AbstractConfusionMatrixMetric {
 
     @Override
     public String getWriteableName() {
-        return NAME.getPreferredName();
+        return registeredMetricName(BinarySoftClassification.NAME, NAME);
     }
 
     @Override

+ 3 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Recall.java

@@ -19,6 +19,8 @@ import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.List;
 
+import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
+
 public class Recall extends AbstractConfusionMatrixMetric {
 
     public static final ParseField NAME = new ParseField("recall");
@@ -44,7 +46,7 @@ public class Recall extends AbstractConfusionMatrixMetric {
 
     @Override
     public String getWriteableName() {
-        return NAME.getPreferredName();
+        return registeredMetricName(BinarySoftClassification.NAME, NAME);
     }
 
     @Override

+ 5 - 2
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ScoreByThresholdResult.java

@@ -5,6 +5,7 @@
  */
 package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification;
 
+import org.elasticsearch.common.ParseField;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.xcontent.XContentBuilder;
@@ -13,9 +14,11 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResu
 import java.io.IOException;
 import java.util.Objects;
 
+import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
+
 public class ScoreByThresholdResult implements EvaluationMetricResult {
 
-    public static final String NAME = "score_by_threshold_result";
+    public static final ParseField NAME = new ParseField("score_by_threshold_result");
 
     private final String name;
     private final double[] thresholds;
@@ -36,7 +39,7 @@ public class ScoreByThresholdResult implements EvaluationMetricResult {
 
     @Override
     public String getWriteableName() {
-        return NAME;
+        return registeredMetricName(BinarySoftClassification.NAME, NAME);
     }
 
     @Override

+ 0 - 17
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/SoftClassificationMetric.java

@@ -1,17 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License;
- * you may not use this file except in compliance with the Elastic License.
- */
-package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification;
-
-import org.elasticsearch.index.query.QueryBuilder;
-import org.elasticsearch.index.query.QueryBuilders;
-import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
-
-public interface SoftClassificationMetric extends EvaluationMetric {
-
-    static QueryBuilder actualIsTrueQuery(String actualField) {
-        return QueryBuilders.queryStringQuery(actualField + ": (1 OR true)");
-    }
-}

+ 5 - 6
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameActionRequestTests.java

@@ -31,7 +31,7 @@ public class EvaluateDataFrameActionRequestTests extends AbstractSerializingTest
     @Override
     protected NamedWriteableRegistry getNamedWriteableRegistry() {
         List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
-        namedWriteables.addAll(new MlEvaluationNamedXContentProvider().getNamedWriteables());
+        namedWriteables.addAll(MlEvaluationNamedXContentProvider.getNamedWriteables());
         namedWriteables.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedWriteables());
         return new NamedWriteableRegistry(namedWriteables);
     }
@@ -46,13 +46,11 @@ public class EvaluateDataFrameActionRequestTests extends AbstractSerializingTest
 
     @Override
     protected Request createTestInstance() {
-        Request request = new Request();
         int indicesCount = randomIntBetween(1, 5);
         List<String> indices = new ArrayList<>(indicesCount);
         for (int i = 0; i < indicesCount; i++) {
             indices.add(randomAlphaOfLength(10));
         }
-        request.setIndices(indices);
         QueryProvider queryProvider = null;
         if (randomBoolean()) {
             try {
@@ -62,10 +60,11 @@ public class EvaluateDataFrameActionRequestTests extends AbstractSerializingTest
                 throw new UncheckedIOException(e);
             }
         }
-        request.setQueryProvider(queryProvider);
         Evaluation evaluation = randomBoolean() ? BinarySoftClassificationTests.createRandom() : RegressionTests.createRandom();
-        request.setEvaluation(evaluation);
-        return request;
+        return new Request()
+            .setIndices(indices)
+            .setQueryProvider(queryProvider)
+            .setEvaluation(evaluation);
     }
 
     @Override

+ 8 - 3
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameActionResponseTests.java

@@ -11,7 +11,10 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase;
 import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction.Response;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.AccuracyResultTests;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixResultTests;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.PrecisionResultTests;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.RecallResultTests;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared;
 
@@ -21,7 +24,7 @@ public class EvaluateDataFrameActionResponseTests extends AbstractWireSerializin
 
     @Override
     protected NamedWriteableRegistry getNamedWriteableRegistry() {
-        return new NamedWriteableRegistry(new MlEvaluationNamedXContentProvider().getNamedWriteables());
+        return new NamedWriteableRegistry(MlEvaluationNamedXContentProvider.getNamedWriteables());
     }
 
     @Override
@@ -29,11 +32,13 @@ public class EvaluateDataFrameActionResponseTests extends AbstractWireSerializin
         String evaluationName = randomAlphaOfLength(10);
         List<EvaluationMetricResult> metrics =
             List.of(
+                AccuracyResultTests.createRandom(),
+                PrecisionResultTests.createRandom(),
+                RecallResultTests.createRandom(),
                 MulticlassConfusionMatrixResultTests.createRandom(),
                 new MeanSquaredError.Result(randomDouble()),
                 new RSquared.Result(randomDouble()));
-        int numMetrics = randomIntBetween(0, metrics.size());
-        return new Response(evaluationName, metrics.subList(0, numMetrics));
+        return new Response(evaluationName, randomSubsetOf(metrics));
     }
 
     @Override

+ 14 - 10
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AccuracyResultTests.java

@@ -17,15 +17,9 @@ import java.util.List;
 import java.util.stream.Collectors;
 import java.util.stream.Stream;
 
-public class AccuracyResultTests extends AbstractWireSerializingTestCase<Accuracy.Result> {
+public class AccuracyResultTests extends AbstractWireSerializingTestCase<Result> {
 
-    @Override
-    protected NamedWriteableRegistry getNamedWriteableRegistry() {
-        return new NamedWriteableRegistry(new MlEvaluationNamedXContentProvider().getNamedWriteables());
-    }
-
-    @Override
-    protected Accuracy.Result createTestInstance() {
+    public static Result createRandom() {
         int numClasses = randomIntBetween(2, 100);
         List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
         List<ActualClass> actualClasses = new ArrayList<>(numClasses);
@@ -38,7 +32,17 @@ public class AccuracyResultTests extends AbstractWireSerializingTestCase<Accurac
     }
 
     @Override
-    protected Writeable.Reader<Accuracy.Result> instanceReader() {
-        return Accuracy.Result::new;
+    protected NamedWriteableRegistry getNamedWriteableRegistry() {
+        return new NamedWriteableRegistry(MlEvaluationNamedXContentProvider.getNamedWriteables());
+    }
+
+    @Override
+    protected Result createTestInstance() {
+        return createRandom();
+    }
+
+    @Override
+    protected Writeable.Reader<Result> instanceReader() {
+        return Result::new;
     }
 }

+ 14 - 9
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/ClassificationTests.java

@@ -8,6 +8,7 @@ package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
 import org.apache.lucene.search.TotalHits;
 import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.action.search.SearchResponse;
+import org.elasticsearch.common.collect.Tuple;
 import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.io.stream.Writeable;
@@ -19,8 +20,10 @@ import org.elasticsearch.index.query.QueryBuilders;
 import org.elasticsearch.search.SearchHits;
 import org.elasticsearch.search.aggregations.AggregationBuilder;
 import org.elasticsearch.search.aggregations.Aggregations;
+import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
 import org.elasticsearch.search.builder.SearchSourceBuilder;
 import org.elasticsearch.test.AbstractSerializingTestCase;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
 
@@ -42,7 +45,7 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
 
     @Override
     protected NamedWriteableRegistry getNamedWriteableRegistry() {
-        return new NamedWriteableRegistry(new MlEvaluationNamedXContentProvider().getNamedWriteables());
+        return new NamedWriteableRegistry(MlEvaluationNamedXContentProvider.getNamedWriteables());
     }
 
     @Override
@@ -51,10 +54,12 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
     }
 
     public static Classification createRandom() {
-        List<ClassificationMetric> metrics =
+        List<EvaluationMetric> metrics =
             randomSubsetOf(
                 Arrays.asList(
                     AccuracyTests.createRandom(),
+                    PrecisionTests.createRandom(),
+                    RecallTests.createRandom(),
                     MulticlassConfusionMatrixTests.createRandom()));
         return new Classification(randomAlphaOfLength(10), randomAlphaOfLength(10), metrics.isEmpty() ? null : metrics);
     }
@@ -101,10 +106,10 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
     }
 
     public void testProcess_MultipleMetricsWithDifferentNumberOfSteps() {
-        ClassificationMetric metric1 = new FakeClassificationMetric("fake_metric_1", 2);
-        ClassificationMetric metric2 = new FakeClassificationMetric("fake_metric_2", 3);
-        ClassificationMetric metric3 = new FakeClassificationMetric("fake_metric_3", 4);
-        ClassificationMetric metric4 = new FakeClassificationMetric("fake_metric_4", 5);
+        EvaluationMetric metric1 = new FakeClassificationMetric("fake_metric_1", 2);
+        EvaluationMetric metric2 = new FakeClassificationMetric("fake_metric_2", 3);
+        EvaluationMetric metric3 = new FakeClassificationMetric("fake_metric_3", 4);
+        EvaluationMetric metric4 = new FakeClassificationMetric("fake_metric_4", 5);
 
         Classification evaluation = new Classification("act", "pred", Arrays.asList(metric1, metric2, metric3, metric4));
         assertThat(metric1.getResult(), isEmpty());
@@ -168,7 +173,7 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
      * Number of steps is configurable.
      * Upon reaching the last step, the result is produced.
      */
-    private static class FakeClassificationMetric implements ClassificationMetric {
+    private static class FakeClassificationMetric implements EvaluationMetric {
 
         private final String name;
         private final int numSteps;
@@ -191,8 +196,8 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
         }
 
         @Override
-        public List<AggregationBuilder> aggs(String actualField, String predictedField) {
-            return List.of();
+        public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
+            return Tuple.tuple(List.of(), List.of());
         }
 
         @Override

+ 7 - 5
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixTests.java

@@ -6,10 +6,12 @@
 package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
 
 import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.common.collect.Tuple;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.search.aggregations.AggregationBuilder;
 import org.elasticsearch.search.aggregations.Aggregations;
+import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
 import org.elasticsearch.test.AbstractSerializingTestCase;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass;
@@ -23,9 +25,9 @@ import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregati
 import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFiltersBucket;
 import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms;
 import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTermsBucket;
+import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.TupleMatchers.isTuple;
 import static org.hamcrest.Matchers.empty;
 import static org.hamcrest.Matchers.equalTo;
-import static org.hamcrest.Matchers.is;
 import static org.hamcrest.Matchers.not;
 
 public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<MulticlassConfusionMatrix> {
@@ -72,8 +74,8 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
 
     public void testAggs() {
         MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix();
-        List<AggregationBuilder> aggs = confusionMatrix.aggs("act", "pred");
-        assertThat(aggs, is(not(empty())));
+        Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs = confusionMatrix.aggs("act", "pred");
+        assertThat(aggs, isTuple(not(empty()), empty()));
         assertThat(confusionMatrix.getResult(), isEmpty());
     }
 
@@ -105,7 +107,7 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
         MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2);
         confusionMatrix.process(aggs);
 
-        assertThat(confusionMatrix.aggs("act", "pred"), is(empty()));
+        assertThat(confusionMatrix.aggs("act", "pred"), isTuple(empty(), empty()));
         MulticlassConfusionMatrix.Result result = (MulticlassConfusionMatrix.Result) confusionMatrix.getResult().get();
         assertThat(result.getMetricName(), equalTo("multiclass_confusion_matrix"));
         assertThat(
@@ -145,7 +147,7 @@ public class MulticlassConfusionMatrixTests extends AbstractSerializingTestCase<
         MulticlassConfusionMatrix confusionMatrix = new MulticlassConfusionMatrix(2);
         confusionMatrix.process(aggs);
 
-        assertThat(confusionMatrix.aggs("act", "pred"), is(empty()));
+        assertThat(confusionMatrix.aggs("act", "pred"), isTuple(empty(), empty()));
         MulticlassConfusionMatrix.Result result = (MulticlassConfusionMatrix.Result) confusionMatrix.getResult().get();
         assertThat(result.getMetricName(), equalTo("multiclass_confusion_matrix"));
         assertThat(

+ 48 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/PrecisionResultTests.java

@@ -0,0 +1,48 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
+
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision.PerClassResult;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision.Result;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+public class PrecisionResultTests extends AbstractWireSerializingTestCase<Result> {
+
+    public static Result createRandom() {
+        int numClasses = randomIntBetween(2, 100);
+        List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
+        List<PerClassResult> classes = new ArrayList<>(numClasses);
+        for (int i = 0; i < numClasses; i++) {
+            double precision = randomDoubleBetween(0.0, 1.0, true);
+            classes.add(new PerClassResult(classNames.get(i), precision));
+        }
+        double avgPrecision = randomDoubleBetween(0.0, 1.0, true);
+        return new Result(classes, avgPrecision);
+    }
+
+    @Override
+    protected NamedWriteableRegistry getNamedWriteableRegistry() {
+        return new NamedWriteableRegistry(MlEvaluationNamedXContentProvider.getNamedWriteables());
+    }
+
+    @Override
+    protected Result createTestInstance() {
+        return createRandom();
+    }
+
+    @Override
+    protected Writeable.Reader<Result> instanceReader() {
+        return Result::new;
+    }
+}

+ 119 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/PrecisionTests.java

@@ -0,0 +1,119 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
+
+import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.search.aggregations.Aggregations;
+import org.elasticsearch.test.AbstractSerializingTestCase;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty;
+import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockFilters;
+import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue;
+import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms;
+import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.TupleMatchers.isTuple;
+import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.empty;
+import static org.hamcrest.Matchers.equalTo;
+
+public class PrecisionTests extends AbstractSerializingTestCase<Precision> {
+
+    @Override
+    protected Precision doParseInstance(XContentParser parser) throws IOException {
+        return Precision.fromXContent(parser);
+    }
+
+    @Override
+    protected Precision createTestInstance() {
+        return createRandom();
+    }
+
+    @Override
+    protected Writeable.Reader<Precision> instanceReader() {
+        return Precision::new;
+    }
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return true;
+    }
+
+    public static Precision createRandom() {
+        return new Precision();
+    }
+
+    public void testProcess() {
+        Aggregations aggs = new Aggregations(Arrays.asList(
+            mockTerms(Precision.ACTUAL_CLASSES_NAMES_AGG_NAME),
+            mockFilters(Precision.BY_PREDICTED_CLASS_AGG_NAME),
+            mockSingleValue(Precision.AVG_PRECISION_AGG_NAME, 0.8123),
+            mockSingleValue("some_other_single_metric_agg", 0.2377)
+        ));
+
+        Precision precision = new Precision();
+        precision.process(aggs);
+
+        assertThat(precision.aggs("act", "pred"), isTuple(empty(), empty()));
+        assertThat(precision.getResult().get(), equalTo(new Precision.Result(List.of(), 0.8123)));
+    }
+
+    public void testProcess_GivenMissingAgg() {
+        {
+            Aggregations aggs = new Aggregations(Arrays.asList(
+                mockFilters(Precision.BY_PREDICTED_CLASS_AGG_NAME),
+                mockSingleValue("some_other_single_metric_agg", 0.2377)
+            ));
+            Precision precision = new Precision();
+            precision.process(aggs);
+            assertThat(precision.getResult(), isEmpty());
+        }
+        {
+            Aggregations aggs = new Aggregations(Arrays.asList(
+                mockSingleValue(Precision.AVG_PRECISION_AGG_NAME, 0.8123),
+                mockSingleValue("some_other_single_metric_agg", 0.2377)
+            ));
+            Precision precision = new Precision();
+            precision.process(aggs);
+            assertThat(precision.getResult(), isEmpty());
+        }
+    }
+
+    public void testProcess_GivenAggOfWrongType() {
+        {
+            Aggregations aggs = new Aggregations(Arrays.asList(
+                mockFilters(Precision.BY_PREDICTED_CLASS_AGG_NAME),
+                mockFilters(Precision.AVG_PRECISION_AGG_NAME)
+            ));
+            Precision precision = new Precision();
+            precision.process(aggs);
+            assertThat(precision.getResult(), isEmpty());
+        }
+        {
+            Aggregations aggs = new Aggregations(Arrays.asList(
+                mockSingleValue(Precision.BY_PREDICTED_CLASS_AGG_NAME, 1.0),
+                mockSingleValue(Precision.AVG_PRECISION_AGG_NAME, 0.8123)
+            ));
+            Precision precision = new Precision();
+            precision.process(aggs);
+            assertThat(precision.getResult(), isEmpty());
+        }
+    }
+
+    public void testProcess_GivenCardinalityTooHigh() {
+        Aggregations aggs =
+            new Aggregations(Collections.singletonList(mockTerms(Precision.ACTUAL_CLASSES_NAMES_AGG_NAME, Collections.emptyList(), 1)));
+        Precision precision = new Precision();
+        precision.aggs("foo", "bar");
+        ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> precision.process(aggs));
+        assertThat(e.getMessage(), containsString("Cardinality of field [foo] is too high"));
+    }
+}

+ 48 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/RecallResultTests.java

@@ -0,0 +1,48 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
+
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall.PerClassResult;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall.Result;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+public class RecallResultTests extends AbstractWireSerializingTestCase<Result> {
+
+    public static Result createRandom() {
+        int numClasses = randomIntBetween(2, 100);
+        List<String> classNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numClasses).collect(Collectors.toList());
+        List<PerClassResult> classes = new ArrayList<>(numClasses);
+        for (int i = 0; i < numClasses; i++) {
+            double recall = randomDoubleBetween(0.0, 1.0, true);
+            classes.add(new PerClassResult(classNames.get(i), recall));
+        }
+        double avgRecall = randomDoubleBetween(0.0, 1.0, true);
+        return new Result(classes, avgRecall);
+    }
+
+    @Override
+    protected NamedWriteableRegistry getNamedWriteableRegistry() {
+        return new NamedWriteableRegistry(MlEvaluationNamedXContentProvider.getNamedWriteables());
+    }
+
+    @Override
+    protected Result createTestInstance() {
+        return createRandom();
+    }
+
+    @Override
+    protected Writeable.Reader<Result> instanceReader() {
+        return Result::new;
+    }
+}

+ 118 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/RecallTests.java

@@ -0,0 +1,118 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
+
+import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.search.aggregations.Aggregations;
+import org.elasticsearch.test.AbstractSerializingTestCase;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty;
+import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue;
+import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockTerms;
+import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.TupleMatchers.isTuple;
+import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.empty;
+import static org.hamcrest.Matchers.equalTo;
+
+public class RecallTests extends AbstractSerializingTestCase<Recall> {
+
+    @Override
+    protected Recall doParseInstance(XContentParser parser) throws IOException {
+        return Recall.fromXContent(parser);
+    }
+
+    @Override
+    protected Recall createTestInstance() {
+        return createRandom();
+    }
+
+    @Override
+    protected Writeable.Reader<Recall> instanceReader() {
+        return Recall::new;
+    }
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return true;
+    }
+
+    public static Recall createRandom() {
+        return new Recall();
+    }
+
+    public void testProcess() {
+        Aggregations aggs = new Aggregations(Arrays.asList(
+            mockTerms(Recall.BY_ACTUAL_CLASS_AGG_NAME),
+            mockSingleValue(Recall.AVG_RECALL_AGG_NAME, 0.8123),
+            mockSingleValue("some_other_single_metric_agg", 0.2377)
+        ));
+
+        Recall recall = new Recall();
+        recall.process(aggs);
+
+        assertThat(recall.aggs("act", "pred"), isTuple(empty(), empty()));
+        assertThat(recall.getResult().get(), equalTo(new Recall.Result(List.of(), 0.8123)));
+    }
+
+    public void testProcess_GivenMissingAgg() {
+        {
+            Aggregations aggs = new Aggregations(Arrays.asList(
+                mockTerms(Recall.BY_ACTUAL_CLASS_AGG_NAME),
+                mockSingleValue("some_other_single_metric_agg", 0.2377)
+            ));
+            Recall recall = new Recall();
+            recall.process(aggs);
+            assertThat(recall.getResult(), isEmpty());
+        }
+        {
+            Aggregations aggs = new Aggregations(Arrays.asList(
+                mockSingleValue(Recall.AVG_RECALL_AGG_NAME, 0.8123),
+                mockSingleValue("some_other_single_metric_agg", 0.2377)
+            ));
+            Recall recall = new Recall();
+            recall.process(aggs);
+            assertThat(recall.getResult(), isEmpty());
+        }
+    }
+
+    public void testProcess_GivenAggOfWrongType() {
+        {
+            Aggregations aggs = new Aggregations(Arrays.asList(
+                mockTerms(Recall.BY_ACTUAL_CLASS_AGG_NAME),
+                mockTerms(Recall.AVG_RECALL_AGG_NAME)
+            ));
+            Recall recall = new Recall();
+            recall.process(aggs);
+            assertThat(recall.getResult(), isEmpty());
+        }
+        {
+            Aggregations aggs = new Aggregations(Arrays.asList(
+                mockSingleValue(Recall.BY_ACTUAL_CLASS_AGG_NAME, 1.0),
+                mockSingleValue(Recall.AVG_RECALL_AGG_NAME, 0.8123)
+            ));
+            Recall recall = new Recall();
+            recall.process(aggs);
+            assertThat(recall.getResult(), isEmpty());
+        }
+    }
+
+    public void testProcess_GivenCardinalityTooHigh() {
+        Aggregations aggs = new Aggregations(Arrays.asList(
+            mockTerms(Recall.BY_ACTUAL_CLASS_AGG_NAME, Collections.emptyList(), 1),
+            mockSingleValue(Recall.AVG_RECALL_AGG_NAME, 0.8123)));
+        Recall recall = new Recall();
+        recall.aggs("foo", "bar");
+        ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> recall.process(aggs));
+        assertThat(e.getMessage(), containsString("Cardinality of field [foo] is too high"));
+    }
+}

+ 49 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/TupleMatchers.java

@@ -0,0 +1,49 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
+
+import org.elasticsearch.common.collect.Tuple;
+import org.hamcrest.Description;
+import org.hamcrest.Matcher;
+import org.hamcrest.TypeSafeMatcher;
+
+import java.util.Arrays;
+
+public class TupleMatchers {
+
+    private static class TupleMatcher<V1, V2> extends TypeSafeMatcher<Tuple<? extends V1, ? extends V2>> {
+
+        private final Matcher<? super V1> v1Matcher;
+        private final Matcher<? super V2> v2Matcher;
+
+        private TupleMatcher(Matcher<? super V1> v1Matcher, Matcher<? super V2> v2Matcher) {
+            this.v1Matcher = v1Matcher;
+            this.v2Matcher = v2Matcher;
+        }
+
+        @Override
+        protected boolean matchesSafely(final Tuple<? extends V1, ? extends V2> item) {
+            return item != null && v1Matcher.matches(item.v1()) && v2Matcher.matches(item.v2());
+        }
+
+        @Override
+        public void describeTo(final Description description) {
+            description.appendText("expected tuple matching ").appendList("[", ", ", "]", Arrays.asList(v1Matcher, v2Matcher));
+        }
+    }
+
+    /**
+     * Creates a matcher that matches iff:
+     *  1. the examined tuple's <code>v1()</code> matches the specified <code>v1Matcher</code>
+     * and
+     *  2. the examined tuple's <code>v2()</code> matches the specified <code>v2Matcher</code>
+     * For example:
+     * <pre>assertThat(Tuple.tuple("myValue1", "myValue2"), isTuple(startsWith("my"), containsString("Val")))</pre>
+     */
+    public static <V1, V2> TupleMatcher<? extends V1, ? extends V2> isTuple(Matcher<? super V1> v1Matcher, Matcher<? super V2> v2Matcher) {
+        return new TupleMatcher(v1Matcher, v2Matcher);
+    }
+}

+ 3 - 2
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java

@@ -14,6 +14,7 @@ import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.index.query.QueryBuilders;
 import org.elasticsearch.search.builder.SearchSourceBuilder;
 import org.elasticsearch.test.AbstractSerializingTestCase;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
 
 import java.io.IOException;
@@ -29,7 +30,7 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
 
     @Override
     protected NamedWriteableRegistry getNamedWriteableRegistry() {
-        return new NamedWriteableRegistry(new MlEvaluationNamedXContentProvider().getNamedWriteables());
+        return new NamedWriteableRegistry(MlEvaluationNamedXContentProvider.getNamedWriteables());
     }
 
     @Override
@@ -38,7 +39,7 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
     }
 
     public static Regression createRandom() {
-        List<RegressionMetric> metrics = new ArrayList<>();
+        List<EvaluationMetric> metrics = new ArrayList<>();
         if (randomBoolean()) {
             metrics.add(MeanSquaredErrorTests.createRandom());
         }

+ 3 - 2
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassificationTests.java

@@ -14,6 +14,7 @@ import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.index.query.QueryBuilders;
 import org.elasticsearch.search.builder.SearchSourceBuilder;
 import org.elasticsearch.test.AbstractSerializingTestCase;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
 
 import java.io.IOException;
@@ -29,7 +30,7 @@ public class BinarySoftClassificationTests extends AbstractSerializingTestCase<B
 
     @Override
     protected NamedWriteableRegistry getNamedWriteableRegistry() {
-        return new NamedWriteableRegistry(new MlEvaluationNamedXContentProvider().getNamedWriteables());
+        return new NamedWriteableRegistry(MlEvaluationNamedXContentProvider.getNamedWriteables());
     }
 
     @Override
@@ -38,7 +39,7 @@ public class BinarySoftClassificationTests extends AbstractSerializingTestCase<B
     }
 
     public static BinarySoftClassification createRandom() {
-        List<SoftClassificationMetric> metrics = new ArrayList<>();
+        List<EvaluationMetric> metrics = new ArrayList<>();
         if (randomBoolean()) {
             metrics.add(AucRocTests.createRandom());
         }

+ 111 - 38
x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java

@@ -5,21 +5,23 @@
  */
 package org.elasticsearch.xpack.ml.integration;
 
+import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.action.bulk.BulkRequestBuilder;
 import org.elasticsearch.action.bulk.BulkResponse;
 import org.elasticsearch.action.index.IndexRequest;
 import org.elasticsearch.action.support.WriteRequest;
 import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
-import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.ActualClass;
-import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix.PredictedClass;
 import org.junit.After;
 import org.junit.Before;
 
 import java.util.List;
 
+import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.hasSize;
 
@@ -116,6 +118,68 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
         assertThat(accuracyResult.getOverallAccuracy(), equalTo(45.0 / 75));
     }
 
+    public void testEvaluate_Precision() {
+        EvaluateDataFrameAction.Response evaluateDataFrameResponse =
+            evaluateDataFrame(
+                ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, List.of(new Precision())));
+
+        assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
+        assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
+
+        Precision.Result precisionResult = (Precision.Result) evaluateDataFrameResponse.getMetrics().get(0);
+        assertThat(precisionResult.getMetricName(), equalTo(Precision.NAME.getPreferredName()));
+        assertThat(
+            precisionResult.getClasses(),
+            equalTo(
+                List.of(
+                    new Precision.PerClassResult("ant", 1.0 / 15),
+                    new Precision.PerClassResult("cat", 1.0 / 15),
+                    new Precision.PerClassResult("dog", 1.0 / 15),
+                    new Precision.PerClassResult("fox", 1.0 / 15),
+                    new Precision.PerClassResult("mouse", 1.0 / 15))));
+        assertThat(precisionResult.getAvgPrecision(), equalTo(5.0 / 75));
+    }
+
+    public void testEvaluate_Precision_CardinalityTooHigh() {
+        ElasticsearchStatusException e =
+            expectThrows(
+                ElasticsearchStatusException.class,
+                () -> evaluateDataFrame(
+                    ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, List.of(new Precision(4)))));
+        assertThat(e.getMessage(), containsString("Cardinality of field [animal_name] is too high"));
+    }
+
+    public void testEvaluate_Recall() {
+        EvaluateDataFrameAction.Response evaluateDataFrameResponse =
+            evaluateDataFrame(
+                ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, List.of(new Recall())));
+
+        assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
+        assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
+
+        Recall.Result recallResult = (Recall.Result) evaluateDataFrameResponse.getMetrics().get(0);
+        assertThat(recallResult.getMetricName(), equalTo(Recall.NAME.getPreferredName()));
+        assertThat(
+            recallResult.getClasses(),
+            equalTo(
+                List.of(
+                    new Recall.PerClassResult("ant", 1.0 / 15),
+                    new Recall.PerClassResult("cat", 1.0 / 15),
+                    new Recall.PerClassResult("dog", 1.0 / 15),
+                    new Recall.PerClassResult("fox", 1.0 / 15),
+                    new Recall.PerClassResult("mouse", 1.0 / 15))));
+        assertThat(recallResult.getAvgRecall(), equalTo(5.0 / 75));
+    }
+
+    public void testEvaluate_Recall_CardinalityTooHigh() {
+        ElasticsearchStatusException e =
+            expectThrows(
+                ElasticsearchStatusException.class,
+                () -> evaluateDataFrame(
+                    ANIMALS_DATA_INDEX, new Classification(ANIMAL_NAME_FIELD, ANIMAL_NAME_PREDICTION_FIELD, List.of(new Recall(4)))));
+        assertThat(e.getMessage(), containsString("Cardinality of field [animal_name] is too high"));
+    }
+
     public void testEvaluate_ConfusionMatrixMetricWithDefaultSize() {
         EvaluateDataFrameAction.Response evaluateDataFrameResponse =
             evaluateDataFrame(
@@ -131,50 +195,50 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
         assertThat(
             confusionMatrixResult.getConfusionMatrix(),
             equalTo(List.of(
-                new ActualClass("ant",
+                new MulticlassConfusionMatrix.ActualClass("ant",
                     15,
                     List.of(
-                        new PredictedClass("ant", 1L),
-                        new PredictedClass("cat", 4L),
-                        new PredictedClass("dog", 3L),
-                        new PredictedClass("fox", 2L),
-                        new PredictedClass("mouse", 5L)),
+                        new MulticlassConfusionMatrix.PredictedClass("ant", 1L),
+                        new MulticlassConfusionMatrix.PredictedClass("cat", 4L),
+                        new MulticlassConfusionMatrix.PredictedClass("dog", 3L),
+                        new MulticlassConfusionMatrix.PredictedClass("fox", 2L),
+                        new MulticlassConfusionMatrix.PredictedClass("mouse", 5L)),
                     0),
-                new ActualClass("cat",
+                new MulticlassConfusionMatrix.ActualClass("cat",
                     15,
                     List.of(
-                        new PredictedClass("ant", 3L),
-                        new PredictedClass("cat", 1L),
-                        new PredictedClass("dog", 5L),
-                        new PredictedClass("fox", 4L),
-                        new PredictedClass("mouse", 2L)),
+                        new MulticlassConfusionMatrix.PredictedClass("ant", 3L),
+                        new MulticlassConfusionMatrix.PredictedClass("cat", 1L),
+                        new MulticlassConfusionMatrix.PredictedClass("dog", 5L),
+                        new MulticlassConfusionMatrix.PredictedClass("fox", 4L),
+                        new MulticlassConfusionMatrix.PredictedClass("mouse", 2L)),
                     0),
-                new ActualClass("dog",
+                new MulticlassConfusionMatrix.ActualClass("dog",
                     15,
                     List.of(
-                        new PredictedClass("ant", 4L),
-                        new PredictedClass("cat", 2L),
-                        new PredictedClass("dog", 1L),
-                        new PredictedClass("fox", 5L),
-                        new PredictedClass("mouse", 3L)),
+                        new MulticlassConfusionMatrix.PredictedClass("ant", 4L),
+                        new MulticlassConfusionMatrix.PredictedClass("cat", 2L),
+                        new MulticlassConfusionMatrix.PredictedClass("dog", 1L),
+                        new MulticlassConfusionMatrix.PredictedClass("fox", 5L),
+                        new MulticlassConfusionMatrix.PredictedClass("mouse", 3L)),
                     0),
-                new ActualClass("fox",
+                new MulticlassConfusionMatrix.ActualClass("fox",
                     15,
                     List.of(
-                        new PredictedClass("ant", 5L),
-                        new PredictedClass("cat", 3L),
-                        new PredictedClass("dog", 2L),
-                        new PredictedClass("fox", 1L),
-                        new PredictedClass("mouse", 4L)),
+                        new MulticlassConfusionMatrix.PredictedClass("ant", 5L),
+                        new MulticlassConfusionMatrix.PredictedClass("cat", 3L),
+                        new MulticlassConfusionMatrix.PredictedClass("dog", 2L),
+                        new MulticlassConfusionMatrix.PredictedClass("fox", 1L),
+                        new MulticlassConfusionMatrix.PredictedClass("mouse", 4L)),
                     0),
-                new ActualClass("mouse",
+                new MulticlassConfusionMatrix.ActualClass("mouse",
                     15,
                     List.of(
-                        new PredictedClass("ant", 2L),
-                        new PredictedClass("cat", 5L),
-                        new PredictedClass("dog", 4L),
-                        new PredictedClass("fox", 3L),
-                        new PredictedClass("mouse", 1L)),
+                        new MulticlassConfusionMatrix.PredictedClass("ant", 2L),
+                        new MulticlassConfusionMatrix.PredictedClass("cat", 5L),
+                        new MulticlassConfusionMatrix.PredictedClass("dog", 4L),
+                        new MulticlassConfusionMatrix.PredictedClass("fox", 3L),
+                        new MulticlassConfusionMatrix.PredictedClass("mouse", 1L)),
                     0))));
         assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(0L));
     }
@@ -193,17 +257,26 @@ public class ClassificationEvaluationIT extends MlNativeDataFrameAnalyticsIntegT
         assertThat(
             confusionMatrixResult.getConfusionMatrix(),
             equalTo(List.of(
-                new ActualClass("ant",
+                new MulticlassConfusionMatrix.ActualClass("ant",
                     15,
-                    List.of(new PredictedClass("ant", 1L), new PredictedClass("cat", 4L), new PredictedClass("dog", 3L)),
+                    List.of(
+                        new MulticlassConfusionMatrix.PredictedClass("ant", 1L),
+                        new MulticlassConfusionMatrix.PredictedClass("cat", 4L),
+                        new MulticlassConfusionMatrix.PredictedClass("dog", 3L)),
                     7),
-                new ActualClass("cat",
+                new MulticlassConfusionMatrix.ActualClass("cat",
                     15,
-                    List.of(new PredictedClass("ant", 3L), new PredictedClass("cat", 1L), new PredictedClass("dog", 5L)),
+                    List.of(
+                        new MulticlassConfusionMatrix.PredictedClass("ant", 3L),
+                        new MulticlassConfusionMatrix.PredictedClass("cat", 1L),
+                        new MulticlassConfusionMatrix.PredictedClass("dog", 5L)),
                     6),
-                new ActualClass("dog",
+                new MulticlassConfusionMatrix.ActualClass("dog",
                     15,
-                    List.of(new PredictedClass("ant", 4L), new PredictedClass("cat", 2L), new PredictedClass("dog", 1L)),
+                    List.of(
+                        new MulticlassConfusionMatrix.PredictedClass("ant", 4L),
+                        new MulticlassConfusionMatrix.PredictedClass("cat", 2L),
+                        new MulticlassConfusionMatrix.PredictedClass("dog", 1L)),
                     8))));
         assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(2L));
     }

+ 24 - 2
x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java

@@ -27,6 +27,8 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParamsTests
 import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall;
 import org.junit.After;
 
 import java.util.ArrayList;
@@ -450,9 +452,11 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
             evaluateDataFrame(
                 destIndex,
                 new org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification(
-                    dependentVariable, predictedClassField, Arrays.asList(new Accuracy(), new MulticlassConfusionMatrix())));
+                    dependentVariable,
+                    predictedClassField,
+                    Arrays.asList(new Accuracy(), new MulticlassConfusionMatrix(), new Precision(), new Recall())));
         assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME.getPreferredName()));
-        assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(2));
+        assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(4));
 
         {   // Accuracy
             Accuracy.Result accuracyResult = (Accuracy.Result) evaluateDataFrameResponse.getMetrics().get(0);
@@ -483,6 +487,24 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
             }
             assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(0L));
         }
+
+        {   // Precision
+            Precision.Result precisionResult = (Precision.Result) evaluateDataFrameResponse.getMetrics().get(2);
+            assertThat(precisionResult.getMetricName(), equalTo(Precision.NAME.getPreferredName()));
+            for (Precision.PerClassResult klass : precisionResult.getClasses()) {
+                assertThat(klass.getClassName(), is(in(dependentVariableValuesAsStrings)));
+                assertThat(klass.getPrecision(), allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0)));
+            }
+        }
+
+        {   // Recall
+            Recall.Result recallResult = (Recall.Result) evaluateDataFrameResponse.getMetrics().get(3);
+            assertThat(recallResult.getMetricName(), equalTo(Recall.NAME.getPreferredName()));
+            for (Recall.PerClassResult klass : recallResult.getClasses()) {
+                assertThat(klass.getClassName(), is(in(dependentVariableValuesAsStrings)));
+                assertThat(klass.getRecall(), allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0)));
+            }
+        }
     }
 
     protected String stateDocId() {

+ 52 - 0
x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml

@@ -632,6 +632,58 @@ setup:
             accuracy: 0.5  # 1 out of 2
         overall_accuracy: 0.625  # 5 out of 8
 ---
+"Test classification precision":
+  - do:
+      ml.evaluate_data_frame:
+        body: >
+          {
+            "index": "utopia",
+            "evaluation": {
+              "classification": {
+                "actual_field": "classification_field_act.keyword",
+                "predicted_field": "classification_field_pred.keyword",
+                "metrics": { "precision": {} }
+              }
+            }
+          }
+
+  - match:
+      classification.precision:
+        classes:
+          - class_name: "cat"
+            precision: 0.5  # 2 out of 4
+          - class_name: "dog"
+            precision: 0.6666666666666666  # 2 out of 3
+          - class_name: "mouse"
+            precision: 1.0  # 1 out of 1
+        avg_precision: 0.7222222222222222
+---
+"Test classification recall":
+  - do:
+      ml.evaluate_data_frame:
+        body: >
+          {
+            "index": "utopia",
+            "evaluation": {
+              "classification": {
+                "actual_field": "classification_field_act.keyword",
+                "predicted_field": "classification_field_pred.keyword",
+                "metrics": { "recall": {} }
+              }
+            }
+          }
+
+  - match:
+      classification.recall:
+        classes:
+          - class_name: "cat"
+            recall: 0.6666666666666666  # 2 out of 3
+          - class_name: "dog"
+            recall: 0.6666666666666666  # 2 out of 3
+          - class_name: "mouse"
+            recall: 0.5  # 1 out of 2
+        avg_recall: 0.611111111111111
+---
 "Test classification multiclass_confusion_matrix":
   - do:
       ml.evaluate_data_frame: