Browse Source

[ML] Relax assertion to expect at least on trained model (#67710)

... for data frame analytics native tests. This fixes assertion
failures seen in #67581.

This commit also enables debug logging for `RegressionIT` and
unmutes `test_stop_and_restart` in order to collect more information
about why some times progress gets stuck after `loading_data`.
Dimitris Athanasiou 4 years ago
parent
commit
2468639b07

+ 10 - 10
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java

@@ -166,7 +166,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
         assertProgressComplete(jobId);
         assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
         assertModelStatePersisted(stateDocId());
-        assertInferenceModelPersisted(jobId);
+        assertExactlyOneInferenceModelPersisted(jobId);
         assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword");
         assertThatAuditMessagesMatch(jobId,
             "Created analytics with analysis type [classification]",
@@ -222,7 +222,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
         assertProgressComplete(jobId);
         assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
         assertModelStatePersisted(stateDocId());
-        assertInferenceModelPersisted(jobId);
+        assertExactlyOneInferenceModelPersisted(jobId);
         assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword");
         assertThatAuditMessagesMatch(jobId,
             "Created analytics with analysis type [classification]",
@@ -272,7 +272,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
         assertProgressComplete(jobId);
         assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
         assertModelStatePersisted(stateDocId());
-        assertInferenceModelPersisted(jobId);
+        assertExactlyOneInferenceModelPersisted(jobId);
         assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword");
         assertThatAuditMessagesMatch(jobId,
             "Created analytics with analysis type [classification]",
@@ -338,7 +338,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
         assertProgressComplete(jobId);
         assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
         assertModelStatePersisted(stateDocId());
-        assertInferenceModelPersisted(jobId);
+        assertExactlyOneInferenceModelPersisted(jobId);
         assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword");
         assertThatAuditMessagesMatch(jobId,
             "Created analytics with analysis type [classification]",
@@ -428,7 +428,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
         assertProgressComplete(jobId);
         assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
         assertModelStatePersisted(stateDocId());
-        assertInferenceModelPersisted(jobId);
+        assertExactlyOneInferenceModelPersisted(jobId);
         assertMlResultsFieldMappings(destIndex, predictedClassField, expectedMappingTypeForPredictedField);
         assertThatAuditMessagesMatch(jobId,
             "Created analytics with analysis type [classification]",
@@ -527,7 +527,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
         assertProgressComplete(jobId);
         assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
         assertModelStatePersisted(stateDocId());
-        assertInferenceModelPersisted(jobId);
+        assertAtLeastOneInferenceModelPersisted(jobId);
         assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword");
         assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
     }
@@ -588,7 +588,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
         assertProgressComplete(jobId);
         assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
         assertModelStatePersisted(stateDocId());
-        assertInferenceModelPersisted(jobId);
+        assertExactlyOneInferenceModelPersisted(jobId);
         assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword");
         assertEvaluation(NESTED_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
     }
@@ -606,7 +606,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
         assertProgressComplete(jobId);
         assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
         assertModelStatePersisted(stateDocId());
-        assertInferenceModelPersisted(jobId);
+        assertExactlyOneInferenceModelPersisted(jobId);
         assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword");
         assertEvaluation(ALIAS_TO_KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
     }
@@ -624,7 +624,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
         assertProgressComplete(jobId);
         assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
         assertModelStatePersisted(stateDocId());
-        assertInferenceModelPersisted(jobId);
+        assertExactlyOneInferenceModelPersisted(jobId);
         assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword");
         assertEvaluation(ALIAS_TO_NESTED_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
     }
@@ -751,7 +751,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
         assertProgressComplete(jobId);
         assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
         assertModelStatePersisted(stateDocId());
-        assertInferenceModelPersisted(jobId);
+        assertExactlyOneInferenceModelPersisted(jobId);
 
         // Call _delete_expired_data API and check nothing was deleted
         assertThat(deleteExpiredData().isDeleted(), is(true));

+ 1 - 1
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/DataFrameAnalysisCustomFeatureIT.java

@@ -152,7 +152,7 @@ public class DataFrameAnalysisCustomFeatureIT extends MlNativeDataFrameAnalytics
 
         assertProgressComplete(jobId);
         assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
-        assertInferenceModelPersisted(jobId);
+        assertExactlyOneInferenceModelPersisted(jobId);
         assertModelStatePersisted(stateDocId());
     }
 

+ 13 - 2
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java

@@ -255,11 +255,22 @@ abstract class MlNativeDataFrameAnalyticsIntegTestCase extends MlNativeIntegTest
             .get();
     }
 
-    protected void assertInferenceModelPersisted(String jobId) {
+    protected void assertExactlyOneInferenceModelPersisted(String jobId) {
+        assertInferenceModelPersisted(jobId, equalTo(1));
+    }
+
+    protected void assertAtLeastOneInferenceModelPersisted(String jobId) {
+        assertInferenceModelPersisted(jobId, greaterThanOrEqualTo(1));
+    }
+
+    private void assertInferenceModelPersisted(String jobId, Matcher<? super Integer> modelHitsArraySizeMatcher) {
         SearchResponse searchResponse = client().prepareSearch(InferenceIndexConstants.LATEST_INDEX_NAME)
             .setQuery(QueryBuilders.boolQuery().filter(QueryBuilders.termQuery(TrainedModelConfig.TAGS.getPreferredName(), jobId)))
             .get();
-        assertThat("Hits were: " + Strings.toString(searchResponse.getHits()), searchResponse.getHits().getHits(), arrayWithSize(1));
+        // If the job is stopped during writing_results phase and it is then restarted, there is a chance two trained models
+        // were persisted as there is no way currently for the process to be certain the model was persisted.
+        assertThat("Hits were: " + Strings.toString(searchResponse.getHits()), searchResponse.getHits().getHits(),
+            is(arrayWithSize(modelHitsArraySizeMatcher)));
     }
 
     protected Collection<PersistentTasksCustomMetadata.PersistentTask<?>> analyticsTaskList() {

+ 23 - 9
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java

@@ -35,6 +35,7 @@ import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
 import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
 import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
 import org.junit.After;
+import org.junit.Before;
 
 import java.io.IOException;
 import java.time.Instant;
@@ -68,9 +69,23 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
     private String sourceIndex;
     private String destIndex;
 
+    @Before
+    public void setupLogging() {
+        client().admin().cluster()
+            .prepareUpdateSettings()
+            .setTransientSettings(Settings.builder()
+                .put("logger.org.elasticsearch.xpack.ml.dataframe", "DEBUG"))
+            .get();
+    }
+
     @After
     public void cleanup() {
         cleanUp();
+        client().admin().cluster()
+            .prepareUpdateSettings()
+            .setTransientSettings(Settings.builder()
+                .putNull("logger.org.elasticsearch.xpack.ml.dataframe"))
+            .get();
     }
 
     @Override
@@ -159,7 +174,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
         assertProgressComplete(jobId);
         assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
         assertModelStatePersisted(stateDocId());
-        assertInferenceModelPersisted(jobId);
+        assertExactlyOneInferenceModelPersisted(jobId);
         assertMlResultsFieldMappings(destIndex, predictedClassField, "double");
         assertThatAuditMessagesMatch(jobId,
             "Created analytics with analysis type [regression]",
@@ -208,7 +223,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
         assertThat(stats.getDataCounts().getSkippedDocsCount(), equalTo(0L));
 
         assertModelStatePersisted(stateDocId());
-        assertInferenceModelPersisted(jobId);
+        assertExactlyOneInferenceModelPersisted(jobId);
         assertMlResultsFieldMappings(destIndex, predictedClassField, "double");
         assertThatAuditMessagesMatch(jobId,
             "Created analytics with analysis type [regression]",
@@ -273,7 +288,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
         assertProgressComplete(jobId);
         assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
         assertModelStatePersisted(stateDocId());
-        assertInferenceModelPersisted(jobId);
+        assertExactlyOneInferenceModelPersisted(jobId);
         assertMlResultsFieldMappings(destIndex, predictedClassField, "double");
         assertThatAuditMessagesMatch(jobId,
             "Created analytics with analysis type [regression]",
@@ -289,7 +304,6 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
             "Finished analysis");
     }
 
-    @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/67581")
     public void testStopAndRestart() throws Exception {
         initialize("regression_stop_and_restart");
         String predictedClassField = DEPENDENT_VARIABLE_FIELD + "_prediction";
@@ -335,7 +349,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
         assertProgressComplete(jobId);
         assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
         assertModelStatePersisted(stateDocId());
-        assertInferenceModelPersisted(jobId);
+        assertAtLeastOneInferenceModelPersisted(jobId);
         assertMlResultsFieldMappings(destIndex, predictedClassField, "double");
     }
 
@@ -391,7 +405,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
         assertProgressComplete(jobId);
         assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
         assertModelStatePersisted(stateDocId());
-        assertInferenceModelPersisted(jobId);
+        assertExactlyOneInferenceModelPersisted(jobId);
         assertMlResultsFieldMappings(destIndex, predictedClassField, "double");
 
         // Call _delete_expired_data API and check nothing was deleted
@@ -480,7 +494,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
         assertProgressComplete(jobId);
         assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
         assertModelStatePersisted(stateDocId());
-        assertInferenceModelPersisted(jobId);
+        assertExactlyOneInferenceModelPersisted(jobId);
         assertMlResultsFieldMappings(destIndex, predictedClassField, "double");
         assertThatAuditMessagesMatch(jobId,
             "Created analytics with analysis type [regression]",
@@ -587,7 +601,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
         assertProgressComplete(jobId);
         assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
         assertModelStatePersisted(stateDocId());
-        assertInferenceModelPersisted(jobId);
+        assertExactlyOneInferenceModelPersisted(jobId);
         assertMlResultsFieldMappings(destIndex, predictionField, "double");
         assertThatAuditMessagesMatch(jobId,
             "Created analytics with analysis type [regression]",
@@ -644,7 +658,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
         assertProgressComplete(jobId);
         assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
         assertModelStatePersisted(stateDocId());
-        assertInferenceModelPersisted(jobId);
+        assertExactlyOneInferenceModelPersisted(jobId);
         assertMlResultsFieldMappings(destIndex, predictedClassField, "double");
         assertThatAuditMessagesMatch(jobId,
             "Created analytics with analysis type [regression]",