|
@@ -36,6 +36,13 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable {
|
|
|
public static final ParseField MAX_TREES = new ParseField("max_trees", "maximum_number_trees");
|
|
|
public static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction");
|
|
|
public static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values");
|
|
|
+ public static final ParseField ALPHA = new ParseField("alpha");
|
|
|
+ public static final ParseField ETA_GROWTH_RATE_PER_TREE = new ParseField("eta_growth_rate_per_tree");
|
|
|
+ public static final ParseField SOFT_TREE_DEPTH_LIMIT = new ParseField("soft_tree_depth_limit");
|
|
|
+ public static final ParseField SOFT_TREE_DEPTH_TOLERANCE = new ParseField("soft_tree_depth_tolerance");
|
|
|
+ public static final ParseField DOWNSAMPLE_FACTOR = new ParseField("downsample_factor");
|
|
|
+ public static final ParseField MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER =
|
|
|
+ new ParseField("max_optimization_rounds_per_hyperparameter");
|
|
|
|
|
|
static void declareFields(AbstractObjectParser<?, Void> parser) {
|
|
|
parser.declareDouble(optionalConstructorArg(), LAMBDA);
|
|
@@ -44,6 +51,12 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable {
|
|
|
parser.declareInt(optionalConstructorArg(), MAX_TREES);
|
|
|
parser.declareDouble(optionalConstructorArg(), FEATURE_BAG_FRACTION);
|
|
|
parser.declareInt(optionalConstructorArg(), NUM_TOP_FEATURE_IMPORTANCE_VALUES);
|
|
|
+ parser.declareDouble(optionalConstructorArg(), ALPHA);
|
|
|
+ parser.declareDouble(optionalConstructorArg(), ETA_GROWTH_RATE_PER_TREE);
|
|
|
+ parser.declareDouble(optionalConstructorArg(), SOFT_TREE_DEPTH_LIMIT);
|
|
|
+ parser.declareDouble(optionalConstructorArg(), SOFT_TREE_DEPTH_TOLERANCE);
|
|
|
+ parser.declareDouble(optionalConstructorArg(), DOWNSAMPLE_FACTOR);
|
|
|
+ parser.declareInt(optionalConstructorArg(), MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER);
|
|
|
}
|
|
|
|
|
|
private final Double lambda;
|
|
@@ -52,13 +65,25 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable {
|
|
|
private final Integer maxTrees;
|
|
|
private final Double featureBagFraction;
|
|
|
private final Integer numTopFeatureImportanceValues;
|
|
|
+ private final Double alpha;
|
|
|
+ private final Double etaGrowthRatePerTree;
|
|
|
+ private final Double softTreeDepthLimit;
|
|
|
+ private final Double softTreeDepthTolerance;
|
|
|
+ private final Double downsampleFactor;
|
|
|
+ private final Integer maxOptimizationRoundsPerHyperparameter;
|
|
|
|
|
|
public BoostedTreeParams(@Nullable Double lambda,
|
|
|
@Nullable Double gamma,
|
|
|
@Nullable Double eta,
|
|
|
@Nullable Integer maxTrees,
|
|
|
@Nullable Double featureBagFraction,
|
|
|
- @Nullable Integer numTopFeatureImportanceValues) {
|
|
|
+ @Nullable Integer numTopFeatureImportanceValues,
|
|
|
+ @Nullable Double alpha,
|
|
|
+ @Nullable Double etaGrowthRatePerTree,
|
|
|
+ @Nullable Double softTreeDepthLimit,
|
|
|
+ @Nullable Double softTreeDepthTolerance,
|
|
|
+ @Nullable Double downsampleFactor,
|
|
|
+ @Nullable Integer maxOptimizationRoundsPerHyperparameter) {
|
|
|
if (lambda != null && lambda < 0) {
|
|
|
throw ExceptionsHelper.badRequestException("[{}] must be a non-negative double", LAMBDA.getPreferredName());
|
|
|
}
|
|
@@ -78,12 +103,39 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable {
|
|
|
throw ExceptionsHelper.badRequestException("[{}] must be a non-negative integer",
|
|
|
NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName());
|
|
|
}
|
|
|
+ if (alpha != null && alpha < 0) {
|
|
|
+ throw ExceptionsHelper.badRequestException("[{}] must be a non-negative double", ALPHA.getPreferredName());
|
|
|
+ }
|
|
|
+ if (etaGrowthRatePerTree != null && (etaGrowthRatePerTree < 0.5 || etaGrowthRatePerTree > 2.0)) {
|
|
|
+ throw ExceptionsHelper.badRequestException("[{}] must be a double in [0.5, 2.0]", ETA_GROWTH_RATE_PER_TREE.getPreferredName());
|
|
|
+ }
|
|
|
+ if (softTreeDepthLimit != null && softTreeDepthLimit < 0) {
|
|
|
+ throw ExceptionsHelper.badRequestException("[{}] must be a non-negative double", SOFT_TREE_DEPTH_LIMIT.getPreferredName());
|
|
|
+ }
|
|
|
+ if (softTreeDepthTolerance != null && softTreeDepthTolerance < 0.01) {
|
|
|
+ throw ExceptionsHelper.badRequestException("[{}] must be a double greater than or equal to 0.01",
|
|
|
+ SOFT_TREE_DEPTH_TOLERANCE.getPreferredName());
|
|
|
+ }
|
|
|
+ if (downsampleFactor != null && (downsampleFactor <= 0 || downsampleFactor > 1.0)) {
|
|
|
+ throw ExceptionsHelper.badRequestException("[{}] must be a double in (0, 1]", DOWNSAMPLE_FACTOR.getPreferredName());
|
|
|
+ }
|
|
|
+ if (maxOptimizationRoundsPerHyperparameter != null
|
|
|
+ && (maxOptimizationRoundsPerHyperparameter < 0 || maxOptimizationRoundsPerHyperparameter > 20)) {
|
|
|
+ throw ExceptionsHelper.badRequestException("[{}] must be an integer in [0, 20]",
|
|
|
+ MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER.getPreferredName());
|
|
|
+ }
|
|
|
this.lambda = lambda;
|
|
|
this.gamma = gamma;
|
|
|
this.eta = eta;
|
|
|
this.maxTrees = maxTrees;
|
|
|
this.featureBagFraction = featureBagFraction;
|
|
|
this.numTopFeatureImportanceValues = numTopFeatureImportanceValues;
|
|
|
+ this.alpha = alpha;
|
|
|
+ this.etaGrowthRatePerTree = etaGrowthRatePerTree;
|
|
|
+ this.softTreeDepthLimit = softTreeDepthLimit;
|
|
|
+ this.softTreeDepthTolerance = softTreeDepthTolerance;
|
|
|
+ this.downsampleFactor = downsampleFactor;
|
|
|
+ this.maxOptimizationRoundsPerHyperparameter = maxOptimizationRoundsPerHyperparameter;
|
|
|
}
|
|
|
|
|
|
BoostedTreeParams(StreamInput in) throws IOException {
|
|
@@ -97,6 +149,21 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable {
|
|
|
} else {
|
|
|
numTopFeatureImportanceValues = null;
|
|
|
}
|
|
|
+ if (in.getVersion().onOrAfter(Version.V_8_0_0)) {
|
|
|
+ alpha = in.readOptionalDouble();
|
|
|
+ etaGrowthRatePerTree = in.readOptionalDouble();
|
|
|
+ softTreeDepthLimit = in.readOptionalDouble();
|
|
|
+ softTreeDepthTolerance = in.readOptionalDouble();
|
|
|
+ downsampleFactor = in.readOptionalDouble();
|
|
|
+ maxOptimizationRoundsPerHyperparameter = in.readOptionalVInt();
|
|
|
+ } else {
|
|
|
+ alpha = null;
|
|
|
+ etaGrowthRatePerTree = null;
|
|
|
+ softTreeDepthLimit = null;
|
|
|
+ softTreeDepthTolerance = null;
|
|
|
+ downsampleFactor = null;
|
|
|
+ maxOptimizationRoundsPerHyperparameter = null;
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
public Double getLambda() {
|
|
@@ -123,6 +190,30 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable {
|
|
|
return numTopFeatureImportanceValues;
|
|
|
}
|
|
|
|
|
|
+ public Double getAlpha() {
|
|
|
+ return alpha;
|
|
|
+ }
|
|
|
+
|
|
|
+ public Double getEtaGrowthRatePerTree() {
|
|
|
+ return etaGrowthRatePerTree;
|
|
|
+ }
|
|
|
+
|
|
|
+ public Double getSoftTreeDepthLimit() {
|
|
|
+ return softTreeDepthLimit;
|
|
|
+ }
|
|
|
+
|
|
|
+ public Double getSoftTreeDepthTolerance() {
|
|
|
+ return softTreeDepthTolerance;
|
|
|
+ }
|
|
|
+
|
|
|
+ public Double getDownsampleFactor() {
|
|
|
+ return downsampleFactor;
|
|
|
+ }
|
|
|
+
|
|
|
+ public Integer getMaxOptimizationRoundsPerHyperparameter() {
|
|
|
+ return maxOptimizationRoundsPerHyperparameter;
|
|
|
+ }
|
|
|
+
|
|
|
@Override
|
|
|
public void writeTo(StreamOutput out) throws IOException {
|
|
|
out.writeOptionalDouble(lambda);
|
|
@@ -133,10 +224,21 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable {
|
|
|
if (out.getVersion().onOrAfter(Version.V_7_6_0)) {
|
|
|
out.writeOptionalInt(numTopFeatureImportanceValues);
|
|
|
}
|
|
|
+ if (out.getVersion().onOrAfter(Version.V_8_0_0)) {
|
|
|
+ out.writeOptionalDouble(alpha);
|
|
|
+ out.writeOptionalDouble(etaGrowthRatePerTree);
|
|
|
+ out.writeOptionalDouble(softTreeDepthLimit);
|
|
|
+ out.writeOptionalDouble(softTreeDepthTolerance);
|
|
|
+ out.writeOptionalDouble(downsampleFactor);
|
|
|
+ out.writeOptionalVInt(maxOptimizationRoundsPerHyperparameter);
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
|
|
+ if (alpha != null) {
|
|
|
+ builder.field(ALPHA.getPreferredName(), alpha);
|
|
|
+ }
|
|
|
if (lambda != null) {
|
|
|
builder.field(LAMBDA.getPreferredName(), lambda);
|
|
|
}
|
|
@@ -146,6 +248,9 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable {
|
|
|
if (eta != null) {
|
|
|
builder.field(ETA.getPreferredName(), eta);
|
|
|
}
|
|
|
+ if (etaGrowthRatePerTree != null) {
|
|
|
+ builder.field(ETA_GROWTH_RATE_PER_TREE.getPreferredName(), etaGrowthRatePerTree);
|
|
|
+ }
|
|
|
if (maxTrees != null) {
|
|
|
builder.field(MAX_TREES.getPreferredName(), maxTrees);
|
|
|
}
|
|
@@ -155,6 +260,18 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable {
|
|
|
if (numTopFeatureImportanceValues != null) {
|
|
|
builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues);
|
|
|
}
|
|
|
+ if (softTreeDepthLimit != null) {
|
|
|
+ builder.field(SOFT_TREE_DEPTH_LIMIT.getPreferredName(), softTreeDepthLimit);
|
|
|
+ }
|
|
|
+ if (softTreeDepthTolerance != null) {
|
|
|
+ builder.field(SOFT_TREE_DEPTH_TOLERANCE.getPreferredName(), softTreeDepthTolerance);
|
|
|
+ }
|
|
|
+ if (downsampleFactor != null) {
|
|
|
+ builder.field(DOWNSAMPLE_FACTOR.getPreferredName(), downsampleFactor);
|
|
|
+ }
|
|
|
+ if (maxOptimizationRoundsPerHyperparameter != null) {
|
|
|
+ builder.field(MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER.getPreferredName(), maxOptimizationRoundsPerHyperparameter);
|
|
|
+ }
|
|
|
return builder;
|
|
|
}
|
|
|
|
|
@@ -178,6 +295,24 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable {
|
|
|
if (numTopFeatureImportanceValues != null) {
|
|
|
params.put(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues);
|
|
|
}
|
|
|
+ if (alpha != null) {
|
|
|
+ params.put(ALPHA.getPreferredName(), alpha);
|
|
|
+ }
|
|
|
+ if (etaGrowthRatePerTree != null) {
|
|
|
+ params.put(ETA_GROWTH_RATE_PER_TREE.getPreferredName(), etaGrowthRatePerTree);
|
|
|
+ }
|
|
|
+ if (softTreeDepthLimit != null) {
|
|
|
+ params.put(SOFT_TREE_DEPTH_LIMIT.getPreferredName(), softTreeDepthLimit);
|
|
|
+ }
|
|
|
+ if (softTreeDepthTolerance != null) {
|
|
|
+ params.put(SOFT_TREE_DEPTH_TOLERANCE.getPreferredName(), softTreeDepthTolerance);
|
|
|
+ }
|
|
|
+ if (downsampleFactor != null) {
|
|
|
+ params.put(DOWNSAMPLE_FACTOR.getPreferredName(), downsampleFactor);
|
|
|
+ }
|
|
|
+ if (maxOptimizationRoundsPerHyperparameter != null) {
|
|
|
+ params.put(MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER.getPreferredName(), maxOptimizationRoundsPerHyperparameter);
|
|
|
+ }
|
|
|
return params;
|
|
|
}
|
|
|
|
|
@@ -191,12 +326,19 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable {
|
|
|
&& Objects.equals(eta, that.eta)
|
|
|
&& Objects.equals(maxTrees, that.maxTrees)
|
|
|
&& Objects.equals(featureBagFraction, that.featureBagFraction)
|
|
|
- && Objects.equals(numTopFeatureImportanceValues, that.numTopFeatureImportanceValues);
|
|
|
+ && Objects.equals(numTopFeatureImportanceValues, that.numTopFeatureImportanceValues)
|
|
|
+ && Objects.equals(alpha, that.alpha)
|
|
|
+ && Objects.equals(etaGrowthRatePerTree, that.etaGrowthRatePerTree)
|
|
|
+ && Objects.equals(softTreeDepthLimit, that.softTreeDepthLimit)
|
|
|
+ && Objects.equals(softTreeDepthTolerance, that.softTreeDepthTolerance)
|
|
|
+ && Objects.equals(downsampleFactor, that.downsampleFactor)
|
|
|
+ && Objects.equals(maxOptimizationRoundsPerHyperparameter, that.maxOptimizationRoundsPerHyperparameter);
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
public int hashCode() {
|
|
|
- return Objects.hash(lambda, gamma, eta, maxTrees, featureBagFraction, numTopFeatureImportanceValues);
|
|
|
+ return Objects.hash(lambda, gamma, eta, maxTrees, featureBagFraction, numTopFeatureImportanceValues, alpha, etaGrowthRatePerTree,
|
|
|
+ softTreeDepthLimit, softTreeDepthTolerance, downsampleFactor, maxOptimizationRoundsPerHyperparameter);
|
|
|
}
|
|
|
|
|
|
public static Builder builder() {
|
|
@@ -211,6 +353,12 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable {
|
|
|
private Integer maxTrees;
|
|
|
private Double featureBagFraction;
|
|
|
private Integer numTopFeatureImportanceValues;
|
|
|
+ private Double alpha;
|
|
|
+ private Double etaGrowthRatePerTree;
|
|
|
+ private Double softTreeDepthLimit;
|
|
|
+ private Double softTreeDepthTolerance;
|
|
|
+ private Double downsampleFactor;
|
|
|
+ private Integer maxOptimizationRoundsPerHyperparameter;
|
|
|
|
|
|
private Builder() {}
|
|
|
|
|
@@ -221,6 +369,12 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable {
|
|
|
this.maxTrees = params.maxTrees;
|
|
|
this.featureBagFraction = params.featureBagFraction;
|
|
|
this.numTopFeatureImportanceValues = params.numTopFeatureImportanceValues;
|
|
|
+ this.alpha = params.alpha;
|
|
|
+ this.etaGrowthRatePerTree = params.etaGrowthRatePerTree;
|
|
|
+ this.softTreeDepthLimit = params.softTreeDepthLimit;
|
|
|
+ this.softTreeDepthTolerance = params.softTreeDepthTolerance;
|
|
|
+ this.downsampleFactor = params.downsampleFactor;
|
|
|
+ this.maxOptimizationRoundsPerHyperparameter = params.maxOptimizationRoundsPerHyperparameter;
|
|
|
}
|
|
|
|
|
|
public Builder setLambda(Double lambda) {
|
|
@@ -253,8 +407,39 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable {
|
|
|
return this;
|
|
|
}
|
|
|
|
|
|
+ public Builder setAlpha(Double alpha) {
|
|
|
+ this.alpha = alpha;
|
|
|
+ return this;
|
|
|
+ }
|
|
|
+
|
|
|
+ public Builder setEtaGrowthRatePerTree(Double etaGrowthRatePerTree) {
|
|
|
+ this.etaGrowthRatePerTree = etaGrowthRatePerTree;
|
|
|
+ return this;
|
|
|
+ }
|
|
|
+
|
|
|
+ public Builder setSoftTreeDepthLimit(Double softTreeDepthLimit) {
|
|
|
+ this.softTreeDepthLimit = softTreeDepthLimit;
|
|
|
+ return this;
|
|
|
+ }
|
|
|
+
|
|
|
+ public Builder setSoftTreeDepthTolerance(Double softTreeDepthTolerance) {
|
|
|
+ this.softTreeDepthTolerance = softTreeDepthTolerance;
|
|
|
+ return this;
|
|
|
+ }
|
|
|
+
|
|
|
+ public Builder setDownsampleFactor(Double downsampleFactor) {
|
|
|
+ this.downsampleFactor = downsampleFactor;
|
|
|
+ return this;
|
|
|
+ }
|
|
|
+
|
|
|
+ public Builder setMaxOptimizationRoundsPerHyperparameter(Integer maxOptimizationRoundsPerHyperparameter) {
|
|
|
+ this.maxOptimizationRoundsPerHyperparameter = maxOptimizationRoundsPerHyperparameter;
|
|
|
+ return this;
|
|
|
+ }
|
|
|
+
|
|
|
public BoostedTreeParams build() {
|
|
|
- return new BoostedTreeParams(lambda, gamma, eta, maxTrees, featureBagFraction, numTopFeatureImportanceValues);
|
|
|
+ return new BoostedTreeParams(lambda, gamma, eta, maxTrees, featureBagFraction, numTopFeatureImportanceValues, alpha,
|
|
|
+ etaGrowthRatePerTree, softTreeDepthLimit, softTreeDepthTolerance, downsampleFactor, maxOptimizationRoundsPerHyperparameter);
|
|
|
}
|
|
|
}
|
|
|
}
|