Przeglądaj źródła

Extract indexData method out of RegressionIT tests (#49306)

Przemysław Witek 6 lat temu
rodzic
commit
cdd4b6784f

+ 29 - 77
x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java

@@ -48,34 +48,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
 
     public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception {
         initialize("regression_single_numeric_feature_and_mixed_data_set");
-
-        {  // Index 350 rows, 300 of them being training rows.
-            client().admin().indices().prepareCreate(sourceIndex)
-                .addMapping("_doc", NUMERICAL_FEATURE_FIELD, "type=double", DEPENDENT_VARIABLE_FIELD, "type=double")
-                .get();
-
-            BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
-                .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
-            for (int i = 0; i < 300; i++) {
-                Double field = NUMERICAL_FEATURE_VALUES.get(i % 3);
-                Double value = DEPENDENT_VARIABLE_VALUES.get(i % 3);
-
-                IndexRequest indexRequest = new IndexRequest(sourceIndex)
-                    .source(NUMERICAL_FEATURE_FIELD, field, DEPENDENT_VARIABLE_FIELD, value);
-                bulkRequestBuilder.add(indexRequest);
-            }
-            for (int i = 300; i < 350; i++) {
-                Double field = NUMERICAL_FEATURE_VALUES.get(i % 3);
-
-                IndexRequest indexRequest = new IndexRequest(sourceIndex)
-                    .source(NUMERICAL_FEATURE_FIELD, field);
-                bulkRequestBuilder.add(indexRequest);
-            }
-            BulkResponse bulkResponse = bulkRequestBuilder.get();
-            if (bulkResponse.hasFailures()) {
-                fail("Failed to index data: " + bulkResponse.buildFailureMessage());
-            }
-        }
+        indexData(sourceIndex, 300, 50);
 
         DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Regression(DEPENDENT_VARIABLE_FIELD));
         registerAnalytics(config);
@@ -120,23 +93,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
 
     public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Exception {
         initialize("regression_only_training_data_and_training_percent_is_100");
-
-        {  // Index 350 rows, all of them being training rows.
-            BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
-                .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
-            for (int i = 0; i < 350; i++) {
-                Double field = NUMERICAL_FEATURE_VALUES.get(i % 3);
-                Double value = DEPENDENT_VARIABLE_VALUES.get(i % 3);
-
-                IndexRequest indexRequest = new IndexRequest(sourceIndex)
-                    .source(NUMERICAL_FEATURE_FIELD, field, DEPENDENT_VARIABLE_FIELD, value);
-                bulkRequestBuilder.add(indexRequest);
-            }
-            BulkResponse bulkResponse = bulkRequestBuilder.get();
-            if (bulkResponse.hasFailures()) {
-                fail("Failed to index data: " + bulkResponse.buildFailureMessage());
-            }
-        }
+        indexData(sourceIndex, 350, 0);
 
         DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Regression(DEPENDENT_VARIABLE_FIELD));
         registerAnalytics(config);
@@ -173,23 +130,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
 
     public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception {
         initialize("regression_only_training_data_and_training_percent_is_50");
-
-        {  // Index 350 rows, all of them being training rows.
-            BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
-                .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
-            for (int i = 0; i < 350; i++) {
-                Double field = NUMERICAL_FEATURE_VALUES.get(i % 3);
-                Double value = DEPENDENT_VARIABLE_VALUES.get(i % 3);
-
-                IndexRequest indexRequest = new IndexRequest(sourceIndex)
-                    .source(NUMERICAL_FEATURE_FIELD, field, DEPENDENT_VARIABLE_FIELD, value);
-                bulkRequestBuilder.add(indexRequest);
-            }
-            BulkResponse bulkResponse = bulkRequestBuilder.get();
-            if (bulkResponse.hasFailures()) {
-                fail("Failed to index data: " + bulkResponse.buildFailureMessage());
-            }
-        }
+        indexData(sourceIndex, 350, 0);
 
         DataFrameAnalyticsConfig config =
             buildAnalytics(
@@ -242,21 +183,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
     @AwaitsFix(bugUrl="https://github.com/elastic/elasticsearch/issues/49095")
     public void testStopAndRestart() throws Exception {
         initialize("regression_stop_and_restart");
-
-        BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
-            .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
-        for (int i = 0; i < 350; i++) {
-            Double field = NUMERICAL_FEATURE_VALUES.get(i % 3);
-            Double value = DEPENDENT_VARIABLE_VALUES.get(i % 3);
-
-            IndexRequest indexRequest = new IndexRequest(sourceIndex)
-                .source("feature", field, "variable", value);
-            bulkRequestBuilder.add(indexRequest);
-        }
-        BulkResponse bulkResponse = bulkRequestBuilder.get();
-        if (bulkResponse.hasFailures()) {
-            fail("Failed to index data: " + bulkResponse.buildFailureMessage());
-        }
+        indexData(sourceIndex, 350, 0);
 
         DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Regression(DEPENDENT_VARIABLE_FIELD));
         registerAnalytics(config);
@@ -310,6 +237,31 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
         this.destIndex = sourceIndex + "_results";
     }
 
+    private static void indexData(String sourceIndex, int numTrainingRows, int numNonTrainingRows) {
+        client().admin().indices().prepareCreate(sourceIndex)
+            .addMapping("_doc", NUMERICAL_FEATURE_FIELD, "type=double", DEPENDENT_VARIABLE_FIELD, "type=double")
+            .get();
+
+        BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
+            .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
+        for (int i = 0; i < numTrainingRows; i++) {
+            List<Object> source = List.of(
+                NUMERICAL_FEATURE_FIELD, NUMERICAL_FEATURE_VALUES.get(i % NUMERICAL_FEATURE_VALUES.size()),
+                DEPENDENT_VARIABLE_FIELD, DEPENDENT_VARIABLE_VALUES.get(i % DEPENDENT_VARIABLE_VALUES.size()));
+            IndexRequest indexRequest = new IndexRequest(sourceIndex).source(source.toArray());
+            bulkRequestBuilder.add(indexRequest);
+        }
+        for (int i = numTrainingRows; i < numTrainingRows + numNonTrainingRows; i++) {
+            List<Object> source = List.of(NUMERICAL_FEATURE_FIELD, NUMERICAL_FEATURE_VALUES.get(i % NUMERICAL_FEATURE_VALUES.size()));
+            IndexRequest indexRequest = new IndexRequest(sourceIndex).source(source.toArray());
+            bulkRequestBuilder.add(indexRequest);
+        }
+        BulkResponse bulkResponse = bulkRequestBuilder.get();
+        if (bulkResponse.hasFailures()) {
+            fail("Failed to index data: " + bulkResponse.buildFailureMessage());
+        }
+    }
+
     private static Map<String, Object> getDestDoc(DataFrameAnalyticsConfig config, SearchHit hit) {
         GetResponse destDocGetResponse = client().prepareGet().setIndex(config.getDest().getIndex()).setId(hit.getId()).get();
         assertThat(destDocGetResponse.isExists(), is(true));