|
@@ -46,6 +46,7 @@ public class Classification implements DataFrameAnalysis {
|
|
|
static final ParseField ETA = new ParseField("eta");
|
|
|
static final ParseField MAXIMUM_NUMBER_TREES = new ParseField("maximum_number_trees");
|
|
|
static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction");
|
|
|
+ static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values");
|
|
|
static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
|
|
|
static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
|
|
|
static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
|
|
@@ -62,10 +63,11 @@ public class Classification implements DataFrameAnalysis {
|
|
|
(Double) a[3],
|
|
|
(Integer) a[4],
|
|
|
(Double) a[5],
|
|
|
- (String) a[6],
|
|
|
- (Double) a[7],
|
|
|
- (Integer) a[8],
|
|
|
- (Long) a[9]));
|
|
|
+ (Integer) a[6],
|
|
|
+ (String) a[7],
|
|
|
+ (Double) a[8],
|
|
|
+ (Integer) a[9],
|
|
|
+ (Long) a[10]));
|
|
|
|
|
|
static {
|
|
|
PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE);
|
|
@@ -74,6 +76,7 @@ public class Classification implements DataFrameAnalysis {
|
|
|
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), ETA);
|
|
|
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), MAXIMUM_NUMBER_TREES);
|
|
|
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), FEATURE_BAG_FRACTION);
|
|
|
+ PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_TOP_FEATURE_IMPORTANCE_VALUES);
|
|
|
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME);
|
|
|
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT);
|
|
|
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_TOP_CLASSES);
|
|
@@ -86,13 +89,15 @@ public class Classification implements DataFrameAnalysis {
|
|
|
private final Double eta;
|
|
|
private final Integer maximumNumberTrees;
|
|
|
private final Double featureBagFraction;
|
|
|
+ private final Integer numTopFeatureImportanceValues;
|
|
|
private final String predictionFieldName;
|
|
|
private final Double trainingPercent;
|
|
|
private final Integer numTopClasses;
|
|
|
private final Long randomizeSeed;
|
|
|
|
|
|
private Classification(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
|
|
|
- @Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName,
|
|
|
+ @Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction,
|
|
|
+ @Nullable Integer numTopFeatureImportanceValues, @Nullable String predictionFieldName,
|
|
|
@Nullable Double trainingPercent, @Nullable Integer numTopClasses, @Nullable Long randomizeSeed) {
|
|
|
this.dependentVariable = Objects.requireNonNull(dependentVariable);
|
|
|
this.lambda = lambda;
|
|
@@ -100,6 +105,7 @@ public class Classification implements DataFrameAnalysis {
|
|
|
this.eta = eta;
|
|
|
this.maximumNumberTrees = maximumNumberTrees;
|
|
|
this.featureBagFraction = featureBagFraction;
|
|
|
+ this.numTopFeatureImportanceValues = numTopFeatureImportanceValues;
|
|
|
this.predictionFieldName = predictionFieldName;
|
|
|
this.trainingPercent = trainingPercent;
|
|
|
this.numTopClasses = numTopClasses;
|
|
@@ -135,6 +141,10 @@ public class Classification implements DataFrameAnalysis {
|
|
|
return featureBagFraction;
|
|
|
}
|
|
|
|
|
|
+ public Integer getNumTopFeatureImportanceValues() {
|
|
|
+ return numTopFeatureImportanceValues;
|
|
|
+ }
|
|
|
+
|
|
|
public String getPredictionFieldName() {
|
|
|
return predictionFieldName;
|
|
|
}
|
|
@@ -170,6 +180,9 @@ public class Classification implements DataFrameAnalysis {
|
|
|
if (featureBagFraction != null) {
|
|
|
builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction);
|
|
|
}
|
|
|
+ if (numTopFeatureImportanceValues != null) {
|
|
|
+ builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues);
|
|
|
+ }
|
|
|
if (predictionFieldName != null) {
|
|
|
builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
|
|
|
}
|
|
@@ -188,8 +201,8 @@ public class Classification implements DataFrameAnalysis {
|
|
|
|
|
|
@Override
|
|
|
public int hashCode() {
|
|
|
- return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
|
|
|
- trainingPercent, randomizeSeed, numTopClasses);
|
|
|
+ return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, numTopFeatureImportanceValues,
|
|
|
+ predictionFieldName, trainingPercent, randomizeSeed, numTopClasses);
|
|
|
}
|
|
|
|
|
|
@Override
|
|
@@ -203,6 +216,7 @@ public class Classification implements DataFrameAnalysis {
|
|
|
&& Objects.equals(eta, that.eta)
|
|
|
&& Objects.equals(maximumNumberTrees, that.maximumNumberTrees)
|
|
|
&& Objects.equals(featureBagFraction, that.featureBagFraction)
|
|
|
+ && Objects.equals(numTopFeatureImportanceValues, that.numTopFeatureImportanceValues)
|
|
|
&& Objects.equals(predictionFieldName, that.predictionFieldName)
|
|
|
&& Objects.equals(trainingPercent, that.trainingPercent)
|
|
|
&& Objects.equals(randomizeSeed, that.randomizeSeed)
|
|
@@ -221,6 +235,7 @@ public class Classification implements DataFrameAnalysis {
|
|
|
private Double eta;
|
|
|
private Integer maximumNumberTrees;
|
|
|
private Double featureBagFraction;
|
|
|
+ private Integer numTopFeatureImportanceValues;
|
|
|
private String predictionFieldName;
|
|
|
private Double trainingPercent;
|
|
|
private Integer numTopClasses;
|
|
@@ -255,6 +270,11 @@ public class Classification implements DataFrameAnalysis {
|
|
|
return this;
|
|
|
}
|
|
|
|
|
|
+ public Builder setNumTopFeatureImportanceValues(Integer numTopFeatureImportanceValues) {
|
|
|
+ this.numTopFeatureImportanceValues = numTopFeatureImportanceValues;
|
|
|
+ return this;
|
|
|
+ }
|
|
|
+
|
|
|
public Builder setPredictionFieldName(String predictionFieldName) {
|
|
|
this.predictionFieldName = predictionFieldName;
|
|
|
return this;
|
|
@@ -276,8 +296,8 @@ public class Classification implements DataFrameAnalysis {
|
|
|
}
|
|
|
|
|
|
public Classification build() {
|
|
|
- return new Classification(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
|
|
|
- trainingPercent, numTopClasses, randomizeSeed);
|
|
|
+ return new Classification(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction,
|
|
|
+ numTopFeatureImportanceValues, predictionFieldName, trainingPercent, numTopClasses, randomizeSeed);
|
|
|
}
|
|
|
}
|
|
|
}
|