浏览代码

[ML] Allow training_percent to be any positive double up to hundred (#61977)

This changes the valid range of `training_percent` for regression and
classification from [1, 100] to (0, 100].

Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
Dimitris Athanasiou 5 年之前
父节点
当前提交
b4fcb77e20

+ 2 - 2
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java

@@ -170,8 +170,8 @@ public class Classification implements DataFrameAnalysis {
         if (numTopClasses != null && (numTopClasses < 0 || numTopClasses > 1000)) {
         if (numTopClasses != null && (numTopClasses < 0 || numTopClasses > 1000)) {
             throw ExceptionsHelper.badRequestException("[{}] must be an integer in [0, 1000]", NUM_TOP_CLASSES.getPreferredName());
             throw ExceptionsHelper.badRequestException("[{}] must be an integer in [0, 1000]", NUM_TOP_CLASSES.getPreferredName());
         }
         }
-        if (trainingPercent != null && (trainingPercent < 1.0 || trainingPercent > 100.0)) {
-            throw ExceptionsHelper.badRequestException("[{}] must be a double in [1, 100]", TRAINING_PERCENT.getPreferredName());
+        if (trainingPercent != null && (trainingPercent <= 0.0 || trainingPercent > 100.0)) {
+            throw ExceptionsHelper.badRequestException("[{}] must be a positive double in (0, 100]", TRAINING_PERCENT.getPreferredName());
         }
         }
         this.dependentVariable = ExceptionsHelper.requireNonNull(dependentVariable, DEPENDENT_VARIABLE);
         this.dependentVariable = ExceptionsHelper.requireNonNull(dependentVariable, DEPENDENT_VARIABLE);
         this.boostedTreeParams = ExceptionsHelper.requireNonNull(boostedTreeParams, BoostedTreeParams.NAME);
         this.boostedTreeParams = ExceptionsHelper.requireNonNull(boostedTreeParams, BoostedTreeParams.NAME);

+ 2 - 2
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java

@@ -130,8 +130,8 @@ public class Regression implements DataFrameAnalysis {
                       @Nullable LossFunction lossFunction,
                       @Nullable LossFunction lossFunction,
                       @Nullable Double lossFunctionParameter,
                       @Nullable Double lossFunctionParameter,
                       @Nullable List<PreProcessor> featureProcessors) {
                       @Nullable List<PreProcessor> featureProcessors) {
-        if (trainingPercent != null && (trainingPercent < 1.0 || trainingPercent > 100.0)) {
-            throw ExceptionsHelper.badRequestException("[{}] must be a double in [1, 100]", TRAINING_PERCENT.getPreferredName());
+        if (trainingPercent != null && (trainingPercent <= 0.0 || trainingPercent > 100.0)) {
+            throw ExceptionsHelper.badRequestException("[{}] must be a positive double in (0, 100]", TRAINING_PERCENT.getPreferredName());
         }
         }
         this.dependentVariable = ExceptionsHelper.requireNonNull(dependentVariable, DEPENDENT_VARIABLE);
         this.dependentVariable = ExceptionsHelper.requireNonNull(dependentVariable, DEPENDENT_VARIABLE);
         this.boostedTreeParams = ExceptionsHelper.requireNonNull(boostedTreeParams, BoostedTreeParams.NAME);
         this.boostedTreeParams = ExceptionsHelper.requireNonNull(boostedTreeParams, BoostedTreeParams.NAME);

+ 11 - 5
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java

@@ -93,7 +93,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
         Classification.ClassAssignmentObjective classAssignmentObjective = randomBoolean() ?
         Classification.ClassAssignmentObjective classAssignmentObjective = randomBoolean() ?
             null : randomFrom(Classification.ClassAssignmentObjective.values());
             null : randomFrom(Classification.ClassAssignmentObjective.values());
         Integer numTopClasses = randomBoolean() ? null : randomIntBetween(0, 1000);
         Integer numTopClasses = randomBoolean() ? null : randomIntBetween(0, 1000);
-        Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true);
+        Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(0.0, 100.0, false);
         Long randomizeSeed = randomBoolean() ? null : randomLong();
         Long randomizeSeed = randomBoolean() ? null : randomLong();
         return new Classification(dependentVariableName, boostedTreeParams, predictionFieldName, classAssignmentObjective,
         return new Classification(dependentVariableName, boostedTreeParams, predictionFieldName, classAssignmentObjective,
             numTopClasses, trainingPercent, randomizeSeed,
             numTopClasses, trainingPercent, randomizeSeed,
@@ -198,19 +198,25 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
         }
         }
     }
     }
 
 
+    public void testConstructor_GivenTrainingPercentIsZero() {
+        ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
+            () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 0.0, randomLong(), null));
+
+        assertThat(e.getMessage(), equalTo("[training_percent] must be a positive double in (0, 100]"));
+    }
 
 
-    public void testConstructor_GivenTrainingPercentIsLessThanOne() {
+    public void testConstructor_GivenTrainingPercentIsLessThanZero() {
         ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
         ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
-            () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 0.999, randomLong(), null));
+            () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, -1.0, randomLong(), null));
 
 
-        assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
+        assertThat(e.getMessage(), equalTo("[training_percent] must be a positive double in (0, 100]"));
     }
     }
 
 
     public void testConstructor_GivenTrainingPercentIsGreaterThan100() {
     public void testConstructor_GivenTrainingPercentIsGreaterThan100() {
         ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
         ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
             () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0001, randomLong(), null));
             () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0001, randomLong(), null));
 
 
-        assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
+        assertThat(e.getMessage(), equalTo("[training_percent] must be a positive double in (0, 100]"));
     }
     }
 
 
     public void testConstructor_GivenNumTopClassesIsLessThanZero() {
     public void testConstructor_GivenNumTopClassesIsLessThanZero() {

+ 12 - 5
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java

@@ -84,7 +84,7 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
     private static Regression createRandom(BoostedTreeParams boostedTreeParams) {
     private static Regression createRandom(BoostedTreeParams boostedTreeParams) {
         String dependentVariableName = randomAlphaOfLength(10);
         String dependentVariableName = randomAlphaOfLength(10);
         String predictionFieldName = randomBoolean() ? null : randomAlphaOfLength(10);
         String predictionFieldName = randomBoolean() ? null : randomAlphaOfLength(10);
-        Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true);
+        Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(0.0, 100.0, false);
         Long randomizeSeed = randomBoolean() ? null : randomLong();
         Long randomizeSeed = randomBoolean() ? null : randomLong();
         Regression.LossFunction lossFunction = randomBoolean() ? null : randomFrom(Regression.LossFunction.values());
         Regression.LossFunction lossFunction = randomBoolean() ? null : randomFrom(Regression.LossFunction.values());
         Double lossFunctionParameter = randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, false);
         Double lossFunctionParameter = randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, false);
@@ -196,11 +196,18 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
         }
         }
     }
     }
 
 
-    public void testConstructor_GivenTrainingPercentIsLessThanOne() {
+    public void testConstructor_GivenTrainingPercentIsZero() {
         ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
         ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
-            () -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 0.999, randomLong(), Regression.LossFunction.MSE, null, null));
+            () -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 0.0, randomLong(), Regression.LossFunction.MSE, null, null));
 
 
-        assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
+        assertThat(e.getMessage(), equalTo("[training_percent] must be a positive double in (0, 100]"));
+    }
+
+    public void testConstructor_GivenTrainingPercentIsLessThanZero() {
+        ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
+            () -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", -0.01, randomLong(), Regression.LossFunction.MSE, null, null));
+
+        assertThat(e.getMessage(), equalTo("[training_percent] must be a positive double in (0, 100]"));
     }
     }
 
 
     public void testConstructor_GivenTrainingPercentIsGreaterThan100() {
     public void testConstructor_GivenTrainingPercentIsGreaterThan100() {
@@ -208,7 +215,7 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
             () -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0001, randomLong(), Regression.LossFunction.MSE, null, null));
             () -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0001, randomLong(), Regression.LossFunction.MSE, null, null));
 
 
 
 
-        assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
+        assertThat(e.getMessage(), equalTo("[training_percent] must be a positive double in (0, 100]"));
     }
     }
 
 
     public void testConstructor_GivenLossFunctionParameterIsZero() {
     public void testConstructor_GivenLossFunctionParameterIsZero() {

+ 2 - 2
x-pack/plugin/ml/qa/ml-with-security/build.gradle

@@ -79,7 +79,7 @@ yamlRestTest {
     'ml/data_frame_analytics_crud/Test put regression given max_trees is greater than 2k',
     'ml/data_frame_analytics_crud/Test put regression given max_trees is greater than 2k',
     'ml/data_frame_analytics_crud/Test put regression given feature_bag_fraction is negative',
     'ml/data_frame_analytics_crud/Test put regression given feature_bag_fraction is negative',
     'ml/data_frame_analytics_crud/Test put regression given feature_bag_fraction is greater than one',
     'ml/data_frame_analytics_crud/Test put regression given feature_bag_fraction is greater than one',
-    'ml/data_frame_analytics_crud/Test put regression given training_percent is less than one',
+    'ml/data_frame_analytics_crud/Test put regression given training_percent is less than zero',
     'ml/data_frame_analytics_crud/Test put regression given training_percent is greater than hundred',
     'ml/data_frame_analytics_crud/Test put regression given training_percent is greater than hundred',
     'ml/data_frame_analytics_crud/Test put regression given loss_function_parameter is zero',
     'ml/data_frame_analytics_crud/Test put regression given loss_function_parameter is zero',
     'ml/data_frame_analytics_crud/Test put regression given loss_function_parameter is negative',
     'ml/data_frame_analytics_crud/Test put regression given loss_function_parameter is negative',
@@ -94,7 +94,7 @@ yamlRestTest {
     'ml/data_frame_analytics_crud/Test put classification given feature_bag_fraction is greater than one',
     'ml/data_frame_analytics_crud/Test put classification given feature_bag_fraction is greater than one',
     'ml/data_frame_analytics_crud/Test put classification given num_top_classes is less than zero',
     'ml/data_frame_analytics_crud/Test put classification given num_top_classes is less than zero',
     'ml/data_frame_analytics_crud/Test put classification given num_top_classes is greater than 1k',
     'ml/data_frame_analytics_crud/Test put classification given num_top_classes is greater than 1k',
-    'ml/data_frame_analytics_crud/Test put classification given training_percent is less than one',
+    'ml/data_frame_analytics_crud/Test put classification given training_percent is less than zero',
     'ml/data_frame_analytics_crud/Test put classification given training_percent is greater than hundred',
     'ml/data_frame_analytics_crud/Test put classification given training_percent is greater than hundred',
     'ml/estimate_model_memory/Test missing overall cardinality',
     'ml/estimate_model_memory/Test missing overall cardinality',
     'ml/estimate_model_memory/Test missing max bucket cardinality',
     'ml/estimate_model_memory/Test missing max bucket cardinality',

+ 8 - 8
x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml

@@ -1522,10 +1522,10 @@ setup:
           }
           }
 
 
 ---
 ---
-"Test put regression given training_percent is less than one":
+"Test put regression given training_percent is less than zero":
 
 
   - do:
   - do:
-      catch: /\[training_percent\] must be a double in \[1, 100\]/
+      catch: /\[training_percent\] must be a positive double in \(0, 100\]/
       ml.put_data_frame_analytics:
       ml.put_data_frame_analytics:
         id: "regression-training-percent-is-less-than-one"
         id: "regression-training-percent-is-less-than-one"
         body: >
         body: >
@@ -1539,7 +1539,7 @@ setup:
             "analysis": {
             "analysis": {
               "regression": {
               "regression": {
                 "dependent_variable": "foo",
                 "dependent_variable": "foo",
-                "training_percent": 0.999
+                "training_percent": -1.0
               }
               }
             }
             }
           }
           }
@@ -1548,7 +1548,7 @@ setup:
 "Test put regression given training_percent is greater than hundred":
 "Test put regression given training_percent is greater than hundred":
 
 
   - do:
   - do:
-      catch: /\[training_percent\] must be a double in \[1, 100\]/
+      catch: /\[training_percent\] must be a positive double in \(0, 100\]/
       ml.put_data_frame_analytics:
       ml.put_data_frame_analytics:
         id: "regression-training-percent-is-greater-than-hundred"
         id: "regression-training-percent-is-greater-than-hundred"
         body: >
         body: >
@@ -1914,10 +1914,10 @@ setup:
           }
           }
 
 
 ---
 ---
-"Test put classification given training_percent is less than one":
+"Test put classification given training_percent is less than zero":
 
 
   - do:
   - do:
-      catch: /\[training_percent\] must be a double in \[1, 100\]/
+      catch: /\[training_percent\] must be a positive double in \(0, 100\]/
       ml.put_data_frame_analytics:
       ml.put_data_frame_analytics:
         id: "classification-training-percent-is-less-than-one"
         id: "classification-training-percent-is-less-than-one"
         body: >
         body: >
@@ -1931,7 +1931,7 @@ setup:
             "analysis": {
             "analysis": {
               "classification": {
               "classification": {
                 "dependent_variable": "foo",
                 "dependent_variable": "foo",
-                "training_percent": 0.999
+                "training_percent": -1.0
               }
               }
             }
             }
           }
           }
@@ -1940,7 +1940,7 @@ setup:
 "Test put classification given training_percent is greater than hundred":
 "Test put classification given training_percent is greater than hundred":
 
 
   - do:
   - do:
-      catch: /\[training_percent\] must be a double in \[1, 100\]/
+      catch: /\[training_percent\] must be a positive double in \(0, 100\]/
       ml.put_data_frame_analytics:
       ml.put_data_frame_analytics:
         id: "classification-training-percent-is-greater-than-hundred"
         id: "classification-training-percent-is-greater-than-hundred"
         body: >
         body: >