Browse Source

[ML] Add early stopping DFA configuration parameter (#68099)

The PR adds early_stopping_enabled optional data frame analysis configuration parameter. The enhancement was already described in elastic/ml-cpp#1676 and so I mark it here as non-issue.
Valeriy Khakhutskyy 4 năm trước cách đây
mục cha
commit
78368428b3
22 tập tin đã thay đổi với 257 bổ sung104 xóa
  1. 25 5
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java
  2. 26 5
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Regression.java
  3. 3 0
      client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java
  4. 2 0
      client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java
  5. 1 0
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java
  6. 1 0
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/RegressionTests.java
  7. 2 0
      docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc
  8. 8 0
      docs/reference/ml/df-analytics/apis/put-dfanalytics.asciidoc
  9. 8 0
      docs/reference/ml/ml-shared.asciidoc
  10. 27 4
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java
  11. 27 5
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java
  12. 2 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java
  13. 6 0
      x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/config_index_mappings.json
  14. 8 4
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java
  15. 35 26
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java
  16. 36 38
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java
  17. 8 4
      x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java
  18. 2 1
      x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/DataFrameAnalysisCustomFeatureIT.java
  19. 3 0
      x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ExplainDataFrameAnalyticsIT.java
  20. 13 5
      x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java
  21. 2 1
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java
  22. 12 6
      x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml

+ 25 - 5
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java

@@ -63,6 +63,7 @@ public class Classification implements DataFrameAnalysis {
     static final ParseField SOFT_TREE_DEPTH_TOLERANCE = new ParseField("soft_tree_depth_tolerance");
     static final ParseField DOWNSAMPLE_FACTOR = new ParseField("downsample_factor");
     static final ParseField MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER = new ParseField("max_optimization_rounds_per_hyperparameter");
+    static final ParseField EARLY_STOPPING_ENABLED = new ParseField("early_stopping_enabled");
 
     @SuppressWarnings("unchecked")
     private static final ConstructingObjectParser<Classification, Void> PARSER =
@@ -88,7 +89,8 @@ public class Classification implements DataFrameAnalysis {
                 (Double) a[15],
                 (Double) a[16],
                 (Double) a[17],
-                (Integer) a[18]
+                (Integer) a[18],
+                (Boolean) a[19]
             ));
 
     static {
@@ -115,6 +117,7 @@ public class Classification implements DataFrameAnalysis {
         PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), SOFT_TREE_DEPTH_TOLERANCE);
         PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), DOWNSAMPLE_FACTOR);
         PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER);
+        PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), EARLY_STOPPING_ENABLED);
     }
 
     private final String dependentVariable;
@@ -136,6 +139,7 @@ public class Classification implements DataFrameAnalysis {
     private final Double softTreeDepthTolerance;
     private final Double downsampleFactor;
     private final Integer maxOptimizationRoundsPerHyperparameter;
+    private final Boolean earlyStoppingEnabled;
 
     private Classification(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
                            @Nullable Integer maxTrees, @Nullable Double featureBagFraction,
@@ -144,7 +148,7 @@ public class Classification implements DataFrameAnalysis {
                            @Nullable ClassAssignmentObjective classAssignmentObjective, @Nullable List<PreProcessor> featureProcessors,
                            @Nullable Double alpha, @Nullable Double etaGrowthRatePerTree, @Nullable Double softTreeDepthLimit,
                            @Nullable Double softTreeDepthTolerance, @Nullable Double downsampleFactor,
-                           @Nullable Integer maxOptimizationRoundsPerHyperparameter) {
+                           @Nullable Integer maxOptimizationRoundsPerHyperparameter, @Nullable Boolean earlyStoppingEnabled) {
         this.dependentVariable = Objects.requireNonNull(dependentVariable);
         this.lambda = lambda;
         this.gamma = gamma;
@@ -164,6 +168,7 @@ public class Classification implements DataFrameAnalysis {
         this.softTreeDepthTolerance = softTreeDepthTolerance;
         this.downsampleFactor = downsampleFactor;
         this.maxOptimizationRoundsPerHyperparameter = maxOptimizationRoundsPerHyperparameter;
+        this.earlyStoppingEnabled = earlyStoppingEnabled;
     }
 
     @Override
@@ -247,6 +252,10 @@ public class Classification implements DataFrameAnalysis {
         return maxOptimizationRoundsPerHyperparameter;
     }
 
+    public Boolean getEarlyStoppingEnable() {
+        return earlyStoppingEnabled;
+    }
+
     @Override
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
         builder.startObject();
@@ -305,6 +314,9 @@ public class Classification implements DataFrameAnalysis {
         if (maxOptimizationRoundsPerHyperparameter != null) {
             builder.field(MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER.getPreferredName(), maxOptimizationRoundsPerHyperparameter);
         }
+        if (earlyStoppingEnabled != null) {
+            builder.field(EARLY_STOPPING_ENABLED.getPreferredName(), earlyStoppingEnabled);
+        }
         builder.endObject();
         return builder;
     }
@@ -313,7 +325,8 @@ public class Classification implements DataFrameAnalysis {
     public int hashCode() {
         return Objects.hash(dependentVariable, lambda, gamma, eta, maxTrees, featureBagFraction, numTopFeatureImportanceValues,
             predictionFieldName, trainingPercent, randomizeSeed, numTopClasses, classAssignmentObjective, featureProcessors, alpha,
-            etaGrowthRatePerTree, softTreeDepthLimit, softTreeDepthTolerance, downsampleFactor, maxOptimizationRoundsPerHyperparameter);
+            etaGrowthRatePerTree, softTreeDepthLimit, softTreeDepthTolerance, downsampleFactor, maxOptimizationRoundsPerHyperparameter,
+            earlyStoppingEnabled);
     }
 
     @Override
@@ -339,7 +352,8 @@ public class Classification implements DataFrameAnalysis {
             && Objects.equals(softTreeDepthLimit, that.softTreeDepthLimit)
             && Objects.equals(softTreeDepthTolerance, that.softTreeDepthTolerance)
             && Objects.equals(downsampleFactor, that.downsampleFactor)
-            && Objects.equals(maxOptimizationRoundsPerHyperparameter, that.maxOptimizationRoundsPerHyperparameter);
+            && Objects.equals(maxOptimizationRoundsPerHyperparameter, that.maxOptimizationRoundsPerHyperparameter)
+            && Objects.equals(earlyStoppingEnabled, that.earlyStoppingEnabled);
     }
 
     @Override
@@ -380,6 +394,7 @@ public class Classification implements DataFrameAnalysis {
         private Double softTreeDepthTolerance;
         private Double downsampleFactor;
         private Integer maxOptimizationRoundsPerHyperparameter;
+        private Boolean earlyStoppingEnabled;
 
         private Builder(String dependentVariable) {
             this.dependentVariable = Objects.requireNonNull(dependentVariable);
@@ -475,11 +490,16 @@ public class Classification implements DataFrameAnalysis {
             return this;
         }
 
+        public Builder setEarlyStoppingEnabled(Boolean earlyStoppingEnabled) {
+            this.earlyStoppingEnabled = earlyStoppingEnabled;
+            return this;
+        }
+
         public Classification build() {
             return new Classification(dependentVariable, lambda, gamma, eta, maxTrees, featureBagFraction,
                 numTopFeatureImportanceValues, predictionFieldName, trainingPercent, numTopClasses, randomizeSeed,
                 classAssignmentObjective, featureProcessors, alpha, etaGrowthRatePerTree, softTreeDepthLimit, softTreeDepthTolerance,
-                downsampleFactor, maxOptimizationRoundsPerHyperparameter);
+                downsampleFactor, maxOptimizationRoundsPerHyperparameter, earlyStoppingEnabled);
         }
     }
 }

+ 26 - 5
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Regression.java

@@ -65,6 +65,7 @@ public class Regression implements DataFrameAnalysis {
     static final ParseField SOFT_TREE_DEPTH_TOLERANCE = new ParseField("soft_tree_depth_tolerance");
     static final ParseField DOWNSAMPLE_FACTOR = new ParseField("downsample_factor");
     static final ParseField MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER = new ParseField("max_optimization_rounds_per_hyperparameter");
+    static final ParseField EARLY_STOPPING_ENABLED = new ParseField("early_stopping_enabled");
 
     @SuppressWarnings("unchecked")
     private static final ConstructingObjectParser<Regression, Void> PARSER =
@@ -90,7 +91,8 @@ public class Regression implements DataFrameAnalysis {
                 (Double) a[15],
                 (Double) a[16],
                 (Double) a[17],
-                (Integer) a[18]
+                (Integer) a[18],
+                (Boolean) a[19]
             ));
 
     static {
@@ -116,6 +118,7 @@ public class Regression implements DataFrameAnalysis {
         PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), SOFT_TREE_DEPTH_TOLERANCE);
         PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), DOWNSAMPLE_FACTOR);
         PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER);
+        PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), EARLY_STOPPING_ENABLED);
     }
 
     private final String dependentVariable;
@@ -137,6 +140,7 @@ public class Regression implements DataFrameAnalysis {
     private final Double softTreeDepthTolerance;
     private final Double downsampleFactor;
     private final Integer maxOptimizationRoundsPerHyperparameter;
+    private final Boolean earlyStoppingEnabled;
 
     private Regression(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
                        @Nullable Integer maxTrees, @Nullable Double featureBagFraction,
@@ -144,7 +148,8 @@ public class Regression implements DataFrameAnalysis {
                        @Nullable Double trainingPercent, @Nullable Long randomizeSeed, @Nullable LossFunction lossFunction,
                        @Nullable Double lossFunctionParameter, @Nullable List<PreProcessor> featureProcessors, @Nullable Double alpha,
                        @Nullable Double etaGrowthRatePerTree, @Nullable Double softTreeDepthLimit, @Nullable Double softTreeDepthTolerance,
-                       @Nullable Double downsampleFactor, @Nullable Integer maxOptimizationRoundsPerHyperparameter) {
+                       @Nullable Double downsampleFactor, @Nullable Integer maxOptimizationRoundsPerHyperparameter,
+                       @Nullable Boolean earlyStoppingEnabled) {
         this.dependentVariable = Objects.requireNonNull(dependentVariable);
         this.lambda = lambda;
         this.gamma = gamma;
@@ -164,6 +169,7 @@ public class Regression implements DataFrameAnalysis {
         this.softTreeDepthTolerance = softTreeDepthTolerance;
         this.downsampleFactor = downsampleFactor;
         this.maxOptimizationRoundsPerHyperparameter = maxOptimizationRoundsPerHyperparameter;
+        this.earlyStoppingEnabled = earlyStoppingEnabled;
     }
 
     @Override
@@ -247,6 +253,10 @@ public class Regression implements DataFrameAnalysis {
         return maxOptimizationRoundsPerHyperparameter;
     }
 
+    public Boolean getEarlyStoppingEnabled() {
+        return earlyStoppingEnabled;
+    }
+
     @Override
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
         builder.startObject();
@@ -305,6 +315,9 @@ public class Regression implements DataFrameAnalysis {
         if (maxOptimizationRoundsPerHyperparameter != null) {
             builder.field(MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER.getPreferredName(), maxOptimizationRoundsPerHyperparameter);
         }
+        if (earlyStoppingEnabled != null) {
+            builder.field(EARLY_STOPPING_ENABLED.getPreferredName(), earlyStoppingEnabled);
+        }
         builder.endObject();
         return builder;
     }
@@ -313,7 +326,8 @@ public class Regression implements DataFrameAnalysis {
     public int hashCode() {
         return Objects.hash(dependentVariable, lambda, gamma, eta, maxTrees, featureBagFraction, numTopFeatureImportanceValues,
             predictionFieldName, trainingPercent, randomizeSeed, lossFunction, lossFunctionParameter, featureProcessors, alpha,
-            etaGrowthRatePerTree, softTreeDepthLimit, softTreeDepthTolerance, downsampleFactor, maxOptimizationRoundsPerHyperparameter);
+            etaGrowthRatePerTree, softTreeDepthLimit, softTreeDepthTolerance, downsampleFactor, maxOptimizationRoundsPerHyperparameter,
+            earlyStoppingEnabled);
     }
 
     @Override
@@ -339,7 +353,8 @@ public class Regression implements DataFrameAnalysis {
             && Objects.equals(softTreeDepthLimit, that.softTreeDepthLimit)
             && Objects.equals(softTreeDepthTolerance, that.softTreeDepthTolerance)
             && Objects.equals(downsampleFactor, that.downsampleFactor)
-            && Objects.equals(maxOptimizationRoundsPerHyperparameter, that.maxOptimizationRoundsPerHyperparameter);
+            && Objects.equals(maxOptimizationRoundsPerHyperparameter, that.maxOptimizationRoundsPerHyperparameter)
+            && Objects.equals(earlyStoppingEnabled, that.earlyStoppingEnabled);
     }
 
     @Override
@@ -367,6 +382,7 @@ public class Regression implements DataFrameAnalysis {
         private Double softTreeDepthTolerance;
         private Double downsampleFactor;
         private Integer maxOptimizationRoundsPerHyperparameter;
+        private Boolean earlyStoppingEnabled;
 
         private Builder(String dependentVariable) {
             this.dependentVariable = Objects.requireNonNull(dependentVariable);
@@ -462,11 +478,16 @@ public class Regression implements DataFrameAnalysis {
             return this;
         }
 
+        public Builder setEarlyStoppingEnabled(Boolean earlyStoppingEnabled) {
+            this.earlyStoppingEnabled = earlyStoppingEnabled;
+            return this;
+        }
+
         public Regression build() {
             return new Regression(dependentVariable, lambda, gamma, eta, maxTrees, featureBagFraction,
                 numTopFeatureImportanceValues, predictionFieldName, trainingPercent, randomizeSeed, lossFunction, lossFunctionParameter,
                 featureProcessors, alpha, etaGrowthRatePerTree, softTreeDepthLimit, softTreeDepthTolerance, downsampleFactor,
-                maxOptimizationRoundsPerHyperparameter);
+                maxOptimizationRoundsPerHyperparameter, earlyStoppingEnabled);
         }
     }
 

+ 3 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java

@@ -1366,6 +1366,8 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
                 .setSoftTreeDepthTolerance(0.1)
                 .setDownsampleFactor(0.5)
                 .setMaxOptimizationRoundsPerHyperparameter(3)
+                .setMaxOptimizationRoundsPerHyperparameter(3)
+                .setEarlyStoppingEnabled(false)
                 .build())
             .setDescription("this is a regression")
             .build();
@@ -1417,6 +1419,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
                 .setSoftTreeDepthTolerance(0.1)
                 .setDownsampleFactor(0.5)
                 .setMaxOptimizationRoundsPerHyperparameter(3)
+                .setEarlyStoppingEnabled(false)
                 .build())
             .setDescription("this is a classification")
             .build();

+ 2 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java

@@ -3059,6 +3059,7 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
                 .setSoftTreeDepthTolerance(1.0) // <17>
                 .setDownsampleFactor(0.5) // <18>
                 .setMaxOptimizationRoundsPerHyperparameter(3) // <19>
+                .setEarlyStoppingEnabled(true) // <20>
                 .build();
             // end::put-data-frame-analytics-classification
 
@@ -3084,6 +3085,7 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
                 .setSoftTreeDepthTolerance(1.0) // <17>
                 .setDownsampleFactor(0.5) // <18>
                 .setMaxOptimizationRoundsPerHyperparameter(3) // <19>
+                .setEarlyStoppingEnabled(true) // <20>
                 .build();
             // end::put-data-frame-analytics-regression
 

+ 1 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java

@@ -60,6 +60,7 @@ public class ClassificationTests extends AbstractXContentTestCase<Classification
             .setSoftTreeDepthTolerance(randomBoolean() ? null : randomDoubleBetween(0.01, Double.MAX_VALUE, true))
             .setDownsampleFactor(randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false))
             .setMaxOptimizationRoundsPerHyperparameter(randomBoolean() ? null : randomIntBetween(0, 20))
+            .setEarlyStoppingEnabled(randomBoolean() ? null : randomBoolean())
             .build();
     }
 

+ 1 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/RegressionTests.java

@@ -59,6 +59,7 @@ public class RegressionTests extends AbstractXContentTestCase<Regression> {
             .setSoftTreeDepthTolerance(randomBoolean() ? null : randomDoubleBetween(0.01, Double.MAX_VALUE, true))
             .setDownsampleFactor(randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false))
             .setMaxOptimizationRoundsPerHyperparameter(randomBoolean() ? null : randomIntBetween(0, 20))
+            .setEarlyStoppingEnabled(randomBoolean() ? null : randomBoolean())
             .build();
     }
 

+ 2 - 0
docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc

@@ -134,6 +134,7 @@ include-tagged::{doc-tests-file}[{api}-classification]
 <17> The soft tree depth tolerance. Controls how much the soft tree depth limit is respected. A double greater than or equal to 0.01.
 <18> The amount by which to downsample the data for stochastic gradient estimates. A double in (0, 1.0].
 <19> The maximum number of optimisation rounds we use for hyperparameter optimisation per parameter. An integer in [0, 20].
+<20> Whether to enable early stopping to finish training process if it is not finding better models.
 
 ===== Regression
 
@@ -164,6 +165,7 @@ fields. Note, automatic categorical {ml-docs}/ml-feature-encoding.html[feature e
 <17> The soft tree depth tolerance. Controls how much the soft tree depth limit is respected. A double greater than or equal to 0.01.
 <18> The amount by which to downsample the data for stochastic gradient estimates. A double in (0, 1.0].
 <19> The maximum number of optimisation rounds we use for hyperparameter optimisation per parameter. An integer in [0, 20].
+<20> Whether to enable early stopping to finish training process if it is not finding better models.
 
 ==== Analyzed fields
 

+ 8 - 0
docs/reference/ml/df-analytics/apis/put-dfanalytics.asciidoc

@@ -117,6 +117,10 @@ different values in this field.
 (Optional, double)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=dfas-downsample-factor]
 
+`early_stopping_enabled`::::
+(Optional, Boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=dfas-early-stopping-enabled]
+
 `eta`::::
 (Optional, double)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=eta]
@@ -359,6 +363,10 @@ The data type of the field must be numeric.
 (Optional, double)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=dfas-downsample-factor]
 
+`early_stopping_enabled`::::
+(Optional, Boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=dfas-early-stopping-enabled]
+
 `eta`::::
 (Optional, double)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=eta]

+ 8 - 0
docs/reference/ml/ml-shared.asciidoc

@@ -557,6 +557,14 @@ Values must be greater than zero and less than or equal to 1.
 By default, this value is calculated during hyperparameter optimization.
 end::dfas-downsample-factor[]
 
+tag::dfas-early-stopping-enabled[]
+Advanced configuration option.
+Specifies whether the training process should finish if it is not finding any
+better perfoming models. If disabled, the training process can take significantly
+longer and the chance of finding a better performing model is unremarkable.
+By default, early stoppping is enabled.
+end::dfas-early-stopping-enabled[]
+
 tag::dfas-eta-growth[]
 Advanced configuration option.
 Specifies the rate at which `eta` increases for each new tree that is added

+ 27 - 4
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java

@@ -55,6 +55,7 @@ public class Classification implements DataFrameAnalysis {
     public static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
     public static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
     public static final ParseField FEATURE_PROCESSORS = new ParseField("feature_processors");
+    public static final ParseField EARLY_STOPPING_ENABLED = new ParseField("early_stopping_enabled");
 
     private static final String STATE_DOC_ID_INFIX = "_classification_state#";
 
@@ -82,7 +83,8 @@ public class Classification implements DataFrameAnalysis {
                 (Integer) a[15],
                 (Double) a[16],
                 (Long) a[17],
-                (List<PreProcessor>) a[18]));
+                (List<PreProcessor>) a[18],
+                (Boolean) a[19]));
         parser.declareString(constructorArg(), DEPENDENT_VARIABLE);
         BoostedTreeParams.declareFields(parser);
         parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME);
@@ -96,6 +98,7 @@ public class Classification implements DataFrameAnalysis {
                 p.namedObject(StrictlyParsedPreProcessor.class, n, new PreProcessor.PreProcessorParseContext(true)),
             (classification) -> {/*TODO should we throw if this is not set?*/},
             FEATURE_PROCESSORS);
+        parser.declareBoolean(optionalConstructorArg(), EARLY_STOPPING_ENABLED);
         return parser;
     }
 
@@ -159,6 +162,7 @@ public class Classification implements DataFrameAnalysis {
     private final double trainingPercent;
     private final long randomizeSeed;
     private final List<PreProcessor> featureProcessors;
+    private final boolean earlyStoppingEnabled;
 
     public Classification(String dependentVariable,
                           BoostedTreeParams boostedTreeParams,
@@ -167,7 +171,8 @@ public class Classification implements DataFrameAnalysis {
                           @Nullable Integer numTopClasses,
                           @Nullable Double trainingPercent,
                           @Nullable Long randomizeSeed,
-                          @Nullable List<PreProcessor> featureProcessors) {
+                          @Nullable List<PreProcessor> featureProcessors,
+                          @Nullable Boolean earlyStoppingEnabled) {
         if (numTopClasses != null && (numTopClasses < -1 || numTopClasses > 1000)) {
             throw ExceptionsHelper.badRequestException(
                 "[{}] must be an integer in [0, 1000] or a special value -1", NUM_TOP_CLASSES.getPreferredName());
@@ -184,10 +189,12 @@ public class Classification implements DataFrameAnalysis {
         this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent;
         this.randomizeSeed = randomizeSeed == null ? Randomness.get().nextLong() : randomizeSeed;
         this.featureProcessors = featureProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(featureProcessors);
+        // Early stopping is true by default
+        this.earlyStoppingEnabled = earlyStoppingEnabled == null ? true : earlyStoppingEnabled;
     }
 
     public Classification(String dependentVariable) {
-        this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null, null);
+        this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null, null, null);
     }
 
     public Classification(StreamInput in) throws IOException {
@@ -211,6 +218,11 @@ public class Classification implements DataFrameAnalysis {
         } else {
             featureProcessors = Collections.emptyList();
         }
+        if (in.getVersion().onOrAfter(Version.V_8_0_0)) {
+            earlyStoppingEnabled = in.readBoolean();
+        } else {
+            earlyStoppingEnabled = true;
+        }
     }
 
     public String getDependentVariable() {
@@ -246,6 +258,10 @@ public class Classification implements DataFrameAnalysis {
         return featureProcessors;
     }
 
+    public Boolean getEarlyStoppingEnabled() {
+        return earlyStoppingEnabled;
+    }
+
     @Override
     public String getWriteableName() {
         return NAME.getPreferredName();
@@ -267,6 +283,9 @@ public class Classification implements DataFrameAnalysis {
         if (out.getVersion().onOrAfter(Version.V_7_10_0)) {
             out.writeNamedWriteableList(featureProcessors);
         }
+        if (out.getVersion().onOrAfter(Version.V_8_0_0)) {
+            out.writeBoolean(earlyStoppingEnabled);;
+        }
     }
 
     @Override
@@ -288,6 +307,7 @@ public class Classification implements DataFrameAnalysis {
         if (featureProcessors.isEmpty() == false) {
             NamedXContentObjectHelper.writeNamedObjects(builder, params, true, FEATURE_PROCESSORS.getPreferredName(), featureProcessors);
         }
+        builder.field(EARLY_STOPPING_ENABLED.getPreferredName(), earlyStoppingEnabled);
         builder.endObject();
         return builder;
     }
@@ -312,6 +332,7 @@ public class Classification implements DataFrameAnalysis {
             params.put(FEATURE_PROCESSORS.getPreferredName(),
                 featureProcessors.stream().map(p -> Collections.singletonMap(p.getName(), p)).collect(Collectors.toList()));
         }
+        params.put(EARLY_STOPPING_ENABLED.getPreferredName(), earlyStoppingEnabled);
         return params;
     }
 
@@ -457,6 +478,7 @@ public class Classification implements DataFrameAnalysis {
             && Objects.equals(classAssignmentObjective, that.classAssignmentObjective)
             && Objects.equals(numTopClasses, that.numTopClasses)
             && Objects.equals(featureProcessors, that.featureProcessors)
+            && Objects.equals(earlyStoppingEnabled, that.earlyStoppingEnabled)
             && trainingPercent == that.trainingPercent
             && randomizeSeed == that.randomizeSeed;
     }
@@ -464,7 +486,8 @@ public class Classification implements DataFrameAnalysis {
     @Override
     public int hashCode() {
         return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, classAssignmentObjective,
-                            numTopClasses, trainingPercent, randomizeSeed, featureProcessors);
+                            numTopClasses, trainingPercent, randomizeSeed, featureProcessors, 
+                            earlyStoppingEnabled);
     }
 
     public enum ClassAssignmentObjective {

+ 27 - 5
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java

@@ -52,6 +52,7 @@ public class Regression implements DataFrameAnalysis {
     public static final ParseField LOSS_FUNCTION = new ParseField("loss_function");
     public static final ParseField LOSS_FUNCTION_PARAMETER = new ParseField("loss_function_parameter");
     public static final ParseField FEATURE_PROCESSORS = new ParseField("feature_processors");
+    public static final ParseField EARLY_STOPPING_ENABLED = new ParseField("early_stopping_enabled");
 
     private static final String STATE_DOC_ID_INFIX = "_regression_state#";
 
@@ -72,7 +73,8 @@ public class Regression implements DataFrameAnalysis {
                 (Long) a[15],
                 (LossFunction) a[16],
                 (Double) a[17],
-                (List<PreProcessor>) a[18]));
+                (List<PreProcessor>) a[18],
+                (Boolean) a[19]));
         parser.declareString(constructorArg(), DEPENDENT_VARIABLE);
         BoostedTreeParams.declareFields(parser);
         parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME);
@@ -86,6 +88,7 @@ public class Regression implements DataFrameAnalysis {
                 p.namedObject(StrictlyParsedPreProcessor.class, n, new PreProcessor.PreProcessorParseContext(true)),
             (regression) -> {/*TODO should we throw if this is not set?*/},
             FEATURE_PROCESSORS);
+        parser.declareBoolean(optionalConstructorArg(), EARLY_STOPPING_ENABLED);
         return parser;
     }
 
@@ -124,6 +127,7 @@ public class Regression implements DataFrameAnalysis {
     private final LossFunction lossFunction;
     private final Double lossFunctionParameter;
     private final List<PreProcessor> featureProcessors;
+    private final boolean earlyStoppingEnabled;
 
     public Regression(String dependentVariable,
                       BoostedTreeParams boostedTreeParams,
@@ -132,7 +136,8 @@ public class Regression implements DataFrameAnalysis {
                       @Nullable Long randomizeSeed,
                       @Nullable LossFunction lossFunction,
                       @Nullable Double lossFunctionParameter,
-                      @Nullable List<PreProcessor> featureProcessors) {
+                      @Nullable List<PreProcessor> featureProcessors,
+                      @Nullable Boolean earlyStoppingEnabled) {
         if (trainingPercent != null && (trainingPercent <= 0.0 || trainingPercent > 100.0)) {
             throw ExceptionsHelper.badRequestException("[{}] must be a positive double in (0, 100]", TRAINING_PERCENT.getPreferredName());
         }
@@ -148,10 +153,12 @@ public class Regression implements DataFrameAnalysis {
         }
         this.lossFunctionParameter = lossFunctionParameter;
         this.featureProcessors = featureProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(featureProcessors);
+        // Early stopping is true by default
+        this.earlyStoppingEnabled = earlyStoppingEnabled == null ? true : earlyStoppingEnabled;
     }
 
     public Regression(String dependentVariable) {
-        this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null, null);
+        this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null, null, null);
     }
 
     public Regression(StreamInput in) throws IOException {
@@ -167,6 +174,11 @@ public class Regression implements DataFrameAnalysis {
         } else {
             featureProcessors = Collections.emptyList();
         }
+        if (in.getVersion().onOrAfter(Version.V_8_0_0)) {
+            earlyStoppingEnabled = in.readBoolean();
+        } else {
+            earlyStoppingEnabled = true;
+        }
     }
 
     public String getDependentVariable() {
@@ -202,6 +214,10 @@ public class Regression implements DataFrameAnalysis {
         return featureProcessors;
     }
 
+    public Boolean getEarlyStoppingEnabled() {
+        return earlyStoppingEnabled;
+    }
+
     @Override
     public String getWriteableName() {
         return NAME.getPreferredName();
@@ -219,6 +235,9 @@ public class Regression implements DataFrameAnalysis {
         if (out.getVersion().onOrAfter(Version.V_7_10_0)) {
             out.writeNamedWriteableList(featureProcessors);
         }
+        if (out.getVersion().onOrAfter(Version.V_8_0_0)) {
+            out.writeBoolean(earlyStoppingEnabled);
+        }
     }
 
     @Override
@@ -242,6 +261,7 @@ public class Regression implements DataFrameAnalysis {
         if (featureProcessors.isEmpty() == false) {
             NamedXContentObjectHelper.writeNamedObjects(builder, params, true, FEATURE_PROCESSORS.getPreferredName(), featureProcessors);
         }
+        builder.field(EARLY_STOPPING_ENABLED.getPreferredName(), earlyStoppingEnabled);
         builder.endObject();
         return builder;
     }
@@ -263,6 +283,7 @@ public class Regression implements DataFrameAnalysis {
             params.put(FEATURE_PROCESSORS.getPreferredName(),
                 featureProcessors.stream().map(p -> Collections.singletonMap(p.getName(), p)).collect(Collectors.toList()));
         }
+        params.put(EARLY_STOPPING_ENABLED.getPreferredName(), earlyStoppingEnabled);
         return params;
     }
 
@@ -348,13 +369,14 @@ public class Regression implements DataFrameAnalysis {
             && randomizeSeed == that.randomizeSeed
             && lossFunction == that.lossFunction
             && Objects.equals(featureProcessors, that.featureProcessors)
-            && Objects.equals(lossFunctionParameter, that.lossFunctionParameter);
+            && Objects.equals(lossFunctionParameter, that.lossFunctionParameter)
+            && Objects.equals(earlyStoppingEnabled, that.earlyStoppingEnabled);
     }
 
     @Override
     public int hashCode() {
         return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, trainingPercent, randomizeSeed, lossFunction,
-            lossFunctionParameter, featureProcessors);
+            lossFunctionParameter, featureProcessors, earlyStoppingEnabled);
     }
 
     public enum LossFunction {

+ 2 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java

@@ -329,6 +329,7 @@ public final class ReservedFieldNames {
             Regression.PREDICTION_FIELD_NAME.getPreferredName(),
             Regression.TRAINING_PERCENT.getPreferredName(),
             Regression.FEATURE_PROCESSORS.getPreferredName(),
+            Regression.EARLY_STOPPING_ENABLED.getPreferredName(),
             Classification.NAME.getPreferredName(),
             Classification.DEPENDENT_VARIABLE.getPreferredName(),
             Classification.PREDICTION_FIELD_NAME.getPreferredName(),
@@ -336,6 +337,7 @@ public final class ReservedFieldNames {
             Classification.NUM_TOP_CLASSES.getPreferredName(),
             Classification.TRAINING_PERCENT.getPreferredName(),
             Classification.FEATURE_PROCESSORS.getPreferredName(),
+            Classification.EARLY_STOPPING_ENABLED.getPreferredName(),
             BoostedTreeParams.ALPHA.getPreferredName(),
             BoostedTreeParams.DOWNSAMPLE_FACTOR.getPreferredName(),
             BoostedTreeParams.LAMBDA.getPreferredName(),

+ 6 - 0
x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/config_index_mappings.json

@@ -78,6 +78,9 @@
               },
               "training_percent" : {
                 "type" : "double"
+              },
+              "early_stopping_enabled" : {
+                "type": "boolean"
               }
             }
           },
@@ -149,6 +152,9 @@
               },
               "training_percent" : {
                 "type" : "double"
+              },
+              "early_stopping_enabled" : {
+                "type": "boolean"
               }
             }
           }

+ 8 - 4
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java

@@ -151,7 +151,8 @@ public class DataFrameAnalyticsConfigTests extends AbstractBWCSerializationTestC
                 42L,
                 bwcRegression.getLossFunction(),
                 bwcRegression.getLossFunctionParameter(),
-                bwcRegression.getFeatureProcessors());
+                bwcRegression.getFeatureProcessors(),
+                bwcRegression.getEarlyStoppingEnabled());
             testAnalysis = new Regression(testRegression.getDependentVariable(),
                 testRegression.getBoostedTreeParams(),
                 testRegression.getPredictionFieldName(),
@@ -159,7 +160,8 @@ public class DataFrameAnalyticsConfigTests extends AbstractBWCSerializationTestC
                 42L,
                 testRegression.getLossFunction(),
                 testRegression.getLossFunctionParameter(),
-                bwcRegression.getFeatureProcessors());
+                testRegression.getFeatureProcessors(),
+                testRegression.getEarlyStoppingEnabled());
         } else {
             Classification testClassification = (Classification)testInstance.getAnalysis();
             Classification bwcClassification = (Classification)bwcSerializedObject.getAnalysis();
@@ -170,7 +172,8 @@ public class DataFrameAnalyticsConfigTests extends AbstractBWCSerializationTestC
                 bwcClassification.getNumTopClasses(),
                 bwcClassification.getTrainingPercent(),
                 42L,
-                bwcClassification.getFeatureProcessors());
+                bwcClassification.getFeatureProcessors(),
+                bwcClassification.getEarlyStoppingEnabled());
             testAnalysis = new Classification(testClassification.getDependentVariable(),
                 testClassification.getBoostedTreeParams(),
                 testClassification.getPredictionFieldName(),
@@ -178,7 +181,8 @@ public class DataFrameAnalyticsConfigTests extends AbstractBWCSerializationTestC
                 testClassification.getNumTopClasses(),
                 testClassification.getTrainingPercent(),
                 42L,
-                testClassification.getFeatureProcessors());
+                testClassification.getFeatureProcessors(),
+                testClassification.getEarlyStoppingEnabled());
         }
         super.assertOnBWCObject(new DataFrameAnalyticsConfig.Builder(bwcSerializedObject)
             .setAnalysis(bwcAnalysis)

+ 35 - 26
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java

@@ -97,6 +97,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
         Integer numTopClasses = randomBoolean() ? null : randomIntBetween(-1, 1000);
         Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(0.0, 100.0, false);
         Long randomizeSeed = randomBoolean() ? null : randomLong();
+        Boolean earlyStoppingEnabled = randomBoolean() ? null : randomBoolean();
         return new Classification(dependentVariableName, boostedTreeParams, predictionFieldName, classAssignmentObjective,
             numTopClasses, trainingPercent, randomizeSeed,
             randomBoolean() ?
@@ -105,7 +106,8 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
                     OneHotEncodingTests.createRandom(true),
                     TargetMeanEncodingTests.createRandom(true)))
                     .limit(randomIntBetween(0, 5))
-                    .collect(Collectors.toList()));
+                    .collect(Collectors.toList()),
+            earlyStoppingEnabled);
     }
 
     public static Classification mutateForVersion(Classification instance, Version version) {
@@ -116,7 +118,8 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
             instance.getNumTopClasses(),
             instance.getTrainingPercent(),
             instance.getRandomizeSeed(),
-            version.onOrAfter(Version.V_7_10_0) ? instance.getFeatureProcessors() : Collections.emptyList());
+            version.onOrAfter(Version.V_7_10_0) ? instance.getFeatureProcessors() : Collections.emptyList(),
+            version.onOrAfter(Version.V_8_0_0) ? instance.getEarlyStoppingEnabled() : null);
     }
 
     @Override
@@ -133,7 +136,8 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
             bwcSerializedObject.getNumTopClasses(),
             bwcSerializedObject.getTrainingPercent(),
             42L,
-            bwcSerializedObject.getFeatureProcessors());
+            bwcSerializedObject.getFeatureProcessors(),
+            bwcSerializedObject.getEarlyStoppingEnabled());
         Classification newInstance = new Classification(testInstance.getDependentVariable(),
             testInstance.getBoostedTreeParams(),
             testInstance.getPredictionFieldName(),
@@ -141,7 +145,8 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
             testInstance.getNumTopClasses(),
             testInstance.getTrainingPercent(),
             42L,
-            testInstance.getFeatureProcessors());
+            testInstance.getFeatureProcessors(),
+            testInstance.getEarlyStoppingEnabled());
         super.assertOnBWCObject(newBwc, newInstance, version);
     }
 
@@ -202,96 +207,96 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
 
     public void testConstructor_GivenTrainingPercentIsZero() {
         ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
-            () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 0.0, randomLong(), null));
+            () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 0.0, randomLong(), null, null));
 
         assertThat(e.getMessage(), equalTo("[training_percent] must be a positive double in (0, 100]"));
     }
 
     public void testConstructor_GivenTrainingPercentIsLessThanZero() {
         ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
-            () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, -1.0, randomLong(), null));
+            () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, -1.0, randomLong(), null, null));
 
         assertThat(e.getMessage(), equalTo("[training_percent] must be a positive double in (0, 100]"));
     }
 
     public void testConstructor_GivenTrainingPercentIsGreaterThan100() {
         ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
-            () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0001, randomLong(), null));
+            () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0001, randomLong(), null, null));
 
         assertThat(e.getMessage(), equalTo("[training_percent] must be a positive double in (0, 100]"));
     }
 
     public void testConstructor_GivenNumTopClassesIsLessThanMinusOne() {
         ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
-            () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -2, 1.0, randomLong(), null));
+            () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -2, 1.0, randomLong(), null, null));
 
         assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000] or a special value -1"));
     }
 
     public void testConstructor_GivenNumTopClassesIsGreaterThan1000() {
         ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
-            () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1001, 1.0, randomLong(), null));
+            () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1001, 1.0, randomLong(), null, null));
 
         assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000] or a special value -1"));
     }
 
     public void testGetPredictionFieldName() {
-        Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 50.0, randomLong(), null);
+        Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 50.0, randomLong(), null, null);
         assertThat(classification.getPredictionFieldName(), equalTo("result"));
 
-        classification = new Classification("foo", BOOSTED_TREE_PARAMS, null, null, 3, 50.0, randomLong(), null);
+        classification = new Classification("foo", BOOSTED_TREE_PARAMS, null, null, 3, 50.0, randomLong(), null, null);
         assertThat(classification.getPredictionFieldName(), equalTo("foo_prediction"));
     }
 
     public void testClassAssignmentObjective() {
         Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result",
-            Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY, 7, 1.0, randomLong(), null);
+            Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY, 7, 1.0, randomLong(), null, null);
         assertThat(classification.getClassAssignmentObjective(), equalTo(Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY));
 
         classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result",
-        Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL, 7, 1.0, randomLong(), null);
+        Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL, 7, 1.0, randomLong(), null, null);
         assertThat(classification.getClassAssignmentObjective(), equalTo(Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL));
 
         // class_assignment_objective == null, default applied
-        classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong(), null);
+        classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong(), null, null);
         assertThat(classification.getClassAssignmentObjective(), equalTo(Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL));
     }
 
     public void testGetNumTopClasses() {
-        Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong(), null);
+        Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong(), null, null);
         assertThat(classification.getNumTopClasses(), equalTo(7));
 
         // Special value: num_top_classes == -1
-        classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -1, 1.0, randomLong(), null);
+        classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -1, 1.0, randomLong(), null, null);
         assertThat(classification.getNumTopClasses(), equalTo(-1));
 
         // Boundary condition: num_top_classes == 0
-        classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 0, 1.0, randomLong(), null);
+        classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 0, 1.0, randomLong(), null, null);
         assertThat(classification.getNumTopClasses(), equalTo(0));
 
         // Boundary condition: num_top_classes == 1000
-        classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1000, 1.0, randomLong(), null);
+        classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1000, 1.0, randomLong(), null, null);
         assertThat(classification.getNumTopClasses(), equalTo(1000));
 
         // num_top_classes == null, default applied
-        classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, null, 1.0, randomLong(), null);
+        classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, null, 1.0, randomLong(), null, null);
         assertThat(classification.getNumTopClasses(), equalTo(2));
     }
 
     public void testGetTrainingPercent() {
-        Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 50.0, randomLong(), null);
+        Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 50.0, randomLong(), null, null);
         assertThat(classification.getTrainingPercent(), equalTo(50.0));
 
         // Boundary condition: training_percent == 1.0
-        classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 1.0, randomLong(), null);
+        classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 1.0, randomLong(), null, null);
         assertThat(classification.getTrainingPercent(), equalTo(1.0));
 
         // Boundary condition: training_percent == 100.0
-        classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0, randomLong(), null);
+        classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0, randomLong(), null, null);
         assertThat(classification.getTrainingPercent(), equalTo(100.0));
 
         // training_percent == null, default applied
-        classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, null, randomLong(), null);
+        classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, null, randomLong(), null, null);
         assertThat(classification.getTrainingPercent(), equalTo(100.0));
     }
 
@@ -316,7 +321,8 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
                     "prediction_field_name", "foo_prediction",
                     "prediction_field_type", "bool",
                     "num_classes", 10L,
-                    "training_percent", 100.0)));
+                    "training_percent", 100.0,
+                    "early_stopping_enabled", true)));
         assertThat(
             new Classification("bar").getParams(fieldInfo),
             equalTo(
@@ -327,7 +333,8 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
                     "prediction_field_name", "bar_prediction",
                     "prediction_field_type", "int",
                     "num_classes", 20L,
-                    "training_percent", 100.0)));
+                    "training_percent", 100.0,
+                    "early_stopping_enabled", true)));
         assertThat(
             new Classification("baz",
                 BoostedTreeParams.builder().build() ,
@@ -336,6 +343,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
                 null,
                 50.0,
                 null,
+                null,
                 null).getParams(fieldInfo),
             equalTo(
                 Map.of(
@@ -345,7 +353,8 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
                     "prediction_field_name", "baz_prediction",
                     "prediction_field_type", "string",
                     "num_classes", 30L,
-                    "training_percent", 50.0)));
+                    "training_percent", 50.0,
+                    "early_stopping_enabled", true)));
     }
 
     public void testRequiredFieldsIsNonEmpty() {

+ 36 - 38
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java

@@ -88,6 +88,7 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
         Long randomizeSeed = randomBoolean() ? null : randomLong();
         Regression.LossFunction lossFunction = randomBoolean() ? null : randomFrom(Regression.LossFunction.values());
         Double lossFunctionParameter = randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, false);
+        Boolean earlyStoppingEnabled = randomBoolean() ? null : randomBoolean();
         return new Regression(dependentVariableName, boostedTreeParams, predictionFieldName, trainingPercent, randomizeSeed, lossFunction,
             lossFunctionParameter,
             randomBoolean() ?
@@ -96,7 +97,8 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
                     OneHotEncodingTests.createRandom(true),
                     TargetMeanEncodingTests.createRandom(true)))
                     .limit(randomIntBetween(0, 5))
-                    .collect(Collectors.toList()));
+                    .collect(Collectors.toList()),
+            earlyStoppingEnabled);
     }
 
     public static Regression mutateForVersion(Regression instance, Version version) {
@@ -107,7 +109,8 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
             instance.getRandomizeSeed(),
             instance.getLossFunction(),
             instance.getLossFunctionParameter(),
-            version.onOrAfter(Version.V_7_10_0) ? instance.getFeatureProcessors() : Collections.emptyList());
+            version.onOrAfter(Version.V_7_10_0) ? instance.getFeatureProcessors() : Collections.emptyList(),
+            version.onOrAfter(Version.V_8_0_0) ? instance.getEarlyStoppingEnabled() : null);
     }
 
     @Override
@@ -124,7 +127,8 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
             42L,
             bwcSerializedObject.getLossFunction(),
             bwcSerializedObject.getLossFunctionParameter(),
-            bwcSerializedObject.getFeatureProcessors());
+            bwcSerializedObject.getFeatureProcessors(),
+            bwcSerializedObject.getEarlyStoppingEnabled());
         Regression newInstance = new Regression(testInstance.getDependentVariable(),
             testInstance.getBoostedTreeParams(),
             testInstance.getPredictionFieldName(),
@@ -132,7 +136,8 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
             42L,
             testInstance.getLossFunction(),
             testInstance.getLossFunctionParameter(),
-            testInstance.getFeatureProcessors());
+            testInstance.getFeatureProcessors(),
+            testInstance.getEarlyStoppingEnabled());
         super.assertOnBWCObject(newBwc, newInstance, version);
     }
 
@@ -198,21 +203,24 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
 
     public void testConstructor_GivenTrainingPercentIsZero() {
         ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
-            () -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 0.0, randomLong(), Regression.LossFunction.MSE, null, null));
+            () -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 0.0, randomLong(), 
+                                Regression.LossFunction.MSE, null, null, null));
 
         assertThat(e.getMessage(), equalTo("[training_percent] must be a positive double in (0, 100]"));
     }
 
     public void testConstructor_GivenTrainingPercentIsLessThanZero() {
         ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
-            () -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", -0.01, randomLong(), Regression.LossFunction.MSE, null, null));
+            () -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", -0.01, randomLong(), 
+                                Regression.LossFunction.MSE, null, null, null));
 
         assertThat(e.getMessage(), equalTo("[training_percent] must be a positive double in (0, 100]"));
     }
 
     public void testConstructor_GivenTrainingPercentIsGreaterThan100() {
         ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
-            () -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0001, randomLong(), Regression.LossFunction.MSE, null, null));
+            () -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0001, randomLong(), 
+                                Regression.LossFunction.MSE, null, null, null));
 
 
         assertThat(e.getMessage(), equalTo("[training_percent] must be a positive double in (0, 100]"));
@@ -220,55 +228,48 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
 
     public void testConstructor_GivenLossFunctionParameterIsZero() {
         ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
-            () -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong(), Regression.LossFunction.MSE, 0.0, null));
+            () -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong(), 
+                                Regression.LossFunction.MSE, 0.0, null, null));
 
         assertThat(e.getMessage(), equalTo("[loss_function_parameter] must be a positive double"));
     }
 
     public void testConstructor_GivenLossFunctionParameterIsNegative() {
         ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
-            () -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong(), Regression.LossFunction.MSE, -1.0, null));
+            () -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong(), 
+                                Regression.LossFunction.MSE, -1.0, null, null));
 
         assertThat(e.getMessage(), equalTo("[loss_function_parameter] must be a positive double"));
     }
 
     public void testGetPredictionFieldName() {
-        Regression regression = new Regression(
-            "foo",
-            BOOSTED_TREE_PARAMS,
-            "result",
-            50.0,
-            randomLong(),
-            Regression.LossFunction.MSE,
-            1.0,
-            null);
+        Regression regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 50.0, randomLong(),
+                                                Regression.LossFunction.MSE, 1.0, null, null);
         assertThat(regression.getPredictionFieldName(), equalTo("result"));
 
-        regression = new Regression("foo", BOOSTED_TREE_PARAMS, null, 50.0, randomLong(), Regression.LossFunction.MSE, null, null);
+        regression = new Regression("foo", BOOSTED_TREE_PARAMS, null, 50.0, randomLong(), 
+                                    Regression.LossFunction.MSE, null, null, null);
         assertThat(regression.getPredictionFieldName(), equalTo("foo_prediction"));
     }
 
     public void testGetTrainingPercent() {
-        Regression regression = new Regression("foo",
-            BOOSTED_TREE_PARAMS,
-            "result",
-            50.0,
-            randomLong(),
-            Regression.LossFunction.MSE,
-            1.0,
-            null);
+        Regression regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 50.0, randomLong(),
+                                                Regression.LossFunction.MSE, 1.0, null, null);
         assertThat(regression.getTrainingPercent(), equalTo(50.0));
 
         // Boundary condition: training_percent == 1.0
-        regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 1.0, randomLong(), Regression.LossFunction.MSE, null, null);
+        regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 1.0, randomLong(), 
+                                    Regression.LossFunction.MSE, null, null, null);
         assertThat(regression.getTrainingPercent(), equalTo(1.0));
 
         // Boundary condition: training_percent == 100.0
-        regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong(), Regression.LossFunction.MSE, null, null);
+        regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong(), 
+                                    Regression.LossFunction.MSE, null, null, null);
         assertThat(regression.getTrainingPercent(), equalTo(100.0));
 
         // training_percent == null, default applied
-        regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", null, randomLong(), Regression.LossFunction.MSE, null, null);
+        regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", null, randomLong(), 
+                                    Regression.LossFunction.MSE, null, null, null);
         assertThat(regression.getTrainingPercent(), equalTo(100.0));
     }
 
@@ -276,21 +277,17 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
         int maxTrees = randomIntBetween(1, 100);
         Regression regression = new Regression("foo",
             BoostedTreeParams.builder().setMaxTrees(maxTrees).build(),
-            null,
-            100.0,
-            0L,
-            Regression.LossFunction.MSE,
-            null,
-            null);
+            null, 100.0, 0L, Regression.LossFunction.MSE, null, null, null);
 
         Map<String, Object> params = regression.getParams(null);
 
-        assertThat(params.size(), equalTo(5));
+        assertThat(params.size(), equalTo(6));
         assertThat(params.get("dependent_variable"), equalTo("foo"));
         assertThat(params.get("prediction_field_name"), equalTo("foo_prediction"));
         assertThat(params.get("max_trees"), equalTo(maxTrees));
         assertThat(params.get("training_percent"), equalTo(100.0));
         assertThat(params.get("loss_function"), equalTo("mse"));
+        assertThat(params.get("early_stopping_enabled"), equalTo(true));
     }
 
     public void testGetParams_GivenRandomWithoutBoostedTreeParams() {
@@ -298,7 +295,7 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
 
         Map<String, Object> params = regression.getParams(null);
 
-        int expectedParamsCount = 4
+        int expectedParamsCount = 5
             + (regression.getLossFunctionParameter() == null ? 0 : 1)
             + (regression.getFeatureProcessors().isEmpty() ? 0 : 1);
         assertThat(params.size(), equalTo(expectedParamsCount));
@@ -311,6 +308,7 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
         } else {
             assertThat(params.get("loss_function_parameter"), equalTo(regression.getLossFunctionParameter()));
         }
+        assertThat(params.get("early_stopping_enabled"), equalTo(regression.getEarlyStoppingEnabled()));
     }
 
     public void testRequiredFieldsIsNonEmpty() {

+ 8 - 4
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java

@@ -141,6 +141,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
                 null,
                 null,
                 null,
+                null,
                 null));
         putAnalytics(config);
 
@@ -197,6 +198,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
                 null,
                 null,
                 null,
+                null,
                 null));
         putAnalytics(config);
 
@@ -317,7 +319,8 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
                     new OneHotEncoding(TEXT_FIELD, MapBuilder.<String, String>newMapBuilder()
                         .put(KEYWORD_FIELD_VALUES.get(0), "cat_column_custom_3")
                         .put(KEYWORD_FIELD_VALUES.get(1), "dog_column_custom_3").map(), true)
-                )));
+                ),
+                null));
         putAnalytics(config);
 
         assertIsStopped(jobId);
@@ -386,7 +389,8 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
                 sourceIndex,
                 destIndex,
                 null,
-                new Classification(dependentVariable, BoostedTreeParams.builder().build(), null, null, numTopClasses, 50.0, null, null));
+                new Classification(dependentVariable, BoostedTreeParams.builder().build(), null, null, 
+                numTopClasses, 50.0, null, null, null));
         putAnalytics(config);
 
         assertIsStopped(jobId);
@@ -650,7 +654,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
             .build();
 
         DataFrameAnalyticsConfig firstJob = buildAnalytics(firstJobId, sourceIndex, firstJobDestIndex, null,
-            new Classification(dependentVariable, boostedTreeParams, null, null, 1, 50.0, null, null));
+            new Classification(dependentVariable, boostedTreeParams, null, null, 1, 50.0, null, null, null));
         putAnalytics(firstJob);
         startAnalytics(firstJobId);
         waitUntilAnalyticsIsStopped(firstJobId);
@@ -660,7 +664,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
 
         long randomizeSeed = ((Classification) firstJob.getAnalysis()).getRandomizeSeed();
         DataFrameAnalyticsConfig secondJob = buildAnalytics(secondJobId, sourceIndex, secondJobDestIndex, null,
-            new Classification(dependentVariable, boostedTreeParams, null, null, 1, 50.0, randomizeSeed, null));
+            new Classification(dependentVariable, boostedTreeParams, null, null, 1, 50.0, randomizeSeed, null, null));
 
         putAnalytics(secondJob);
         startAnalytics(secondJobId);

+ 2 - 1
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/DataFrameAnalysisCustomFeatureIT.java

@@ -128,7 +128,8 @@ public class DataFrameAnalysisCustomFeatureIT extends MlNativeDataFrameAnalytics
                         new OneHotEncoding("ngram.21", MapBuilder.<String, String>newMapBuilder().put("at", "is_cat").map(), true)
                     },
                         true)
-                    )))
+                    ),
+                    null))
             .setAnalyzedFields(new FetchSourceContext(true, new String[]{TEXT_FIELD, NUMERICAL_FIELD}, new String[]{}))
             .build();
         putAnalytics(config);

+ 3 - 0
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ExplainDataFrameAnalyticsIT.java

@@ -105,6 +105,7 @@ public class ExplainDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsInteg
                 null,
                 null,
                 null,
+                null,
                 null))
             .buildForExplain();
 
@@ -124,6 +125,7 @@ public class ExplainDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsInteg
                 null,
                 null,
                 null,
+                null,
                 null))
             .buildForExplain();
 
@@ -152,6 +154,7 @@ public class ExplainDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsInteg
                 null,
                 null,
                 null,
+                null,
                 null))
             .buildForExplain();
 

+ 13 - 5
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java

@@ -115,6 +115,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
                 null,
                 null,
                 null,
+                null,
                 null)
         );
         putAnalytics(config);
@@ -251,7 +252,8 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
                 sourceIndex,
                 destIndex,
                 null,
-                new Regression(DEPENDENT_VARIABLE_FIELD, BoostedTreeParams.builder().build(), null, 50.0, null, null, null, null));
+                new Regression(DEPENDENT_VARIABLE_FIELD, BoostedTreeParams.builder().build(), 
+                                null, 50.0, null, null, null, null, null));
         putAnalytics(config);
 
         assertIsStopped(jobId);
@@ -371,7 +373,8 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
             .build();
 
         DataFrameAnalyticsConfig firstJob = buildAnalytics(firstJobId, sourceIndex, firstJobDestIndex, null,
-            new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0, null, null, null, null));
+            new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0, 
+                            null, null, null, null, null));
         putAnalytics(firstJob);
         startAnalytics(firstJobId);
         waitUntilAnalyticsIsStopped(firstJobId);
@@ -381,7 +384,8 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
 
         long randomizeSeed = ((Regression) firstJob.getAnalysis()).getRandomizeSeed();
         DataFrameAnalyticsConfig secondJob = buildAnalytics(secondJobId, sourceIndex, secondJobDestIndex, null,
-            new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0, randomizeSeed, null, null, null));
+            new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0, 
+                            randomizeSeed, null, null, null, null));
 
         putAnalytics(secondJob);
         startAnalytics(secondJobId);
@@ -438,7 +442,8 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
                 sourceIndex,
                 destIndex,
                 null,
-                new Regression(DISCRETE_NUMERICAL_FEATURE_FIELD, BoostedTreeParams.builder().build(), null, null, null, null, null, null));
+                new Regression(DISCRETE_NUMERICAL_FEATURE_FIELD, BoostedTreeParams.builder().build(),
+                                 null, null, null, null, null, null, null));
         putAnalytics(config);
 
         assertIsStopped(jobId);
@@ -465,6 +470,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
                 null,
                 null,
                 null,
+                null,
                 null)
         );
         putAnalytics(config);
@@ -562,6 +568,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
             null,
             null,
             null,
+            null,
             null);
         DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder()
             .setId(jobId)
@@ -635,7 +642,8 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
                 Arrays.asList(
                     new OneHotEncoding(DISCRETE_NUMERICAL_FEATURE_FIELD,
                         Collections.singletonMap(DISCRETE_NUMERICAL_FEATURE_VALUES.get(0).toString(), "tenner"), true)
-                ))
+                ),
+                null)
         );
         putAnalytics(config);
 

+ 2 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java

@@ -1169,7 +1169,8 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
                 null,
                 null,
                 null,
-                featureprocessors))
+                featureprocessors,
+                null))
             .build();
     }
 

+ 12 - 6
x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml

@@ -1513,7 +1513,8 @@ setup:
                 "soft_tree_depth_limit": 2.0,
                 "soft_tree_depth_tolerance": 3.0,
                 "downsample_factor": 0.5,
-                "max_optimization_rounds_per_hyperparameter": 3
+                "max_optimization_rounds_per_hyperparameter": 3,
+                "early_stopping_enabled": true
               }
             }
           }
@@ -1538,7 +1539,8 @@ setup:
       "soft_tree_depth_limit": 2.0,
       "soft_tree_depth_tolerance": 3.0,
       "downsample_factor": 0.5,
-      "max_optimization_rounds_per_hyperparameter": 3
+      "max_optimization_rounds_per_hyperparameter": 3,
+      "early_stopping_enabled": true
     }
   }}
   - is_true: create_time
@@ -1870,7 +1872,8 @@ setup:
                 "soft_tree_depth_limit": 2.0,
                 "soft_tree_depth_tolerance": 3.0,
                 "downsample_factor": 0.5,
-                "max_optimization_rounds_per_hyperparameter": 3
+                "max_optimization_rounds_per_hyperparameter": 3,
+                "early_stopping_enabled": true
               }
             }
           }
@@ -1895,7 +1898,8 @@ setup:
       "soft_tree_depth_limit": 2.0,
       "soft_tree_depth_tolerance": 3.0,
       "downsample_factor": 0.5,
-      "max_optimization_rounds_per_hyperparameter": 3
+      "max_optimization_rounds_per_hyperparameter": 3,
+      "early_stopping_enabled": true
     }
   }}
   - is_true: create_time
@@ -1939,7 +1943,8 @@ setup:
       "training_percent": 100.0,
       "randomize_seed": 24,
       "class_assignment_objective": "maximize_minimum_recall",
-      "num_top_classes": 2
+      "num_top_classes": 2,
+      "early_stopping_enabled": true
     }
   }}
   - is_true: create_time
@@ -1977,7 +1982,8 @@ setup:
       "prediction_field_name": "foo_prediction",
       "training_percent": 100.0,
       "randomize_seed": 42,
-      "loss_function": "mse"
+      "loss_function": "mse",
+      "early_stopping_enabled": true
     }
   }}
   - is_true: create_time