Преглед на файлове

Implement MSLE (MeanSquaredLogarithmicError) evaluation metric for regression analysis (#58684)

Przemysław Witek преди 5 години
родител
ревизия
dfa06240fc
променени са 20 файла, в които са добавени 763 реда и са изтрити 23 реда
  1. 9 0
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java
  2. 2 5
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredErrorMetric.java
  3. 142 0
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredLogarithmicErrorMetric.java
  4. 2 5
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/RSquaredMetric.java
  5. 11 2
      client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java
  6. 6 3
      client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java
  7. 10 3
      client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java
  8. 53 0
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredLogarithmicErrorMetricResultTests.java
  9. 49 0
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredLogarithmicErrorMetricTests.java
  10. 3 0
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/RegressionTests.java
  11. 6 3
      docs/java-rest/high-level/ml/evaluate-data-frame.asciidoc
  12. 4 0
      docs/reference/ml/df-analytics/apis/evaluate-dfanalytics.asciidoc
  13. 10 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java
  14. 7 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredError.java
  15. 195 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredLogarithmicError.java
  16. 7 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquared.java
  17. 2 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameActionResponseTests.java
  18. 68 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredLogarithmicErrorTests.java
  19. 155 0
      x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionEvaluationIT.java
  20. 22 0
      x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml

+ 9 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java

@@ -22,6 +22,7 @@ import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyM
 import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
 import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
+import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
 import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
@@ -97,6 +98,10 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
                 EvaluationMetric.class,
                 new ParseField(registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME)),
                 MeanSquaredErrorMetric::fromXContent),
+            new NamedXContentRegistry.Entry(
+                EvaluationMetric.class,
+                new ParseField(registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME)),
+                MeanSquaredLogarithmicErrorMetric::fromXContent),
             new NamedXContentRegistry.Entry(
                 EvaluationMetric.class,
                 new ParseField(registeredMetricName(Regression.NAME, RSquaredMetric.NAME)),
@@ -140,6 +145,10 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
                 EvaluationMetric.Result.class,
                 new ParseField(registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME)),
                 MeanSquaredErrorMetric.Result::fromXContent),
+            new NamedXContentRegistry.Entry(
+                EvaluationMetric.Result.class,
+                new ParseField(registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME)),
+                MeanSquaredLogarithmicErrorMetric.Result::fromXContent),
             new NamedXContentRegistry.Entry(
                 EvaluationMetric.Result.class,
                 new ParseField(registeredMetricName(Regression.NAME, RSquaredMetric.NAME)),

+ 2 - 5
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredErrorMetric.java

@@ -40,16 +40,13 @@ public class MeanSquaredErrorMetric implements EvaluationMetric {
 
     public static final String NAME = "mean_squared_error";
 
-    private static final ObjectParser<MeanSquaredErrorMetric, Void> PARSER =
-        new ObjectParser<>("mean_squared_error", true, MeanSquaredErrorMetric::new);
+    private static final ObjectParser<MeanSquaredErrorMetric, Void> PARSER = new ObjectParser<>(NAME, true, MeanSquaredErrorMetric::new);
 
     public static MeanSquaredErrorMetric fromXContent(XContentParser parser) {
         return PARSER.apply(parser, null);
     }
 
-    public MeanSquaredErrorMetric() {
-
-    }
+    public MeanSquaredErrorMetric() {}
 
     @Override
     public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {

+ 142 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredLogarithmicErrorMetric.java

@@ -0,0 +1,142 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.dataframe.evaluation.regression;
+
+import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
+import org.elasticsearch.common.Nullable;
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ToXContent;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+
+import java.io.IOException;
+import java.util.Objects;
+
+import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
+import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
+
+/**
+ * Calculates the mean squared error between two known numerical fields.
+ *
+ * equation: msle = 1/n * Σ(log(y + offset) - log(y´ + offset))^2
+ * where offset is used to make sure the argument to log function is always positive
+ */
+public class MeanSquaredLogarithmicErrorMetric implements EvaluationMetric {
+
+    public static final String NAME = "mean_squared_logarithmic_error";
+
+    public static final ParseField OFFSET = new ParseField("offset");
+
+    private static final ConstructingObjectParser<MeanSquaredLogarithmicErrorMetric, Void> PARSER =
+        new ConstructingObjectParser<>(NAME, true, args -> new MeanSquaredLogarithmicErrorMetric((Double) args[0]));
+
+    static {
+        PARSER.declareDouble(optionalConstructorArg(), OFFSET);
+    }
+
+    public static MeanSquaredLogarithmicErrorMetric fromXContent(XContentParser parser) {
+        return PARSER.apply(parser, null);
+    }
+
+    private final Double offset;
+
+    public MeanSquaredLogarithmicErrorMetric(@Nullable Double offset) {
+        this.offset = offset;
+    }
+
+    @Override
+    public String getName() {
+        return NAME;
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
+        builder.startObject();
+        if (offset != null) {
+            builder.field(OFFSET.getPreferredName(), offset);
+        }
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        MeanSquaredLogarithmicErrorMetric that = (MeanSquaredLogarithmicErrorMetric) o;
+        return Objects.equals(this.offset, that.offset);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(offset);
+    }
+
+    public static class Result implements EvaluationMetric.Result  {
+
+        public static final ParseField ERROR = new ParseField("error");
+        private final double error;
+
+        public static Result fromXContent(XContentParser parser) {
+            return PARSER.apply(parser, null);
+        }
+
+        private static final ConstructingObjectParser<Result, Void> PARSER =
+            new ConstructingObjectParser<>("mean_squared_error_result", true, args -> new Result((double) args[0]));
+
+        static {
+            PARSER.declareDouble(constructorArg(), ERROR);
+        }
+
+        public Result(double error) {
+            this.error = error;
+        }
+
+        @Override
+        public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
+            builder.startObject();
+            builder.field(ERROR.getPreferredName(), error);
+            builder.endObject();
+            return builder;
+        }
+
+        public double getError() {
+            return error;
+        }
+
+        @Override
+        public String getMetricName() {
+            return NAME;
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            Result that = (Result) o;
+            return Objects.equals(that.error, this.error);
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(error);
+        }
+    }
+}

+ 2 - 5
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/RSquaredMetric.java

@@ -42,16 +42,13 @@ public class RSquaredMetric implements EvaluationMetric {
 
     public static final String NAME = "r_squared";
 
-    private static final ObjectParser<RSquaredMetric, Void> PARSER =
-        new ObjectParser<>("r_squared", true, RSquaredMetric::new);
+    private static final ObjectParser<RSquaredMetric, Void> PARSER = new ObjectParser<>(NAME, true, RSquaredMetric::new);
 
     public static RSquaredMetric fromXContent(XContentParser parser) {
         return PARSER.apply(parser, null);
     }
 
-    public RSquaredMetric() {
-
-    }
+    public RSquaredMetric() {}
 
     @Override
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {

+ 11 - 2
client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java

@@ -142,6 +142,7 @@ import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyM
 import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
 import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
+import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
 import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
@@ -1852,17 +1853,25 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
             new EvaluateDataFrameRequest(
                 regressionIndex,
                 null,
-                new Regression(actualRegression, predictedRegression, new MeanSquaredErrorMetric(), new RSquaredMetric()));
+                new Regression(
+                    actualRegression,
+                    predictedRegression,
+                    new MeanSquaredErrorMetric(), new MeanSquaredLogarithmicErrorMetric(1.0), new RSquaredMetric()));
 
         EvaluateDataFrameResponse evaluateDataFrameResponse =
             execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
         assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Regression.NAME));
-        assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(2));
+        assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(3));
 
         MeanSquaredErrorMetric.Result mseResult = evaluateDataFrameResponse.getMetricByName(MeanSquaredErrorMetric.NAME);
         assertThat(mseResult.getMetricName(), equalTo(MeanSquaredErrorMetric.NAME));
         assertThat(mseResult.getError(), closeTo(0.061000000, 1e-9));
 
+        MeanSquaredLogarithmicErrorMetric.Result msleResult =
+            evaluateDataFrameResponse.getMetricByName(MeanSquaredLogarithmicErrorMetric.NAME);
+        assertThat(msleResult.getMetricName(), equalTo(MeanSquaredLogarithmicErrorMetric.NAME));
+        assertThat(msleResult.getError(), closeTo(0.02759231770210426, 1e-9));
+
         RSquaredMetric.Result rSquaredResult = evaluateDataFrameResponse.getMetricByName(RSquaredMetric.NAME);
         assertThat(rSquaredResult.getMetricName(), equalTo(RSquaredMetric.NAME));
         assertThat(rSquaredResult.getValue(), closeTo(-5.1000000000000005, 1e-9));

+ 6 - 3
client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java

@@ -61,6 +61,7 @@ import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyM
 import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
 import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
+import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
 import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
@@ -701,7 +702,7 @@ public class RestHighLevelClientTests extends ESTestCase {
 
     public void testProvidedNamedXContents() {
         List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
-        assertEquals(64, namedXContents.size());
+        assertEquals(66, namedXContents.size());
         Map<Class<?>, Integer> categories = new HashMap<>();
         List<String> names = new ArrayList<>();
         for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
@@ -748,7 +749,7 @@ public class RestHighLevelClientTests extends ESTestCase {
         assertTrue(names.contains(TimeSyncConfig.NAME));
         assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class));
         assertThat(names, hasItems(BinarySoftClassification.NAME, Classification.NAME, Regression.NAME));
-        assertEquals(Integer.valueOf(10), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class));
+        assertEquals(Integer.valueOf(11), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class));
         assertThat(names,
             hasItems(
                 registeredMetricName(BinarySoftClassification.NAME, AucRocMetric.NAME),
@@ -762,8 +763,9 @@ public class RestHighLevelClientTests extends ESTestCase {
                     Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME),
                 registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME),
                 registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME),
+                registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME),
                 registeredMetricName(Regression.NAME, RSquaredMetric.NAME)));
-        assertEquals(Integer.valueOf(10), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class));
+        assertEquals(Integer.valueOf(11), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class));
         assertThat(names,
             hasItems(
                 registeredMetricName(BinarySoftClassification.NAME, AucRocMetric.NAME),
@@ -777,6 +779,7 @@ public class RestHighLevelClientTests extends ESTestCase {
                     Classification.NAME, org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME),
                 registeredMetricName(Classification.NAME, MulticlassConfusionMatrixMetric.NAME),
                 registeredMetricName(Regression.NAME, MeanSquaredErrorMetric.NAME),
+                registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME),
                 registeredMetricName(Regression.NAME, RSquaredMetric.NAME)));
         assertEquals(Integer.valueOf(4), categories.get(org.elasticsearch.client.ml.inference.preprocessing.PreProcessor.class));
         assertThat(names, hasItems(FrequencyEncoding.NAME, OneHotEncoding.NAME, TargetMeanEncoding.NAME, CustomWordEmbedding.NAME));

+ 10 - 3
client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java

@@ -161,6 +161,7 @@ import org.elasticsearch.client.ml.dataframe.evaluation.classification.Multiclas
 import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.ActualClass;
 import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric.PredictedClass;
 import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
+import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
@@ -3570,7 +3571,8 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
                     "predicted_value", // <3>
                     // Evaluation metrics // <4>
                     new MeanSquaredErrorMetric(), // <5>
-                    new RSquaredMetric()); // <6>
+                    new MeanSquaredLogarithmicErrorMetric(1.0), // <6>
+                    new RSquaredMetric()); // <7>
             // end::evaluate-data-frame-evaluation-regression
 
             EvaluateDataFrameRequest request = new EvaluateDataFrameRequest(indexName, null, evaluation);
@@ -3580,11 +3582,16 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
             MeanSquaredErrorMetric.Result meanSquaredErrorResult = response.getMetricByName(MeanSquaredErrorMetric.NAME); // <1>
             double meanSquaredError = meanSquaredErrorResult.getError(); // <2>
 
-            RSquaredMetric.Result rSquaredResult = response.getMetricByName(RSquaredMetric.NAME); // <3>
-            double rSquared = rSquaredResult.getValue(); // <4>
+            MeanSquaredLogarithmicErrorMetric.Result meanSquaredLogarithmicErrorResult =
+                response.getMetricByName(MeanSquaredLogarithmicErrorMetric.NAME); // <3>
+            double meanSquaredLogarithmicError = meanSquaredLogarithmicErrorResult.getError(); // <4>
+
+            RSquaredMetric.Result rSquaredResult = response.getMetricByName(RSquaredMetric.NAME); // <5>
+            double rSquared = rSquaredResult.getValue(); // <6>
             // end::evaluate-data-frame-results-regression
 
             assertThat(meanSquaredError, closeTo(0.021, 1e-3));
+            assertThat(meanSquaredLogarithmicError, closeTo(0.003, 1e-3));
             assertThat(rSquared, closeTo(0.941, 1e-3));
         }
     }

+ 53 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredLogarithmicErrorMetricResultTests.java

@@ -0,0 +1,53 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.dataframe.evaluation.regression;
+
+import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.test.AbstractXContentTestCase;
+
+import java.io.IOException;
+
+public class MeanSquaredLogarithmicErrorMetricResultTests extends AbstractXContentTestCase<MeanSquaredLogarithmicErrorMetric.Result> {
+
+    public static MeanSquaredLogarithmicErrorMetric.Result randomResult() {
+        return new MeanSquaredLogarithmicErrorMetric.Result(randomDouble());
+    }
+
+    @Override
+    protected MeanSquaredLogarithmicErrorMetric.Result createTestInstance() {
+        return randomResult();
+    }
+
+    @Override
+    protected MeanSquaredLogarithmicErrorMetric.Result doParseInstance(XContentParser parser) throws IOException {
+        return MeanSquaredLogarithmicErrorMetric.Result.fromXContent(parser);
+    }
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return true;
+    }
+
+    @Override
+    protected NamedXContentRegistry xContentRegistry() {
+        return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
+    }
+}

+ 49 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/MeanSquaredLogarithmicErrorMetricTests.java

@@ -0,0 +1,49 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.dataframe.evaluation.regression;
+
+import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.test.AbstractXContentTestCase;
+
+import java.io.IOException;
+
+public class MeanSquaredLogarithmicErrorMetricTests extends AbstractXContentTestCase<MeanSquaredLogarithmicErrorMetric> {
+
+    @Override
+    protected NamedXContentRegistry xContentRegistry() {
+        return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
+    }
+
+    @Override
+    protected MeanSquaredLogarithmicErrorMetric createTestInstance() {
+        return new MeanSquaredLogarithmicErrorMetric(randomBoolean() ? randomDouble() : null);
+    }
+
+    @Override
+    protected MeanSquaredLogarithmicErrorMetric doParseInstance(XContentParser parser) throws IOException {
+        return MeanSquaredLogarithmicErrorMetric.fromXContent(parser);
+    }
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return true;
+    }
+}

+ 3 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/RegressionTests.java

@@ -41,6 +41,9 @@ public class RegressionTests extends AbstractXContentTestCase<Regression> {
         if (randomBoolean()) {
             metrics.add(new MeanSquaredErrorMetric());
         }
+        if (randomBoolean()) {
+            metrics.add(new MeanSquaredLogarithmicErrorMetricTests().createTestInstance());
+        }
         if (randomBoolean()) {
             metrics.add(new RSquaredMetric());
         }

+ 6 - 3
docs/java-rest/high-level/ml/evaluate-data-frame.asciidoc

@@ -68,7 +68,8 @@ include-tagged::{doc-tests-file}[{api}-evaluation-regression]
 <3> Name of the field in the index. Its value denotes the predicted (as per some ML algorithm) value for the example.
 <4> The remaining parameters are the metrics to be calculated based on the two fields described above
 <5> https://en.wikipedia.org/wiki/Mean_squared_error[Mean squared error]
-<6> https://en.wikipedia.org/wiki/Coefficient_of_determination[R squared]
+<6> Mean squared logarithmic error
+<7> https://en.wikipedia.org/wiki/Coefficient_of_determination[R squared]
 
 include::../execution.asciidoc[]
 
@@ -123,5 +124,7 @@ include-tagged::{doc-tests-file}[{api}-results-regression]
 
 <1> Fetching mean squared error metric by name
 <2> Fetching the actual mean squared error value
-<3> Fetching R squared metric by name
-<4> Fetching the actual R squared value
+<3> Fetching mean squared logarithmic error metric by name
+<4> Fetching the actual mean squared logarithmic error value
+<5> Fetching R squared metric by name
+<6> Fetching the actual R squared value

+ 4 - 0
docs/reference/ml/df-analytics/apis/evaluate-dfanalytics.asciidoc

@@ -130,6 +130,10 @@ which outputs a prediction of values.
     (Optional, object) Average squared difference between the predicted values and the actual (`ground truth`) value.
     For more information, read https://en.wikipedia.org/wiki/Mean_squared_error[this wiki article].
 
+  `mean_squared_logarithmic_error`:::
+    (Optional, object) Average squared difference between the logarithm of the predicted values and the logarithm of the actual
+    (`ground truth`) value.
+
   `r_squared`:::
     (Optional, object) Proportion of the variance in the dependent variable that is predictable from the independent variables.
     For more information, read https://en.wikipedia.org/wiki/Coefficient_of_determination[this wiki article].

+ 10 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java

@@ -13,6 +13,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accur
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicError;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Regression;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.AucRoc;
@@ -95,6 +96,9 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
             new NamedXContentRegistry.Entry(EvaluationMetric.class,
                 new ParseField(registeredMetricName(Regression.NAME, MeanSquaredError.NAME)),
                 MeanSquaredError::fromXContent),
+            new NamedXContentRegistry.Entry(EvaluationMetric.class,
+                new ParseField(registeredMetricName(Regression.NAME, MeanSquaredLogarithmicError.NAME)),
+                MeanSquaredLogarithmicError::fromXContent),
             new NamedXContentRegistry.Entry(EvaluationMetric.class,
                 new ParseField(registeredMetricName(Regression.NAME, RSquared.NAME)),
                 RSquared::fromXContent)
@@ -144,6 +148,9 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
             new NamedWriteableRegistry.Entry(EvaluationMetric.class,
                 registeredMetricName(Regression.NAME, MeanSquaredError.NAME),
                 MeanSquaredError::new),
+            new NamedWriteableRegistry.Entry(EvaluationMetric.class,
+                registeredMetricName(Regression.NAME, MeanSquaredLogarithmicError.NAME),
+                MeanSquaredLogarithmicError::new),
             new NamedWriteableRegistry.Entry(EvaluationMetric.class,
                 registeredMetricName(Regression.NAME, RSquared.NAME),
                 RSquared::new),
@@ -175,6 +182,9 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider
             new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
                 registeredMetricName(Regression.NAME, MeanSquaredError.NAME),
                 MeanSquaredError.Result::new),
+            new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
+                registeredMetricName(Regression.NAME, MeanSquaredLogarithmicError.NAME),
+                MeanSquaredLogarithmicError.Result::new),
             new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
                 registeredMetricName(Regression.NAME, RSquared.NAME),
                 RSquared.Result::new)

+ 7 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredError.java

@@ -40,7 +40,9 @@ public class MeanSquaredError implements EvaluationMetric {
 
     public static final ParseField NAME = new ParseField("mean_squared_error");
 
-    private static final String PAINLESS_TEMPLATE = "def diff = doc[''{0}''].value - doc[''{1}''].value;return diff * diff;";
+    private static final String PAINLESS_TEMPLATE =
+        "def diff = doc[''{0}''].value - doc[''{1}''].value;" +
+        "return diff * diff;";
     private static final String AGG_NAME = "regression_" + NAME.getPreferredName();
 
     private static String buildScript(Object...args) {
@@ -141,6 +143,10 @@ public class MeanSquaredError implements EvaluationMetric {
             return NAME.getPreferredName();
         }
 
+        public double getError() {
+            return error;
+        }
+
         @Override
         public void writeTo(StreamOutput out) throws IOException {
             out.writeDouble(error);

+ 195 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredLogarithmicError.java

@@ -0,0 +1,195 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression;
+
+import org.elasticsearch.common.Nullable;
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.collect.Tuple;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.script.Script;
+import org.elasticsearch.search.aggregations.AggregationBuilder;
+import org.elasticsearch.search.aggregations.AggregationBuilders;
+import org.elasticsearch.search.aggregations.Aggregations;
+import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
+import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
+
+import java.io.IOException;
+import java.text.MessageFormat;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.Locale;
+import java.util.Objects;
+import java.util.Optional;
+
+import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
+import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
+
+/**
+ * Calculates the mean squared error between two known numerical fields.
+ *
+ * equation: msle = 1/n * Σ(log(y + offset) - log(y´ + offset))^2
+ * where offset is used to make sure the argument to log function is always positive
+ */
+public class MeanSquaredLogarithmicError implements EvaluationMetric {
+
+    public static final ParseField NAME = new ParseField("mean_squared_logarithmic_error");
+
+    public static final ParseField OFFSET = new ParseField("offset");
+    private static final double DEFAULT_OFFSET = 1.0;
+
+    private static final String PAINLESS_TEMPLATE =
+        "def offset = {2};" +
+        "def diff = Math.log(doc[''{0}''].value + offset) - Math.log(doc[''{1}''].value + offset);" +
+        "return diff * diff;";
+    private static final String AGG_NAME = "regression_" + NAME.getPreferredName();
+
+    private static String buildScript(Object...args) {
+        return new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args);
+    }
+
+    private static final ConstructingObjectParser<MeanSquaredLogarithmicError, Void> PARSER =
+        new ConstructingObjectParser<>(NAME.getPreferredName(), true, args -> new MeanSquaredLogarithmicError((Double) args[0]));
+
+    static {
+        PARSER.declareDouble(optionalConstructorArg(), OFFSET);
+    }
+
+    public static MeanSquaredLogarithmicError fromXContent(XContentParser parser) {
+        return PARSER.apply(parser, null);
+    }
+
+    private final double offset;
+    private EvaluationMetricResult result;
+
+    public MeanSquaredLogarithmicError(StreamInput in) throws IOException {
+        this.offset = in.readDouble();
+    }
+
+    public MeanSquaredLogarithmicError(@Nullable Double offset) {
+        this.offset = offset != null ? offset : DEFAULT_OFFSET;
+    }
+
+    @Override
+    public String getName() {
+        return NAME.getPreferredName();
+    }
+
+    @Override
+    public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters,
+                                                                                  String actualField,
+                                                                                  String predictedField) {
+        if (result != null) {
+            return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
+        }
+        return Tuple.tuple(
+            Arrays.asList(AggregationBuilders.avg(AGG_NAME).script(new Script(buildScript(actualField, predictedField, offset)))),
+            Collections.emptyList());
+    }
+
+    @Override
+    public void process(Aggregations aggs) {
+        NumericMetricsAggregation.SingleValue value = aggs.get(AGG_NAME);
+        result = value == null ? new Result(0.0) : new Result(value.value());
+    }
+
+    @Override
+    public Optional<EvaluationMetricResult> getResult() {
+        return Optional.ofNullable(result);
+    }
+
+    @Override
+    public String getWriteableName() {
+        return registeredMetricName(Regression.NAME, NAME);
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeDouble(offset);
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        builder.field(OFFSET.getPreferredName(), offset);
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        MeanSquaredLogarithmicError that = (MeanSquaredLogarithmicError) o;
+        return this.offset == that.offset;
+    }
+
+    @Override
+    public int hashCode() {
+        return Double.hashCode(offset);
+    }
+
+    public static class Result implements EvaluationMetricResult {
+
+        private static final String ERROR = "error";
+        private final double error;
+
+        public Result(double error) {
+            this.error = error;
+        }
+
+        public Result(StreamInput in) throws IOException {
+            this.error = in.readDouble();
+        }
+
+        @Override
+        public String getWriteableName() {
+            return registeredMetricName(Regression.NAME, NAME);
+        }
+
+        @Override
+        public String getMetricName() {
+            return NAME.getPreferredName();
+        }
+
+        public double getError() {
+            return error;
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            out.writeDouble(error);
+        }
+
+        @Override
+        public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+            builder.startObject();
+            builder.field(ERROR, error);
+            builder.endObject();
+            return builder;
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            Result other = (Result)o;
+            return error == other.error;
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hashCode(error);
+        }
+    }
+}

+ 7 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquared.java

@@ -45,7 +45,9 @@ public class RSquared implements EvaluationMetric {
 
     public static final ParseField NAME = new ParseField("r_squared");
 
-    private static final String PAINLESS_TEMPLATE = "def diff = doc[''{0}''].value - doc[''{1}''].value;return diff * diff;";
+    private static final String PAINLESS_TEMPLATE =
+        "def diff = doc[''{0}''].value - doc[''{1}''].value;" +
+        "return diff * diff;";
     private static final String SS_RES = "residual_sum_of_squares";
 
     private static String buildScript(Object... args) {
@@ -156,6 +158,10 @@ public class RSquared implements EvaluationMetric {
             return NAME.getPreferredName();
         }
 
+        public double getValue() {
+            return value;
+        }
+
         @Override
         public void writeTo(StreamOutput out) throws IOException {
             out.writeDouble(value);

+ 2 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameActionResponseTests.java

@@ -16,6 +16,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Multi
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.PrecisionResultTests;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.RecallResultTests;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicError;
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared;
 
 import java.util.List;
@@ -37,6 +38,7 @@ public class EvaluateDataFrameActionResponseTests extends AbstractWireSerializin
                 RecallResultTests.createRandom(),
                 MulticlassConfusionMatrixResultTests.createRandom(),
                 new MeanSquaredError.Result(randomDouble()),
+                new MeanSquaredLogarithmicError.Result(randomDouble()),
                 new RSquared.Result(randomDouble()));
         return new Response(evaluationName, randomSubsetOf(metrics));
     }

+ 68 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredLogarithmicErrorTests.java

@@ -0,0 +1,68 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression;
+
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.search.aggregations.Aggregations;
+import org.elasticsearch.test.AbstractSerializingTestCase;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+
+import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MockAggregations.mockSingleValue;
+import static org.hamcrest.Matchers.equalTo;
+
+public class MeanSquaredLogarithmicErrorTests extends AbstractSerializingTestCase<MeanSquaredLogarithmicError> {
+
+    @Override
+    protected MeanSquaredLogarithmicError doParseInstance(XContentParser parser) throws IOException {
+        return MeanSquaredLogarithmicError.fromXContent(parser);
+    }
+
+    @Override
+    protected MeanSquaredLogarithmicError createTestInstance() {
+        return createRandom();
+    }
+
+    @Override
+    protected Writeable.Reader<MeanSquaredLogarithmicError> instanceReader() {
+        return MeanSquaredLogarithmicError::new;
+    }
+
+    public static MeanSquaredLogarithmicError createRandom() {
+        return new MeanSquaredLogarithmicError(randomBoolean() ? randomDoubleBetween(0.0, 1000.0, false) : null);
+    }
+
+    public void testEvaluate() {
+        Aggregations aggs = new Aggregations(Arrays.asList(
+            mockSingleValue("regression_mean_squared_logarithmic_error", 0.8123),
+            mockSingleValue("some_other_single_metric_agg", 0.2377)
+        ));
+
+        MeanSquaredLogarithmicError msle = new MeanSquaredLogarithmicError((Double) null);
+        msle.process(aggs);
+
+        EvaluationMetricResult result = msle.getResult().get();
+        String expected = "{\"error\":0.8123}";
+        assertThat(Strings.toString(result), equalTo(expected));
+    }
+
+    public void testEvaluate_GivenMissingAggs() {
+        Aggregations aggs = new Aggregations(Collections.singletonList(
+            mockSingleValue("some_other_single_metric_agg", 0.2377)
+        ));
+
+        MeanSquaredLogarithmicError msle = new MeanSquaredLogarithmicError((Double) null);
+        msle.process(aggs);
+
+        EvaluationMetricResult result = msle.getResult().get();
+        assertThat(result, equalTo(new MeanSquaredLogarithmicError.Result(0.0)));
+    }
+}

+ 155 - 0
x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionEvaluationIT.java

@@ -0,0 +1,155 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.ml.integration;
+
+import org.elasticsearch.action.bulk.BulkRequestBuilder;
+import org.elasticsearch.action.bulk.BulkResponse;
+import org.elasticsearch.action.index.IndexRequest;
+import org.elasticsearch.action.support.WriteRequest;
+import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicError;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared;
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Regression;
+import org.junit.After;
+import org.junit.Before;
+
+import java.util.List;
+
+import static java.util.stream.Collectors.toList;
+import static org.hamcrest.Matchers.closeTo;
+import static org.hamcrest.Matchers.contains;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.hasSize;
+
+public class RegressionEvaluationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
+
+    private static final String HOUSES_DATA_INDEX = "test-evaluate-houses-index";
+
+    private static final String PRICE_FIELD = "price";
+    private static final String PRICE_PREDICTION_FIELD = "price_prediction";
+
+    @Before
+    public void setup() {
+        createHousesIndex(HOUSES_DATA_INDEX);
+        indexHousesData(HOUSES_DATA_INDEX);
+    }
+
+    @After
+    public void cleanup() {
+        cleanUp();
+    }
+
+    public void testEvaluate_DefaultMetrics() {
+        EvaluateDataFrameAction.Response evaluateDataFrameResponse =
+            evaluateDataFrame(HOUSES_DATA_INDEX, new Regression(PRICE_FIELD, PRICE_PREDICTION_FIELD, null));
+
+        assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Regression.NAME.getPreferredName()));
+        assertThat(
+            evaluateDataFrameResponse.getMetrics().stream().map(EvaluationMetricResult::getMetricName).collect(toList()),
+            contains(MeanSquaredError.NAME.getPreferredName(), RSquared.NAME.getPreferredName()));
+    }
+
+    public void testEvaluate_AllMetrics() {
+        EvaluateDataFrameAction.Response evaluateDataFrameResponse =
+            evaluateDataFrame(
+                HOUSES_DATA_INDEX,
+                new Regression(
+                    PRICE_FIELD,
+                    PRICE_PREDICTION_FIELD,
+                    List.of(new MeanSquaredError(), new MeanSquaredLogarithmicError((Double) null), new RSquared())));
+
+        assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Regression.NAME.getPreferredName()));
+        assertThat(
+            evaluateDataFrameResponse.getMetrics().stream().map(EvaluationMetricResult::getMetricName).collect(toList()),
+            contains(
+                MeanSquaredError.NAME.getPreferredName(),
+                MeanSquaredLogarithmicError.NAME.getPreferredName(),
+                RSquared.NAME.getPreferredName()));
+    }
+
+    public void testEvaluate_MeanSquaredError() {
+        EvaluateDataFrameAction.Response evaluateDataFrameResponse =
+            evaluateDataFrame(HOUSES_DATA_INDEX, new Regression(PRICE_FIELD, PRICE_PREDICTION_FIELD, List.of(new MeanSquaredError())));
+
+        assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Regression.NAME.getPreferredName()));
+        assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
+
+        MeanSquaredError.Result mseResult = (MeanSquaredError.Result) evaluateDataFrameResponse.getMetrics().get(0);
+        assertThat(mseResult.getMetricName(), equalTo(MeanSquaredError.NAME.getPreferredName()));
+        assertThat(mseResult.getError(), equalTo(1000000.0));
+    }
+
+    public void testEvaluate_MeanSquaredLogarithmicError() {
+        EvaluateDataFrameAction.Response evaluateDataFrameResponse =
+            evaluateDataFrame(
+                HOUSES_DATA_INDEX,
+                new Regression(PRICE_FIELD, PRICE_PREDICTION_FIELD, List.of(new MeanSquaredLogarithmicError((Double) null))));
+
+        assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Regression.NAME.getPreferredName()));
+        assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
+
+        MeanSquaredLogarithmicError.Result msleResult = (MeanSquaredLogarithmicError.Result) evaluateDataFrameResponse.getMetrics().get(0);
+        assertThat(msleResult.getMetricName(), equalTo(MeanSquaredLogarithmicError.NAME.getPreferredName()));
+        assertThat(msleResult.getError(), closeTo(Math.pow(Math.log(1001), 2), 10E-6));
+    }
+
+    public void testEvaluate_RSquared() {
+        EvaluateDataFrameAction.Response evaluateDataFrameResponse =
+            evaluateDataFrame(HOUSES_DATA_INDEX, new Regression(PRICE_FIELD, PRICE_PREDICTION_FIELD, List.of(new RSquared())));
+
+        assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Regression.NAME.getPreferredName()));
+        assertThat(evaluateDataFrameResponse.getMetrics(), hasSize(1));
+
+        RSquared.Result rSquaredResult = (RSquared.Result) evaluateDataFrameResponse.getMetrics().get(0);
+        assertThat(rSquaredResult.getMetricName(), equalTo(RSquared.NAME.getPreferredName()));
+        assertThat(rSquaredResult.getValue(), equalTo(0.0));
+    }
+
+    private static void createHousesIndex(String indexName) {
+        client().admin().indices().prepareCreate(indexName)
+            .setMapping(
+                PRICE_FIELD, "type=double",
+                PRICE_PREDICTION_FIELD, "type=double")
+            .get();
+    }
+
+    private static void indexHousesData(String indexName) {
+        BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
+            .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
+        for (int i = 0; i < 100; i++) {
+            bulkRequestBuilder.add(
+                new IndexRequest(indexName)
+                    .source(
+                        PRICE_FIELD, 1000,
+                        PRICE_PREDICTION_FIELD, 0));
+        }
+        BulkResponse bulkResponse = bulkRequestBuilder.get();
+        if (bulkResponse.hasFailures()) {
+            fail("Failed to index data: " + bulkResponse.buildFailureMessage());
+        }
+    }
+}

+ 22 - 0
x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml

@@ -847,6 +847,26 @@ setup:
           }
 
   - match: { regression.mean_squared_error.error: 28.67749840974834 }
+  - is_false: regression.mean_squared_logarithmic_error.value
+  - is_false: regression.r_squared.value
+---
+"Test regression mean_squared_logarithmic_error":
+  - do:
+      ml.evaluate_data_frame:
+        body: >
+          {
+            "index": "utopia",
+            "evaluation": {
+              "regression": {
+                "actual_field": "regression_field_act",
+                "predicted_field": "regression_field_pred",
+                "metrics": { "mean_squared_logarithmic_error": { "offset": 6.0 } }
+              }
+            }
+          }
+
+  - match: { regression.mean_squared_logarithmic_error.error: 0.08680568028334916 }
+  - is_false: regression.mean_squared_error.value
   - is_false: regression.r_squared.value
 ---
 "Test regression r_squared":
@@ -865,6 +885,7 @@ setup:
           }
   - match: { regression.r_squared.value: 0.8551031778603486 }
   - is_false: regression.mean_squared_error
+  - is_false: regression.mean_squared_logarithmic_error.value
 ---
 "Test regression with null metrics":
   - do:
@@ -882,6 +903,7 @@ setup:
 
   - match: { regression.mean_squared_error.error: 28.67749840974834 }
   - match: { regression.r_squared.value: 0.8551031778603486 }
+  - is_false: regression.mean_squared_logarithmic_error.value
 ---
 "Test regression given missing actual_field":
   - do: