Browse Source

[ML] Hyperparameter names should match config (#54401)

Java side of elastic/ml-cpp#1096
Dimitris Athanasiou 5 năm trước cách đây
mục cha
commit
c7da75a638

+ 70 - 72
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/classification/Hyperparameters.java

@@ -31,23 +31,21 @@ import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optiona
 public class Hyperparameters implements ToXContentObject {
 
     public static final ParseField CLASS_ASSIGNMENT_OBJECTIVE = new ParseField("class_assignment_objective");
+    public static final ParseField ALPHA = new ParseField("alpha");
     public static final ParseField DOWNSAMPLE_FACTOR = new ParseField("downsample_factor");
     public static final ParseField ETA = new ParseField("eta");
     public static final ParseField ETA_GROWTH_RATE_PER_TREE = new ParseField("eta_growth_rate_per_tree");
     public static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction");
+    public static final ParseField GAMMA = new ParseField("gamma");
+    public static final ParseField LAMBDA = new ParseField("lambda");
     public static final ParseField MAX_ATTEMPTS_TO_ADD_TREE = new ParseField("max_attempts_to_add_tree");
     public static final ParseField MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER = new ParseField(
         "max_optimization_rounds_per_hyperparameter");
     public static final ParseField MAX_TREES = new ParseField("max_trees");
     public static final ParseField NUM_FOLDS = new ParseField("num_folds");
     public static final ParseField NUM_SPLITS_PER_FEATURE = new ParseField("num_splits_per_feature");
-    public static final ParseField REGULARIZATION_DEPTH_PENALTY_MULTIPLIER = new ParseField("regularization_depth_penalty_multiplier");
-    public static final ParseField REGULARIZATION_LEAF_WEIGHT_PENALTY_MULTIPLIER
-        = new ParseField("regularization_leaf_weight_penalty_multiplier");
-    public static final ParseField REGULARIZATION_SOFT_TREE_DEPTH_LIMIT = new ParseField("regularization_soft_tree_depth_limit");
-    public static final ParseField REGULARIZATION_SOFT_TREE_DEPTH_TOLERANCE = new ParseField("regularization_soft_tree_depth_tolerance");
-    public static final ParseField REGULARIZATION_TREE_SIZE_PENALTY_MULTIPLIER =
-        new ParseField("regularization_tree_size_penalty_multiplier");
+    public static final ParseField SOFT_TREE_DEPTH_LIMIT = new ParseField("soft_tree_depth_limit");
+    public static final ParseField SOFT_TREE_DEPTH_TOLERANCE = new ParseField("soft_tree_depth_tolerance");
 
     public static ConstructingObjectParser<Hyperparameters, Void> PARSER = new ConstructingObjectParser<>("classification_hyperparameters",
         true,
@@ -57,88 +55,92 @@ public class Hyperparameters implements ToXContentObject {
             (Double) a[2],
             (Double) a[3],
             (Double) a[4],
-            (Integer) a[5],
-            (Integer) a[6],
-            (Integer) a[7],
+            (Double) a[5],
+            (Double) a[6],
+            (Double) a[7],
             (Integer) a[8],
             (Integer) a[9],
-            (Double) a[10],
-            (Double) a[11],
-            (Double) a[12],
+            (Integer) a[10],
+            (Integer) a[11],
+            (Integer) a[12],
             (Double) a[13],
             (Double) a[14]
         ));
 
     static {
         PARSER.declareString(optionalConstructorArg(), CLASS_ASSIGNMENT_OBJECTIVE);
+        PARSER.declareDouble(optionalConstructorArg(), ALPHA);
         PARSER.declareDouble(optionalConstructorArg(), DOWNSAMPLE_FACTOR);
         PARSER.declareDouble(optionalConstructorArg(), ETA);
         PARSER.declareDouble(optionalConstructorArg(), ETA_GROWTH_RATE_PER_TREE);
         PARSER.declareDouble(optionalConstructorArg(), FEATURE_BAG_FRACTION);
+        PARSER.declareDouble(optionalConstructorArg(), GAMMA);
+        PARSER.declareDouble(optionalConstructorArg(), LAMBDA);
         PARSER.declareInt(optionalConstructorArg(), MAX_ATTEMPTS_TO_ADD_TREE);
         PARSER.declareInt(optionalConstructorArg(), MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER);
         PARSER.declareInt(optionalConstructorArg(), MAX_TREES);
         PARSER.declareInt(optionalConstructorArg(), NUM_FOLDS);
         PARSER.declareInt(optionalConstructorArg(), NUM_SPLITS_PER_FEATURE);
-        PARSER.declareDouble(optionalConstructorArg(), REGULARIZATION_DEPTH_PENALTY_MULTIPLIER);
-        PARSER.declareDouble(optionalConstructorArg(), REGULARIZATION_LEAF_WEIGHT_PENALTY_MULTIPLIER);
-        PARSER.declareDouble(optionalConstructorArg(), REGULARIZATION_SOFT_TREE_DEPTH_LIMIT);
-        PARSER.declareDouble(optionalConstructorArg(), REGULARIZATION_SOFT_TREE_DEPTH_TOLERANCE);
-        PARSER.declareDouble(optionalConstructorArg(), REGULARIZATION_TREE_SIZE_PENALTY_MULTIPLIER);
+        PARSER.declareDouble(optionalConstructorArg(), SOFT_TREE_DEPTH_LIMIT);
+        PARSER.declareDouble(optionalConstructorArg(), SOFT_TREE_DEPTH_TOLERANCE);
     }
 
     private final String classAssignmentObjective;
+    private final Double alpha;
     private final Double downsampleFactor;
     private final Double eta;
     private final Double etaGrowthRatePerTree;
     private final Double featureBagFraction;
+    private final Double gamma;
+    private final Double lambda;
     private final Integer maxAttemptsToAddTree;
     private final Integer maxOptimizationRoundsPerHyperparameter;
     private final Integer maxTrees;
     private final Integer numFolds;
     private final Integer numSplitsPerFeature;
-    private final Double regularizationDepthPenaltyMultiplier;
-    private final Double regularizationLeafWeightPenaltyMultiplier;
-    private final Double regularizationSoftTreeDepthLimit;
-    private final Double regularizationSoftTreeDepthTolerance;
-    private final Double regularizationTreeSizePenaltyMultiplier;
+    private final Double softTreeDepthLimit;
+    private final Double softTreeDepthTolerance;
 
     public Hyperparameters(String classAssignmentObjective,
+                           Double alpha,
                            Double downsampleFactor,
                            Double eta,
                            Double etaGrowthRatePerTree,
                            Double featureBagFraction,
+                           Double gamma,
+                           Double lambda,
                            Integer maxAttemptsToAddTree,
                            Integer maxOptimizationRoundsPerHyperparameter,
                            Integer maxTrees,
                            Integer numFolds,
                            Integer numSplitsPerFeature,
-                           Double regularizationDepthPenaltyMultiplier,
-                           Double regularizationLeafWeightPenaltyMultiplier,
-                           Double regularizationSoftTreeDepthLimit,
-                           Double regularizationSoftTreeDepthTolerance,
-                           Double regularizationTreeSizePenaltyMultiplier) {
+                           Double softTreeDepthLimit,
+                           Double softTreeDepthTolerance) {
         this.classAssignmentObjective = classAssignmentObjective;
+        this.alpha = alpha;
         this.downsampleFactor = downsampleFactor;
         this.eta = eta;
         this.etaGrowthRatePerTree = etaGrowthRatePerTree;
         this.featureBagFraction = featureBagFraction;
+        this.gamma = gamma;
+        this.lambda = lambda;
         this.maxAttemptsToAddTree = maxAttemptsToAddTree;
         this.maxOptimizationRoundsPerHyperparameter = maxOptimizationRoundsPerHyperparameter;
         this.maxTrees = maxTrees;
         this.numFolds = numFolds;
         this.numSplitsPerFeature = numSplitsPerFeature;
-        this.regularizationDepthPenaltyMultiplier = regularizationDepthPenaltyMultiplier;
-        this.regularizationLeafWeightPenaltyMultiplier = regularizationLeafWeightPenaltyMultiplier;
-        this.regularizationSoftTreeDepthLimit = regularizationSoftTreeDepthLimit;
-        this.regularizationSoftTreeDepthTolerance = regularizationSoftTreeDepthTolerance;
-        this.regularizationTreeSizePenaltyMultiplier = regularizationTreeSizePenaltyMultiplier;
+        this.softTreeDepthLimit = softTreeDepthLimit;
+        this.softTreeDepthTolerance = softTreeDepthTolerance;
     }
 
     public String getClassAssignmentObjective() {
         return classAssignmentObjective;
     }
 
+    public Double getAlpha() {
+        return alpha;
+    }
+
     public Double getDownsampleFactor() {
         return downsampleFactor;
     }
@@ -155,6 +157,14 @@ public class Hyperparameters implements ToXContentObject {
         return featureBagFraction;
     }
 
+    public Double getGamma() {
+        return gamma;
+    }
+
+    public Double getLambda() {
+        return lambda;
+    }
+
     public Integer getMaxAttemptsToAddTree() {
         return maxAttemptsToAddTree;
     }
@@ -175,24 +185,12 @@ public class Hyperparameters implements ToXContentObject {
         return numSplitsPerFeature;
     }
 
-    public Double getRegularizationDepthPenaltyMultiplier() {
-        return regularizationDepthPenaltyMultiplier;
-    }
-
-    public Double getRegularizationLeafWeightPenaltyMultiplier() {
-        return regularizationLeafWeightPenaltyMultiplier;
-    }
-
-    public Double getRegularizationSoftTreeDepthLimit() {
-        return regularizationSoftTreeDepthLimit;
-    }
-
-    public Double getRegularizationSoftTreeDepthTolerance() {
-        return regularizationSoftTreeDepthTolerance;
+    public Double getSoftTreeDepthLimit() {
+        return softTreeDepthLimit;
     }
 
-    public Double getRegularizationTreeSizePenaltyMultiplier() {
-        return regularizationTreeSizePenaltyMultiplier;
+    public Double getSoftTreeDepthTolerance() {
+        return softTreeDepthTolerance;
     }
 
     @Override
@@ -201,6 +199,9 @@ public class Hyperparameters implements ToXContentObject {
         if (classAssignmentObjective != null) {
             builder.field(CLASS_ASSIGNMENT_OBJECTIVE.getPreferredName(), classAssignmentObjective);
         }
+        if (alpha != null) {
+            builder.field(ALPHA.getPreferredName(), alpha);
+        }
         if (downsampleFactor != null) {
             builder.field(DOWNSAMPLE_FACTOR.getPreferredName(), downsampleFactor);
         }
@@ -213,6 +214,12 @@ public class Hyperparameters implements ToXContentObject {
         if (featureBagFraction != null) {
             builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction);
         }
+        if (gamma != null) {
+            builder.field(GAMMA.getPreferredName(), gamma);
+        }
+        if (lambda != null) {
+            builder.field(LAMBDA.getPreferredName(), lambda);
+        }
         if (maxAttemptsToAddTree != null) {
             builder.field(MAX_ATTEMPTS_TO_ADD_TREE.getPreferredName(), maxAttemptsToAddTree);
         }
@@ -228,20 +235,11 @@ public class Hyperparameters implements ToXContentObject {
         if (numSplitsPerFeature != null) {
             builder.field(NUM_SPLITS_PER_FEATURE.getPreferredName(), numSplitsPerFeature);
         }
-        if (regularizationDepthPenaltyMultiplier != null) {
-            builder.field(REGULARIZATION_DEPTH_PENALTY_MULTIPLIER.getPreferredName(), regularizationDepthPenaltyMultiplier);
-        }
-        if (regularizationLeafWeightPenaltyMultiplier != null) {
-            builder.field(REGULARIZATION_LEAF_WEIGHT_PENALTY_MULTIPLIER.getPreferredName(), regularizationLeafWeightPenaltyMultiplier);
-        }
-        if (regularizationSoftTreeDepthLimit != null) {
-            builder.field(REGULARIZATION_SOFT_TREE_DEPTH_LIMIT.getPreferredName(), regularizationSoftTreeDepthLimit);
-        }
-        if (regularizationSoftTreeDepthTolerance != null) {
-            builder.field(REGULARIZATION_SOFT_TREE_DEPTH_TOLERANCE.getPreferredName(), regularizationSoftTreeDepthTolerance);
+        if (softTreeDepthLimit != null) {
+            builder.field(SOFT_TREE_DEPTH_LIMIT.getPreferredName(), softTreeDepthLimit);
         }
-        if (regularizationTreeSizePenaltyMultiplier != null) {
-            builder.field(REGULARIZATION_TREE_SIZE_PENALTY_MULTIPLIER.getPreferredName(), regularizationTreeSizePenaltyMultiplier);
+        if (softTreeDepthTolerance != null) {
+            builder.field(SOFT_TREE_DEPTH_TOLERANCE.getPreferredName(), softTreeDepthTolerance);
         }
         builder.endObject();
         return builder;
@@ -254,40 +252,40 @@ public class Hyperparameters implements ToXContentObject {
 
         Hyperparameters that = (Hyperparameters) o;
         return Objects.equals(classAssignmentObjective, that.classAssignmentObjective)
+            && Objects.equals(alpha, that.alpha)
             && Objects.equals(downsampleFactor, that.downsampleFactor)
             && Objects.equals(eta, that.eta)
             && Objects.equals(etaGrowthRatePerTree, that.etaGrowthRatePerTree)
             && Objects.equals(featureBagFraction, that.featureBagFraction)
+            && Objects.equals(gamma, that.gamma)
+            && Objects.equals(lambda, that.lambda)
             && Objects.equals(maxAttemptsToAddTree, that.maxAttemptsToAddTree)
             && Objects.equals(maxOptimizationRoundsPerHyperparameter, that.maxOptimizationRoundsPerHyperparameter)
             && Objects.equals(maxTrees, that.maxTrees)
             && Objects.equals(numFolds, that.numFolds)
             && Objects.equals(numSplitsPerFeature, that.numSplitsPerFeature)
-            && Objects.equals(regularizationDepthPenaltyMultiplier, that.regularizationDepthPenaltyMultiplier)
-            && Objects.equals(regularizationLeafWeightPenaltyMultiplier, that.regularizationLeafWeightPenaltyMultiplier)
-            && Objects.equals(regularizationSoftTreeDepthLimit, that.regularizationSoftTreeDepthLimit)
-            && Objects.equals(regularizationSoftTreeDepthTolerance, that.regularizationSoftTreeDepthTolerance)
-            && Objects.equals(regularizationTreeSizePenaltyMultiplier, that.regularizationTreeSizePenaltyMultiplier);
+            && Objects.equals(softTreeDepthLimit, that.softTreeDepthLimit)
+            && Objects.equals(softTreeDepthTolerance, that.softTreeDepthTolerance);
     }
 
     @Override
     public int hashCode() {
         return Objects.hash(
             classAssignmentObjective,
+            alpha,
             downsampleFactor,
             eta,
             etaGrowthRatePerTree,
             featureBagFraction,
+            gamma,
+            lambda,
             maxAttemptsToAddTree,
             maxOptimizationRoundsPerHyperparameter,
             maxTrees,
             numFolds,
             numSplitsPerFeature,
-            regularizationDepthPenaltyMultiplier,
-            regularizationLeafWeightPenaltyMultiplier,
-            regularizationSoftTreeDepthLimit,
-            regularizationSoftTreeDepthTolerance,
-            regularizationTreeSizePenaltyMultiplier
+            softTreeDepthLimit,
+            softTreeDepthTolerance
         );
     }
 }

+ 72 - 74
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/regression/Hyperparameters.java

@@ -30,23 +30,21 @@ import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optiona
 
 public class Hyperparameters implements ToXContentObject {
 
+    public static final ParseField ALPHA = new ParseField("alpha");
     public static final ParseField DOWNSAMPLE_FACTOR = new ParseField("downsample_factor");
     public static final ParseField ETA = new ParseField("eta");
     public static final ParseField ETA_GROWTH_RATE_PER_TREE = new ParseField("eta_growth_rate_per_tree");
     public static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction");
+    public static final ParseField GAMMA = new ParseField("gamma");
+    public static final ParseField LAMBDA = new ParseField("lambda");
     public static final ParseField MAX_ATTEMPTS_TO_ADD_TREE = new ParseField("max_attempts_to_add_tree");
     public static final ParseField MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER = new ParseField(
         "max_optimization_rounds_per_hyperparameter");
     public static final ParseField MAX_TREES = new ParseField("max_trees");
     public static final ParseField NUM_FOLDS = new ParseField("num_folds");
     public static final ParseField NUM_SPLITS_PER_FEATURE = new ParseField("num_splits_per_feature");
-    public static final ParseField REGULARIZATION_DEPTH_PENALTY_MULTIPLIER = new ParseField("regularization_depth_penalty_multiplier");
-    public static final ParseField REGULARIZATION_LEAF_WEIGHT_PENALTY_MULTIPLIER
-        = new ParseField("regularization_leaf_weight_penalty_multiplier");
-    public static final ParseField REGULARIZATION_SOFT_TREE_DEPTH_LIMIT = new ParseField("regularization_soft_tree_depth_limit");
-    public static final ParseField REGULARIZATION_SOFT_TREE_DEPTH_TOLERANCE = new ParseField("regularization_soft_tree_depth_tolerance");
-    public static final ParseField REGULARIZATION_TREE_SIZE_PENALTY_MULTIPLIER =
-        new ParseField("regularization_tree_size_penalty_multiplier");
+    public static final ParseField SOFT_TREE_DEPTH_LIMIT = new ParseField("soft_tree_depth_limit");
+    public static final ParseField SOFT_TREE_DEPTH_TOLERANCE = new ParseField("soft_tree_depth_tolerance");
 
     public static ConstructingObjectParser<Hyperparameters, Void> PARSER = new ConstructingObjectParser<>("regression_hyperparameters",
         true,
@@ -55,78 +53,82 @@ public class Hyperparameters implements ToXContentObject {
             (Double) a[1],
             (Double) a[2],
             (Double) a[3],
-            (Integer) a[4],
-            (Integer) a[5],
-            (Integer) a[6],
+            (Double) a[4],
+            (Double) a[5],
+            (Double) a[6],
             (Integer) a[7],
             (Integer) a[8],
-            (Double) a[9],
-            (Double) a[10],
-            (Double) a[11],
+            (Integer) a[9],
+            (Integer) a[10],
+            (Integer) a[11],
             (Double) a[12],
             (Double) a[13]
         ));
 
     static {
+        PARSER.declareDouble(optionalConstructorArg(), ALPHA);
         PARSER.declareDouble(optionalConstructorArg(), DOWNSAMPLE_FACTOR);
         PARSER.declareDouble(optionalConstructorArg(), ETA);
         PARSER.declareDouble(optionalConstructorArg(), ETA_GROWTH_RATE_PER_TREE);
         PARSER.declareDouble(optionalConstructorArg(), FEATURE_BAG_FRACTION);
+        PARSER.declareDouble(optionalConstructorArg(), GAMMA);
+        PARSER.declareDouble(optionalConstructorArg(), LAMBDA);
         PARSER.declareInt(optionalConstructorArg(), MAX_ATTEMPTS_TO_ADD_TREE);
         PARSER.declareInt(optionalConstructorArg(), MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER);
         PARSER.declareInt(optionalConstructorArg(), MAX_TREES);
         PARSER.declareInt(optionalConstructorArg(), NUM_FOLDS);
         PARSER.declareInt(optionalConstructorArg(), NUM_SPLITS_PER_FEATURE);
-        PARSER.declareDouble(optionalConstructorArg(), REGULARIZATION_DEPTH_PENALTY_MULTIPLIER);
-        PARSER.declareDouble(optionalConstructorArg(), REGULARIZATION_LEAF_WEIGHT_PENALTY_MULTIPLIER);
-        PARSER.declareDouble(optionalConstructorArg(), REGULARIZATION_SOFT_TREE_DEPTH_LIMIT);
-        PARSER.declareDouble(optionalConstructorArg(), REGULARIZATION_SOFT_TREE_DEPTH_TOLERANCE);
-        PARSER.declareDouble(optionalConstructorArg(), REGULARIZATION_TREE_SIZE_PENALTY_MULTIPLIER);
+        PARSER.declareDouble(optionalConstructorArg(), SOFT_TREE_DEPTH_LIMIT);
+        PARSER.declareDouble(optionalConstructorArg(), SOFT_TREE_DEPTH_TOLERANCE);
     }
 
+    private final Double alpha;
     private final Double downsampleFactor;
     private final Double eta;
     private final Double etaGrowthRatePerTree;
     private final Double featureBagFraction;
+    private final Double gamma;
+    private final Double lambda;
     private final Integer maxAttemptsToAddTree;
     private final Integer maxOptimizationRoundsPerHyperparameter;
     private final Integer maxTrees;
     private final Integer numFolds;
     private final Integer numSplitsPerFeature;
-    private final Double regularizationDepthPenaltyMultiplier;
-    private final Double regularizationLeafWeightPenaltyMultiplier;
-    private final Double regularizationSoftTreeDepthLimit;
-    private final Double regularizationSoftTreeDepthTolerance;
-    private final Double regularizationTreeSizePenaltyMultiplier;
+    private final Double softTreeDepthLimit;
+    private final Double softTreeDepthTolerance;
 
-    public Hyperparameters(Double downsampleFactor,
+    public Hyperparameters(Double alpha,
+                           Double downsampleFactor,
                            Double eta,
                            Double etaGrowthRatePerTree,
                            Double featureBagFraction,
+                           Double gamma,
+                           Double lambda,
                            Integer maxAttemptsToAddTree,
                            Integer maxOptimizationRoundsPerHyperparameter,
                            Integer maxTrees,
                            Integer numFolds,
                            Integer numSplitsPerFeature,
-                           Double regularizationDepthPenaltyMultiplier,
-                           Double regularizationLeafWeightPenaltyMultiplier,
-                           Double regularizationSoftTreeDepthLimit,
-                           Double regularizationSoftTreeDepthTolerance,
-                           Double regularizationTreeSizePenaltyMultiplier) {
+                           Double softTreeDepthLimit,
+                           Double softTreeDepthTolerance) {
+        this.alpha = alpha;
         this.downsampleFactor = downsampleFactor;
         this.eta = eta;
         this.etaGrowthRatePerTree = etaGrowthRatePerTree;
         this.featureBagFraction = featureBagFraction;
+        this.gamma = gamma;
+        this.lambda = lambda;
         this.maxAttemptsToAddTree = maxAttemptsToAddTree;
         this.maxOptimizationRoundsPerHyperparameter = maxOptimizationRoundsPerHyperparameter;
         this.maxTrees = maxTrees;
         this.numFolds = numFolds;
         this.numSplitsPerFeature = numSplitsPerFeature;
-        this.regularizationDepthPenaltyMultiplier = regularizationDepthPenaltyMultiplier;
-        this.regularizationLeafWeightPenaltyMultiplier = regularizationLeafWeightPenaltyMultiplier;
-        this.regularizationSoftTreeDepthLimit = regularizationSoftTreeDepthLimit;
-        this.regularizationSoftTreeDepthTolerance = regularizationSoftTreeDepthTolerance;
-        this.regularizationTreeSizePenaltyMultiplier = regularizationTreeSizePenaltyMultiplier;
+        this.softTreeDepthLimit = softTreeDepthLimit;
+        this.softTreeDepthTolerance = softTreeDepthTolerance;
+    }
+
+    public Double getAlpha() {
+        return alpha;
     }
 
     public Double getDownsampleFactor() {
@@ -145,6 +147,14 @@ public class Hyperparameters implements ToXContentObject {
         return featureBagFraction;
     }
 
+    public Double getGamma() {
+        return gamma;
+    }
+
+    public Double getLambda() {
+        return lambda;
+    }
+
     public Integer getMaxAttemptsToAddTree() {
         return maxAttemptsToAddTree;
     }
@@ -165,29 +175,20 @@ public class Hyperparameters implements ToXContentObject {
         return numSplitsPerFeature;
     }
 
-    public Double getRegularizationDepthPenaltyMultiplier() {
-        return regularizationDepthPenaltyMultiplier;
-    }
-
-    public Double getRegularizationLeafWeightPenaltyMultiplier() {
-        return regularizationLeafWeightPenaltyMultiplier;
-    }
-
-    public Double getRegularizationSoftTreeDepthLimit() {
-        return regularizationSoftTreeDepthLimit;
-    }
-
-    public Double getRegularizationSoftTreeDepthTolerance() {
-        return regularizationSoftTreeDepthTolerance;
+    public Double getSoftTreeDepthLimit() {
+        return softTreeDepthLimit;
     }
 
-    public Double getRegularizationTreeSizePenaltyMultiplier() {
-        return regularizationTreeSizePenaltyMultiplier;
+    public Double getSoftTreeDepthTolerance() {
+        return softTreeDepthTolerance;
     }
 
     @Override
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
         builder.startObject();
+        if (alpha != null) {
+            builder.field(ALPHA.getPreferredName(), alpha);
+        }
         if (downsampleFactor != null) {
             builder.field(DOWNSAMPLE_FACTOR.getPreferredName(), downsampleFactor);
         }
@@ -200,6 +201,12 @@ public class Hyperparameters implements ToXContentObject {
         if (featureBagFraction != null) {
             builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction);
         }
+        if (gamma != null) {
+            builder.field(GAMMA.getPreferredName(), gamma);
+        }
+        if (lambda != null) {
+            builder.field(LAMBDA.getPreferredName(), lambda);
+        }
         if (maxAttemptsToAddTree != null) {
             builder.field(MAX_ATTEMPTS_TO_ADD_TREE.getPreferredName(), maxAttemptsToAddTree);
         }
@@ -215,20 +222,11 @@ public class Hyperparameters implements ToXContentObject {
         if (numSplitsPerFeature != null) {
             builder.field(NUM_SPLITS_PER_FEATURE.getPreferredName(), numSplitsPerFeature);
         }
-        if (regularizationDepthPenaltyMultiplier != null) {
-            builder.field(REGULARIZATION_DEPTH_PENALTY_MULTIPLIER.getPreferredName(), regularizationDepthPenaltyMultiplier);
-        }
-        if (regularizationLeafWeightPenaltyMultiplier != null) {
-            builder.field(REGULARIZATION_LEAF_WEIGHT_PENALTY_MULTIPLIER.getPreferredName(), regularizationLeafWeightPenaltyMultiplier);
-        }
-        if (regularizationSoftTreeDepthLimit != null) {
-            builder.field(REGULARIZATION_SOFT_TREE_DEPTH_LIMIT.getPreferredName(), regularizationSoftTreeDepthLimit);
-        }
-        if (regularizationSoftTreeDepthTolerance != null) {
-            builder.field(REGULARIZATION_SOFT_TREE_DEPTH_TOLERANCE.getPreferredName(), regularizationSoftTreeDepthTolerance);
+        if (softTreeDepthLimit != null) {
+            builder.field(SOFT_TREE_DEPTH_LIMIT.getPreferredName(), softTreeDepthLimit);
         }
-        if (regularizationTreeSizePenaltyMultiplier != null) {
-            builder.field(REGULARIZATION_TREE_SIZE_PENALTY_MULTIPLIER.getPreferredName(), regularizationTreeSizePenaltyMultiplier);
+        if (softTreeDepthTolerance != null) {
+            builder.field(SOFT_TREE_DEPTH_TOLERANCE.getPreferredName(), softTreeDepthTolerance);
         }
         builder.endObject();
         return builder;
@@ -240,39 +238,39 @@ public class Hyperparameters implements ToXContentObject {
         if (o == null || getClass() != o.getClass()) return false;
 
         Hyperparameters that = (Hyperparameters) o;
-        return Objects.equals(downsampleFactor, that.downsampleFactor)
+        return Objects.equals(alpha, that.alpha)
+            && Objects.equals(downsampleFactor, that.downsampleFactor)
             && Objects.equals(eta, that.eta)
             && Objects.equals(etaGrowthRatePerTree, that.etaGrowthRatePerTree)
             && Objects.equals(featureBagFraction, that.featureBagFraction)
+            && Objects.equals(gamma, that.gamma)
+            && Objects.equals(lambda, that.lambda)
             && Objects.equals(maxAttemptsToAddTree, that.maxAttemptsToAddTree)
             && Objects.equals(maxOptimizationRoundsPerHyperparameter, that.maxOptimizationRoundsPerHyperparameter)
             && Objects.equals(maxTrees, that.maxTrees)
             && Objects.equals(numFolds, that.numFolds)
             && Objects.equals(numSplitsPerFeature, that.numSplitsPerFeature)
-            && Objects.equals(regularizationDepthPenaltyMultiplier, that.regularizationDepthPenaltyMultiplier)
-            && Objects.equals(regularizationLeafWeightPenaltyMultiplier, that.regularizationLeafWeightPenaltyMultiplier)
-            && Objects.equals(regularizationSoftTreeDepthLimit, that.regularizationSoftTreeDepthLimit)
-            && Objects.equals(regularizationSoftTreeDepthTolerance, that.regularizationSoftTreeDepthTolerance)
-            && Objects.equals(regularizationTreeSizePenaltyMultiplier, that.regularizationTreeSizePenaltyMultiplier);
+            && Objects.equals(softTreeDepthLimit, that.softTreeDepthLimit)
+            && Objects.equals(softTreeDepthTolerance, that.softTreeDepthTolerance);
     }
 
     @Override
     public int hashCode() {
         return Objects.hash(
+            alpha,
             downsampleFactor,
             eta,
             etaGrowthRatePerTree,
             featureBagFraction,
+            gamma,
+            lambda,
             maxAttemptsToAddTree,
             maxOptimizationRoundsPerHyperparameter,
             maxTrees,
             numFolds,
             numSplitsPerFeature,
-            regularizationDepthPenaltyMultiplier,
-            regularizationLeafWeightPenaltyMultiplier,
-            regularizationSoftTreeDepthLimit,
-            regularizationSoftTreeDepthTolerance,
-            regularizationTreeSizePenaltyMultiplier
+            softTreeDepthLimit,
+            softTreeDepthTolerance
         );
     }
 }

+ 3 - 3
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/stats/classification/HyperparametersTests.java

@@ -47,15 +47,15 @@ public class HyperparametersTests extends AbstractXContentTestCase<Hyperparamete
             randomBoolean() ? null : randomDouble(),
             randomBoolean() ? null : randomDouble(),
             randomBoolean() ? null : randomDouble(),
+            randomBoolean() ? null : randomDouble(),
+            randomBoolean() ? null : randomDouble(),
+            randomBoolean() ? null : randomDouble(),
             randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE),
             randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE),
             randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE),
             randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE),
             randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE),
             randomBoolean() ? null : randomDouble(),
-            randomBoolean() ? null : randomDouble(),
-            randomBoolean() ? null : randomDouble(),
-            randomBoolean() ? null : randomDouble(),
             randomBoolean() ? null : randomDouble()
         );
     }

+ 3 - 3
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/stats/regression/HyperparametersTests.java

@@ -43,6 +43,9 @@ public class HyperparametersTests extends AbstractXContentTestCase<Hyperparamete
 
     public static Hyperparameters createRandom() {
         return new Hyperparameters(
+            randomBoolean() ? null : randomDouble(),
+            randomBoolean() ? null : randomDouble(),
+            randomBoolean() ? null : randomDouble(),
             randomBoolean() ? null : randomDouble(),
             randomBoolean() ? null : randomDouble(),
             randomBoolean() ? null : randomDouble(),
@@ -53,9 +56,6 @@ public class HyperparametersTests extends AbstractXContentTestCase<Hyperparamete
             randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE),
             randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE),
             randomBoolean() ? null : randomDouble(),
-            randomBoolean() ? null : randomDouble(),
-            randomBoolean() ? null : randomDouble(),
-            randomBoolean() ? null : randomDouble(),
             randomBoolean() ? null : randomDouble()
         );
     }

+ 56 - 58
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/stats/classification/Hyperparameters.java

@@ -22,23 +22,21 @@ import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constru
 public class Hyperparameters implements ToXContentObject, Writeable {
 
     public static final ParseField CLASS_ASSIGNMENT_OBJECTIVE = new ParseField("class_assignment_objective");
+    public static final ParseField ALPHA = new ParseField("alpha");
     public static final ParseField DOWNSAMPLE_FACTOR = new ParseField("downsample_factor");
     public static final ParseField ETA = new ParseField("eta");
     public static final ParseField ETA_GROWTH_RATE_PER_TREE = new ParseField("eta_growth_rate_per_tree");
     public static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction");
+    public static final ParseField GAMMA = new ParseField("gamma");
+    public static final ParseField LAMBDA = new ParseField("lambda");
     public static final ParseField MAX_ATTEMPTS_TO_ADD_TREE = new ParseField("max_attempts_to_add_tree");
     public static final ParseField MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER = new ParseField(
         "max_optimization_rounds_per_hyperparameter");
     public static final ParseField MAX_TREES = new ParseField("max_trees");
     public static final ParseField NUM_FOLDS = new ParseField("num_folds");
     public static final ParseField NUM_SPLITS_PER_FEATURE = new ParseField("num_splits_per_feature");
-    public static final ParseField REGULARIZATION_DEPTH_PENALTY_MULTIPLIER = new ParseField("regularization_depth_penalty_multiplier");
-    public static final ParseField REGULARIZATION_LEAF_WEIGHT_PENALTY_MULTIPLIER
-        = new ParseField("regularization_leaf_weight_penalty_multiplier");
-    public static final ParseField REGULARIZATION_SOFT_TREE_DEPTH_LIMIT = new ParseField("regularization_soft_tree_depth_limit");
-    public static final ParseField REGULARIZATION_SOFT_TREE_DEPTH_TOLERANCE = new ParseField("regularization_soft_tree_depth_tolerance");
-    public static final ParseField REGULARIZATION_TREE_SIZE_PENALTY_MULTIPLIER =
-        new ParseField("regularization_tree_size_penalty_multiplier");
+    public static final ParseField SOFT_TREE_DEPTH_LIMIT = new ParseField("soft_tree_depth_limit");
+    public static final ParseField SOFT_TREE_DEPTH_TOLERANCE = new ParseField("soft_tree_depth_tolerance");
 
     public static Hyperparameters fromXContent(XContentParser parser, boolean ignoreUnknownFields) {
         return createParser(ignoreUnknownFields).apply(parser, null);
@@ -53,140 +51,140 @@ public class Hyperparameters implements ToXContentObject, Writeable {
                 (double) a[2],
                 (double) a[3],
                 (double) a[4],
-                (int) a[5],
-                (int) a[6],
-                (int) a[7],
+                (double) a[5],
+                (double) a[6],
+                (double) a[7],
                 (int) a[8],
                 (int) a[9],
-                (double) a[10],
-                (double) a[11],
-                (double) a[12],
+                (int) a[10],
+                (int) a[11],
+                (int) a[12],
                 (double) a[13],
                 (double) a[14]
             ));
 
         parser.declareString(constructorArg(), CLASS_ASSIGNMENT_OBJECTIVE);
+        parser.declareDouble(constructorArg(), ALPHA);
         parser.declareDouble(constructorArg(), DOWNSAMPLE_FACTOR);
         parser.declareDouble(constructorArg(), ETA);
         parser.declareDouble(constructorArg(), ETA_GROWTH_RATE_PER_TREE);
         parser.declareDouble(constructorArg(), FEATURE_BAG_FRACTION);
+        parser.declareDouble(constructorArg(), GAMMA);
+        parser.declareDouble(constructorArg(), LAMBDA);
         parser.declareInt(constructorArg(), MAX_ATTEMPTS_TO_ADD_TREE);
         parser.declareInt(constructorArg(), MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER);
         parser.declareInt(constructorArg(), MAX_TREES);
         parser.declareInt(constructorArg(), NUM_FOLDS);
         parser.declareInt(constructorArg(), NUM_SPLITS_PER_FEATURE);
-        parser.declareDouble(constructorArg(), REGULARIZATION_DEPTH_PENALTY_MULTIPLIER);
-        parser.declareDouble(constructorArg(), REGULARIZATION_LEAF_WEIGHT_PENALTY_MULTIPLIER);
-        parser.declareDouble(constructorArg(), REGULARIZATION_SOFT_TREE_DEPTH_LIMIT);
-        parser.declareDouble(constructorArg(), REGULARIZATION_SOFT_TREE_DEPTH_TOLERANCE);
-        parser.declareDouble(constructorArg(), REGULARIZATION_TREE_SIZE_PENALTY_MULTIPLIER);
+        parser.declareDouble(constructorArg(), SOFT_TREE_DEPTH_LIMIT);
+        parser.declareDouble(constructorArg(), SOFT_TREE_DEPTH_TOLERANCE);
 
         return parser;
     }
 
     private final String classAssignmentObjective;
+    private final double alpha;
     private final double downsampleFactor;
     private final double eta;
     private final double etaGrowthRatePerTree;
     private final double featureBagFraction;
+    private final double gamma;
+    private final double lambda;
     private final int maxAttemptsToAddTree;
     private final int maxOptimizationRoundsPerHyperparameter;
     private final int maxTrees;
     private final int numFolds;
     private final int numSplitsPerFeature;
-    private final double regularizationDepthPenaltyMultiplier;
-    private final double regularizationLeafWeightPenaltyMultiplier;
-    private final double regularizationSoftTreeDepthLimit;
-    private final double regularizationSoftTreeDepthTolerance;
-    private final double regularizationTreeSizePenaltyMultiplier;
+    private final double softTreeDepthLimit;
+    private final double softTreeDepthTolerance;
 
     public Hyperparameters(String classAssignmentObjective,
+                           double alpha,
                            double downsampleFactor,
                            double eta,
                            double etaGrowthRatePerTree,
                            double featureBagFraction,
+                           double gamma,
+                           double lambda,
                            int maxAttemptsToAddTree,
                            int maxOptimizationRoundsPerHyperparameter,
                            int maxTrees,
                            int numFolds,
                            int numSplitsPerFeature,
-                           double regularizationDepthPenaltyMultiplier,
-                           double regularizationLeafWeightPenaltyMultiplier,
-                           double regularizationSoftTreeDepthLimit,
-                           double regularizationSoftTreeDepthTolerance,
-                           double regularizationTreeSizePenaltyMultiplier) {
+                           double softTreeDepthLimit,
+                           double softTreeDepthTolerance) {
         this.classAssignmentObjective = Objects.requireNonNull(classAssignmentObjective);
+        this.alpha = alpha;
         this.downsampleFactor = downsampleFactor;
         this.eta = eta;
         this.etaGrowthRatePerTree = etaGrowthRatePerTree;
         this.featureBagFraction = featureBagFraction;
+        this.gamma = gamma;
+        this.lambda = lambda;
         this.maxAttemptsToAddTree = maxAttemptsToAddTree;
         this.maxOptimizationRoundsPerHyperparameter = maxOptimizationRoundsPerHyperparameter;
         this.maxTrees = maxTrees;
         this.numFolds = numFolds;
         this.numSplitsPerFeature = numSplitsPerFeature;
-        this.regularizationDepthPenaltyMultiplier = regularizationDepthPenaltyMultiplier;
-        this.regularizationLeafWeightPenaltyMultiplier = regularizationLeafWeightPenaltyMultiplier;
-        this.regularizationSoftTreeDepthLimit = regularizationSoftTreeDepthLimit;
-        this.regularizationSoftTreeDepthTolerance = regularizationSoftTreeDepthTolerance;
-        this.regularizationTreeSizePenaltyMultiplier = regularizationTreeSizePenaltyMultiplier;
+        this.softTreeDepthLimit = softTreeDepthLimit;
+        this.softTreeDepthTolerance = softTreeDepthTolerance;
     }
 
     public Hyperparameters(StreamInput in) throws IOException {
         this.classAssignmentObjective = in.readString();
+        this.alpha = in.readDouble();
         this.downsampleFactor = in.readDouble();
         this.eta = in.readDouble();
         this.etaGrowthRatePerTree = in.readDouble();
         this.featureBagFraction = in.readDouble();
+        this.gamma = in.readDouble();
+        this.lambda = in.readDouble();
         this.maxAttemptsToAddTree = in.readVInt();
         this.maxOptimizationRoundsPerHyperparameter = in.readVInt();
         this.maxTrees = in.readVInt();
         this.numFolds = in.readVInt();
         this.numSplitsPerFeature = in.readVInt();
-        this.regularizationDepthPenaltyMultiplier = in.readDouble();
-        this.regularizationLeafWeightPenaltyMultiplier = in.readDouble();
-        this.regularizationSoftTreeDepthLimit = in.readDouble();
-        this.regularizationSoftTreeDepthTolerance = in.readDouble();
-        this.regularizationTreeSizePenaltyMultiplier = in.readDouble();
+        this.softTreeDepthLimit = in.readDouble();
+        this.softTreeDepthTolerance = in.readDouble();
     }
 
     @Override
     public void writeTo(StreamOutput out) throws IOException {
         out.writeString(classAssignmentObjective);
+        out.writeDouble(alpha);
         out.writeDouble(downsampleFactor);
         out.writeDouble(eta);
         out.writeDouble(etaGrowthRatePerTree);
         out.writeDouble(featureBagFraction);
+        out.writeDouble(gamma);
+        out.writeDouble(lambda);
         out.writeVInt(maxAttemptsToAddTree);
         out.writeVInt(maxOptimizationRoundsPerHyperparameter);
         out.writeVInt(maxTrees);
         out.writeVInt(numFolds);
         out.writeVInt(numSplitsPerFeature);
-        out.writeDouble(regularizationDepthPenaltyMultiplier);
-        out.writeDouble(regularizationLeafWeightPenaltyMultiplier);
-        out.writeDouble(regularizationSoftTreeDepthLimit);
-        out.writeDouble(regularizationSoftTreeDepthTolerance);
-        out.writeDouble(regularizationTreeSizePenaltyMultiplier);
+        out.writeDouble(softTreeDepthLimit);
+        out.writeDouble(softTreeDepthTolerance);
     }
 
     @Override
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
         builder.startObject();
         builder.field(CLASS_ASSIGNMENT_OBJECTIVE.getPreferredName(), classAssignmentObjective);
+        builder.field(ALPHA.getPreferredName(), alpha);
         builder.field(DOWNSAMPLE_FACTOR.getPreferredName(), downsampleFactor);
         builder.field(ETA.getPreferredName(), eta);
         builder.field(ETA_GROWTH_RATE_PER_TREE.getPreferredName(), etaGrowthRatePerTree);
         builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction);
+        builder.field(GAMMA.getPreferredName(), gamma);
+        builder.field(LAMBDA.getPreferredName(), lambda);
         builder.field(MAX_ATTEMPTS_TO_ADD_TREE.getPreferredName(), maxAttemptsToAddTree);
         builder.field(MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER.getPreferredName(), maxOptimizationRoundsPerHyperparameter);
         builder.field(MAX_TREES.getPreferredName(), maxTrees);
         builder.field(NUM_FOLDS.getPreferredName(), numFolds);
         builder.field(NUM_SPLITS_PER_FEATURE.getPreferredName(), numSplitsPerFeature);
-        builder.field(REGULARIZATION_DEPTH_PENALTY_MULTIPLIER.getPreferredName(), regularizationDepthPenaltyMultiplier);
-        builder.field(REGULARIZATION_LEAF_WEIGHT_PENALTY_MULTIPLIER.getPreferredName(), regularizationLeafWeightPenaltyMultiplier);
-        builder.field(REGULARIZATION_SOFT_TREE_DEPTH_LIMIT.getPreferredName(), regularizationSoftTreeDepthLimit);
-        builder.field(REGULARIZATION_SOFT_TREE_DEPTH_TOLERANCE.getPreferredName(), regularizationSoftTreeDepthTolerance);
-        builder.field(REGULARIZATION_TREE_SIZE_PENALTY_MULTIPLIER.getPreferredName(), regularizationTreeSizePenaltyMultiplier);
+        builder.field(SOFT_TREE_DEPTH_LIMIT.getPreferredName(), softTreeDepthLimit);
+        builder.field(SOFT_TREE_DEPTH_TOLERANCE.getPreferredName(), softTreeDepthTolerance);
         builder.endObject();
         return builder;
     }
@@ -198,40 +196,40 @@ public class Hyperparameters implements ToXContentObject, Writeable {
 
         Hyperparameters that = (Hyperparameters) o;
         return Objects.equals(classAssignmentObjective, that.classAssignmentObjective)
+            && alpha == that.alpha
             && downsampleFactor == that.downsampleFactor
             && eta == that.eta
             && etaGrowthRatePerTree == that.etaGrowthRatePerTree
             && featureBagFraction == that.featureBagFraction
+            && gamma == that.gamma
+            && lambda == that.lambda
             && maxAttemptsToAddTree == that.maxAttemptsToAddTree
             && maxOptimizationRoundsPerHyperparameter == that.maxOptimizationRoundsPerHyperparameter
             && maxTrees == that.maxTrees
             && numFolds == that.numFolds
             && numSplitsPerFeature == that.numSplitsPerFeature
-            && regularizationDepthPenaltyMultiplier == that.regularizationDepthPenaltyMultiplier
-            && regularizationLeafWeightPenaltyMultiplier == that.regularizationLeafWeightPenaltyMultiplier
-            && regularizationSoftTreeDepthLimit == that.regularizationSoftTreeDepthLimit
-            && regularizationSoftTreeDepthTolerance == that.regularizationSoftTreeDepthTolerance
-            && regularizationTreeSizePenaltyMultiplier == that.regularizationTreeSizePenaltyMultiplier;
+            && softTreeDepthLimit == that.softTreeDepthLimit
+            && softTreeDepthTolerance == that.softTreeDepthTolerance;
     }
 
     @Override
     public int hashCode() {
         return Objects.hash(
             classAssignmentObjective,
+            alpha,
             downsampleFactor,
             eta,
             etaGrowthRatePerTree,
             featureBagFraction,
+            gamma,
+            lambda,
             maxAttemptsToAddTree,
             maxOptimizationRoundsPerHyperparameter,
             maxTrees,
             numFolds,
             numSplitsPerFeature,
-            regularizationDepthPenaltyMultiplier,
-            regularizationLeafWeightPenaltyMultiplier,
-            regularizationSoftTreeDepthLimit,
-            regularizationSoftTreeDepthTolerance,
-            regularizationTreeSizePenaltyMultiplier
+            softTreeDepthLimit,
+            softTreeDepthTolerance
         );
     }
 }

+ 58 - 60
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/stats/regression/Hyperparameters.java

@@ -21,23 +21,21 @@ import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constru
 
 public class Hyperparameters implements ToXContentObject, Writeable {
 
+    public static final ParseField ALPHA = new ParseField("alpha");
     public static final ParseField DOWNSAMPLE_FACTOR = new ParseField("downsample_factor");
     public static final ParseField ETA = new ParseField("eta");
     public static final ParseField ETA_GROWTH_RATE_PER_TREE = new ParseField("eta_growth_rate_per_tree");
     public static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction");
+    public static final ParseField GAMMA = new ParseField("gamma");
+    public static final ParseField LAMBDA = new ParseField("lambda");
     public static final ParseField MAX_ATTEMPTS_TO_ADD_TREE = new ParseField("max_attempts_to_add_tree");
     public static final ParseField MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER = new ParseField(
         "max_optimization_rounds_per_hyperparameter");
     public static final ParseField MAX_TREES = new ParseField("max_trees");
     public static final ParseField NUM_FOLDS = new ParseField("num_folds");
     public static final ParseField NUM_SPLITS_PER_FEATURE = new ParseField("num_splits_per_feature");
-    public static final ParseField REGULARIZATION_DEPTH_PENALTY_MULTIPLIER = new ParseField("regularization_depth_penalty_multiplier");
-    public static final ParseField REGULARIZATION_LEAF_WEIGHT_PENALTY_MULTIPLIER
-        = new ParseField("regularization_leaf_weight_penalty_multiplier");
-    public static final ParseField REGULARIZATION_SOFT_TREE_DEPTH_LIMIT = new ParseField("regularization_soft_tree_depth_limit");
-    public static final ParseField REGULARIZATION_SOFT_TREE_DEPTH_TOLERANCE = new ParseField("regularization_soft_tree_depth_tolerance");
-    public static final ParseField REGULARIZATION_TREE_SIZE_PENALTY_MULTIPLIER =
-        new ParseField("regularization_tree_size_penalty_multiplier");
+    public static final ParseField SOFT_TREE_DEPTH_LIMIT = new ParseField("soft_tree_depth_limit");
+    public static final ParseField SOFT_TREE_DEPTH_TOLERANCE = new ParseField("soft_tree_depth_tolerance");
 
     public static Hyperparameters fromXContent(XContentParser parser, boolean ignoreUnknownFields) {
         return createParser(ignoreUnknownFields).apply(parser, null);
@@ -51,133 +49,133 @@ public class Hyperparameters implements ToXContentObject, Writeable {
                 (double) a[1],
                 (double) a[2],
                 (double) a[3],
-                (int) a[4],
-                (int) a[5],
-                (int) a[6],
+                (double) a[4],
+                (double) a[5],
+                (double) a[6],
                 (int) a[7],
                 (int) a[8],
-                (double) a[9],
-                (double) a[10],
-                (double) a[11],
+                (int) a[9],
+                (int) a[10],
+                (int) a[11],
                 (double) a[12],
                 (double) a[13]
             ));
 
+        parser.declareDouble(constructorArg(), ALPHA);
         parser.declareDouble(constructorArg(), DOWNSAMPLE_FACTOR);
         parser.declareDouble(constructorArg(), ETA);
         parser.declareDouble(constructorArg(), ETA_GROWTH_RATE_PER_TREE);
         parser.declareDouble(constructorArg(), FEATURE_BAG_FRACTION);
+        parser.declareDouble(constructorArg(), GAMMA);
+        parser.declareDouble(constructorArg(), LAMBDA);
         parser.declareInt(constructorArg(), MAX_ATTEMPTS_TO_ADD_TREE);
         parser.declareInt(constructorArg(), MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER);
         parser.declareInt(constructorArg(), MAX_TREES);
         parser.declareInt(constructorArg(), NUM_FOLDS);
         parser.declareInt(constructorArg(), NUM_SPLITS_PER_FEATURE);
-        parser.declareDouble(constructorArg(), REGULARIZATION_DEPTH_PENALTY_MULTIPLIER);
-        parser.declareDouble(constructorArg(), REGULARIZATION_LEAF_WEIGHT_PENALTY_MULTIPLIER);
-        parser.declareDouble(constructorArg(), REGULARIZATION_SOFT_TREE_DEPTH_LIMIT);
-        parser.declareDouble(constructorArg(), REGULARIZATION_SOFT_TREE_DEPTH_TOLERANCE);
-        parser.declareDouble(constructorArg(), REGULARIZATION_TREE_SIZE_PENALTY_MULTIPLIER);
+        parser.declareDouble(constructorArg(), SOFT_TREE_DEPTH_LIMIT);
+        parser.declareDouble(constructorArg(), SOFT_TREE_DEPTH_TOLERANCE);
 
         return parser;
     }
 
+    private final double alpha;
     private final double downsampleFactor;
     private final double eta;
     private final double etaGrowthRatePerTree;
     private final double featureBagFraction;
+    private final double gamma;
+    private final double lambda;
     private final int maxAttemptsToAddTree;
     private final int maxOptimizationRoundsPerHyperparameter;
     private final int maxTrees;
     private final int numFolds;
     private final int numSplitsPerFeature;
-    private final double regularizationDepthPenaltyMultiplier;
-    private final double regularizationLeafWeightPenaltyMultiplier;
-    private final double regularizationSoftTreeDepthLimit;
-    private final double regularizationSoftTreeDepthTolerance;
-    private final double regularizationTreeSizePenaltyMultiplier;
+    private final double softTreeDepthLimit;
+    private final double softTreeDepthTolerance;
 
-    public Hyperparameters(double downsampleFactor,
+    public Hyperparameters(double alpha,
+                           double downsampleFactor,
                            double eta,
                            double etaGrowthRatePerTree,
                            double featureBagFraction,
+                           double gamma,
+                           double lambda,
                            int maxAttemptsToAddTree,
                            int maxOptimizationRoundsPerHyperparameter,
                            int maxTrees,
                            int numFolds,
                            int numSplitsPerFeature,
-                           double regularizationDepthPenaltyMultiplier,
-                           double regularizationLeafWeightPenaltyMultiplier,
-                           double regularizationSoftTreeDepthLimit,
-                           double regularizationSoftTreeDepthTolerance,
-                           double regularizationTreeSizePenaltyMultiplier) {
+                           double softTreeDepthLimit,
+                           double softTreeDepthTolerance) {
+        this.alpha = alpha;
         this.downsampleFactor = downsampleFactor;
         this.eta = eta;
         this.etaGrowthRatePerTree = etaGrowthRatePerTree;
         this.featureBagFraction = featureBagFraction;
+        this.gamma = gamma;
+        this.lambda = lambda;
         this.maxAttemptsToAddTree = maxAttemptsToAddTree;
         this.maxOptimizationRoundsPerHyperparameter = maxOptimizationRoundsPerHyperparameter;
         this.maxTrees = maxTrees;
         this.numFolds = numFolds;
         this.numSplitsPerFeature = numSplitsPerFeature;
-        this.regularizationDepthPenaltyMultiplier = regularizationDepthPenaltyMultiplier;
-        this.regularizationLeafWeightPenaltyMultiplier = regularizationLeafWeightPenaltyMultiplier;
-        this.regularizationSoftTreeDepthLimit = regularizationSoftTreeDepthLimit;
-        this.regularizationSoftTreeDepthTolerance = regularizationSoftTreeDepthTolerance;
-        this.regularizationTreeSizePenaltyMultiplier = regularizationTreeSizePenaltyMultiplier;
+        this.softTreeDepthLimit = softTreeDepthLimit;
+        this.softTreeDepthTolerance = softTreeDepthTolerance;
     }
 
     public Hyperparameters(StreamInput in) throws IOException {
+        this.alpha = in.readDouble();
         this.downsampleFactor = in.readDouble();
         this.eta = in.readDouble();
         this.etaGrowthRatePerTree = in.readDouble();
         this.featureBagFraction = in.readDouble();
+        this.gamma = in.readDouble();
+        this.lambda = in.readDouble();
         this.maxAttemptsToAddTree = in.readVInt();
         this.maxOptimizationRoundsPerHyperparameter = in.readVInt();
         this.maxTrees = in.readVInt();
         this.numFolds = in.readVInt();
         this.numSplitsPerFeature = in.readVInt();
-        this.regularizationDepthPenaltyMultiplier = in.readDouble();
-        this.regularizationLeafWeightPenaltyMultiplier = in.readDouble();
-        this.regularizationSoftTreeDepthLimit = in.readDouble();
-        this.regularizationSoftTreeDepthTolerance = in.readDouble();
-        this.regularizationTreeSizePenaltyMultiplier = in.readDouble();
+        this.softTreeDepthLimit = in.readDouble();
+        this.softTreeDepthTolerance = in.readDouble();
     }
 
     @Override
     public void writeTo(StreamOutput out) throws IOException {
+        out.writeDouble(alpha);
         out.writeDouble(downsampleFactor);
         out.writeDouble(eta);
         out.writeDouble(etaGrowthRatePerTree);
         out.writeDouble(featureBagFraction);
+        out.writeDouble(gamma);
+        out.writeDouble(lambda);
         out.writeVInt(maxAttemptsToAddTree);
         out.writeVInt(maxOptimizationRoundsPerHyperparameter);
         out.writeVInt(maxTrees);
         out.writeVInt(numFolds);
         out.writeVInt(numSplitsPerFeature);
-        out.writeDouble(regularizationDepthPenaltyMultiplier);
-        out.writeDouble(regularizationLeafWeightPenaltyMultiplier);
-        out.writeDouble(regularizationSoftTreeDepthLimit);
-        out.writeDouble(regularizationSoftTreeDepthTolerance);
-        out.writeDouble(regularizationTreeSizePenaltyMultiplier);
+        out.writeDouble(softTreeDepthLimit);
+        out.writeDouble(softTreeDepthTolerance);
     }
 
     @Override
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
         builder.startObject();
+        builder.field(ALPHA.getPreferredName(), alpha);
         builder.field(DOWNSAMPLE_FACTOR.getPreferredName(), downsampleFactor);
         builder.field(ETA.getPreferredName(), eta);
         builder.field(ETA_GROWTH_RATE_PER_TREE.getPreferredName(), etaGrowthRatePerTree);
         builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction);
+        builder.field(GAMMA.getPreferredName(), gamma);
+        builder.field(LAMBDA.getPreferredName(), lambda);
         builder.field(MAX_ATTEMPTS_TO_ADD_TREE.getPreferredName(), maxAttemptsToAddTree);
         builder.field(MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER.getPreferredName(), maxOptimizationRoundsPerHyperparameter);
         builder.field(MAX_TREES.getPreferredName(), maxTrees);
         builder.field(NUM_FOLDS.getPreferredName(), numFolds);
         builder.field(NUM_SPLITS_PER_FEATURE.getPreferredName(), numSplitsPerFeature);
-        builder.field(REGULARIZATION_DEPTH_PENALTY_MULTIPLIER.getPreferredName(), regularizationDepthPenaltyMultiplier);
-        builder.field(REGULARIZATION_LEAF_WEIGHT_PENALTY_MULTIPLIER.getPreferredName(), regularizationLeafWeightPenaltyMultiplier);
-        builder.field(REGULARIZATION_SOFT_TREE_DEPTH_LIMIT.getPreferredName(), regularizationSoftTreeDepthLimit);
-        builder.field(REGULARIZATION_SOFT_TREE_DEPTH_TOLERANCE.getPreferredName(), regularizationSoftTreeDepthTolerance);
-        builder.field(REGULARIZATION_TREE_SIZE_PENALTY_MULTIPLIER.getPreferredName(), regularizationTreeSizePenaltyMultiplier);
+        builder.field(SOFT_TREE_DEPTH_LIMIT.getPreferredName(), softTreeDepthLimit);
+        builder.field(SOFT_TREE_DEPTH_TOLERANCE.getPreferredName(), softTreeDepthTolerance);
         builder.endObject();
         return builder;
     }
@@ -188,39 +186,39 @@ public class Hyperparameters implements ToXContentObject, Writeable {
         if (o == null || getClass() != o.getClass()) return false;
 
         Hyperparameters that = (Hyperparameters) o;
-        return downsampleFactor == that.downsampleFactor
+        return alpha == that.alpha
+            && downsampleFactor == that.downsampleFactor
             && eta == that.eta
             && etaGrowthRatePerTree == that.etaGrowthRatePerTree
             && featureBagFraction == that.featureBagFraction
+            && gamma == that.gamma
+            && lambda == that.lambda
             && maxAttemptsToAddTree == that.maxAttemptsToAddTree
             && maxOptimizationRoundsPerHyperparameter == that.maxOptimizationRoundsPerHyperparameter
             && maxTrees == that.maxTrees
             && numFolds == that.numFolds
             && numSplitsPerFeature == that.numSplitsPerFeature
-            && regularizationDepthPenaltyMultiplier == that.regularizationDepthPenaltyMultiplier
-            && regularizationLeafWeightPenaltyMultiplier == that.regularizationLeafWeightPenaltyMultiplier
-            && regularizationSoftTreeDepthLimit == that.regularizationSoftTreeDepthLimit
-            && regularizationSoftTreeDepthTolerance == that.regularizationSoftTreeDepthTolerance
-            && regularizationTreeSizePenaltyMultiplier == that.regularizationTreeSizePenaltyMultiplier;
+            && softTreeDepthLimit == that.softTreeDepthLimit
+            && softTreeDepthTolerance == that.softTreeDepthTolerance;
     }
 
     @Override
     public int hashCode() {
         return Objects.hash(
+            alpha,
             downsampleFactor,
             eta,
             etaGrowthRatePerTree,
             featureBagFraction,
+            gamma,
+            lambda,
             maxAttemptsToAddTree,
             maxOptimizationRoundsPerHyperparameter,
             maxTrees,
             numFolds,
             numSplitsPerFeature,
-            regularizationDepthPenaltyMultiplier,
-            regularizationLeafWeightPenaltyMultiplier,
-            regularizationSoftTreeDepthLimit,
-            regularizationSoftTreeDepthTolerance,
-            regularizationTreeSizePenaltyMultiplier
+            softTreeDepthLimit,
+            softTreeDepthTolerance
         );
     }
 }

+ 3 - 3
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/stats/classification/HyperparametersTests.java

@@ -54,15 +54,15 @@ public class HyperparametersTests extends AbstractBWCSerializationTestCase<Hyper
             randomDouble(),
             randomDouble(),
             randomDouble(),
+            randomDouble(),
+            randomDouble(),
+            randomDouble(),
             randomIntBetween(0, Integer.MAX_VALUE),
             randomIntBetween(0, Integer.MAX_VALUE),
             randomIntBetween(0, Integer.MAX_VALUE),
             randomIntBetween(0, Integer.MAX_VALUE),
             randomIntBetween(0, Integer.MAX_VALUE),
             randomDouble(),
-            randomDouble(),
-            randomDouble(),
-            randomDouble(),
             randomDouble()
         );
     }

+ 3 - 3
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/stats/regression/HyperparametersTests.java

@@ -49,6 +49,9 @@ public class HyperparametersTests extends AbstractBWCSerializationTestCase<Hyper
 
     public static Hyperparameters createRandom() {
         return new Hyperparameters(
+            randomDouble(),
+            randomDouble(),
+            randomDouble(),
             randomDouble(),
             randomDouble(),
             randomDouble(),
@@ -59,9 +62,6 @@ public class HyperparametersTests extends AbstractBWCSerializationTestCase<Hyper
             randomIntBetween(0, Integer.MAX_VALUE),
             randomIntBetween(0, Integer.MAX_VALUE),
             randomDouble(),
-            randomDouble(),
-            randomDouble(),
-            randomDouble(),
             randomDouble()
         );
     }

+ 2 - 0
x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java

@@ -6,6 +6,7 @@
 package org.elasticsearch.xpack.ml.integration;
 
 import com.google.common.collect.Ordering;
+import org.apache.lucene.util.LuceneTestCase;
 import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
 import org.elasticsearch.action.bulk.BulkRequestBuilder;
@@ -53,6 +54,7 @@ import static org.hamcrest.Matchers.lessThan;
 import static org.hamcrest.Matchers.lessThanOrEqualTo;
 import static org.hamcrest.Matchers.startsWith;
 
+@LuceneTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/pull/1096")
 public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
 
     private static final String BOOLEAN_FIELD = "boolean-field";

+ 2 - 0
x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java

@@ -5,6 +5,7 @@
  */
 package org.elasticsearch.xpack.ml.integration;
 
+import org.apache.lucene.util.LuceneTestCase;
 import org.elasticsearch.action.bulk.BulkRequestBuilder;
 import org.elasticsearch.action.bulk.BulkResponse;
 import org.elasticsearch.action.delete.DeleteResponse;
@@ -34,6 +35,7 @@ import static org.hamcrest.Matchers.hasSize;
 import static org.hamcrest.Matchers.is;
 import static org.hamcrest.Matchers.lessThan;
 
+@LuceneTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/pull/1096")
 public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
 
     private static final String NUMERICAL_FEATURE_FIELD = "feature";