|
@@ -97,6 +97,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
|
|
Integer numTopClasses = randomBoolean() ? null : randomIntBetween(-1, 1000);
|
|
|
Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(0.0, 100.0, false);
|
|
|
Long randomizeSeed = randomBoolean() ? null : randomLong();
|
|
|
+ Boolean earlyStoppingEnabled = randomBoolean() ? null : randomBoolean();
|
|
|
return new Classification(dependentVariableName, boostedTreeParams, predictionFieldName, classAssignmentObjective,
|
|
|
numTopClasses, trainingPercent, randomizeSeed,
|
|
|
randomBoolean() ?
|
|
@@ -105,7 +106,8 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
|
|
OneHotEncodingTests.createRandom(true),
|
|
|
TargetMeanEncodingTests.createRandom(true)))
|
|
|
.limit(randomIntBetween(0, 5))
|
|
|
- .collect(Collectors.toList()));
|
|
|
+ .collect(Collectors.toList()),
|
|
|
+ earlyStoppingEnabled);
|
|
|
}
|
|
|
|
|
|
public static Classification mutateForVersion(Classification instance, Version version) {
|
|
@@ -116,7 +118,8 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
|
|
instance.getNumTopClasses(),
|
|
|
instance.getTrainingPercent(),
|
|
|
instance.getRandomizeSeed(),
|
|
|
- version.onOrAfter(Version.V_7_10_0) ? instance.getFeatureProcessors() : Collections.emptyList());
|
|
|
+ version.onOrAfter(Version.V_7_10_0) ? instance.getFeatureProcessors() : Collections.emptyList(),
|
|
|
+ version.onOrAfter(Version.V_8_0_0) ? instance.getEarlyStoppingEnabled() : null);
|
|
|
}
|
|
|
|
|
|
@Override
|
|
@@ -133,7 +136,8 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
|
|
bwcSerializedObject.getNumTopClasses(),
|
|
|
bwcSerializedObject.getTrainingPercent(),
|
|
|
42L,
|
|
|
- bwcSerializedObject.getFeatureProcessors());
|
|
|
+ bwcSerializedObject.getFeatureProcessors(),
|
|
|
+ bwcSerializedObject.getEarlyStoppingEnabled());
|
|
|
Classification newInstance = new Classification(testInstance.getDependentVariable(),
|
|
|
testInstance.getBoostedTreeParams(),
|
|
|
testInstance.getPredictionFieldName(),
|
|
@@ -141,7 +145,8 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
|
|
testInstance.getNumTopClasses(),
|
|
|
testInstance.getTrainingPercent(),
|
|
|
42L,
|
|
|
- testInstance.getFeatureProcessors());
|
|
|
+ testInstance.getFeatureProcessors(),
|
|
|
+ testInstance.getEarlyStoppingEnabled());
|
|
|
super.assertOnBWCObject(newBwc, newInstance, version);
|
|
|
}
|
|
|
|
|
@@ -202,96 +207,96 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
|
|
|
|
|
public void testConstructor_GivenTrainingPercentIsZero() {
|
|
|
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
|
|
- () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 0.0, randomLong(), null));
|
|
|
+ () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 0.0, randomLong(), null, null));
|
|
|
|
|
|
assertThat(e.getMessage(), equalTo("[training_percent] must be a positive double in (0, 100]"));
|
|
|
}
|
|
|
|
|
|
public void testConstructor_GivenTrainingPercentIsLessThanZero() {
|
|
|
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
|
|
- () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, -1.0, randomLong(), null));
|
|
|
+ () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, -1.0, randomLong(), null, null));
|
|
|
|
|
|
assertThat(e.getMessage(), equalTo("[training_percent] must be a positive double in (0, 100]"));
|
|
|
}
|
|
|
|
|
|
public void testConstructor_GivenTrainingPercentIsGreaterThan100() {
|
|
|
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
|
|
- () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0001, randomLong(), null));
|
|
|
+ () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0001, randomLong(), null, null));
|
|
|
|
|
|
assertThat(e.getMessage(), equalTo("[training_percent] must be a positive double in (0, 100]"));
|
|
|
}
|
|
|
|
|
|
public void testConstructor_GivenNumTopClassesIsLessThanMinusOne() {
|
|
|
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
|
|
- () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -2, 1.0, randomLong(), null));
|
|
|
+ () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -2, 1.0, randomLong(), null, null));
|
|
|
|
|
|
assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000] or a special value -1"));
|
|
|
}
|
|
|
|
|
|
public void testConstructor_GivenNumTopClassesIsGreaterThan1000() {
|
|
|
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
|
|
- () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1001, 1.0, randomLong(), null));
|
|
|
+ () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1001, 1.0, randomLong(), null, null));
|
|
|
|
|
|
assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000] or a special value -1"));
|
|
|
}
|
|
|
|
|
|
public void testGetPredictionFieldName() {
|
|
|
- Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 50.0, randomLong(), null);
|
|
|
+ Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 50.0, randomLong(), null, null);
|
|
|
assertThat(classification.getPredictionFieldName(), equalTo("result"));
|
|
|
|
|
|
- classification = new Classification("foo", BOOSTED_TREE_PARAMS, null, null, 3, 50.0, randomLong(), null);
|
|
|
+ classification = new Classification("foo", BOOSTED_TREE_PARAMS, null, null, 3, 50.0, randomLong(), null, null);
|
|
|
assertThat(classification.getPredictionFieldName(), equalTo("foo_prediction"));
|
|
|
}
|
|
|
|
|
|
public void testClassAssignmentObjective() {
|
|
|
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result",
|
|
|
- Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY, 7, 1.0, randomLong(), null);
|
|
|
+ Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY, 7, 1.0, randomLong(), null, null);
|
|
|
assertThat(classification.getClassAssignmentObjective(), equalTo(Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY));
|
|
|
|
|
|
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result",
|
|
|
- Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL, 7, 1.0, randomLong(), null);
|
|
|
+ Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL, 7, 1.0, randomLong(), null, null);
|
|
|
assertThat(classification.getClassAssignmentObjective(), equalTo(Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL));
|
|
|
|
|
|
// class_assignment_objective == null, default applied
|
|
|
- classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong(), null);
|
|
|
+ classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong(), null, null);
|
|
|
assertThat(classification.getClassAssignmentObjective(), equalTo(Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL));
|
|
|
}
|
|
|
|
|
|
public void testGetNumTopClasses() {
|
|
|
- Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong(), null);
|
|
|
+ Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong(), null, null);
|
|
|
assertThat(classification.getNumTopClasses(), equalTo(7));
|
|
|
|
|
|
// Special value: num_top_classes == -1
|
|
|
- classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -1, 1.0, randomLong(), null);
|
|
|
+ classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -1, 1.0, randomLong(), null, null);
|
|
|
assertThat(classification.getNumTopClasses(), equalTo(-1));
|
|
|
|
|
|
// Boundary condition: num_top_classes == 0
|
|
|
- classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 0, 1.0, randomLong(), null);
|
|
|
+ classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 0, 1.0, randomLong(), null, null);
|
|
|
assertThat(classification.getNumTopClasses(), equalTo(0));
|
|
|
|
|
|
// Boundary condition: num_top_classes == 1000
|
|
|
- classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1000, 1.0, randomLong(), null);
|
|
|
+ classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1000, 1.0, randomLong(), null, null);
|
|
|
assertThat(classification.getNumTopClasses(), equalTo(1000));
|
|
|
|
|
|
// num_top_classes == null, default applied
|
|
|
- classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, null, 1.0, randomLong(), null);
|
|
|
+ classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, null, 1.0, randomLong(), null, null);
|
|
|
assertThat(classification.getNumTopClasses(), equalTo(2));
|
|
|
}
|
|
|
|
|
|
public void testGetTrainingPercent() {
|
|
|
- Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 50.0, randomLong(), null);
|
|
|
+ Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 50.0, randomLong(), null, null);
|
|
|
assertThat(classification.getTrainingPercent(), equalTo(50.0));
|
|
|
|
|
|
// Boundary condition: training_percent == 1.0
|
|
|
- classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 1.0, randomLong(), null);
|
|
|
+ classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 1.0, randomLong(), null, null);
|
|
|
assertThat(classification.getTrainingPercent(), equalTo(1.0));
|
|
|
|
|
|
// Boundary condition: training_percent == 100.0
|
|
|
- classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0, randomLong(), null);
|
|
|
+ classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0, randomLong(), null, null);
|
|
|
assertThat(classification.getTrainingPercent(), equalTo(100.0));
|
|
|
|
|
|
// training_percent == null, default applied
|
|
|
- classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, null, randomLong(), null);
|
|
|
+ classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, null, randomLong(), null, null);
|
|
|
assertThat(classification.getTrainingPercent(), equalTo(100.0));
|
|
|
}
|
|
|
|
|
@@ -316,7 +321,8 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
|
|
"prediction_field_name", "foo_prediction",
|
|
|
"prediction_field_type", "bool",
|
|
|
"num_classes", 10L,
|
|
|
- "training_percent", 100.0)));
|
|
|
+ "training_percent", 100.0,
|
|
|
+ "early_stopping_enabled", true)));
|
|
|
assertThat(
|
|
|
new Classification("bar").getParams(fieldInfo),
|
|
|
equalTo(
|
|
@@ -327,7 +333,8 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
|
|
"prediction_field_name", "bar_prediction",
|
|
|
"prediction_field_type", "int",
|
|
|
"num_classes", 20L,
|
|
|
- "training_percent", 100.0)));
|
|
|
+ "training_percent", 100.0,
|
|
|
+ "early_stopping_enabled", true)));
|
|
|
assertThat(
|
|
|
new Classification("baz",
|
|
|
BoostedTreeParams.builder().build() ,
|
|
@@ -336,6 +343,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
|
|
null,
|
|
|
50.0,
|
|
|
null,
|
|
|
+ null,
|
|
|
null).getParams(fieldInfo),
|
|
|
equalTo(
|
|
|
Map.of(
|
|
@@ -345,7 +353,8 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
|
|
"prediction_field_name", "baz_prediction",
|
|
|
"prediction_field_type", "string",
|
|
|
"num_classes", 30L,
|
|
|
- "training_percent", 50.0)));
|
|
|
+ "training_percent", 50.0,
|
|
|
+ "early_stopping_enabled", true)));
|
|
|
}
|
|
|
|
|
|
public void testRequiredFieldsIsNonEmpty() {
|