Browse Source

[ML] Avoid classification integ test training on single class (#50072)

The `ClassificationIT.testTwoJobsWithSameRandomizeSeedUseSameTrainingSet`
test was previously set up to just have 10 rows. With `training_percent`
of 50%, only 5 rows will be used for training. There is a good chance that
all 5 rows will be of one class which results to failure.

This commit increases the rows to 100. Now 50 rows should be used for training
and the chance of failure should be very small.
Dimitris Athanasiou 5 years ago
parent
commit
cdcf132678

+ 4 - 1
x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java

@@ -274,7 +274,10 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
     public void testTwoJobsWithSameRandomizeSeedUseSameTrainingSet() throws Exception {
         String sourceIndex = "classification_two_jobs_with_same_randomize_seed_source";
         String dependentVariable = KEYWORD_FIELD;
-        indexData(sourceIndex, 10, 0, dependentVariable);
+
+        // We use 100 rows as we can't set this too low. If too low it is possible
+        // we only train with rows of one of the two classes which leads to a failure.
+        indexData(sourceIndex, 100, 0, dependentVariable);
 
         String firstJobId = "classification_two_jobs_with_same_randomize_seed_1";
         String firstJobDestIndex = firstJobId + "_dest";

+ 1 - 1
x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java

@@ -259,7 +259,7 @@ abstract class MlNativeDataFrameAnalyticsIntegTestCase extends MlNativeIntegTest
 
     protected static Set<String> getTrainingRowsIds(String index) {
         Set<String> trainingRowsIds = new HashSet<>();
-        SearchResponse hits = client().prepareSearch(index).get();
+        SearchResponse hits = client().prepareSearch(index).setSize(10000).get();
         for (SearchHit hit : hits.getHits()) {
             Map<String, Object> sourceAsMap = hit.getSourceAsMap();
             assertThat(sourceAsMap.containsKey("ml"), is(true));