|
|
@@ -31,23 +31,21 @@ import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optiona
|
|
|
public class Hyperparameters implements ToXContentObject {
|
|
|
|
|
|
public static final ParseField CLASS_ASSIGNMENT_OBJECTIVE = new ParseField("class_assignment_objective");
|
|
|
+ public static final ParseField ALPHA = new ParseField("alpha");
|
|
|
public static final ParseField DOWNSAMPLE_FACTOR = new ParseField("downsample_factor");
|
|
|
public static final ParseField ETA = new ParseField("eta");
|
|
|
public static final ParseField ETA_GROWTH_RATE_PER_TREE = new ParseField("eta_growth_rate_per_tree");
|
|
|
public static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction");
|
|
|
+ public static final ParseField GAMMA = new ParseField("gamma");
|
|
|
+ public static final ParseField LAMBDA = new ParseField("lambda");
|
|
|
public static final ParseField MAX_ATTEMPTS_TO_ADD_TREE = new ParseField("max_attempts_to_add_tree");
|
|
|
public static final ParseField MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER = new ParseField(
|
|
|
"max_optimization_rounds_per_hyperparameter");
|
|
|
public static final ParseField MAX_TREES = new ParseField("max_trees");
|
|
|
public static final ParseField NUM_FOLDS = new ParseField("num_folds");
|
|
|
public static final ParseField NUM_SPLITS_PER_FEATURE = new ParseField("num_splits_per_feature");
|
|
|
- public static final ParseField REGULARIZATION_DEPTH_PENALTY_MULTIPLIER = new ParseField("regularization_depth_penalty_multiplier");
|
|
|
- public static final ParseField REGULARIZATION_LEAF_WEIGHT_PENALTY_MULTIPLIER
|
|
|
- = new ParseField("regularization_leaf_weight_penalty_multiplier");
|
|
|
- public static final ParseField REGULARIZATION_SOFT_TREE_DEPTH_LIMIT = new ParseField("regularization_soft_tree_depth_limit");
|
|
|
- public static final ParseField REGULARIZATION_SOFT_TREE_DEPTH_TOLERANCE = new ParseField("regularization_soft_tree_depth_tolerance");
|
|
|
- public static final ParseField REGULARIZATION_TREE_SIZE_PENALTY_MULTIPLIER =
|
|
|
- new ParseField("regularization_tree_size_penalty_multiplier");
|
|
|
+ public static final ParseField SOFT_TREE_DEPTH_LIMIT = new ParseField("soft_tree_depth_limit");
|
|
|
+ public static final ParseField SOFT_TREE_DEPTH_TOLERANCE = new ParseField("soft_tree_depth_tolerance");
|
|
|
|
|
|
public static ConstructingObjectParser<Hyperparameters, Void> PARSER = new ConstructingObjectParser<>("classification_hyperparameters",
|
|
|
true,
|
|
|
@@ -57,88 +55,92 @@ public class Hyperparameters implements ToXContentObject {
|
|
|
(Double) a[2],
|
|
|
(Double) a[3],
|
|
|
(Double) a[4],
|
|
|
- (Integer) a[5],
|
|
|
- (Integer) a[6],
|
|
|
- (Integer) a[7],
|
|
|
+ (Double) a[5],
|
|
|
+ (Double) a[6],
|
|
|
+ (Double) a[7],
|
|
|
(Integer) a[8],
|
|
|
(Integer) a[9],
|
|
|
- (Double) a[10],
|
|
|
- (Double) a[11],
|
|
|
- (Double) a[12],
|
|
|
+ (Integer) a[10],
|
|
|
+ (Integer) a[11],
|
|
|
+ (Integer) a[12],
|
|
|
(Double) a[13],
|
|
|
(Double) a[14]
|
|
|
));
|
|
|
|
|
|
static {
|
|
|
PARSER.declareString(optionalConstructorArg(), CLASS_ASSIGNMENT_OBJECTIVE);
|
|
|
+ PARSER.declareDouble(optionalConstructorArg(), ALPHA);
|
|
|
PARSER.declareDouble(optionalConstructorArg(), DOWNSAMPLE_FACTOR);
|
|
|
PARSER.declareDouble(optionalConstructorArg(), ETA);
|
|
|
PARSER.declareDouble(optionalConstructorArg(), ETA_GROWTH_RATE_PER_TREE);
|
|
|
PARSER.declareDouble(optionalConstructorArg(), FEATURE_BAG_FRACTION);
|
|
|
+ PARSER.declareDouble(optionalConstructorArg(), GAMMA);
|
|
|
+ PARSER.declareDouble(optionalConstructorArg(), LAMBDA);
|
|
|
PARSER.declareInt(optionalConstructorArg(), MAX_ATTEMPTS_TO_ADD_TREE);
|
|
|
PARSER.declareInt(optionalConstructorArg(), MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER);
|
|
|
PARSER.declareInt(optionalConstructorArg(), MAX_TREES);
|
|
|
PARSER.declareInt(optionalConstructorArg(), NUM_FOLDS);
|
|
|
PARSER.declareInt(optionalConstructorArg(), NUM_SPLITS_PER_FEATURE);
|
|
|
- PARSER.declareDouble(optionalConstructorArg(), REGULARIZATION_DEPTH_PENALTY_MULTIPLIER);
|
|
|
- PARSER.declareDouble(optionalConstructorArg(), REGULARIZATION_LEAF_WEIGHT_PENALTY_MULTIPLIER);
|
|
|
- PARSER.declareDouble(optionalConstructorArg(), REGULARIZATION_SOFT_TREE_DEPTH_LIMIT);
|
|
|
- PARSER.declareDouble(optionalConstructorArg(), REGULARIZATION_SOFT_TREE_DEPTH_TOLERANCE);
|
|
|
- PARSER.declareDouble(optionalConstructorArg(), REGULARIZATION_TREE_SIZE_PENALTY_MULTIPLIER);
|
|
|
+ PARSER.declareDouble(optionalConstructorArg(), SOFT_TREE_DEPTH_LIMIT);
|
|
|
+ PARSER.declareDouble(optionalConstructorArg(), SOFT_TREE_DEPTH_TOLERANCE);
|
|
|
}
|
|
|
|
|
|
private final String classAssignmentObjective;
|
|
|
+ private final Double alpha;
|
|
|
private final Double downsampleFactor;
|
|
|
private final Double eta;
|
|
|
private final Double etaGrowthRatePerTree;
|
|
|
private final Double featureBagFraction;
|
|
|
+ private final Double gamma;
|
|
|
+ private final Double lambda;
|
|
|
private final Integer maxAttemptsToAddTree;
|
|
|
private final Integer maxOptimizationRoundsPerHyperparameter;
|
|
|
private final Integer maxTrees;
|
|
|
private final Integer numFolds;
|
|
|
private final Integer numSplitsPerFeature;
|
|
|
- private final Double regularizationDepthPenaltyMultiplier;
|
|
|
- private final Double regularizationLeafWeightPenaltyMultiplier;
|
|
|
- private final Double regularizationSoftTreeDepthLimit;
|
|
|
- private final Double regularizationSoftTreeDepthTolerance;
|
|
|
- private final Double regularizationTreeSizePenaltyMultiplier;
|
|
|
+ private final Double softTreeDepthLimit;
|
|
|
+ private final Double softTreeDepthTolerance;
|
|
|
|
|
|
public Hyperparameters(String classAssignmentObjective,
|
|
|
+ Double alpha,
|
|
|
Double downsampleFactor,
|
|
|
Double eta,
|
|
|
Double etaGrowthRatePerTree,
|
|
|
Double featureBagFraction,
|
|
|
+ Double gamma,
|
|
|
+ Double lambda,
|
|
|
Integer maxAttemptsToAddTree,
|
|
|
Integer maxOptimizationRoundsPerHyperparameter,
|
|
|
Integer maxTrees,
|
|
|
Integer numFolds,
|
|
|
Integer numSplitsPerFeature,
|
|
|
- Double regularizationDepthPenaltyMultiplier,
|
|
|
- Double regularizationLeafWeightPenaltyMultiplier,
|
|
|
- Double regularizationSoftTreeDepthLimit,
|
|
|
- Double regularizationSoftTreeDepthTolerance,
|
|
|
- Double regularizationTreeSizePenaltyMultiplier) {
|
|
|
+ Double softTreeDepthLimit,
|
|
|
+ Double softTreeDepthTolerance) {
|
|
|
this.classAssignmentObjective = classAssignmentObjective;
|
|
|
+ this.alpha = alpha;
|
|
|
this.downsampleFactor = downsampleFactor;
|
|
|
this.eta = eta;
|
|
|
this.etaGrowthRatePerTree = etaGrowthRatePerTree;
|
|
|
this.featureBagFraction = featureBagFraction;
|
|
|
+ this.gamma = gamma;
|
|
|
+ this.lambda = lambda;
|
|
|
this.maxAttemptsToAddTree = maxAttemptsToAddTree;
|
|
|
this.maxOptimizationRoundsPerHyperparameter = maxOptimizationRoundsPerHyperparameter;
|
|
|
this.maxTrees = maxTrees;
|
|
|
this.numFolds = numFolds;
|
|
|
this.numSplitsPerFeature = numSplitsPerFeature;
|
|
|
- this.regularizationDepthPenaltyMultiplier = regularizationDepthPenaltyMultiplier;
|
|
|
- this.regularizationLeafWeightPenaltyMultiplier = regularizationLeafWeightPenaltyMultiplier;
|
|
|
- this.regularizationSoftTreeDepthLimit = regularizationSoftTreeDepthLimit;
|
|
|
- this.regularizationSoftTreeDepthTolerance = regularizationSoftTreeDepthTolerance;
|
|
|
- this.regularizationTreeSizePenaltyMultiplier = regularizationTreeSizePenaltyMultiplier;
|
|
|
+ this.softTreeDepthLimit = softTreeDepthLimit;
|
|
|
+ this.softTreeDepthTolerance = softTreeDepthTolerance;
|
|
|
}
|
|
|
|
|
|
public String getClassAssignmentObjective() {
|
|
|
return classAssignmentObjective;
|
|
|
}
|
|
|
|
|
|
+ public Double getAlpha() {
|
|
|
+ return alpha;
|
|
|
+ }
|
|
|
+
|
|
|
public Double getDownsampleFactor() {
|
|
|
return downsampleFactor;
|
|
|
}
|
|
|
@@ -155,6 +157,14 @@ public class Hyperparameters implements ToXContentObject {
|
|
|
return featureBagFraction;
|
|
|
}
|
|
|
|
|
|
+ public Double getGamma() {
|
|
|
+ return gamma;
|
|
|
+ }
|
|
|
+
|
|
|
+ public Double getLambda() {
|
|
|
+ return lambda;
|
|
|
+ }
|
|
|
+
|
|
|
public Integer getMaxAttemptsToAddTree() {
|
|
|
return maxAttemptsToAddTree;
|
|
|
}
|
|
|
@@ -175,24 +185,12 @@ public class Hyperparameters implements ToXContentObject {
|
|
|
return numSplitsPerFeature;
|
|
|
}
|
|
|
|
|
|
- public Double getRegularizationDepthPenaltyMultiplier() {
|
|
|
- return regularizationDepthPenaltyMultiplier;
|
|
|
- }
|
|
|
-
|
|
|
- public Double getRegularizationLeafWeightPenaltyMultiplier() {
|
|
|
- return regularizationLeafWeightPenaltyMultiplier;
|
|
|
- }
|
|
|
-
|
|
|
- public Double getRegularizationSoftTreeDepthLimit() {
|
|
|
- return regularizationSoftTreeDepthLimit;
|
|
|
- }
|
|
|
-
|
|
|
- public Double getRegularizationSoftTreeDepthTolerance() {
|
|
|
- return regularizationSoftTreeDepthTolerance;
|
|
|
+ public Double getSoftTreeDepthLimit() {
|
|
|
+ return softTreeDepthLimit;
|
|
|
}
|
|
|
|
|
|
- public Double getRegularizationTreeSizePenaltyMultiplier() {
|
|
|
- return regularizationTreeSizePenaltyMultiplier;
|
|
|
+ public Double getSoftTreeDepthTolerance() {
|
|
|
+ return softTreeDepthTolerance;
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
@@ -201,6 +199,9 @@ public class Hyperparameters implements ToXContentObject {
|
|
|
if (classAssignmentObjective != null) {
|
|
|
builder.field(CLASS_ASSIGNMENT_OBJECTIVE.getPreferredName(), classAssignmentObjective);
|
|
|
}
|
|
|
+ if (alpha != null) {
|
|
|
+ builder.field(ALPHA.getPreferredName(), alpha);
|
|
|
+ }
|
|
|
if (downsampleFactor != null) {
|
|
|
builder.field(DOWNSAMPLE_FACTOR.getPreferredName(), downsampleFactor);
|
|
|
}
|
|
|
@@ -213,6 +214,12 @@ public class Hyperparameters implements ToXContentObject {
|
|
|
if (featureBagFraction != null) {
|
|
|
builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction);
|
|
|
}
|
|
|
+ if (gamma != null) {
|
|
|
+ builder.field(GAMMA.getPreferredName(), gamma);
|
|
|
+ }
|
|
|
+ if (lambda != null) {
|
|
|
+ builder.field(LAMBDA.getPreferredName(), lambda);
|
|
|
+ }
|
|
|
if (maxAttemptsToAddTree != null) {
|
|
|
builder.field(MAX_ATTEMPTS_TO_ADD_TREE.getPreferredName(), maxAttemptsToAddTree);
|
|
|
}
|
|
|
@@ -228,20 +235,11 @@ public class Hyperparameters implements ToXContentObject {
|
|
|
if (numSplitsPerFeature != null) {
|
|
|
builder.field(NUM_SPLITS_PER_FEATURE.getPreferredName(), numSplitsPerFeature);
|
|
|
}
|
|
|
- if (regularizationDepthPenaltyMultiplier != null) {
|
|
|
- builder.field(REGULARIZATION_DEPTH_PENALTY_MULTIPLIER.getPreferredName(), regularizationDepthPenaltyMultiplier);
|
|
|
- }
|
|
|
- if (regularizationLeafWeightPenaltyMultiplier != null) {
|
|
|
- builder.field(REGULARIZATION_LEAF_WEIGHT_PENALTY_MULTIPLIER.getPreferredName(), regularizationLeafWeightPenaltyMultiplier);
|
|
|
- }
|
|
|
- if (regularizationSoftTreeDepthLimit != null) {
|
|
|
- builder.field(REGULARIZATION_SOFT_TREE_DEPTH_LIMIT.getPreferredName(), regularizationSoftTreeDepthLimit);
|
|
|
- }
|
|
|
- if (regularizationSoftTreeDepthTolerance != null) {
|
|
|
- builder.field(REGULARIZATION_SOFT_TREE_DEPTH_TOLERANCE.getPreferredName(), regularizationSoftTreeDepthTolerance);
|
|
|
+ if (softTreeDepthLimit != null) {
|
|
|
+ builder.field(SOFT_TREE_DEPTH_LIMIT.getPreferredName(), softTreeDepthLimit);
|
|
|
}
|
|
|
- if (regularizationTreeSizePenaltyMultiplier != null) {
|
|
|
- builder.field(REGULARIZATION_TREE_SIZE_PENALTY_MULTIPLIER.getPreferredName(), regularizationTreeSizePenaltyMultiplier);
|
|
|
+ if (softTreeDepthTolerance != null) {
|
|
|
+ builder.field(SOFT_TREE_DEPTH_TOLERANCE.getPreferredName(), softTreeDepthTolerance);
|
|
|
}
|
|
|
builder.endObject();
|
|
|
return builder;
|
|
|
@@ -254,40 +252,40 @@ public class Hyperparameters implements ToXContentObject {
|
|
|
|
|
|
Hyperparameters that = (Hyperparameters) o;
|
|
|
return Objects.equals(classAssignmentObjective, that.classAssignmentObjective)
|
|
|
+ && Objects.equals(alpha, that.alpha)
|
|
|
&& Objects.equals(downsampleFactor, that.downsampleFactor)
|
|
|
&& Objects.equals(eta, that.eta)
|
|
|
&& Objects.equals(etaGrowthRatePerTree, that.etaGrowthRatePerTree)
|
|
|
&& Objects.equals(featureBagFraction, that.featureBagFraction)
|
|
|
+ && Objects.equals(gamma, that.gamma)
|
|
|
+ && Objects.equals(lambda, that.lambda)
|
|
|
&& Objects.equals(maxAttemptsToAddTree, that.maxAttemptsToAddTree)
|
|
|
&& Objects.equals(maxOptimizationRoundsPerHyperparameter, that.maxOptimizationRoundsPerHyperparameter)
|
|
|
&& Objects.equals(maxTrees, that.maxTrees)
|
|
|
&& Objects.equals(numFolds, that.numFolds)
|
|
|
&& Objects.equals(numSplitsPerFeature, that.numSplitsPerFeature)
|
|
|
- && Objects.equals(regularizationDepthPenaltyMultiplier, that.regularizationDepthPenaltyMultiplier)
|
|
|
- && Objects.equals(regularizationLeafWeightPenaltyMultiplier, that.regularizationLeafWeightPenaltyMultiplier)
|
|
|
- && Objects.equals(regularizationSoftTreeDepthLimit, that.regularizationSoftTreeDepthLimit)
|
|
|
- && Objects.equals(regularizationSoftTreeDepthTolerance, that.regularizationSoftTreeDepthTolerance)
|
|
|
- && Objects.equals(regularizationTreeSizePenaltyMultiplier, that.regularizationTreeSizePenaltyMultiplier);
|
|
|
+ && Objects.equals(softTreeDepthLimit, that.softTreeDepthLimit)
|
|
|
+ && Objects.equals(softTreeDepthTolerance, that.softTreeDepthTolerance);
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
public int hashCode() {
|
|
|
return Objects.hash(
|
|
|
classAssignmentObjective,
|
|
|
+ alpha,
|
|
|
downsampleFactor,
|
|
|
eta,
|
|
|
etaGrowthRatePerTree,
|
|
|
featureBagFraction,
|
|
|
+ gamma,
|
|
|
+ lambda,
|
|
|
maxAttemptsToAddTree,
|
|
|
maxOptimizationRoundsPerHyperparameter,
|
|
|
maxTrees,
|
|
|
numFolds,
|
|
|
numSplitsPerFeature,
|
|
|
- regularizationDepthPenaltyMultiplier,
|
|
|
- regularizationLeafWeightPenaltyMultiplier,
|
|
|
- regularizationSoftTreeDepthLimit,
|
|
|
- regularizationSoftTreeDepthTolerance,
|
|
|
- regularizationTreeSizePenaltyMultiplier
|
|
|
+ softTreeDepthLimit,
|
|
|
+ softTreeDepthTolerance
|
|
|
);
|
|
|
}
|
|
|
}
|