|
@@ -179,8 +179,8 @@ public class StratifiedCrossValidationSplitterTests extends ESTestCase {
|
|
|
|
|
|
double expectedTotalTrainingCount = ROWS_COUNT * trainingFraction;
|
|
|
assertThat(trainingDocsCount + testDocsCount, equalTo((long) ROWS_COUNT));
|
|
|
- assertThat(trainingDocsCount, greaterThanOrEqualTo((long) Math.floor(expectedTotalTrainingCount - 1)));
|
|
|
- assertThat(trainingDocsCount, lessThanOrEqualTo((long) Math.ceil(expectedTotalTrainingCount + 1)));
|
|
|
+ assertThat(trainingDocsCount, greaterThanOrEqualTo((long) (expectedTotalTrainingCount - 2)));
|
|
|
+ assertThat(trainingDocsCount, lessThanOrEqualTo((long) Math.ceil(expectedTotalTrainingCount) + 2));
|
|
|
|
|
|
for (String classValue : classCardinalities.keySet()) {
|
|
|
double expectedClassTrainingCount = totalRowsPerClass.get(classValue) * trainingFraction;
|
|
@@ -221,7 +221,7 @@ public class StratifiedCrossValidationSplitterTests extends ESTestCase {
|
|
|
// should be close to the training percent, which is set to 0.5
|
|
|
for (int rowTrainingCount : trainingCountPerRow) {
|
|
|
double meanCount = rowTrainingCount / (double) runCount;
|
|
|
- assertThat(meanCount, is(closeTo(0.5, 0.1)));
|
|
|
+ assertThat(meanCount, is(closeTo(0.5, 0.12)));
|
|
|
}
|
|
|
}
|
|
|
|