Browse Source

[ML] better handle empty results when evaluating regression (#45745)

* [ML] better handle empty results when evaluating regression

* adding new failure test to ml_security black list

* fixing equality check for regression results
Benjamin Trent 6 years ago
parent
commit
2202d00aac

+ 14 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredError.java

@@ -69,7 +69,7 @@ public class MeanSquaredError implements RegressionMetric {
     @Override
     public EvaluationMetricResult evaluate(Aggregations aggs) {
         NumericMetricsAggregation.SingleValue value = aggs.get(AGG_NAME);
-        return value == null ? null : new Result(value.value());
+        return value == null ? new Result(0.0) : new Result(value.value());
     }
 
     @Override
@@ -137,5 +137,18 @@ public class MeanSquaredError implements RegressionMetric {
             builder.endObject();
             return builder;
         }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            Result other = (Result)o;
+            return error == other.error;
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hashCode(error);
+        }
     }
 }

+ 14 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquared.java

@@ -79,7 +79,7 @@ public class RSquared implements RegressionMetric {
         ExtendedStats extendedStats = aggs.get(ExtendedStatsAggregationBuilder.NAME + "_actual");
         // extendedStats.getVariance() is the statistical sumOfSquares divided by count
         return residualSumOfSquares == null || extendedStats == null || extendedStats.getCount() == 0 ?
-            null :
+            new Result(0.0) :
             new Result(1 - (residualSumOfSquares.value() / (extendedStats.getVariance() * extendedStats.getCount())));
     }
 
@@ -148,5 +148,18 @@ public class RSquared implements RegressionMetric {
             builder.endObject();
             return builder;
         }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            Result other = (Result)o;
+            return value == other.value;
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hashCode(value);
+        }
     }
 }

+ 6 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java

@@ -121,6 +121,12 @@ public class Regression implements Evaluation {
     @Override
     public void evaluate(SearchResponse searchResponse, ActionListener<List<EvaluationMetricResult>> listener) {
         List<EvaluationMetricResult> results = new ArrayList<>(metrics.size());
+        if (searchResponse.getHits().getTotalHits().value == 0) {
+            listener.onFailure(ExceptionsHelper.badRequestException("No documents found containing both [{}, {}] fields",
+                actualField,
+                predictedField));
+            return;
+        }
         for (RegressionMetric metric : metrics) {
             results.add(metric.evaluate(searchResponse.getAggregations()));
         }

+ 1 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Recall.java

@@ -81,7 +81,7 @@ public class Recall extends AbstractConfusionMatrixMetric {
         for (int i = 0; i < recalls.length; i++) {
             double threshold = thresholds[i];
             Filter tpAgg = aggs.get(aggName(classInfo, threshold, Condition.TP));
-            Filter fnAgg =aggs.get(aggName(classInfo, threshold, Condition.FN));
+            Filter fnAgg = aggs.get(aggName(classInfo, threshold, Condition.FN));
             long tp = tpAgg.getDocCount();
             long fn = fnAgg.getDocCount();
             recalls[i] = tp + fn == 0 ? 0.0 : (double) tp / (tp + fn);

+ 1 - 3
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredErrorTests.java

@@ -17,9 +17,7 @@ import java.io.IOException;
 import java.util.Arrays;
 import java.util.Collections;
 
-import static org.hamcrest.CoreMatchers.is;
 import static org.hamcrest.Matchers.equalTo;
-import static org.hamcrest.Matchers.nullValue;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
 
@@ -64,7 +62,7 @@ public class MeanSquaredErrorTests extends AbstractSerializingTestCase<MeanSquar
 
         MeanSquaredError mse = new MeanSquaredError();
         EvaluationMetricResult result = mse.evaluate(aggs);
-        assertThat(result, is(nullValue()));
+        assertThat(result, equalTo(new MeanSquaredError.Result(0.0)));
     }
 
     private static NumericMetricsAggregation.SingleValue createSingleMetricAgg(String name, double value) {

+ 5 - 6
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquaredTests.java

@@ -18,9 +18,7 @@ import java.io.IOException;
 import java.util.Arrays;
 import java.util.Collections;
 
-import static org.hamcrest.CoreMatchers.is;
 import static org.hamcrest.Matchers.equalTo;
-import static org.hamcrest.Matchers.nullValue;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
 
@@ -70,17 +68,18 @@ public class RSquaredTests extends AbstractSerializingTestCase<RSquared> {
 
         RSquared rSquared = new RSquared();
         EvaluationMetricResult result = rSquared.evaluate(aggs);
-        assertThat(result, is(nullValue()));
+        assertThat(result, equalTo(new RSquared.Result(0.0)));
     }
 
     public void testEvaluate_GivenMissingAggs() {
+        EvaluationMetricResult zeroResult = new RSquared.Result(0.0);
         Aggregations aggs = new Aggregations(Collections.singletonList(
             createSingleMetricAgg("some_other_single_metric_agg", 0.2377)
         ));
 
         RSquared rSquared = new RSquared();
         EvaluationMetricResult result = rSquared.evaluate(aggs);
-        assertThat(result, is(nullValue()));
+        assertThat(result, equalTo(zeroResult));
 
         aggs = new Aggregations(Arrays.asList(
             createSingleMetricAgg("some_other_single_metric_agg", 0.2377),
@@ -88,7 +87,7 @@ public class RSquaredTests extends AbstractSerializingTestCase<RSquared> {
         ));
 
         result = rSquared.evaluate(aggs);
-        assertThat(result, is(nullValue()));
+        assertThat(result, equalTo(zeroResult));
 
         aggs = new Aggregations(Arrays.asList(
             createSingleMetricAgg("some_other_single_metric_agg", 0.2377),
@@ -96,7 +95,7 @@ public class RSquaredTests extends AbstractSerializingTestCase<RSquared> {
         ));
 
         result = rSquared.evaluate(aggs);
-        assertThat(result, is(nullValue()));
+        assertThat(result, equalTo(zeroResult));
     }
 
     private static NumericMetricsAggregation.SingleValue createSingleMetricAgg(String name, double value) {

+ 2 - 0
x-pack/plugin/ml/qa/ml-with-security/build.gradle

@@ -89,6 +89,8 @@ integTest.runner  {
     'ml/evaluate_data_frame/Test binary_soft_classification given recall with empty thresholds',
     'ml/evaluate_data_frame/Test binary_soft_classification given confusion_matrix with empty thresholds',
     'ml/evaluate_data_frame/Test regression given evaluation with empty metrics',
+    'ml/evaluate_data_frame/Test regression given missing actual_field',
+    'ml/evaluate_data_frame/Test regression given missing predicted_field',
     'ml/delete_job_force/Test cannot force delete a non-existent job',
     'ml/delete_model_snapshot/Test delete snapshot missing snapshotId',
     'ml/delete_model_snapshot/Test delete snapshot missing job_id',

+ 31 - 0
x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml

@@ -602,3 +602,34 @@ setup:
 
   - match: { regression.mean_squared_error.error: 28.67749840974834 }
   - match: { regression.r_squared.value: 0.8551031778603486 }
+---
+"Test regression given missing actual_field":
+  - do:
+      catch: /No documents found containing both \[missing, regression_field_pred\] fields/
+      ml.evaluate_data_frame:
+        body:  >
+          {
+            "index": "utopia",
+            "evaluation": {
+              "regression": {
+                "actual_field": "missing",
+                "predicted_field": "regression_field_pred"
+              }
+            }
+          }
+
+---
+"Test regression given missing predicted_field":
+  - do:
+      catch: /No documents found containing both \[regression_field_act, missing\] fields/
+      ml.evaluate_data_frame:
+        body:  >
+          {
+            "index": "utopia",
+            "evaluation": {
+              "regression": {
+                "actual_field": "regression_field_act",
+                "predicted_field": "missing"
+              }
+            }
+          }