|
@@ -19,12 +19,14 @@ import org.elasticsearch.action.index.IndexAction;
|
|
|
import org.elasticsearch.action.index.IndexRequest;
|
|
|
import org.elasticsearch.action.search.SearchResponse;
|
|
|
import org.elasticsearch.action.support.WriteRequest;
|
|
|
+import org.elasticsearch.common.unit.TimeValue;
|
|
|
import org.elasticsearch.index.query.QueryBuilder;
|
|
|
import org.elasticsearch.index.query.QueryBuilders;
|
|
|
import org.elasticsearch.rest.RestStatus;
|
|
|
import org.elasticsearch.search.SearchHit;
|
|
|
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
|
|
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
|
|
+import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
|
|
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
|
|
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParamsTests;
|
|
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
|
|
@@ -44,6 +46,7 @@ import java.util.Set;
|
|
|
import static java.util.stream.Collectors.toList;
|
|
|
import static org.elasticsearch.common.xcontent.support.XContentMapValues.extractValue;
|
|
|
import static org.hamcrest.Matchers.allOf;
|
|
|
+import static org.hamcrest.Matchers.anyOf;
|
|
|
import static org.hamcrest.Matchers.equalTo;
|
|
|
import static org.hamcrest.Matchers.greaterThan;
|
|
|
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
|
|
@@ -243,6 +246,64 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|
|
"classification_training_percent_is_50_boolean", BOOLEAN_FIELD, BOOLEAN_FIELD_VALUES, "boolean");
|
|
|
}
|
|
|
|
|
|
+ public void testStopAndRestart() throws Exception {
|
|
|
+ initialize("classification_stop_and_restart");
|
|
|
+ String predictedClassField = KEYWORD_FIELD + "_prediction";
|
|
|
+ indexData(sourceIndex, 350, 0, KEYWORD_FIELD);
|
|
|
+
|
|
|
+ DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD));
|
|
|
+ registerAnalytics(config);
|
|
|
+ putAnalytics(config);
|
|
|
+
|
|
|
+ assertIsStopped(jobId);
|
|
|
+ assertProgress(jobId, 0, 0, 0, 0);
|
|
|
+
|
|
|
+ startAnalytics(jobId);
|
|
|
+
|
|
|
+ // Wait until state is one of REINDEXING or ANALYZING, or until it is STOPPED.
|
|
|
+ assertBusy(() -> {
|
|
|
+ DataFrameAnalyticsState state = getAnalyticsStats(jobId).getState();
|
|
|
+ assertThat(
|
|
|
+ state,
|
|
|
+ is(anyOf(
|
|
|
+ equalTo(DataFrameAnalyticsState.REINDEXING),
|
|
|
+ equalTo(DataFrameAnalyticsState.ANALYZING),
|
|
|
+ equalTo(DataFrameAnalyticsState.STOPPED))));
|
|
|
+ });
|
|
|
+ stopAnalytics(jobId);
|
|
|
+ waitUntilAnalyticsIsStopped(jobId);
|
|
|
+
|
|
|
+ // Now let's start it again
|
|
|
+ try {
|
|
|
+ startAnalytics(jobId);
|
|
|
+ } catch (Exception e) {
|
|
|
+ if (e.getMessage().equals("Cannot start because the job has already finished")) {
|
|
|
+ // That means the job had managed to complete
|
|
|
+ } else {
|
|
|
+ throw e;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ waitUntilAnalyticsIsStopped(jobId, TimeValue.timeValueMinutes(1));
|
|
|
+
|
|
|
+ SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
|
|
|
+ for (SearchHit hit : sourceData.getHits()) {
|
|
|
+ Map<String, Object> destDoc = getDestDoc(config, hit);
|
|
|
+ Map<String, Object> resultsObject = getFieldValue(destDoc, "ml");
|
|
|
+ assertThat(getFieldValue(resultsObject, predictedClassField), is(in(KEYWORD_FIELD_VALUES)));
|
|
|
+ assertThat(getFieldValue(resultsObject, "is_training"), is(true));
|
|
|
+ assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES);
|
|
|
+ }
|
|
|
+
|
|
|
+ assertProgress(jobId, 100, 100, 100, 100);
|
|
|
+ assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
|
|
+ assertModelStatePersisted(stateDocId());
|
|
|
+ assertInferenceModelPersisted(jobId);
|
|
|
+ assertMlResultsFieldMappings(predictedClassField, "keyword");
|
|
|
+ assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
|
|
|
+
|
|
|
+ }
|
|
|
+
|
|
|
public void testDependentVariableCardinalityTooHighError() throws Exception {
|
|
|
initialize("cardinality_too_high");
|
|
|
indexData(sourceIndex, 6, 5, KEYWORD_FIELD);
|