|
@@ -54,17 +54,20 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
|
|
String dependentVariableName = randomAlphaOfLength(10);
|
|
|
BoostedTreeParams boostedTreeParams = BoostedTreeParamsTests.createRandom();
|
|
|
String predictionFieldName = randomBoolean() ? null : randomAlphaOfLength(10);
|
|
|
+ Classification.ClassAssignmentObjective classAssignmentObjective = randomBoolean() ?
|
|
|
+ null : randomFrom(Classification.ClassAssignmentObjective.values());
|
|
|
Integer numTopClasses = randomBoolean() ? null : randomIntBetween(0, 1000);
|
|
|
Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true);
|
|
|
Long randomizeSeed = randomBoolean() ? null : randomLong();
|
|
|
- return new Classification(dependentVariableName, boostedTreeParams, predictionFieldName, numTopClasses, trainingPercent,
|
|
|
- randomizeSeed);
|
|
|
+ return new Classification(dependentVariableName, boostedTreeParams, predictionFieldName, classAssignmentObjective,
|
|
|
+ numTopClasses, trainingPercent, randomizeSeed);
|
|
|
}
|
|
|
|
|
|
public static Classification mutateForVersion(Classification instance, Version version) {
|
|
|
return new Classification(instance.getDependentVariable(),
|
|
|
BoostedTreeParamsTests.mutateForVersion(instance.getBoostedTreeParams(), version),
|
|
|
instance.getPredictionFieldName(),
|
|
|
+ version.onOrAfter(Version.V_8_0_0) ? instance.getClassAssignmentObjective() : null,
|
|
|
instance.getNumTopClasses(),
|
|
|
instance.getTrainingPercent(),
|
|
|
instance.getRandomizeSeed());
|
|
@@ -80,12 +83,14 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
|
|
Classification newBwc = new Classification(bwcSerializedObject.getDependentVariable(),
|
|
|
bwcSerializedObject.getBoostedTreeParams(),
|
|
|
bwcSerializedObject.getPredictionFieldName(),
|
|
|
+ bwcSerializedObject.getClassAssignmentObjective(),
|
|
|
bwcSerializedObject.getNumTopClasses(),
|
|
|
bwcSerializedObject.getTrainingPercent(),
|
|
|
42L);
|
|
|
Classification newInstance = new Classification(testInstance.getDependentVariable(),
|
|
|
testInstance.getBoostedTreeParams(),
|
|
|
testInstance.getPredictionFieldName(),
|
|
|
+ testInstance.getClassAssignmentObjective(),
|
|
|
testInstance.getNumTopClasses(),
|
|
|
testInstance.getTrainingPercent(),
|
|
|
42L);
|
|
@@ -99,71 +104,85 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
|
|
|
|
|
public void testConstructor_GivenTrainingPercentIsLessThanOne() {
|
|
|
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
|
|
- () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 0.999, randomLong()));
|
|
|
+ () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 0.999, randomLong()));
|
|
|
|
|
|
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
|
|
|
}
|
|
|
|
|
|
public void testConstructor_GivenTrainingPercentIsGreaterThan100() {
|
|
|
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
|
|
- () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 100.0001, randomLong()));
|
|
|
+ () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0001, randomLong()));
|
|
|
|
|
|
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
|
|
|
}
|
|
|
|
|
|
public void testConstructor_GivenNumTopClassesIsLessThanZero() {
|
|
|
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
|
|
- () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", -1, 1.0, randomLong()));
|
|
|
+ () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -1, 1.0, randomLong()));
|
|
|
|
|
|
assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]"));
|
|
|
}
|
|
|
|
|
|
public void testConstructor_GivenNumTopClassesIsGreaterThan1000() {
|
|
|
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
|
|
- () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 1001, 1.0, randomLong()));
|
|
|
+ () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1001, 1.0, randomLong()));
|
|
|
|
|
|
assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]"));
|
|
|
}
|
|
|
|
|
|
public void testGetPredictionFieldName() {
|
|
|
- Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 50.0, randomLong());
|
|
|
+ Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 50.0, randomLong());
|
|
|
assertThat(classification.getPredictionFieldName(), equalTo("result"));
|
|
|
|
|
|
- classification = new Classification("foo", BOOSTED_TREE_PARAMS, null, 3, 50.0, randomLong());
|
|
|
+ classification = new Classification("foo", BOOSTED_TREE_PARAMS, null, null, 3, 50.0, randomLong());
|
|
|
assertThat(classification.getPredictionFieldName(), equalTo("foo_prediction"));
|
|
|
}
|
|
|
|
|
|
+ public void testClassAssignmentObjective() {
|
|
|
+ Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result",
|
|
|
+ Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY, 7, 1.0, randomLong());
|
|
|
+ assertThat(classification.getClassAssignmentObjective(), equalTo(Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY));
|
|
|
+
|
|
|
+ classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result",
|
|
|
+ Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL, 7, 1.0, randomLong());
|
|
|
+ assertThat(classification.getClassAssignmentObjective(), equalTo(Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL));
|
|
|
+
|
|
|
+ // class_assignment_objective == null, default applied
|
|
|
+ classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong());
|
|
|
+ assertThat(classification.getClassAssignmentObjective(), equalTo(Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL));
|
|
|
+ }
|
|
|
+
|
|
|
public void testGetNumTopClasses() {
|
|
|
- Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 7, 1.0, randomLong());
|
|
|
+ Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong());
|
|
|
assertThat(classification.getNumTopClasses(), equalTo(7));
|
|
|
|
|
|
// Boundary condition: num_top_classes == 0
|
|
|
- classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 0, 1.0, randomLong());
|
|
|
+ classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 0, 1.0, randomLong());
|
|
|
assertThat(classification.getNumTopClasses(), equalTo(0));
|
|
|
|
|
|
// Boundary condition: num_top_classes == 1000
|
|
|
- classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 1000, 1.0, randomLong());
|
|
|
+ classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1000, 1.0, randomLong());
|
|
|
assertThat(classification.getNumTopClasses(), equalTo(1000));
|
|
|
|
|
|
// num_top_classes == null, default applied
|
|
|
- classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1.0, randomLong());
|
|
|
+ classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, null, 1.0, randomLong());
|
|
|
assertThat(classification.getNumTopClasses(), equalTo(2));
|
|
|
}
|
|
|
|
|
|
public void testGetTrainingPercent() {
|
|
|
- Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 50.0, randomLong());
|
|
|
+ Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 50.0, randomLong());
|
|
|
assertThat(classification.getTrainingPercent(), equalTo(50.0));
|
|
|
|
|
|
// Boundary condition: training_percent == 1.0
|
|
|
- classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 1.0, randomLong());
|
|
|
+ classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 1.0, randomLong());
|
|
|
assertThat(classification.getTrainingPercent(), equalTo(1.0));
|
|
|
|
|
|
// Boundary condition: training_percent == 100.0
|
|
|
- classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 100.0, randomLong());
|
|
|
+ classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0, randomLong());
|
|
|
assertThat(classification.getTrainingPercent(), equalTo(100.0));
|
|
|
|
|
|
// training_percent == null, default applied
|
|
|
- classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, null, randomLong());
|
|
|
+ classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, null, randomLong());
|
|
|
assertThat(classification.getTrainingPercent(), equalTo(100.0));
|
|
|
}
|
|
|
|
|
@@ -178,6 +197,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
|
|
equalTo(
|
|
|
Map.of(
|
|
|
"dependent_variable", "foo",
|
|
|
+ "class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL,
|
|
|
"num_top_classes", 2,
|
|
|
"prediction_field_name", "foo_prediction",
|
|
|
"prediction_field_type", "bool")));
|
|
@@ -186,6 +206,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
|
|
equalTo(
|
|
|
Map.of(
|
|
|
"dependent_variable", "bar",
|
|
|
+ "class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL,
|
|
|
"num_top_classes", 2,
|
|
|
"prediction_field_name", "bar_prediction",
|
|
|
"prediction_field_type", "int")));
|
|
@@ -194,6 +215,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
|
|
equalTo(
|
|
|
Map.of(
|
|
|
"dependent_variable", "baz",
|
|
|
+ "class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL,
|
|
|
"num_top_classes", 2,
|
|
|
"prediction_field_name", "baz_prediction",
|
|
|
"prediction_field_type", "string")));
|