Browse Source

[ML] Adds support for regression.mean_squared_error to eval API (#44140)

* [ML] Adds support for regression.mean_squared_error to eval API

* addressing PR comments

* fixing tests
Benjamin Trent 6 years ago
parent
commit
873e9f93cf
17 changed files with 1069 additions and 20 deletions
  1. 7 0
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java
  2. 129 0
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredErrorMetric.java
  3. 129 0
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/Regression.java
  4. 51 0
      client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java
  5. 11 7
      client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java
  6. 4 0
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/EvaluateDataFrameResponseTests.java
  7. 53 0
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredErrorMetricResultTests.java
  8. 49 0
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredErrorMetricTests.java
  9. 59 0
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/RegressionTests.java
  10. 14 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java
  11. 141 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredError.java
  12. 171 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java
  13. 37 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionMetric.java
  14. 76 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredErrorTests.java
  15. 59 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java
  16. 4 3
      x-pack/plugin/ml/qa/ml-with-security/build.gradle
  17. 75 10
      x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml

+ 7 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java

@@ -18,6 +18,8 @@
  */
 package org.elasticsearch.client.ml.dataframe.evaluation;
 
+import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
+import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
 import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
 import org.elasticsearch.common.ParseField;
 import org.elasticsearch.common.xcontent.NamedXContentRegistry;
@@ -38,12 +40,15 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
             // Evaluations
             new NamedXContentRegistry.Entry(
                 Evaluation.class, new ParseField(BinarySoftClassification.NAME), BinarySoftClassification::fromXContent),
+            new NamedXContentRegistry.Entry(Evaluation.class, new ParseField(Regression.NAME), Regression::fromXContent),
             // Evaluation metrics
             new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(AucRocMetric.NAME), AucRocMetric::fromXContent),
             new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(PrecisionMetric.NAME), PrecisionMetric::fromXContent),
             new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(RecallMetric.NAME), RecallMetric::fromXContent),
             new NamedXContentRegistry.Entry(
                 EvaluationMetric.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric::fromXContent),
+            new NamedXContentRegistry.Entry(
+                EvaluationMetric.class, new ParseField(MeanSquaredErrorMetric.NAME), MeanSquaredErrorMetric::fromXContent),
             // Evaluation metrics results
             new NamedXContentRegistry.Entry(
                 EvaluationMetric.Result.class, new ParseField(AucRocMetric.NAME), AucRocMetric.Result::fromXContent),
@@ -51,6 +56,8 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
                 EvaluationMetric.Result.class, new ParseField(PrecisionMetric.NAME), PrecisionMetric.Result::fromXContent),
             new NamedXContentRegistry.Entry(
                 EvaluationMetric.Result.class, new ParseField(RecallMetric.NAME), RecallMetric.Result::fromXContent),
+            new NamedXContentRegistry.Entry(
+                EvaluationMetric.Result.class, new ParseField(MeanSquaredErrorMetric.NAME), MeanSquaredErrorMetric.Result::fromXContent),
             new NamedXContentRegistry.Entry(
                 EvaluationMetric.Result.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric.Result::fromXContent));
     }

+ 129 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredErrorMetric.java

@@ -0,0 +1,129 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.dataframe.evaluation.regression;
+
+import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ObjectParser;
+import org.elasticsearch.common.xcontent.ToXContent;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+
+import java.io.IOException;
+import java.util.Objects;
+
+import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
+
+/**
+ * Calculates the mean squared error between two known numerical fields.
+ *
+ * equation: mse = 1/n * Σ(y - y´)^2
+ */
+public class MeanSquaredErrorMetric implements EvaluationMetric {
+
+    public static final String NAME = "mean_squared_error";
+
+    private static final ObjectParser<MeanSquaredErrorMetric, Void> PARSER =
+        new ObjectParser<>("mean_squared_error", true, MeanSquaredErrorMetric::new);
+
+    public static MeanSquaredErrorMetric fromXContent(XContentParser parser) {
+        return PARSER.apply(parser, null);
+    }
+
+    public MeanSquaredErrorMetric() {
+
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
+        builder.startObject();
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        return true;
+    }
+
+    @Override
+    public int hashCode() {
+        // create static hash code from name as there are currently no unique fields per class instance
+        return Objects.hashCode(NAME);
+    }
+
+    @Override
+    public String getName() {
+        return NAME;
+    }
+
+    public static class Result implements EvaluationMetric.Result {
+
+        public static final ParseField ERROR = new ParseField("error");
+        private final double error;
+
+        public static Result fromXContent(XContentParser parser) {
+            return PARSER.apply(parser, null);
+        }
+
+        private static final ConstructingObjectParser<Result, Void> PARSER =
+            new ConstructingObjectParser<>("mean_squared_error_result", true, args -> new Result((double) args[0]));
+
+        static {
+            PARSER.declareDouble(constructorArg(), ERROR);
+        }
+
+        public Result(double error) {
+            this.error = error;
+        }
+
+        @Override
+        public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
+            builder.startObject();
+            builder.field(ERROR.getPreferredName(), error);
+            builder.endObject();
+            return builder;
+        }
+
+        public double getError() {
+            return error;
+        }
+
+        @Override
+        public String getMetricName() {
+            return NAME;
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            Result that = (Result) o;
+            return Objects.equals(that.error, this.error);
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(error);
+        }
+    }
+}

+ 129 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/Regression.java

@@ -0,0 +1,129 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.dataframe.evaluation.regression;
+
+import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation;
+import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
+import org.elasticsearch.common.Nullable;
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ToXContent;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Objects;
+
+/**
+ * Evaluation of regression results.
+ */
+public class Regression implements Evaluation {
+
+    public static final String NAME = "regression";
+
+    private static final ParseField ACTUAL_FIELD = new ParseField("actual_field");
+    private static final ParseField PREDICTED_FIELD = new ParseField("predicted_field");
+    private static final ParseField METRICS = new ParseField("metrics");
+
+    @SuppressWarnings("unchecked")
+    public static final ConstructingObjectParser<Regression, Void> PARSER = new ConstructingObjectParser<>(
+        NAME, true, a -> new Regression((String) a[0], (String) a[1], (List<EvaluationMetric>) a[2]));
+
+    static {
+        PARSER.declareString(ConstructingObjectParser.constructorArg(), ACTUAL_FIELD);
+        PARSER.declareString(ConstructingObjectParser.constructorArg(), PREDICTED_FIELD);
+        PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(),
+            (p, c, n) -> p.namedObject(EvaluationMetric.class, n, c), METRICS);
+    }
+
+    public static Regression fromXContent(XContentParser parser) {
+        return PARSER.apply(parser, null);
+    }
+
+    /**
+     * The field containing the actual value
+     * The value of this field is assumed to be numeric
+     */
+    private final String actualField;
+
+    /**
+     * The field containing the predicted value
+     * The value of this field is assumed to be numeric
+     */
+    private final String predictedField;
+
+    /**
+     * The list of metrics to calculate
+     */
+    private final List<EvaluationMetric> metrics;
+
+    public Regression(String actualField, String predictedField) {
+        this(actualField, predictedField, (List<EvaluationMetric>)null);
+    }
+
+    public Regression(String actualField, String predictedField, EvaluationMetric... metrics) {
+        this(actualField, predictedField, Arrays.asList(metrics));
+    }
+
+    public Regression(String actualField, String predictedField, @Nullable List<EvaluationMetric> metrics) {
+        this.actualField = actualField;
+        this.predictedField = predictedField;
+        this.metrics = metrics;
+    }
+
+    @Override
+    public String getName() {
+        return NAME;
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
+        builder.startObject();
+        builder.field(ACTUAL_FIELD.getPreferredName(), actualField);
+        builder.field(PREDICTED_FIELD.getPreferredName(), predictedField);
+
+        if (metrics != null) {
+           builder.startObject(METRICS.getPreferredName());
+           for (EvaluationMetric metric : metrics) {
+               builder.field(metric.getName(), metric);
+           }
+           builder.endObject();
+        }
+
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        Regression that = (Regression) o;
+        return Objects.equals(that.actualField, this.actualField)
+            && Objects.equals(that.predictedField, this.predictedField)
+            && Objects.equals(that.metrics, this.metrics);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(actualField, predictedField, metrics);
+    }
+}

+ 51 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java

@@ -123,6 +123,8 @@ import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsState;
 import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsStats;
 import org.elasticsearch.client.ml.dataframe.OutlierDetection;
 import org.elasticsearch.client.ml.dataframe.QueryConfig;
+import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
+import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
 import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
 import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric;
@@ -1578,6 +1580,33 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
         assertThat(curvePointAtThreshold1.getTruePositiveRate(), equalTo(0.0));
         assertThat(curvePointAtThreshold1.getFalsePositiveRate(), equalTo(0.0));
         assertThat(curvePointAtThreshold1.getThreshold(), equalTo(1.0));
+
+        String regressionIndex = "evaluate-regression-test-index";
+        createIndex(regressionIndex, mappingForRegression());
+        BulkRequest regressionBulk = new BulkRequest()
+            .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
+            .add(docForRegression(regressionIndex, 0.3, 0.1))  // #0
+            .add(docForRegression(regressionIndex, 0.3, 0.2))  // #1
+            .add(docForRegression(regressionIndex, 0.3, 0.3))  // #2
+            .add(docForRegression(regressionIndex, 0.3, 0.4))  // #3
+            .add(docForRegression(regressionIndex, 0.3, 0.7))  // #4
+            .add(docForRegression(regressionIndex, 0.5, 0.2))  // #5
+            .add(docForRegression(regressionIndex, 0.5, 0.3))  // #6
+            .add(docForRegression(regressionIndex, 0.5, 0.4))  // #7
+            .add(docForRegression(regressionIndex, 0.5, 0.8))  // #8
+            .add(docForRegression(regressionIndex, 0.5, 0.9));  // #9
+        highLevelClient().bulk(regressionBulk, RequestOptions.DEFAULT);
+
+        evaluateDataFrameRequest = new EvaluateDataFrameRequest(regressionIndex, new Regression(actualRegression, probabilityRegression));
+
+        evaluateDataFrameResponse =
+            execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
+        assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Regression.NAME));
+        assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
+
+        MeanSquaredErrorMetric.Result mseResult = evaluateDataFrameResponse.getMetricByName(MeanSquaredErrorMetric.NAME);
+        assertThat(mseResult.getMetricName(), equalTo(MeanSquaredErrorMetric.NAME));
+        assertThat(mseResult.getError(), closeTo(0.061000000, 1e-9));
     }
 
     private static XContentBuilder defaultMappingForTest() throws IOException {
@@ -1615,6 +1644,28 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
             .source(XContentType.JSON, actualField, Boolean.toString(isTrue), probabilityField, p);
     }
 
+    private static final String actualRegression = "regression_actual";
+    private static final String probabilityRegression = "regression_prob";
+
+    private static XContentBuilder mappingForRegression() throws IOException {
+        return XContentFactory.jsonBuilder().startObject()
+            .startObject("properties")
+            .startObject(actualRegression)
+            .field("type", "double")
+            .endObject()
+            .startObject(probabilityRegression)
+            .field("type", "double")
+            .endObject()
+            .endObject()
+            .endObject();
+    }
+
+    private static IndexRequest docForRegression(String indexName, double act, double p) {
+        return new IndexRequest()
+            .index(indexName)
+            .source(XContentType.JSON, actualRegression, act, probabilityRegression, p);
+    }
+
     private void createIndex(String indexName, XContentBuilder mapping) throws IOException {
         highLevelClient().indices().create(new CreateIndexRequest(indexName).mapping(mapping), RequestOptions.DEFAULT);
     }

+ 11 - 7
client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java

@@ -60,6 +60,8 @@ import org.elasticsearch.client.indexlifecycle.ShrinkAction;
 import org.elasticsearch.client.indexlifecycle.UnfollowAction;
 import org.elasticsearch.client.ml.dataframe.DataFrameAnalysis;
 import org.elasticsearch.client.ml.dataframe.OutlierDetection;
+import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
+import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
 import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
 import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric;
@@ -674,7 +676,7 @@ public class RestHighLevelClientTests extends ESTestCase {
 
     public void testProvidedNamedXContents() {
         List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
-        assertEquals(31, namedXContents.size());
+        assertEquals(34, namedXContents.size());
         Map<Class<?>, Integer> categories = new HashMap<>();
         List<String> names = new ArrayList<>();
         for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
@@ -712,12 +714,14 @@ public class RestHighLevelClientTests extends ESTestCase {
         assertTrue(names.contains(OutlierDetection.NAME.getPreferredName()));
         assertEquals(Integer.valueOf(1), categories.get(SyncConfig.class));
         assertTrue(names.contains(TimeSyncConfig.NAME));
-        assertEquals(Integer.valueOf(1), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class));
-        assertThat(names, hasItems(BinarySoftClassification.NAME));
-        assertEquals(Integer.valueOf(4), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class));
-        assertThat(names, hasItems(AucRocMetric.NAME, PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME));
-        assertEquals(Integer.valueOf(4), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class));
-        assertThat(names, hasItems(AucRocMetric.NAME, PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME));
+        assertEquals(Integer.valueOf(2), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class));
+        assertThat(names, hasItems(BinarySoftClassification.NAME, Regression.NAME));
+        assertEquals(Integer.valueOf(5), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class));
+        assertThat(names,
+            hasItems(AucRocMetric.NAME, PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME, MeanSquaredErrorMetric.NAME));
+        assertEquals(Integer.valueOf(5), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class));
+        assertThat(names,
+            hasItems(AucRocMetric.NAME, PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME, MeanSquaredErrorMetric.NAME));
     }
 
     public void testApiNamingConventions() throws Exception {

+ 4 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/EvaluateDataFrameResponseTests.java

@@ -20,6 +20,7 @@ package org.elasticsearch.client.ml;
 
 import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
+import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetricResultTests;
 import org.elasticsearch.common.xcontent.NamedXContentRegistry;
 import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.test.AbstractXContentTestCase;
@@ -45,6 +46,9 @@ public class EvaluateDataFrameResponseTests extends AbstractXContentTestCase<Eva
         if (randomBoolean()) {
             metrics.add(ConfusionMatrixMetricResultTests.randomResult());
         }
+        if (randomBoolean()) {
+            metrics.add(MeanSquaredErrorMetricResultTests.randomResult());
+        }
         return new EvaluateDataFrameResponse(randomAlphaOfLength(5), metrics);
     }
 

+ 53 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredErrorMetricResultTests.java

@@ -0,0 +1,53 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.dataframe.evaluation.regression;
+
+import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.test.AbstractXContentTestCase;
+
+import java.io.IOException;
+
+public class MeanSquaredErrorMetricResultTests extends AbstractXContentTestCase<MeanSquaredErrorMetric.Result> {
+
+    public static MeanSquaredErrorMetric.Result randomResult() {
+        return new MeanSquaredErrorMetric.Result(randomDouble());
+    }
+
+    @Override
+    protected MeanSquaredErrorMetric.Result createTestInstance() {
+        return randomResult();
+    }
+
+    @Override
+    protected MeanSquaredErrorMetric.Result doParseInstance(XContentParser parser) throws IOException {
+        return MeanSquaredErrorMetric.Result.fromXContent(parser);
+    }
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return true;
+    }
+
+    @Override
+    protected NamedXContentRegistry xContentRegistry() {
+        return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
+    }
+}

+ 49 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredErrorMetricTests.java

@@ -0,0 +1,49 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.dataframe.evaluation.regression;
+
+import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.test.AbstractXContentTestCase;
+
+import java.io.IOException;
+
+public class MeanSquaredErrorMetricTests extends AbstractXContentTestCase<MeanSquaredErrorMetric> {
+
+    @Override
+    protected NamedXContentRegistry xContentRegistry() {
+        return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
+    }
+
+    @Override
+    protected MeanSquaredErrorMetric createTestInstance() {
+        return new MeanSquaredErrorMetric();
+    }
+
+    @Override
+    protected MeanSquaredErrorMetric doParseInstance(XContentParser parser) throws IOException {
+        return MeanSquaredErrorMetric.fromXContent(parser);
+    }
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return true;
+    }
+}

+ 59 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/RegressionTests.java

@@ -0,0 +1,59 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.dataframe.evaluation.regression;
+
+import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.test.AbstractXContentTestCase;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.function.Predicate;
+
+public class RegressionTests extends AbstractXContentTestCase<Regression> {
+
+    @Override
+    protected NamedXContentRegistry xContentRegistry() {
+        return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
+    }
+
+    @Override
+    protected Regression createTestInstance() {
+        return randomBoolean() ?
+            new Regression(randomAlphaOfLength(10), randomAlphaOfLength(10)) :
+            new Regression(randomAlphaOfLength(10), randomAlphaOfLength(10), Collections.singletonList(new MeanSquaredErrorMetric()));
+    }
+
+    @Override
+    protected Regression doParseInstance(XContentParser parser) throws IOException {
+        return Regression.fromXContent(parser);
+    }
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return true;
+    }
+
+    @Override
+    protected Predicate<String> getRandomFieldsExcludeFilter() {
+        // allow unknown fields in the root of the object only
+        return field -> !field.isEmpty();
+    }
+}

+ 14 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java

@@ -8,6 +8,9 @@ package org.elasticsearch.xpack.core.ml.dataframe.evaluation;
 import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.xcontent.NamedXContentRegistry;
 import org.elasticsearch.plugins.spi.NamedXContentProvider;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Regression;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RegressionMetric;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.AucRoc;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ConfusionMatrix;
@@ -28,6 +31,7 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
         // Evaluations
         namedXContent.add(new NamedXContentRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME,
             BinarySoftClassification::fromXContent));
+        namedXContent.add(new NamedXContentRegistry.Entry(Evaluation.class, Regression.NAME, Regression::fromXContent));
 
         // Soft classification metrics
         namedXContent.add(new NamedXContentRegistry.Entry(SoftClassificationMetric.class, AucRoc.NAME, AucRoc::fromXContent));
@@ -36,6 +40,9 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
         namedXContent.add(new NamedXContentRegistry.Entry(SoftClassificationMetric.class, ConfusionMatrix.NAME,
             ConfusionMatrix::fromXContent));
 
+        // Regression metrics
+        namedXContent.add(new NamedXContentRegistry.Entry(RegressionMetric.class, MeanSquaredError.NAME, MeanSquaredError::fromXContent));
+
         return namedXContent;
     }
 
@@ -45,6 +52,7 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
         // Evaluations
         namedWriteables.add(new NamedWriteableRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME.getPreferredName(),
             BinarySoftClassification::new));
+        namedWriteables.add(new NamedWriteableRegistry.Entry(Evaluation.class, Regression.NAME.getPreferredName(), Regression::new));
 
         // Evaluation Metrics
         namedWriteables.add(new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, AucRoc.NAME.getPreferredName(),
@@ -55,6 +63,9 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
             Recall::new));
         namedWriteables.add(new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, ConfusionMatrix.NAME.getPreferredName(),
             ConfusionMatrix::new));
+        namedWriteables.add(new NamedWriteableRegistry.Entry(RegressionMetric.class,
+            MeanSquaredError.NAME.getPreferredName(),
+            MeanSquaredError::new));
 
         // Evaluation Metrics Results
         namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, AucRoc.NAME.getPreferredName(),
@@ -63,6 +74,9 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
             ScoreByThresholdResult::new));
         namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, ConfusionMatrix.NAME.getPreferredName(),
             ConfusionMatrix.Result::new));
+        namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
+            MeanSquaredError.NAME.getPreferredName(),
+            MeanSquaredError.Result::new));
 
         return namedWriteables;
     }

+ 141 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredError.java

@@ -0,0 +1,141 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression;
+
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.xcontent.ObjectParser;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.script.Script;
+import org.elasticsearch.search.aggregations.AggregationBuilder;
+import org.elasticsearch.search.aggregations.AggregationBuilders;
+import org.elasticsearch.search.aggregations.Aggregations;
+import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
+
+import java.io.IOException;
+import java.text.MessageFormat;
+import java.util.Collections;
+import java.util.List;
+import java.util.Locale;
+import java.util.Objects;
+
+/**
+ * Calculates the mean squared error between two known numerical fields.
+ *
+ * equation: mse = 1/n * Σ(y - y´)^2
+ */
+public class MeanSquaredError implements RegressionMetric {
+
+    public static final ParseField NAME = new ParseField("mean_squared_error");
+
+    private static final String PAINLESS_TEMPLATE = "def diff = doc[''{0}''].value - doc[''{1}''].value;return diff * diff;";
+    private static final String AGG_NAME = "regression_" + NAME.getPreferredName();
+
+    private static String buildScript(Object...args) {
+        return new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args);
+    }
+
+    private static final ObjectParser<MeanSquaredError, Void> PARSER =
+        new ObjectParser<>("mean_squared_error", true, MeanSquaredError::new);
+
+    public static MeanSquaredError fromXContent(XContentParser parser) {
+        return PARSER.apply(parser, null);
+    }
+
+    public MeanSquaredError(StreamInput in) {
+
+    }
+
+    public MeanSquaredError() {
+
+    }
+
+    @Override
+    public String getMetricName() {
+        return NAME.getPreferredName();
+    }
+
+    @Override
+    public List<AggregationBuilder> aggs(String actualField, String predictedField) {
+        return Collections.singletonList(AggregationBuilders.avg(AGG_NAME).script(new Script(buildScript(actualField, predictedField))));
+    }
+
+    @Override
+    public EvaluationMetricResult evaluate(Aggregations aggs) {
+        NumericMetricsAggregation.SingleValue value = aggs.get(AGG_NAME);
+        return value == null ? null : new Result(value.value());
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME.getPreferredName();
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        return true;
+    }
+
+    @Override
+    public int hashCode() {
+        // create static hash code from name as there are currently no unique fields per class instance
+        return Objects.hashCode(NAME.getPreferredName());
+    }
+
+    public static class Result implements EvaluationMetricResult {
+
+        private static final String ERROR = "error";
+        private final double error;
+
+        public Result(double error) {
+            this.error = error;
+        }
+
+        public Result(StreamInput in) throws IOException {
+            this.error = in.readDouble();
+        }
+
+        @Override
+        public String getWriteableName() {
+            return NAME.getPreferredName();
+        }
+
+        @Override
+        public String getName() {
+            return NAME.getPreferredName();
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            out.writeDouble(error);
+        }
+
+        @Override
+        public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+            builder.startObject();
+            builder.field(ERROR, error);
+            builder.endObject();
+            return builder;
+        }
+    }
+}

+ 171 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java

@@ -0,0 +1,171 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression;
+
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.search.SearchResponse;
+import org.elasticsearch.common.Nullable;
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.index.query.BoolQueryBuilder;
+import org.elasticsearch.index.query.QueryBuilders;
+import org.elasticsearch.search.aggregations.AggregationBuilder;
+import org.elasticsearch.search.builder.SearchSourceBuilder;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.List;
+import java.util.Objects;
+
+/**
+ * Evaluation of regression results.
+ */
+public class Regression implements Evaluation {
+
+    public static final ParseField NAME = new ParseField("regression");
+
+    private static final ParseField ACTUAL_FIELD = new ParseField("actual_field");
+    private static final ParseField PREDICTED_FIELD = new ParseField("predicted_field");
+    private static final ParseField METRICS = new ParseField("metrics");
+
+    @SuppressWarnings("unchecked")
+    public static final ConstructingObjectParser<Regression, Void> PARSER = new ConstructingObjectParser<>(
+        NAME.getPreferredName(), a -> new Regression((String) a[0], (String) a[1], (List<RegressionMetric>) a[2]));
+
+    static {
+        PARSER.declareString(ConstructingObjectParser.constructorArg(), ACTUAL_FIELD);
+        PARSER.declareString(ConstructingObjectParser.constructorArg(), PREDICTED_FIELD);
+        PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(),
+            (p, c, n) -> p.namedObject(RegressionMetric.class, n, c), METRICS);
+    }
+
+    public static Regression fromXContent(XContentParser parser) {
+        return PARSER.apply(parser, null);
+    }
+
+    /**
+     * The field containing the actual value
+     * The value of this field is assumed to be numeric
+     */
+    private final String actualField;
+
+    /**
+     * The field containing the predicted value
+     * The value of this field is assumed to be numeric
+     */
+    private final String predictedField;
+
+    /**
+     * The list of metrics to calculate
+     */
+    private final List<RegressionMetric> metrics;
+
+    public Regression(String actualField, String predictedField, @Nullable List<RegressionMetric> metrics) {
+        this.actualField = ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD);
+        this.predictedField = ExceptionsHelper.requireNonNull(predictedField, PREDICTED_FIELD);
+        this.metrics = initMetrics(metrics);
+    }
+
+    public Regression(StreamInput in) throws IOException {
+        this.actualField = in.readString();
+        this.predictedField = in.readString();
+        this.metrics = in.readNamedWriteableList(RegressionMetric.class);
+    }
+
+    private static List<RegressionMetric> initMetrics(@Nullable List<RegressionMetric> parsedMetrics) {
+        List<RegressionMetric> metrics = parsedMetrics == null ? defaultMetrics() : parsedMetrics;
+        if (metrics.isEmpty()) {
+            throw ExceptionsHelper.badRequestException("[{}] must have one or more metrics", NAME.getPreferredName());
+        }
+        Collections.sort(metrics, Comparator.comparing(RegressionMetric::getMetricName));
+        return metrics;
+    }
+
+    private static List<RegressionMetric> defaultMetrics() {
+        List<RegressionMetric> defaultMetrics = new ArrayList<>(1);
+        defaultMetrics.add(new MeanSquaredError());
+        return defaultMetrics;
+    }
+
+    @Override
+    public String getName() {
+        return NAME.getPreferredName();
+    }
+
+    @Override
+    public SearchSourceBuilder buildSearch() {
+        BoolQueryBuilder boolQuery = QueryBuilders.boolQuery()
+            .filter(QueryBuilders.existsQuery(actualField))
+            .filter(QueryBuilders.existsQuery(predictedField));
+        SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).query(boolQuery);
+        for (RegressionMetric metric : metrics) {
+            List<AggregationBuilder> aggs = metric.aggs(actualField, predictedField);
+            aggs.forEach(searchSourceBuilder::aggregation);
+        }
+        return searchSourceBuilder;
+    }
+
+    @Override
+    public void evaluate(SearchResponse searchResponse, ActionListener<List<EvaluationMetricResult>> listener) {
+        List<EvaluationMetricResult> results = new ArrayList<>(metrics.size());
+        for (RegressionMetric metric : metrics) {
+            results.add(metric.evaluate(searchResponse.getAggregations()));
+        }
+        listener.onResponse(results);
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME.getPreferredName();
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeString(actualField);
+        out.writeString(predictedField);
+        out.writeNamedWriteableList(metrics);
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        builder.field(ACTUAL_FIELD.getPreferredName(), actualField);
+        builder.field(PREDICTED_FIELD.getPreferredName(), predictedField);
+
+        builder.startObject(METRICS.getPreferredName());
+        for (RegressionMetric metric : metrics) {
+            builder.field(metric.getWriteableName(), metric);
+        }
+        builder.endObject();
+
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        Regression that = (Regression) o;
+        return Objects.equals(that.actualField, this.actualField)
+            && Objects.equals(that.predictedField, this.predictedField)
+            && Objects.equals(that.metrics, this.metrics);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(actualField, predictedField, metrics);
+    }
+}

+ 37 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionMetric.java

@@ -0,0 +1,37 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression;
+
+import org.elasticsearch.common.io.stream.NamedWriteable;
+import org.elasticsearch.common.xcontent.ToXContentObject;
+import org.elasticsearch.search.aggregations.AggregationBuilder;
+import org.elasticsearch.search.aggregations.Aggregations;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
+
+import java.util.List;
+
+public interface RegressionMetric extends ToXContentObject, NamedWriteable {
+
+    /**
+     * Returns the name of the metric (which may differ to the writeable name)
+     */
+    String getMetricName();
+
+    /**
+     * Builds the aggregation that collect required data to compute the metric
+     * @param actualField the field that stores the actual value
+     * @param predictedField the field that stores the predicted value
+     * @return the aggregations required to compute the metric
+     */
+    List<AggregationBuilder> aggs(String actualField, String predictedField);
+
+    /**
+     * Calculates the metric result
+     * @param aggs the aggregations
+     * @return the metric result
+     */
+    EvaluationMetricResult evaluate(Aggregations aggs);
+}

+ 76 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredErrorTests.java

@@ -0,0 +1,76 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression;
+
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.search.aggregations.Aggregations;
+import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
+import org.elasticsearch.test.AbstractSerializingTestCase;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+
+import static org.hamcrest.CoreMatchers.is;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.nullValue;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+public class MeanSquaredErrorTests extends AbstractSerializingTestCase<MeanSquaredError> {
+
+    @Override
+    protected MeanSquaredError doParseInstance(XContentParser parser) throws IOException {
+        return MeanSquaredError.fromXContent(parser);
+    }
+
+    @Override
+    protected MeanSquaredError createTestInstance() {
+        return createRandom();
+    }
+
+    @Override
+    protected Writeable.Reader<MeanSquaredError> instanceReader() {
+        return MeanSquaredError::new;
+    }
+
+    public static MeanSquaredError createRandom() {
+        return new MeanSquaredError();
+    }
+
+    public void testEvaluate() {
+        Aggregations aggs = new Aggregations(Arrays.asList(
+            createSingleMetricAgg("regression_mean_squared_error", 0.8123),
+            createSingleMetricAgg("some_other_single_metric_agg", 0.2377)
+        ));
+
+        MeanSquaredError mse = new MeanSquaredError();
+        EvaluationMetricResult result = mse.evaluate(aggs);
+
+        String expected = "{\"error\":0.8123}";
+        assertThat(Strings.toString(result), equalTo(expected));
+    }
+
+    public void testEvaluate_GivenMissingAggs() {
+        Aggregations aggs = new Aggregations(Collections.singletonList(
+            createSingleMetricAgg("some_other_single_metric_agg", 0.2377)
+        ));
+
+        MeanSquaredError mse = new MeanSquaredError();
+        EvaluationMetricResult result = mse.evaluate(aggs);
+        assertThat(result, is(nullValue()));
+    }
+
+    private static NumericMetricsAggregation.SingleValue createSingleMetricAgg(String name, double value) {
+        NumericMetricsAggregation.SingleValue agg = mock(NumericMetricsAggregation.SingleValue.class);
+        when(agg.getName()).thenReturn(name);
+        when(agg.value()).thenReturn(value);
+        return agg;
+    }
+}

+ 59 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java

@@ -0,0 +1,59 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression;
+
+import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.test.AbstractSerializingTestCase;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
+
+import static org.hamcrest.Matchers.equalTo;
+
+public class RegressionTests extends AbstractSerializingTestCase<Regression> {
+
+    @Override
+    protected NamedWriteableRegistry getNamedWriteableRegistry() {
+        return new NamedWriteableRegistry(new MlEvaluationNamedXContentProvider().getNamedWriteables());
+    }
+
+    @Override
+    protected NamedXContentRegistry xContentRegistry() {
+        return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
+    }
+
+    public static Regression createRandom() {
+        List<RegressionMetric> metrics = Collections.singletonList(MeanSquaredErrorTests.createRandom());
+        return new Regression(randomAlphaOfLength(10), randomAlphaOfLength(10), randomBoolean() ? null : metrics);
+    }
+
+    @Override
+    protected Regression doParseInstance(XContentParser parser) throws IOException {
+        return Regression.fromXContent(parser);
+    }
+
+    @Override
+    protected Regression createTestInstance() {
+        return createRandom();
+    }
+
+    @Override
+    protected Writeable.Reader<Regression> instanceReader() {
+        return Regression::new;
+    }
+
+    public void testConstructor_GivenEmptyMetrics() {
+        ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
+            () -> new Regression("foo", "bar", Collections.emptyList()));
+        assertThat(e.getMessage(), equalTo("[regression] must have one or more metrics"));
+    }
+}

+ 4 - 3
x-pack/plugin/ml/qa/ml-with-security/build.gradle

@@ -75,9 +75,9 @@ integTest.runner  {
     'ml/evaluate_data_frame/Test given missing index',
     'ml/evaluate_data_frame/Test given index does not exist',
     'ml/evaluate_data_frame/Test given missing evaluation',
-    'ml/evaluate_data_frame/Test binary_soft_classifition auc_roc given actual_field is always true',
-    'ml/evaluate_data_frame/Test binary_soft_classifition auc_roc given actual_field is always false',
-    'ml/evaluate_data_frame/Test binary_soft_classification given evaluation with emtpy metrics',
+    'ml/evaluate_data_frame/Test binary_soft_classification auc_roc given actual_field is always true',
+    'ml/evaluate_data_frame/Test binary_soft_classification auc_roc given actual_field is always false',
+    'ml/evaluate_data_frame/Test binary_soft_classification given evaluation with empty metrics',
     'ml/evaluate_data_frame/Test binary_soft_classification given missing actual_field',
     'ml/evaluate_data_frame/Test binary_soft_classification given missing predicted_probability_field',
     'ml/evaluate_data_frame/Test binary_soft_classification given precision with threshold less than zero',
@@ -86,6 +86,7 @@ integTest.runner  {
     'ml/evaluate_data_frame/Test binary_soft_classification given precision with empty thresholds',
     'ml/evaluate_data_frame/Test binary_soft_classification given recall with empty thresholds',
     'ml/evaluate_data_frame/Test binary_soft_classification given confusion_matrix with empty thresholds',
+    'ml/evaluate_data_frame/Test regression given evaluation with empty metrics',
     'ml/delete_job_force/Test cannot force delete a non-existent job',
     'ml/delete_model_snapshot/Test delete snapshot missing snapshotId',
     'ml/delete_model_snapshot/Test delete snapshot missing job_id',

+ 75 - 10
x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml

@@ -8,6 +8,8 @@ setup:
             "is_outlier": false,
             "is_outlier_int": 0,
             "outlier_score": 0.0,
+            "regression_field_act": 10.9,
+            "regression_field_pred": 10.9,
             "all_true_field": true,
             "all_false_field": false
           }
@@ -20,6 +22,8 @@ setup:
             "is_outlier": false,
             "is_outlier_int": 0,
             "outlier_score": 0.2,
+            "regression_field_act": 12.0,
+            "regression_field_pred": 9.9,
             "all_true_field": true,
             "all_false_field": false
           }
@@ -32,6 +36,8 @@ setup:
             "is_outlier": false,
             "is_outlier_int": 0,
             "outlier_score": 0.3,
+            "regression_field_act": 20.9,
+            "regression_field_pred": 5.9,
             "all_true_field": true,
             "all_false_field": false
           }
@@ -44,6 +50,8 @@ setup:
             "is_outlier": true,
             "is_outlier_int": 1,
             "outlier_score": 0.3,
+            "regression_field_act": 11.9,
+            "regression_field_pred": 11.9,
             "all_true_field": true,
             "all_false_field": false
           }
@@ -56,6 +64,8 @@ setup:
             "is_outlier": true,
             "is_outlier_int": 1,
             "outlier_score": 0.4,
+            "regression_field_act": 42.9,
+            "regression_field_pred": 42.9,
             "all_true_field": true,
             "all_false_field": false
           }
@@ -68,6 +78,8 @@ setup:
             "is_outlier": true,
             "is_outlier_int": 1,
             "outlier_score": 0.5,
+            "regression_field_act": 0.42,
+            "regression_field_pred": 0.42,
             "all_true_field": true,
             "all_false_field": false
           }
@@ -80,6 +92,8 @@ setup:
             "is_outlier": true,
             "is_outlier_int": 1,
             "outlier_score": 0.9,
+            "regression_field_act": 1.1235813,
+            "regression_field_pred": 1.12358,
             "all_true_field": true,
             "all_false_field": false
           }
@@ -92,6 +106,8 @@ setup:
             "is_outlier": true,
             "is_outlier_int": 1,
             "outlier_score": 0.95,
+            "regression_field_act": -5.20,
+            "regression_field_pred": -5.1,
             "all_true_field": true,
             "all_false_field": false
           }
@@ -109,7 +125,7 @@ setup:
       indices.refresh: {}
 
 ---
-"Test binary_soft_classifition auc_roc":
+"Test binary_soft_classification auc_roc":
   - do:
       ml.evaluate_data_frame:
         body:  >
@@ -129,7 +145,7 @@ setup:
   - is_false: binary_soft_classification.auc_roc.curve
 
 ---
-"Test binary_soft_classifition auc_roc given actual_field is int":
+"Test binary_soft_classification auc_roc given actual_field is int":
   - do:
       ml.evaluate_data_frame:
         body:  >
@@ -149,7 +165,7 @@ setup:
   - is_false: binary_soft_classification.auc_roc.curve
 
 ---
-"Test binary_soft_classifition auc_roc include curve":
+"Test binary_soft_classification auc_roc include curve":
   - do:
       ml.evaluate_data_frame:
         body:  >
@@ -169,7 +185,7 @@ setup:
   - is_true: binary_soft_classification.auc_roc.curve
 
 ---
-"Test binary_soft_classifition auc_roc given actual_field is always true":
+"Test binary_soft_classification auc_roc given actual_field is always true":
   - do:
       catch: /\[auc_roc\] requires at least one actual_field to have a different value than \[true\]/
       ml.evaluate_data_frame:
@@ -188,7 +204,7 @@ setup:
           }
 
 ---
-"Test binary_soft_classifition auc_roc given actual_field is always false":
+"Test binary_soft_classification auc_roc given actual_field is always false":
   - do:
       catch: /\[auc_roc\] requires at least one actual_field to have the value \[true\]/
       ml.evaluate_data_frame:
@@ -207,7 +223,7 @@ setup:
           }
 
 ---
-"Test binary_soft_classifition precision":
+"Test binary_soft_classification precision":
   - do:
       ml.evaluate_data_frame:
         body:  >
@@ -230,7 +246,7 @@ setup:
           '0.5': 1.0
 
 ---
-"Test binary_soft_classifition recall":
+"Test binary_soft_classification recall":
   - do:
       ml.evaluate_data_frame:
         body:  >
@@ -254,7 +270,7 @@ setup:
           '0.5': 0.6
 
 ---
-"Test binary_soft_classifition confusion_matrix":
+"Test binary_soft_classification confusion_matrix":
   - do:
       ml.evaluate_data_frame:
         body:  >
@@ -290,7 +306,7 @@ setup:
             fn: 2
 
 ---
-"Test binary_soft_classifition default metrics":
+"Test binary_soft_classification default metrics":
   - do:
       ml.evaluate_data_frame:
         body:  >
@@ -356,7 +372,7 @@ setup:
           }
 
 ---
-"Test binary_soft_classification given evaluation with emtpy metrics":
+"Test binary_soft_classification given evaluation with empty metrics":
   - do:
       catch: /\[binary_soft_classification\] must have one or more metrics/
       ml.evaluate_data_frame:
@@ -518,3 +534,52 @@ setup:
               }
             }
           }
+---
+"Test regression given evaluation with empty metrics":
+  - do:
+      catch: /\[regression\] must have one or more metrics/
+      ml.evaluate_data_frame:
+        body: >
+          {
+            "index": "utopia",
+            "evaluation": {
+              "regression": {
+                "actual_field": "regression_field_act",
+                "predicted_field": "regression_field_pred",
+                "metrics": { }
+              }
+            }
+          }
+---
+"Test regression mean_squared_error":
+  - do:
+      ml.evaluate_data_frame:
+        body: >
+          {
+            "index": "utopia",
+            "evaluation": {
+              "regression": {
+                "actual_field": "regression_field_act",
+                "predicted_field": "regression_field_pred",
+                "metrics": { "mean_squared_error": {} }
+              }
+            }
+          }
+
+  - match: { regression.mean_squared_error.error: 28.67749840974834 }
+---
+"Test regression with null metrics":
+  - do:
+      ml.evaluate_data_frame:
+        body: >
+          {
+            "index": "utopia",
+            "evaluation": {
+              "regression": {
+                "actual_field": "regression_field_act",
+                "predicted_field": "regression_field_pred"
+              }
+            }
+          }
+
+  - match: { regression.mean_squared_error.error: 28.67749840974834 }