Browse Source

Implement testStopAndRestart for ClassificationIT (#50585)

Przemysław Witek 5 years ago
parent
commit
6e0d8619e6

+ 61 - 0
x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java

@@ -19,12 +19,14 @@ import org.elasticsearch.action.index.IndexAction;
 import org.elasticsearch.action.index.IndexRequest;
 import org.elasticsearch.action.search.SearchResponse;
 import org.elasticsearch.action.support.WriteRequest;
+import org.elasticsearch.common.unit.TimeValue;
 import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.index.query.QueryBuilders;
 import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.search.SearchHit;
 import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
 import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
+import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
 import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
 import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParamsTests;
 import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
@@ -44,6 +46,7 @@ import java.util.Set;
 import static java.util.stream.Collectors.toList;
 import static org.elasticsearch.common.xcontent.support.XContentMapValues.extractValue;
 import static org.hamcrest.Matchers.allOf;
+import static org.hamcrest.Matchers.anyOf;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.greaterThan;
 import static org.hamcrest.Matchers.greaterThanOrEqualTo;
@@ -243,6 +246,64 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
             "classification_training_percent_is_50_boolean", BOOLEAN_FIELD, BOOLEAN_FIELD_VALUES, "boolean");
     }
 
+    public void testStopAndRestart() throws Exception {
+        initialize("classification_stop_and_restart");
+        String predictedClassField = KEYWORD_FIELD + "_prediction";
+        indexData(sourceIndex, 350, 0, KEYWORD_FIELD);
+
+        DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD));
+        registerAnalytics(config);
+        putAnalytics(config);
+
+        assertIsStopped(jobId);
+        assertProgress(jobId, 0, 0, 0, 0);
+
+        startAnalytics(jobId);
+
+        // Wait until state is one of REINDEXING or ANALYZING, or until it is STOPPED.
+        assertBusy(() -> {
+            DataFrameAnalyticsState state = getAnalyticsStats(jobId).getState();
+            assertThat(
+                state,
+                is(anyOf(
+                    equalTo(DataFrameAnalyticsState.REINDEXING),
+                    equalTo(DataFrameAnalyticsState.ANALYZING),
+                    equalTo(DataFrameAnalyticsState.STOPPED))));
+        });
+        stopAnalytics(jobId);
+        waitUntilAnalyticsIsStopped(jobId);
+
+        // Now let's start it again
+        try {
+            startAnalytics(jobId);
+        } catch (Exception e) {
+            if (e.getMessage().equals("Cannot start because the job has already finished")) {
+                // That means the job had managed to complete
+            } else {
+                throw e;
+            }
+        }
+
+        waitUntilAnalyticsIsStopped(jobId, TimeValue.timeValueMinutes(1));
+
+        SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
+        for (SearchHit hit : sourceData.getHits()) {
+            Map<String, Object> destDoc = getDestDoc(config, hit);
+            Map<String, Object> resultsObject = getFieldValue(destDoc, "ml");
+            assertThat(getFieldValue(resultsObject, predictedClassField), is(in(KEYWORD_FIELD_VALUES)));
+            assertThat(getFieldValue(resultsObject, "is_training"), is(true));
+            assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES);
+        }
+
+        assertProgress(jobId, 100, 100, 100, 100);
+        assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
+        assertModelStatePersisted(stateDocId());
+        assertInferenceModelPersisted(jobId);
+        assertMlResultsFieldMappings(predictedClassField, "keyword");
+        assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
+
+    }
+
     public void testDependentVariableCardinalityTooHighError() throws Exception {
         initialize("cardinality_too_high");
         indexData(sourceIndex, 6, 5, KEYWORD_FIELD);