|
@@ -91,7 +91,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
|
|
String predictionFieldName = randomBoolean() ? null : randomAlphaOfLength(10);
|
|
|
Classification.ClassAssignmentObjective classAssignmentObjective = randomBoolean() ?
|
|
|
null : randomFrom(Classification.ClassAssignmentObjective.values());
|
|
|
- Integer numTopClasses = randomBoolean() ? null : randomIntBetween(0, 1000);
|
|
|
+ Integer numTopClasses = randomBoolean() ? null : randomIntBetween(-1, 1000);
|
|
|
Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(0.0, 100.0, false);
|
|
|
Long randomizeSeed = randomBoolean() ? null : randomLong();
|
|
|
return new Classification(dependentVariableName, boostedTreeParams, predictionFieldName, classAssignmentObjective,
|
|
@@ -218,18 +218,18 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
|
|
assertThat(e.getMessage(), equalTo("[training_percent] must be a positive double in (0, 100]"));
|
|
|
}
|
|
|
|
|
|
- public void testConstructor_GivenNumTopClassesIsLessThanZero() {
|
|
|
+ public void testConstructor_GivenNumTopClassesIsLessThanMinusOne() {
|
|
|
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
|
|
- () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -1, 1.0, randomLong(), null));
|
|
|
+ () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -2, 1.0, randomLong(), null));
|
|
|
|
|
|
- assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]"));
|
|
|
+ assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000] or a special value -1"));
|
|
|
}
|
|
|
|
|
|
public void testConstructor_GivenNumTopClassesIsGreaterThan1000() {
|
|
|
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
|
|
() -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1001, 1.0, randomLong(), null));
|
|
|
|
|
|
- assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]"));
|
|
|
+ assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000] or a special value -1"));
|
|
|
}
|
|
|
|
|
|
public void testGetPredictionFieldName() {
|
|
@@ -258,6 +258,10 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
|
|
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong(), null);
|
|
|
assertThat(classification.getNumTopClasses(), equalTo(7));
|
|
|
|
|
|
+ // Special value: num_top_classes == -1
|
|
|
+ classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -1, 1.0, randomLong(), null);
|
|
|
+ assertThat(classification.getNumTopClasses(), equalTo(-1));
|
|
|
+
|
|
|
// Boundary condition: num_top_classes == 0
|
|
|
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 0, 1.0, randomLong(), null);
|
|
|
assertThat(classification.getNumTopClasses(), equalTo(0));
|