浏览代码

[ML] Expand regression/classification hyperparameters (#67950)

Expands data frame analytics regression and classification
analyses with the followin hyperparameters:

- alpha
- downsample_factor
- eta_growth_rate_per_tree
- max_optimization_rounds_per_hyperparameter
- soft_tree_depth_limit
- soft_tree_depth_tolerance
Dimitris Athanasiou 4 年之前
父节点
当前提交
5c961c1c81

+ 125 - 5
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java

@@ -57,6 +57,12 @@ public class Classification implements DataFrameAnalysis {
     static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
     static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
     static final ParseField FEATURE_PROCESSORS = new ParseField("feature_processors");
+    static final ParseField ALPHA = new ParseField("alpha");
+    static final ParseField ETA_GROWTH_RATE_PER_TREE = new ParseField("eta_growth_rate_per_tree");
+    static final ParseField SOFT_TREE_DEPTH_LIMIT = new ParseField("soft_tree_depth_limit");
+    static final ParseField SOFT_TREE_DEPTH_TOLERANCE = new ParseField("soft_tree_depth_tolerance");
+    static final ParseField DOWNSAMPLE_FACTOR = new ParseField("downsample_factor");
+    static final ParseField MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER = new ParseField("max_optimization_rounds_per_hyperparameter");
 
     @SuppressWarnings("unchecked")
     private static final ConstructingObjectParser<Classification, Void> PARSER =
@@ -76,7 +82,14 @@ public class Classification implements DataFrameAnalysis {
                 (Integer) a[9],
                 (Long) a[10],
                 (ClassAssignmentObjective) a[11],
-                (List<PreProcessor>) a[12]));
+                (List<PreProcessor>) a[12],
+                (Double) a[13],
+                (Double) a[14],
+                (Double) a[15],
+                (Double) a[16],
+                (Double) a[17],
+                (Integer) a[18]
+            ));
 
     static {
         PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE);
@@ -96,6 +109,12 @@ public class Classification implements DataFrameAnalysis {
             (p, c, n) -> p.namedObject(PreProcessor.class, n, c),
             (classification) -> {},
             FEATURE_PROCESSORS);
+        PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), ALPHA);
+        PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), ETA_GROWTH_RATE_PER_TREE);
+        PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), SOFT_TREE_DEPTH_LIMIT);
+        PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), SOFT_TREE_DEPTH_TOLERANCE);
+        PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), DOWNSAMPLE_FACTOR);
+        PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER);
     }
 
     private final String dependentVariable;
@@ -111,12 +130,21 @@ public class Classification implements DataFrameAnalysis {
     private final Integer numTopClasses;
     private final Long randomizeSeed;
     private final List<PreProcessor> featureProcessors;
+    private final Double alpha;
+    private final Double etaGrowthRatePerTree;
+    private final Double softTreeDepthLimit;
+    private final Double softTreeDepthTolerance;
+    private final Double downsampleFactor;
+    private final Integer maxOptimizationRoundsPerHyperparameter;
 
     private Classification(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
                            @Nullable Integer maxTrees, @Nullable Double featureBagFraction,
                            @Nullable Integer numTopFeatureImportanceValues, @Nullable String predictionFieldName,
                            @Nullable Double trainingPercent, @Nullable Integer numTopClasses, @Nullable Long randomizeSeed,
-                           @Nullable ClassAssignmentObjective classAssignmentObjective, @Nullable List<PreProcessor> featureProcessors) {
+                           @Nullable ClassAssignmentObjective classAssignmentObjective, @Nullable List<PreProcessor> featureProcessors,
+                           @Nullable Double alpha, @Nullable Double etaGrowthRatePerTree, @Nullable Double softTreeDepthLimit,
+                           @Nullable Double softTreeDepthTolerance, @Nullable Double downsampleFactor,
+                           @Nullable Integer maxOptimizationRoundsPerHyperparameter) {
         this.dependentVariable = Objects.requireNonNull(dependentVariable);
         this.lambda = lambda;
         this.gamma = gamma;
@@ -130,6 +158,12 @@ public class Classification implements DataFrameAnalysis {
         this.numTopClasses = numTopClasses;
         this.randomizeSeed = randomizeSeed;
         this.featureProcessors = featureProcessors;
+        this.alpha = alpha;
+        this.etaGrowthRatePerTree = etaGrowthRatePerTree;
+        this.softTreeDepthLimit = softTreeDepthLimit;
+        this.softTreeDepthTolerance = softTreeDepthTolerance;
+        this.downsampleFactor = downsampleFactor;
+        this.maxOptimizationRoundsPerHyperparameter = maxOptimizationRoundsPerHyperparameter;
     }
 
     @Override
@@ -189,6 +223,30 @@ public class Classification implements DataFrameAnalysis {
         return featureProcessors;
     }
 
+    public Double getAlpha() {
+        return alpha;
+    }
+
+    public Double getEtaGrowthRatePerTree() {
+        return etaGrowthRatePerTree;
+    }
+
+    public Double getSoftTreeDepthLimit() {
+        return softTreeDepthLimit;
+    }
+
+    public Double getSoftTreeDepthTolerance() {
+        return softTreeDepthTolerance;
+    }
+
+    public Double getDownsampleFactor() {
+        return downsampleFactor;
+    }
+
+    public Integer getMaxOptimizationRoundsPerHyperparameter() {
+        return maxOptimizationRoundsPerHyperparameter;
+    }
+
     @Override
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
         builder.startObject();
@@ -229,6 +287,24 @@ public class Classification implements DataFrameAnalysis {
         if (featureProcessors != null) {
             NamedXContentObjectHelper.writeNamedObjects(builder, params, true, FEATURE_PROCESSORS.getPreferredName(), featureProcessors);
         }
+        if (alpha != null) {
+            builder.field(ALPHA.getPreferredName(), alpha);
+        }
+        if (etaGrowthRatePerTree != null) {
+            builder.field(ETA_GROWTH_RATE_PER_TREE.getPreferredName(), etaGrowthRatePerTree);
+        }
+        if (softTreeDepthLimit != null) {
+            builder.field(SOFT_TREE_DEPTH_LIMIT.getPreferredName(), softTreeDepthLimit);
+        }
+        if (softTreeDepthTolerance != null) {
+            builder.field(SOFT_TREE_DEPTH_TOLERANCE.getPreferredName(), softTreeDepthTolerance);
+        }
+        if (downsampleFactor != null) {
+            builder.field(DOWNSAMPLE_FACTOR.getPreferredName(), downsampleFactor);
+        }
+        if (maxOptimizationRoundsPerHyperparameter != null) {
+            builder.field(MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER.getPreferredName(), maxOptimizationRoundsPerHyperparameter);
+        }
         builder.endObject();
         return builder;
     }
@@ -236,7 +312,8 @@ public class Classification implements DataFrameAnalysis {
     @Override
     public int hashCode() {
         return Objects.hash(dependentVariable, lambda, gamma, eta, maxTrees, featureBagFraction, numTopFeatureImportanceValues,
-            predictionFieldName, trainingPercent, randomizeSeed, numTopClasses, classAssignmentObjective, featureProcessors);
+            predictionFieldName, trainingPercent, randomizeSeed, numTopClasses, classAssignmentObjective, featureProcessors, alpha,
+            etaGrowthRatePerTree, softTreeDepthLimit, softTreeDepthTolerance, downsampleFactor, maxOptimizationRoundsPerHyperparameter);
     }
 
     @Override
@@ -256,7 +333,13 @@ public class Classification implements DataFrameAnalysis {
             && Objects.equals(randomizeSeed, that.randomizeSeed)
             && Objects.equals(numTopClasses, that.numTopClasses)
             && Objects.equals(classAssignmentObjective, that.classAssignmentObjective)
-            && Objects.equals(featureProcessors, that.featureProcessors);
+            && Objects.equals(featureProcessors, that.featureProcessors)
+            && Objects.equals(alpha, that.alpha)
+            && Objects.equals(etaGrowthRatePerTree, that.etaGrowthRatePerTree)
+            && Objects.equals(softTreeDepthLimit, that.softTreeDepthLimit)
+            && Objects.equals(softTreeDepthTolerance, that.softTreeDepthTolerance)
+            && Objects.equals(downsampleFactor, that.downsampleFactor)
+            && Objects.equals(maxOptimizationRoundsPerHyperparameter, that.maxOptimizationRoundsPerHyperparameter);
     }
 
     @Override
@@ -291,6 +374,12 @@ public class Classification implements DataFrameAnalysis {
         private Long randomizeSeed;
         private ClassAssignmentObjective classAssignmentObjective;
         private List<PreProcessor> featureProcessors;
+        private Double alpha;
+        private Double etaGrowthRatePerTree;
+        private Double softTreeDepthLimit;
+        private Double softTreeDepthTolerance;
+        private Double downsampleFactor;
+        private Integer maxOptimizationRoundsPerHyperparameter;
 
         private Builder(String dependentVariable) {
             this.dependentVariable = Objects.requireNonNull(dependentVariable);
@@ -356,10 +445,41 @@ public class Classification implements DataFrameAnalysis {
             return this;
         }
 
+        public Builder setAlpha(Double alpha) {
+            this.alpha = alpha;
+            return this;
+        }
+
+        public Builder setEtaGrowthRatePerTree(Double etaGrowthRatePerTree) {
+            this.etaGrowthRatePerTree = etaGrowthRatePerTree;
+            return this;
+        }
+
+        public Builder setSoftTreeDepthLimit(Double softTreeDepthLimit) {
+            this.softTreeDepthLimit = softTreeDepthLimit;
+            return this;
+        }
+
+        public Builder setSoftTreeDepthTolerance(Double softTreeDepthTolerance) {
+            this.softTreeDepthTolerance = softTreeDepthTolerance;
+            return this;
+        }
+
+        public Builder setDownsampleFactor(Double downsampleFactor) {
+            this.downsampleFactor = downsampleFactor;
+            return this;
+        }
+
+        public Builder setMaxOptimizationRoundsPerHyperparameter(Integer maxOptimizationRoundsPerHyperparameter) {
+            this.maxOptimizationRoundsPerHyperparameter = maxOptimizationRoundsPerHyperparameter;
+            return this;
+        }
+
         public Classification build() {
             return new Classification(dependentVariable, lambda, gamma, eta, maxTrees, featureBagFraction,
                 numTopFeatureImportanceValues, predictionFieldName, trainingPercent, numTopClasses, randomizeSeed,
-                classAssignmentObjective, featureProcessors);
+                classAssignmentObjective, featureProcessors, alpha, etaGrowthRatePerTree, softTreeDepthLimit, softTreeDepthTolerance,
+                downsampleFactor, maxOptimizationRoundsPerHyperparameter);
         }
     }
 }

+ 123 - 5
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Regression.java

@@ -59,6 +59,12 @@ public class Regression implements DataFrameAnalysis {
     static final ParseField LOSS_FUNCTION = new ParseField("loss_function");
     static final ParseField LOSS_FUNCTION_PARAMETER = new ParseField("loss_function_parameter");
     static final ParseField FEATURE_PROCESSORS = new ParseField("feature_processors");
+    static final ParseField ALPHA = new ParseField("alpha");
+    static final ParseField ETA_GROWTH_RATE_PER_TREE = new ParseField("eta_growth_rate_per_tree");
+    static final ParseField SOFT_TREE_DEPTH_LIMIT = new ParseField("soft_tree_depth_limit");
+    static final ParseField SOFT_TREE_DEPTH_TOLERANCE = new ParseField("soft_tree_depth_tolerance");
+    static final ParseField DOWNSAMPLE_FACTOR = new ParseField("downsample_factor");
+    static final ParseField MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER = new ParseField("max_optimization_rounds_per_hyperparameter");
 
     @SuppressWarnings("unchecked")
     private static final ConstructingObjectParser<Regression, Void> PARSER =
@@ -78,7 +84,13 @@ public class Regression implements DataFrameAnalysis {
                 (Long) a[9],
                 (LossFunction) a[10],
                 (Double) a[11],
-                (List<PreProcessor>) a[12]
+                (List<PreProcessor>) a[12],
+                (Double) a[13],
+                (Double) a[14],
+                (Double) a[15],
+                (Double) a[16],
+                (Double) a[17],
+                (Integer) a[18]
             ));
 
     static {
@@ -98,6 +110,12 @@ public class Regression implements DataFrameAnalysis {
             (p, c, n) -> p.namedObject(PreProcessor.class, n, c),
             (regression) -> {},
             FEATURE_PROCESSORS);
+        PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), ALPHA);
+        PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), ETA_GROWTH_RATE_PER_TREE);
+        PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), SOFT_TREE_DEPTH_LIMIT);
+        PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), SOFT_TREE_DEPTH_TOLERANCE);
+        PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), DOWNSAMPLE_FACTOR);
+        PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER);
     }
 
     private final String dependentVariable;
@@ -113,12 +131,20 @@ public class Regression implements DataFrameAnalysis {
     private final LossFunction lossFunction;
     private final Double lossFunctionParameter;
     private final List<PreProcessor> featureProcessors;
+    private final Double alpha;
+    private final Double etaGrowthRatePerTree;
+    private final Double softTreeDepthLimit;
+    private final Double softTreeDepthTolerance;
+    private final Double downsampleFactor;
+    private final Integer maxOptimizationRoundsPerHyperparameter;
 
     private Regression(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
                        @Nullable Integer maxTrees, @Nullable Double featureBagFraction,
                        @Nullable Integer numTopFeatureImportanceValues, @Nullable String predictionFieldName,
                        @Nullable Double trainingPercent, @Nullable Long randomizeSeed, @Nullable LossFunction lossFunction,
-                       @Nullable Double lossFunctionParameter, @Nullable List<PreProcessor> featureProcessors) {
+                       @Nullable Double lossFunctionParameter, @Nullable List<PreProcessor> featureProcessors, @Nullable Double alpha,
+                       @Nullable Double etaGrowthRatePerTree, @Nullable Double softTreeDepthLimit, @Nullable Double softTreeDepthTolerance,
+                       @Nullable Double downsampleFactor, @Nullable Integer maxOptimizationRoundsPerHyperparameter) {
         this.dependentVariable = Objects.requireNonNull(dependentVariable);
         this.lambda = lambda;
         this.gamma = gamma;
@@ -132,6 +158,12 @@ public class Regression implements DataFrameAnalysis {
         this.lossFunction = lossFunction;
         this.lossFunctionParameter = lossFunctionParameter;
         this.featureProcessors = featureProcessors;
+        this.alpha = alpha;
+        this.etaGrowthRatePerTree = etaGrowthRatePerTree;
+        this.softTreeDepthLimit = softTreeDepthLimit;
+        this.softTreeDepthTolerance = softTreeDepthTolerance;
+        this.downsampleFactor = downsampleFactor;
+        this.maxOptimizationRoundsPerHyperparameter = maxOptimizationRoundsPerHyperparameter;
     }
 
     @Override
@@ -191,6 +223,30 @@ public class Regression implements DataFrameAnalysis {
         return featureProcessors;
     }
 
+    public Double getAlpha() {
+        return alpha;
+    }
+
+    public Double getEtaGrowthRatePerTree() {
+        return etaGrowthRatePerTree;
+    }
+
+    public Double getSoftTreeDepthLimit() {
+        return softTreeDepthLimit;
+    }
+
+    public Double getSoftTreeDepthTolerance() {
+        return softTreeDepthTolerance;
+    }
+
+    public Double getDownsampleFactor() {
+        return downsampleFactor;
+    }
+
+    public Integer getMaxOptimizationRoundsPerHyperparameter() {
+        return maxOptimizationRoundsPerHyperparameter;
+    }
+
     @Override
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
         builder.startObject();
@@ -231,6 +287,24 @@ public class Regression implements DataFrameAnalysis {
         if (featureProcessors != null) {
             NamedXContentObjectHelper.writeNamedObjects(builder, params, true, FEATURE_PROCESSORS.getPreferredName(), featureProcessors);
         }
+        if (alpha != null) {
+            builder.field(ALPHA.getPreferredName(), alpha);
+        }
+        if (etaGrowthRatePerTree != null) {
+            builder.field(ETA_GROWTH_RATE_PER_TREE.getPreferredName(), etaGrowthRatePerTree);
+        }
+        if (softTreeDepthLimit != null) {
+            builder.field(SOFT_TREE_DEPTH_LIMIT.getPreferredName(), softTreeDepthLimit);
+        }
+        if (softTreeDepthTolerance != null) {
+            builder.field(SOFT_TREE_DEPTH_TOLERANCE.getPreferredName(), softTreeDepthTolerance);
+        }
+        if (downsampleFactor != null) {
+            builder.field(DOWNSAMPLE_FACTOR.getPreferredName(), downsampleFactor);
+        }
+        if (maxOptimizationRoundsPerHyperparameter != null) {
+            builder.field(MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER.getPreferredName(), maxOptimizationRoundsPerHyperparameter);
+        }
         builder.endObject();
         return builder;
     }
@@ -238,7 +312,8 @@ public class Regression implements DataFrameAnalysis {
     @Override
     public int hashCode() {
         return Objects.hash(dependentVariable, lambda, gamma, eta, maxTrees, featureBagFraction, numTopFeatureImportanceValues,
-            predictionFieldName, trainingPercent, randomizeSeed, lossFunction, lossFunctionParameter, featureProcessors);
+            predictionFieldName, trainingPercent, randomizeSeed, lossFunction, lossFunctionParameter, featureProcessors, alpha,
+            etaGrowthRatePerTree, softTreeDepthLimit, softTreeDepthTolerance, downsampleFactor, maxOptimizationRoundsPerHyperparameter);
     }
 
     @Override
@@ -258,7 +333,13 @@ public class Regression implements DataFrameAnalysis {
             && Objects.equals(randomizeSeed, that.randomizeSeed)
             && Objects.equals(lossFunction, that.lossFunction)
             && Objects.equals(lossFunctionParameter, that.lossFunctionParameter)
-            && Objects.equals(featureProcessors, that.featureProcessors);
+            && Objects.equals(featureProcessors, that.featureProcessors)
+            && Objects.equals(alpha, that.alpha)
+            && Objects.equals(etaGrowthRatePerTree, that.etaGrowthRatePerTree)
+            && Objects.equals(softTreeDepthLimit, that.softTreeDepthLimit)
+            && Objects.equals(softTreeDepthTolerance, that.softTreeDepthTolerance)
+            && Objects.equals(downsampleFactor, that.downsampleFactor)
+            && Objects.equals(maxOptimizationRoundsPerHyperparameter, that.maxOptimizationRoundsPerHyperparameter);
     }
 
     @Override
@@ -280,6 +361,12 @@ public class Regression implements DataFrameAnalysis {
         private LossFunction lossFunction;
         private Double lossFunctionParameter;
         private List<PreProcessor> featureProcessors;
+        private Double alpha;
+        private Double etaGrowthRatePerTree;
+        private Double softTreeDepthLimit;
+        private Double softTreeDepthTolerance;
+        private Double downsampleFactor;
+        private Integer maxOptimizationRoundsPerHyperparameter;
 
         private Builder(String dependentVariable) {
             this.dependentVariable = Objects.requireNonNull(dependentVariable);
@@ -345,10 +432,41 @@ public class Regression implements DataFrameAnalysis {
             return this;
         }
 
+        public Builder setAlpha(Double alpha) {
+            this.alpha = alpha;
+            return this;
+        }
+
+        public Builder setEtaGrowthRatePerTree(Double etaGrowthRatePerTree) {
+            this.etaGrowthRatePerTree = etaGrowthRatePerTree;
+            return this;
+        }
+
+        public Builder setSoftTreeDepthLimit(Double softTreeDepthLimit) {
+            this.softTreeDepthLimit = softTreeDepthLimit;
+            return this;
+        }
+
+        public Builder setSoftTreeDepthTolerance(Double softTreeDepthTolerance) {
+            this.softTreeDepthTolerance = softTreeDepthTolerance;
+            return this;
+        }
+
+        public Builder setDownsampleFactor(Double downsampleFactor) {
+            this.downsampleFactor = downsampleFactor;
+            return this;
+        }
+
+        public Builder setMaxOptimizationRoundsPerHyperparameter(Integer maxOptimizationRoundsPerHyperparameter) {
+            this.maxOptimizationRoundsPerHyperparameter = maxOptimizationRoundsPerHyperparameter;
+            return this;
+        }
+
         public Regression build() {
             return new Regression(dependentVariable, lambda, gamma, eta, maxTrees, featureBagFraction,
                 numTopFeatureImportanceValues, predictionFieldName, trainingPercent, randomizeSeed, lossFunction, lossFunctionParameter,
-                featureProcessors);
+                featureProcessors, alpha, etaGrowthRatePerTree, softTreeDepthLimit, softTreeDepthTolerance, downsampleFactor,
+                maxOptimizationRoundsPerHyperparameter);
         }
     }
 

+ 12 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java

@@ -1360,6 +1360,12 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
                 .setNumTopFeatureImportanceValues(3)
                 .setLossFunction(org.elasticsearch.client.ml.dataframe.Regression.LossFunction.MSLE)
                 .setLossFunctionParameter(1.0)
+                .setAlpha(0.5)
+                .setEtaGrowthRatePerTree(1.0)
+                .setSoftTreeDepthLimit(1.0)
+                .setSoftTreeDepthTolerance(0.1)
+                .setDownsampleFactor(0.5)
+                .setMaxOptimizationRoundsPerHyperparameter(3)
                 .build())
             .setDescription("this is a regression")
             .build();
@@ -1405,6 +1411,12 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
                 .setMaxTrees(10)
                 .setFeatureBagFraction(0.5)
                 .setNumTopFeatureImportanceValues(3)
+                .setAlpha(0.5)
+                .setEtaGrowthRatePerTree(1.0)
+                .setSoftTreeDepthLimit(1.0)
+                .setSoftTreeDepthTolerance(0.1)
+                .setDownsampleFactor(0.5)
+                .setMaxOptimizationRoundsPerHyperparameter(3)
                 .build())
             .setDescription("this is a classification")
             .build();

+ 12 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java

@@ -3053,6 +3053,12 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
                 .setFeatureProcessors(Arrays.asList(OneHotEncoding.builder("categorical_feature") // <13>
                     .addOneHot("cat", "cat_column")
                     .build()))
+                .setAlpha(1.0) // <14>
+                .setEtaGrowthRatePerTree(1.0) // <15>
+                .setSoftTreeDepthLimit(1.0) // <16>
+                .setSoftTreeDepthTolerance(1.0) // <17>
+                .setDownsampleFactor(0.5) // <18>
+                .setMaxOptimizationRoundsPerHyperparameter(3) // <19>
                 .build();
             // end::put-data-frame-analytics-classification
 
@@ -3072,6 +3078,12 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
                 .setFeatureProcessors(Arrays.asList(OneHotEncoding.builder("categorical_feature") // <13>
                     .addOneHot("cat", "cat_column")
                     .build()))
+                .setAlpha(1.0) // <14>
+                .setEtaGrowthRatePerTree(1.0) // <15>
+                .setSoftTreeDepthLimit(1.0) // <16>
+                .setSoftTreeDepthTolerance(1.0) // <17>
+                .setDownsampleFactor(0.5) // <18>
+                .setMaxOptimizationRoundsPerHyperparameter(3) // <19>
                 .build();
             // end::put-data-frame-analytics-regression
 

+ 6 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java

@@ -54,6 +54,12 @@ public class ClassificationTests extends AbstractXContentTestCase<Classification
                     TargetMeanEncodingTests.createRandom()))
                     .limit(randomIntBetween(1, 10))
                     .collect(Collectors.toList()))
+            .setAlpha(randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true))
+            .setEtaGrowthRatePerTree(randomBoolean() ? null : randomDoubleBetween(0.5, 2.0, true))
+            .setSoftTreeDepthLimit(randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true))
+            .setSoftTreeDepthTolerance(randomBoolean() ? null : randomDoubleBetween(0.01, Double.MAX_VALUE, true))
+            .setDownsampleFactor(randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false))
+            .setMaxOptimizationRoundsPerHyperparameter(randomBoolean() ? null : randomIntBetween(0, 20))
             .build();
     }
 

+ 6 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/RegressionTests.java

@@ -53,6 +53,12 @@ public class RegressionTests extends AbstractXContentTestCase<Regression> {
                     TargetMeanEncodingTests.createRandom()))
                     .limit(randomIntBetween(1, 10))
                     .collect(Collectors.toList()))
+            .setAlpha(randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true))
+            .setEtaGrowthRatePerTree(randomBoolean() ? null : randomDoubleBetween(0.5, 2.0, true))
+            .setSoftTreeDepthLimit(randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true))
+            .setSoftTreeDepthTolerance(randomBoolean() ? null : randomDoubleBetween(0.01, Double.MAX_VALUE, true))
+            .setDownsampleFactor(randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false))
+            .setMaxOptimizationRoundsPerHyperparameter(randomBoolean() ? null : randomIntBetween(0, 20))
             .build();
     }
 

+ 12 - 0
docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc

@@ -128,6 +128,12 @@ include-tagged::{doc-tests-file}[{api}-classification]
 <12> The number of top classes (or -1 which denotes all classes) to be reported in the results. Defaults to 2.
 <13> Custom feature processors that will create new features for analysis from the included document
      fields. Note, automatic categorical {ml-docs}/ml-feature-encoding.html[feature encoding] still occurs for all features.
+<14> The alpha regularization parameter. A non-negative double.
+<15> The growth rate of the shrinkage parameter. A double in [0.5, 2.0].
+<16> The soft tree depth limit. A non-negative double.
+<17> The soft tree depth tolerance. Controls how much the soft tree depth limit is respected. A double greater than or equal to 0.01.
+<18> The amount by which to downsample the data for stochastic gradient estimates. A double in (0, 1.0].
+<19> The maximum number of optimisation rounds we use for hyperparameter optimisation per parameter. An integer in [0, 20].
 
 ===== Regression
 
@@ -152,6 +158,12 @@ include-tagged::{doc-tests-file}[{api}-regression]
 <12> An optional parameter to the loss function.
 <13> Custom feature processors that will create new features for analysis from the included document
 fields. Note, automatic categorical {ml-docs}/ml-feature-encoding.html[feature encoding] still occurs for all features.
+<14> The alpha regularization parameter. A non-negative double.
+<15> The growth rate of the shrinkage parameter. A double in [0.5, 2.0].
+<16> The soft tree depth limit. A non-negative double.
+<17> The soft tree depth tolerance. Controls how much the soft tree depth limit is respected. A double greater than or equal to 0.01.
+<18> The amount by which to downsample the data for stochastic gradient estimates. A double in (0, 1.0].
+<19> The maximum number of optimisation rounds we use for hyperparameter optimisation per parameter. An integer in [0, 20].
 
 ==== Analyzed fields
 

+ 48 - 0
docs/reference/ml/df-analytics/apis/put-dfanalytics.asciidoc

@@ -96,6 +96,10 @@ understand the function of these parameters.
 .Properties of `classification`
 [%collapsible%open]
 =====
+`alpha`::::
+(Optional, double)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=dfas-alpha]
+
 `class_assignment_objective`::::
 (Optional, string)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=class-assignment-objective]
@@ -109,10 +113,18 @@ The data type of the field must be numeric (`integer`, `short`, `long`, `byte`),
 categorical (`ip` or `keyword`), or boolean. There must be no more than 30
 different values in this field.
 
+`downsample_factor`::::
+(Optional, double)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=dfas-downsample-factor]
+
 `eta`::::
 (Optional, double)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=eta]
 
+`eta_growth_rate_per_tree`::::
+(Optional, double)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=dfas-eta-growth]
+
 `feature_bag_fraction`::::
 (Optional, double)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=feature-bag-fraction]
@@ -234,6 +246,10 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=gamma]
 (Optional, double)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=lambda]
 
+`max_optimization_rounds_per_hyperparameter`::::
+(Optional, integer)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=dfas-max-optimization-rounds]
+
 `max_trees`::::
 (Optional, integer)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=max-trees]
@@ -267,6 +283,14 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=prediction-field-name]
 (Optional, long)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=randomize-seed]
 
+`soft_tree_depth_limit`::::
+(Optional, double)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=dfas-soft-limit]
+
+`soft_tree_depth_tolerance`::::
+(Optional, double)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=dfas-soft-tolerance]
+
 `training_percent`::::
 (Optional, integer)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=training-percent]
@@ -320,6 +344,10 @@ understand the function of these parameters.
 .Properties of `regression`
 [%collapsible%open]
 =====
+`alpha`::::
+(Optional, double)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=dfas-alpha]
+
 `dependent_variable`::::
 (Required, string)
 +
@@ -327,10 +355,18 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=dependent-variable]
 +
 The data type of the field must be numeric.
 
+`downsample_factor`::::
+(Optional, double)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=dfas-downsample-factor]
+
 `eta`::::
 (Optional, double)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=eta]
 
+`eta_growth_rate_per_tree`::::
+(Optional, double)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=dfas-eta-growth]
+
 `feature_bag_fraction`::::
 (Optional, double)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=feature-bag-fraction]
@@ -359,6 +395,10 @@ to learn more.
 (Optional, double)
 A positive number that is used as a parameter to the `loss_function`.
 
+`max_optimization_rounds_per_hyperparameter`::::
+(Optional, integer)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=dfas-max-optimization-rounds]
+
 `max_trees`::::
 (Optional, integer)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=max-trees]
@@ -377,6 +417,14 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=prediction-field-name]
 (Optional, long)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=randomize-seed]
 
+`soft_tree_depth_limit`::::
+(Optional, double)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=dfas-soft-limit]
+
+`soft_tree_depth_tolerance`::::
+(Optional, double)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=dfas-soft-tolerance]
+
 `training_percent`::::
 (Optional, integer)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=training-percent]

+ 34 - 9
docs/reference/ml/ml-shared.asciidoc

@@ -537,16 +537,32 @@ the detectors in the `analysis_config`, starting at zero.
 end::detector-index[]
 
 tag::dfas-alpha[]
-Regularization factor to penalize deeper trees when training decision trees.
+Advanced configuration option. {ml} uses loss guided tree growing.
+This means that trees will grow where the regularized loss reduces
+the most. This parameter multiplies a term based on tree depth in
+the regularized loss. Higher values result in shallower trees
+and faster training times. Values should be greater than or equal
+to zero. By default, this value is calculated during hyperparameter optimization.
 end::dfas-alpha[]
 
 tag::dfas-downsample-factor[]
-The value of the downsample factor.
+Advanced configuration option. This controls the fraction of data
+that is used to compute the derivatives of the loss function for tree training.
+The lower the value the smaller the fraction of data that is used.
+Typically accuracy improves if this is set to be less than 1. However, too small
+a value may result in poor convergence for the ensemble and so require more trees.
+For more information about shrinkage, refer to
+{wikipedia}/Gradient_boosting#Stochastic_gradient_boosting[this wiki article].
+Values must be greater than zero and less than or equal to 1.
+By default, this value is calculated during hyperparameter optimization.
 end::dfas-downsample-factor[]
 
 tag::dfas-eta-growth[]
-Specifies the rate at which the `eta` increases for each new tree that is added
-to the forest. For example, a rate of `1.05` increases `eta` by 5%.
+Advanced configuration option.
+Specifies the rate at which `eta` increases for each new tree that is added
+to the forest. For example, a rate of `1.05` increases `eta` by 5% for each
+extra tree. Values must be in the range of 0.5 to 2.
+By default, this value is calculated during hyperparameter optimization.
 end::dfas-eta-growth[]
 
 tag::dfas-feature-bag-fraction[]
@@ -653,10 +669,12 @@ training stops.
 end::dfas-max-attempts[]
 
 tag::dfas-max-optimization-rounds[]
+Advanced configuration option.
 A multiplier responsible for determining the maximum number of
 hyperparameter optimization steps in the Bayesian optimization procedure.
 The maximum number of steps is determined based on the number of undefined
 hyperparameters times the maximum optimization rounds per hyperparameter.
+By default, this value is calculated during hyperparameter optimization.
 end::dfas-max-optimization-rounds[]
 
 tag::dfas-num-folds[]
@@ -669,13 +687,20 @@ decision tree when the tree is trained.
 end::dfas-num-splits[]
 
 tag::dfas-soft-limit[]
-Tree depth limit is used for calculating the tree depth penalty. This is a soft
-limit, it can be exceeded.
+Advanced configuration option. {ml} uses loss guided tree growing.
+This means that trees will grow where the regularized loss reduces
+the most. The regularized loss increases quickly where the tree depth
+exceeds this parameter. This is a soft limit, it can be exceeded.
+Values must be greater than or equal to 0.
+By default, this value is calculated during hyperparameter optimization.
 end::dfas-soft-limit[]
 
 tag::dfas-soft-tolerance[]
-Tree depth tolerance is used for calculating the tree depth penalty. This is a
-soft limit, it can be exceeded.
+Advanced configuration option.
+This controls how quickly the regularized loss increases when the tree
+depth exceeds `soft_tree_depth_limit`.
+Values must be greater than or equal to 0.01.
+By default, this value is calculated during hyperparameter optimization.
 end::dfas-soft-tolerance[]
 
 tag::dfas-timestamp[]
@@ -722,7 +747,7 @@ tag::eta[]
 Advanced configuration option. The shrinkage applied to the weights. Smaller
 values result in larger forests which have a better generalization error.
 However, the smaller the value the longer the training will take. For more
-information about shrinkage, see
+information about shrinkage, refer to
 {wikipedia}/Gradient_boosting#Shrinkage[this wiki article].
 By default, this value is calculated during hyperparameter optimization.
 end::eta[]

+ 189 - 4
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/BoostedTreeParams.java

@@ -36,6 +36,13 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable {
     public static final ParseField MAX_TREES = new ParseField("max_trees", "maximum_number_trees");
     public static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction");
     public static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values");
+    public static final ParseField ALPHA = new ParseField("alpha");
+    public static final ParseField ETA_GROWTH_RATE_PER_TREE = new ParseField("eta_growth_rate_per_tree");
+    public static final ParseField SOFT_TREE_DEPTH_LIMIT = new ParseField("soft_tree_depth_limit");
+    public static final ParseField SOFT_TREE_DEPTH_TOLERANCE = new ParseField("soft_tree_depth_tolerance");
+    public static final ParseField DOWNSAMPLE_FACTOR = new ParseField("downsample_factor");
+    public static final ParseField MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER =
+        new ParseField("max_optimization_rounds_per_hyperparameter");
 
     static void declareFields(AbstractObjectParser<?, Void> parser) {
         parser.declareDouble(optionalConstructorArg(), LAMBDA);
@@ -44,6 +51,12 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable {
         parser.declareInt(optionalConstructorArg(), MAX_TREES);
         parser.declareDouble(optionalConstructorArg(), FEATURE_BAG_FRACTION);
         parser.declareInt(optionalConstructorArg(), NUM_TOP_FEATURE_IMPORTANCE_VALUES);
+        parser.declareDouble(optionalConstructorArg(), ALPHA);
+        parser.declareDouble(optionalConstructorArg(), ETA_GROWTH_RATE_PER_TREE);
+        parser.declareDouble(optionalConstructorArg(), SOFT_TREE_DEPTH_LIMIT);
+        parser.declareDouble(optionalConstructorArg(), SOFT_TREE_DEPTH_TOLERANCE);
+        parser.declareDouble(optionalConstructorArg(), DOWNSAMPLE_FACTOR);
+        parser.declareInt(optionalConstructorArg(), MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER);
     }
 
     private final Double lambda;
@@ -52,13 +65,25 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable {
     private final Integer maxTrees;
     private final Double featureBagFraction;
     private final Integer numTopFeatureImportanceValues;
+    private final Double alpha;
+    private final Double etaGrowthRatePerTree;
+    private final Double softTreeDepthLimit;
+    private final Double softTreeDepthTolerance;
+    private final Double downsampleFactor;
+    private final Integer maxOptimizationRoundsPerHyperparameter;
 
     public BoostedTreeParams(@Nullable Double lambda,
                              @Nullable Double gamma,
                              @Nullable Double eta,
                              @Nullable Integer maxTrees,
                              @Nullable Double featureBagFraction,
-                             @Nullable Integer numTopFeatureImportanceValues) {
+                             @Nullable Integer numTopFeatureImportanceValues,
+                             @Nullable Double alpha,
+                             @Nullable Double etaGrowthRatePerTree,
+                             @Nullable Double softTreeDepthLimit,
+                             @Nullable Double softTreeDepthTolerance,
+                             @Nullable Double downsampleFactor,
+                             @Nullable Integer maxOptimizationRoundsPerHyperparameter) {
         if (lambda != null && lambda < 0) {
             throw ExceptionsHelper.badRequestException("[{}] must be a non-negative double", LAMBDA.getPreferredName());
         }
@@ -78,12 +103,39 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable {
             throw ExceptionsHelper.badRequestException("[{}] must be a non-negative integer",
                 NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName());
         }
+        if (alpha != null && alpha < 0) {
+            throw ExceptionsHelper.badRequestException("[{}] must be a non-negative double", ALPHA.getPreferredName());
+        }
+        if (etaGrowthRatePerTree != null && (etaGrowthRatePerTree < 0.5 || etaGrowthRatePerTree > 2.0)) {
+            throw ExceptionsHelper.badRequestException("[{}] must be a double in [0.5, 2.0]", ETA_GROWTH_RATE_PER_TREE.getPreferredName());
+        }
+        if (softTreeDepthLimit != null && softTreeDepthLimit < 0) {
+            throw ExceptionsHelper.badRequestException("[{}] must be a non-negative double", SOFT_TREE_DEPTH_LIMIT.getPreferredName());
+        }
+        if (softTreeDepthTolerance != null && softTreeDepthTolerance < 0.01) {
+            throw ExceptionsHelper.badRequestException("[{}] must be a double greater than or equal to 0.01",
+                SOFT_TREE_DEPTH_TOLERANCE.getPreferredName());
+        }
+        if (downsampleFactor != null && (downsampleFactor <= 0 || downsampleFactor > 1.0)) {
+            throw ExceptionsHelper.badRequestException("[{}] must be a double in (0, 1]", DOWNSAMPLE_FACTOR.getPreferredName());
+        }
+        if (maxOptimizationRoundsPerHyperparameter != null
+                && (maxOptimizationRoundsPerHyperparameter < 0 || maxOptimizationRoundsPerHyperparameter > 20)) {
+            throw ExceptionsHelper.badRequestException("[{}] must be an integer in [0, 20]",
+                MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER.getPreferredName());
+        }
         this.lambda = lambda;
         this.gamma = gamma;
         this.eta = eta;
         this.maxTrees = maxTrees;
         this.featureBagFraction = featureBagFraction;
         this.numTopFeatureImportanceValues = numTopFeatureImportanceValues;
+        this.alpha = alpha;
+        this.etaGrowthRatePerTree = etaGrowthRatePerTree;
+        this.softTreeDepthLimit = softTreeDepthLimit;
+        this.softTreeDepthTolerance = softTreeDepthTolerance;
+        this.downsampleFactor = downsampleFactor;
+        this.maxOptimizationRoundsPerHyperparameter = maxOptimizationRoundsPerHyperparameter;
     }
 
     BoostedTreeParams(StreamInput in) throws IOException {
@@ -97,6 +149,21 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable {
         } else {
             numTopFeatureImportanceValues = null;
         }
+        if (in.getVersion().onOrAfter(Version.V_8_0_0)) {
+            alpha = in.readOptionalDouble();
+            etaGrowthRatePerTree = in.readOptionalDouble();
+            softTreeDepthLimit = in.readOptionalDouble();
+            softTreeDepthTolerance = in.readOptionalDouble();
+            downsampleFactor = in.readOptionalDouble();
+            maxOptimizationRoundsPerHyperparameter = in.readOptionalVInt();
+        } else {
+            alpha = null;
+            etaGrowthRatePerTree = null;
+            softTreeDepthLimit = null;
+            softTreeDepthTolerance = null;
+            downsampleFactor = null;
+            maxOptimizationRoundsPerHyperparameter = null;
+        }
     }
 
     public Double getLambda() {
@@ -123,6 +190,30 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable {
         return numTopFeatureImportanceValues;
     }
 
+    public Double getAlpha() {
+        return alpha;
+    }
+
+    public Double getEtaGrowthRatePerTree() {
+        return etaGrowthRatePerTree;
+    }
+
+    public Double getSoftTreeDepthLimit() {
+        return softTreeDepthLimit;
+    }
+
+    public Double getSoftTreeDepthTolerance() {
+        return softTreeDepthTolerance;
+    }
+
+    public Double getDownsampleFactor() {
+        return downsampleFactor;
+    }
+
+    public Integer getMaxOptimizationRoundsPerHyperparameter() {
+        return maxOptimizationRoundsPerHyperparameter;
+    }
+
     @Override
     public void writeTo(StreamOutput out) throws IOException {
         out.writeOptionalDouble(lambda);
@@ -133,10 +224,21 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable {
         if (out.getVersion().onOrAfter(Version.V_7_6_0)) {
             out.writeOptionalInt(numTopFeatureImportanceValues);
         }
+        if (out.getVersion().onOrAfter(Version.V_8_0_0)) {
+            out.writeOptionalDouble(alpha);
+            out.writeOptionalDouble(etaGrowthRatePerTree);
+            out.writeOptionalDouble(softTreeDepthLimit);
+            out.writeOptionalDouble(softTreeDepthTolerance);
+            out.writeOptionalDouble(downsampleFactor);
+            out.writeOptionalVInt(maxOptimizationRoundsPerHyperparameter);
+        }
     }
 
     @Override
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        if (alpha != null) {
+            builder.field(ALPHA.getPreferredName(), alpha);
+        }
         if (lambda != null) {
             builder.field(LAMBDA.getPreferredName(), lambda);
         }
@@ -146,6 +248,9 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable {
         if (eta != null) {
             builder.field(ETA.getPreferredName(), eta);
         }
+        if (etaGrowthRatePerTree != null) {
+            builder.field(ETA_GROWTH_RATE_PER_TREE.getPreferredName(), etaGrowthRatePerTree);
+        }
         if (maxTrees != null) {
             builder.field(MAX_TREES.getPreferredName(), maxTrees);
         }
@@ -155,6 +260,18 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable {
         if (numTopFeatureImportanceValues != null) {
             builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues);
         }
+        if (softTreeDepthLimit != null) {
+            builder.field(SOFT_TREE_DEPTH_LIMIT.getPreferredName(), softTreeDepthLimit);
+        }
+        if (softTreeDepthTolerance != null) {
+            builder.field(SOFT_TREE_DEPTH_TOLERANCE.getPreferredName(), softTreeDepthTolerance);
+        }
+        if (downsampleFactor != null) {
+            builder.field(DOWNSAMPLE_FACTOR.getPreferredName(), downsampleFactor);
+        }
+        if (maxOptimizationRoundsPerHyperparameter != null) {
+            builder.field(MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER.getPreferredName(), maxOptimizationRoundsPerHyperparameter);
+        }
         return builder;
     }
 
@@ -178,6 +295,24 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable {
         if (numTopFeatureImportanceValues != null) {
             params.put(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues);
         }
+        if (alpha != null) {
+            params.put(ALPHA.getPreferredName(), alpha);
+        }
+        if (etaGrowthRatePerTree != null) {
+            params.put(ETA_GROWTH_RATE_PER_TREE.getPreferredName(), etaGrowthRatePerTree);
+        }
+        if (softTreeDepthLimit != null) {
+            params.put(SOFT_TREE_DEPTH_LIMIT.getPreferredName(), softTreeDepthLimit);
+        }
+        if (softTreeDepthTolerance != null) {
+            params.put(SOFT_TREE_DEPTH_TOLERANCE.getPreferredName(), softTreeDepthTolerance);
+        }
+        if (downsampleFactor != null) {
+            params.put(DOWNSAMPLE_FACTOR.getPreferredName(), downsampleFactor);
+        }
+        if (maxOptimizationRoundsPerHyperparameter != null) {
+            params.put(MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER.getPreferredName(), maxOptimizationRoundsPerHyperparameter);
+        }
         return params;
     }
 
@@ -191,12 +326,19 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable {
             && Objects.equals(eta, that.eta)
             && Objects.equals(maxTrees, that.maxTrees)
             && Objects.equals(featureBagFraction, that.featureBagFraction)
-            && Objects.equals(numTopFeatureImportanceValues, that.numTopFeatureImportanceValues);
+            && Objects.equals(numTopFeatureImportanceValues, that.numTopFeatureImportanceValues)
+            && Objects.equals(alpha, that.alpha)
+            && Objects.equals(etaGrowthRatePerTree, that.etaGrowthRatePerTree)
+            && Objects.equals(softTreeDepthLimit, that.softTreeDepthLimit)
+            && Objects.equals(softTreeDepthTolerance, that.softTreeDepthTolerance)
+            && Objects.equals(downsampleFactor, that.downsampleFactor)
+            && Objects.equals(maxOptimizationRoundsPerHyperparameter, that.maxOptimizationRoundsPerHyperparameter);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(lambda, gamma, eta, maxTrees, featureBagFraction, numTopFeatureImportanceValues);
+        return Objects.hash(lambda, gamma, eta, maxTrees, featureBagFraction, numTopFeatureImportanceValues, alpha, etaGrowthRatePerTree,
+            softTreeDepthLimit, softTreeDepthTolerance, downsampleFactor, maxOptimizationRoundsPerHyperparameter);
     }
 
     public static Builder builder() {
@@ -211,6 +353,12 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable {
         private Integer maxTrees;
         private Double featureBagFraction;
         private Integer numTopFeatureImportanceValues;
+        private Double alpha;
+        private Double etaGrowthRatePerTree;
+        private Double softTreeDepthLimit;
+        private Double softTreeDepthTolerance;
+        private Double downsampleFactor;
+        private Integer maxOptimizationRoundsPerHyperparameter;
 
         private Builder() {}
 
@@ -221,6 +369,12 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable {
             this.maxTrees = params.maxTrees;
             this.featureBagFraction = params.featureBagFraction;
             this.numTopFeatureImportanceValues = params.numTopFeatureImportanceValues;
+            this.alpha = params.alpha;
+            this.etaGrowthRatePerTree = params.etaGrowthRatePerTree;
+            this.softTreeDepthLimit = params.softTreeDepthLimit;
+            this.softTreeDepthTolerance = params.softTreeDepthTolerance;
+            this.downsampleFactor = params.downsampleFactor;
+            this.maxOptimizationRoundsPerHyperparameter = params.maxOptimizationRoundsPerHyperparameter;
         }
 
         public Builder setLambda(Double lambda) {
@@ -253,8 +407,39 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable {
             return this;
         }
 
+        public Builder setAlpha(Double alpha) {
+            this.alpha = alpha;
+            return this;
+        }
+
+        public Builder setEtaGrowthRatePerTree(Double etaGrowthRatePerTree) {
+            this.etaGrowthRatePerTree = etaGrowthRatePerTree;
+            return this;
+        }
+
+        public Builder setSoftTreeDepthLimit(Double softTreeDepthLimit) {
+            this.softTreeDepthLimit = softTreeDepthLimit;
+            return this;
+        }
+
+        public Builder setSoftTreeDepthTolerance(Double softTreeDepthTolerance) {
+            this.softTreeDepthTolerance = softTreeDepthTolerance;
+            return this;
+        }
+
+        public Builder setDownsampleFactor(Double downsampleFactor) {
+            this.downsampleFactor = downsampleFactor;
+            return this;
+        }
+
+        public Builder setMaxOptimizationRoundsPerHyperparameter(Integer maxOptimizationRoundsPerHyperparameter) {
+            this.maxOptimizationRoundsPerHyperparameter = maxOptimizationRoundsPerHyperparameter;
+            return this;
+        }
+
         public BoostedTreeParams build() {
-            return new BoostedTreeParams(lambda, gamma, eta, maxTrees, featureBagFraction, numTopFeatureImportanceValues);
+            return new BoostedTreeParams(lambda, gamma, eta, maxTrees, featureBagFraction, numTopFeatureImportanceValues, alpha,
+                etaGrowthRatePerTree, softTreeDepthLimit, softTreeDepthTolerance, downsampleFactor, maxOptimizationRoundsPerHyperparameter);
         }
     }
 }

+ 8 - 7
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java

@@ -75,13 +75,14 @@ public class Classification implements DataFrameAnalysis {
             lenient,
             a -> new Classification(
                 (String) a[0],
-                new BoostedTreeParams((Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5], (Integer) a[6]),
-                (String) a[7],
-                (ClassAssignmentObjective) a[8],
-                (Integer) a[9],
-                (Double) a[10],
-                (Long) a[11],
-                (List<PreProcessor>) a[12]));
+                new BoostedTreeParams((Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5], (Integer) a[6],
+                    (Double) a[7], (Double) a[8], (Double) a[9], (Double) a[10], (Double) a[11], (Integer) a[12]),
+                (String) a[13],
+                (ClassAssignmentObjective) a[14],
+                (Integer) a[15],
+                (Double) a[16],
+                (Long) a[17],
+                (List<PreProcessor>) a[18]));
         parser.declareString(constructorArg(), DEPENDENT_VARIABLE);
         BoostedTreeParams.declareFields(parser);
         parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME);

+ 8 - 7
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java

@@ -65,13 +65,14 @@ public class Regression implements DataFrameAnalysis {
             lenient,
             a -> new Regression(
                 (String) a[0],
-                new BoostedTreeParams((Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5], (Integer) a[6]),
-                (String) a[7],
-                (Double) a[8],
-                (Long) a[9],
-                (LossFunction) a[10],
-                (Double) a[11],
-                (List<PreProcessor>) a[12]));
+                new BoostedTreeParams((Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5], (Integer) a[6],
+                    (Double) a[7], (Double) a[8], (Double) a[9], (Double) a[10], (Double) a[11], (Integer) a[12]),
+                (String) a[13],
+                (Double) a[14],
+                (Long) a[15],
+                (LossFunction) a[16],
+                (Double) a[17],
+                (List<PreProcessor>) a[18]));
         parser.declareString(constructorArg(), DEPENDENT_VARIABLE);
         BoostedTreeParams.declareFields(parser);
         parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME);

+ 158 - 1
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/BoostedTreeParamsTests.java

@@ -13,8 +13,12 @@ import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
 
 import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
 
+import static org.hamcrest.Matchers.anEmptyMap;
 import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.is;
 
 public class BoostedTreeParamsTests extends AbstractBWCSerializationTestCase<BoostedTreeParams> {
 
@@ -24,7 +28,8 @@ public class BoostedTreeParamsTests extends AbstractBWCSerializationTestCase<Boo
             new ConstructingObjectParser<>(
                 BoostedTreeParams.NAME,
                 true,
-                a -> new BoostedTreeParams((Double) a[0], (Double) a[1], (Double) a[2], (Integer) a[3], (Double) a[4], (Integer) a[5]));
+                a -> new BoostedTreeParams((Double) a[0], (Double) a[1], (Double) a[2], (Integer) a[3], (Double) a[4], (Integer) a[5],
+                    (Double) a[6], (Double) a[7], (Double) a[8], (Double) a[9], (Double) a[10], (Integer) a[11]));
         BoostedTreeParams.declareFields(objParser);
         return objParser.apply(parser, null);
     }
@@ -42,6 +47,12 @@ public class BoostedTreeParamsTests extends AbstractBWCSerializationTestCase<Boo
             .setMaxTrees(randomBoolean() ? null : randomIntBetween(1, 2000))
             .setFeatureBagFraction(randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false))
             .setNumTopFeatureImportanceValues(randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE))
+            .setAlpha(randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true))
+            .setEtaGrowthRatePerTree(randomBoolean() ? null : randomDoubleBetween(0.5, 2.0, true))
+            .setSoftTreeDepthLimit(randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true))
+            .setSoftTreeDepthTolerance(randomBoolean() ? null : randomDoubleBetween(0.01, Double.MAX_VALUE, true))
+            .setDownsampleFactor(randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false))
+            .setMaxOptimizationRoundsPerHyperparameter(randomBoolean() ? null : randomIntBetween(0, 20))
             .build();
     }
 
@@ -50,6 +61,14 @@ public class BoostedTreeParamsTests extends AbstractBWCSerializationTestCase<Boo
         if (version.before(Version.V_7_6_0)) {
             builder.setNumTopFeatureImportanceValues(null);
         }
+        if (version.before(Version.V_8_0_0)) {
+            builder.setAlpha(null);
+            builder.setEtaGrowthRatePerTree(null);
+            builder.setSoftTreeDepthLimit(null);
+            builder.setSoftTreeDepthTolerance(null);
+            builder.setDownsampleFactor(null);
+            builder.setMaxOptimizationRoundsPerHyperparameter(null);
+        }
         return builder.build();
     }
 
@@ -121,6 +140,144 @@ public class BoostedTreeParamsTests extends AbstractBWCSerializationTestCase<Boo
         assertThat(e.getMessage(), equalTo("[num_top_feature_importance_values] must be a non-negative integer"));
     }
 
+    public void testConstructor_GivenAlphaIsNegative() {
+        ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
+            () -> BoostedTreeParams.builder().setAlpha(-0.001).build());
+
+        assertThat(e.getMessage(), equalTo("[alpha] must be a non-negative double"));
+    }
+
+    public void testConstructor_GivenAlphaIsZero() {
+        assertThat(BoostedTreeParams.builder().setAlpha(0.0).build().getAlpha(), equalTo(0.0));
+    }
+
+    public void testConstructor_GivenEtaGrowthRatePerTreeIsOnRangeLimit() {
+        assertThat(BoostedTreeParams.builder().setEtaGrowthRatePerTree(0.5).build().getEtaGrowthRatePerTree(), equalTo(0.5));
+        assertThat(BoostedTreeParams.builder().setEtaGrowthRatePerTree(2.0).build().getEtaGrowthRatePerTree(), equalTo(2.0));
+    }
+
+    public void testConstructor_GivenEtaGrowthRatePerTreeIsLessThanMin() {
+        ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
+            () -> BoostedTreeParams.builder().setEtaGrowthRatePerTree(0.49999).build());
+
+        assertThat(e.getMessage(), equalTo("[eta_growth_rate_per_tree] must be a double in [0.5, 2.0]"));
+    }
+
+    public void testConstructor_GivenEtaGrowthRatePerTreeIsGreaterThanMax() {
+        ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
+            () -> BoostedTreeParams.builder().setEtaGrowthRatePerTree(2.00001).build());
+
+        assertThat(e.getMessage(), equalTo("[eta_growth_rate_per_tree] must be a double in [0.5, 2.0]"));
+    }
+
+    public void testConstructor_GivenSoftTreeDepthLimitIsNegative() {
+        ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
+            () -> BoostedTreeParams.builder().setSoftTreeDepthLimit(-0.001).build());
+
+        assertThat(e.getMessage(), equalTo("[soft_tree_depth_limit] must be a non-negative double"));
+    }
+
+    public void testConstructor_GivenSoftTreeDepthLimitIsZero() {
+        assertThat(BoostedTreeParams.builder().setSoftTreeDepthLimit(0.0).build().getSoftTreeDepthLimit(), equalTo(0.0));
+    }
+
+    public void testConstructor_GivenSoftTreeDepthToleranceIsLessThanMin() {
+        ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
+            () -> BoostedTreeParams.builder().setSoftTreeDepthTolerance(0.001).build());
+
+        assertThat(e.getMessage(), equalTo("[soft_tree_depth_tolerance] must be a double greater than or equal to 0.01"));
+    }
+
+    public void testConstructor_GivenSoftTreeDepthToleranceIsMin() {
+        assertThat(BoostedTreeParams.builder().setSoftTreeDepthTolerance(0.01).build().getSoftTreeDepthTolerance(), equalTo(0.01));
+    }
+
+    public void testConstructor_GivenDownsampleFactorIsZero() {
+        ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
+            () -> BoostedTreeParams.builder().setDownsampleFactor(0.0).build());
+
+        assertThat(e.getMessage(), equalTo("[downsample_factor] must be a double in (0, 1]"));
+    }
+
+    public void testConstructor_GivenDownsampleFactorIsNegative() {
+        ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
+            () -> BoostedTreeParams.builder().setDownsampleFactor(-42.0).build());
+
+        assertThat(e.getMessage(), equalTo("[downsample_factor] must be a double in (0, 1]"));
+    }
+
+    public void testConstructor_GivenDownsampleFactorIsOne() {
+        assertThat(BoostedTreeParams.builder().setDownsampleFactor(1.0).build().getDownsampleFactor(), equalTo(1.0));
+    }
+
+    public void testConstructor_GivenDownsampleFactorIsGreaterThanOne() {
+        ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
+            () -> BoostedTreeParams.builder().setDownsampleFactor(1.00001).build());
+
+        assertThat(e.getMessage(), equalTo("[downsample_factor] must be a double in (0, 1]"));
+    }
+
+    public void testConstructor_GivenMaxOptimizationRoundsPerHyperparameterIsZero() {
+        assertThat(BoostedTreeParams.builder().setMaxOptimizationRoundsPerHyperparameter(0).build()
+            .getMaxOptimizationRoundsPerHyperparameter(), equalTo(0));
+    }
+
+    public void testConstructor_GivenMaxOptimizationRoundsPerHyperparameterIsNegative() {
+        ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
+            () -> BoostedTreeParams.builder().setMaxOptimizationRoundsPerHyperparameter(-1).build());
+
+        assertThat(e.getMessage(), equalTo("[max_optimization_rounds_per_hyperparameter] must be an integer in [0, 20]"));
+    }
+
+    public void testConstructor_GivenMaxOptimizationRoundsPerHyperparameterIsMax() {
+        assertThat(BoostedTreeParams.builder().setMaxOptimizationRoundsPerHyperparameter(20).build()
+            .getMaxOptimizationRoundsPerHyperparameter(), equalTo(20));
+    }
+
+    public void testConstructor_GivenMaxOptimizationRoundsPerHyperparameterIsGreaterThanMax() {
+        ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
+            () -> BoostedTreeParams.builder().setMaxOptimizationRoundsPerHyperparameter(21).build());
+
+        assertThat(e.getMessage(), equalTo("[max_optimization_rounds_per_hyperparameter] must be an integer in [0, 20]"));
+    }
+
+    public void testGetParams_GivenEmpty() {
+        assertThat(BoostedTreeParams.builder().build().getParams(), is(anEmptyMap()));
+    }
+
+    public void testGetParams_GivenAllParams() {
+        BoostedTreeParams boostedTreeParams = BoostedTreeParams.builder()
+            .setLambda(randomDoubleBetween(0.0, Double.MAX_VALUE, true))
+            .setGamma(randomDoubleBetween(0.0, Double.MAX_VALUE, true))
+            .setEta(randomDoubleBetween(0.001, 1.0, true))
+            .setMaxTrees(randomIntBetween(1, 2000))
+            .setFeatureBagFraction(randomDoubleBetween(0.0, 1.0, false))
+            .setNumTopFeatureImportanceValues(randomIntBetween(0, Integer.MAX_VALUE))
+            .setAlpha(randomDoubleBetween(0.0, Double.MAX_VALUE, true))
+            .setEtaGrowthRatePerTree(randomDoubleBetween(0.5, 2.0, true))
+            .setSoftTreeDepthLimit(randomDoubleBetween(0.0, Double.MAX_VALUE, true))
+            .setSoftTreeDepthTolerance(randomDoubleBetween(0.01, Double.MAX_VALUE, true))
+            .setDownsampleFactor(randomDoubleBetween(0.0, 1.0, false))
+            .setMaxOptimizationRoundsPerHyperparameter(randomIntBetween(0, 20))
+            .build();
+
+        Map<String, Object> expectedParams = new HashMap<>();
+        expectedParams.put("lambda", boostedTreeParams.getLambda());
+        expectedParams.put("gamma", boostedTreeParams.getGamma());
+        expectedParams.put("eta", boostedTreeParams.getEta());
+        expectedParams.put("max_trees", boostedTreeParams.getMaxTrees());
+        expectedParams.put("feature_bag_fraction", boostedTreeParams.getFeatureBagFraction());
+        expectedParams.put("num_top_feature_importance_values", boostedTreeParams.getNumTopFeatureImportanceValues());
+        expectedParams.put("alpha", boostedTreeParams.getAlpha());
+        expectedParams.put("eta_growth_rate_per_tree", boostedTreeParams.getEtaGrowthRatePerTree());
+        expectedParams.put("soft_tree_depth_limit", boostedTreeParams.getSoftTreeDepthLimit());
+        expectedParams.put("soft_tree_depth_tolerance", boostedTreeParams.getSoftTreeDepthTolerance());
+        expectedParams.put("downsample_factor", boostedTreeParams.getDownsampleFactor());
+        expectedParams.put("max_optimization_rounds_per_hyperparameter", boostedTreeParams.getMaxOptimizationRoundsPerHyperparameter());
+
+        assertThat(boostedTreeParams.getParams(), equalTo(expectedParams));
+    }
+
     @Override
     protected BoostedTreeParams mutateInstanceForVersion(BoostedTreeParams instance, Version version) {
         return mutateForVersion(instance, version);

+ 28 - 4
x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml

@@ -1507,7 +1507,13 @@ setup:
                 "training_percent": 60.3,
                 "randomize_seed": 42,
                 "loss_function": "msle",
-                "loss_function_parameter": 2.0
+                "loss_function_parameter": 2.0,
+                "alpha": 1.0,
+                "eta_growth_rate_per_tree": 1.0,
+                "soft_tree_depth_limit": 2.0,
+                "soft_tree_depth_tolerance": 3.0,
+                "downsample_factor": 0.5,
+                "max_optimization_rounds_per_hyperparameter": 3
               }
             }
           }
@@ -1526,7 +1532,13 @@ setup:
       "training_percent": 60.3,
       "randomize_seed": 42,
       "loss_function": "msle",
-      "loss_function_parameter": 2.0
+      "loss_function_parameter": 2.0,
+      "alpha": 1.0,
+      "eta_growth_rate_per_tree": 1.0,
+      "soft_tree_depth_limit": 2.0,
+      "soft_tree_depth_tolerance": 3.0,
+      "downsample_factor": 0.5,
+      "max_optimization_rounds_per_hyperparameter": 3
     }
   }}
   - is_true: create_time
@@ -1852,7 +1864,13 @@ setup:
                 "feature_bag_fraction": 0.3,
                 "class_assignment_objective": "maximize_accuracy",
                 "training_percent": 60.3,
-                "randomize_seed": 24
+                "randomize_seed": 24,
+                "alpha": 1.0,
+                "eta_growth_rate_per_tree": 1.0,
+                "soft_tree_depth_limit": 2.0,
+                "soft_tree_depth_tolerance": 3.0,
+                "downsample_factor": 0.5,
+                "max_optimization_rounds_per_hyperparameter": 3
               }
             }
           }
@@ -1871,7 +1889,13 @@ setup:
       "training_percent": 60.3,
       "randomize_seed": 24,
       "class_assignment_objective": "maximize_accuracy",
-      "num_top_classes": 2
+      "num_top_classes": 2,
+      "alpha": 1.0,
+      "eta_growth_rate_per_tree": 1.0,
+      "soft_tree_depth_limit": 2.0,
+      "soft_tree_depth_tolerance": 3.0,
+      "downsample_factor": 0.5,
+      "max_optimization_rounds_per_hyperparameter": 3
     }
   }}
   - is_true: create_time