Browse Source

[ML] fixing feature reset integration tests (#71081)

previously created pipelines referencing ML models were not being appropriately deleted in upstream tests.

This commit ensures that machine learning removes relevant pipelines from cluster state after tests complete

closes #71072
Benjamin Trent 4 years ago
parent
commit
ec9d0624c9

+ 0 - 1
client/rest-high-level/src/test/java/org/elasticsearch/client/FeaturesIT.java

@@ -31,7 +31,6 @@ public class FeaturesIT extends ESRestHighLevelClientTestCase {
         assertTrue(response.getFeatures().stream().anyMatch(feature -> "tasks".equals(feature.getFeatureName())));
     }
 
-    @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/71072")
     public void testResetFeatures() throws IOException {
         ResetFeaturesRequest request = new ResetFeaturesRequest();
 

+ 1 - 1
client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningGetResultsIT.java

@@ -193,7 +193,7 @@ public class MachineLearningGetResultsIT extends ESRestHighLevelClientTestCase {
 
     @After
     public void deleteJob() throws IOException {
-        new MlTestStateCleaner(logger, highLevelClient().machineLearning()).clearMlMetadata();
+        new MlTestStateCleaner(logger, highLevelClient()).clearMlMetadata();
     }
 
     public void testGetModelSnapshots() throws IOException {

+ 1 - 1
client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java

@@ -225,7 +225,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
 
     @After
     public void cleanUp() throws IOException {
-        new MlTestStateCleaner(logger, highLevelClient().machineLearning()).clearMlMetadata();
+        new MlTestStateCleaner(logger, highLevelClient()).clearMlMetadata();
     }
 
     public void testPutJob() throws Exception {

+ 59 - 2
client/rest-high-level/src/test/java/org/elasticsearch/client/MlTestStateCleaner.java

@@ -8,16 +8,22 @@
 package org.elasticsearch.client;
 
 import org.apache.logging.log4j.Logger;
+import org.apache.logging.log4j.message.ParameterizedMessage;
+import org.elasticsearch.action.ingest.DeletePipelineRequest;
+import org.elasticsearch.client.core.PageParams;
 import org.elasticsearch.client.ml.CloseJobRequest;
 import org.elasticsearch.client.ml.DeleteDataFrameAnalyticsRequest;
 import org.elasticsearch.client.ml.DeleteDatafeedRequest;
 import org.elasticsearch.client.ml.DeleteJobRequest;
+import org.elasticsearch.client.ml.DeleteTrainedModelRequest;
 import org.elasticsearch.client.ml.GetDataFrameAnalyticsRequest;
 import org.elasticsearch.client.ml.GetDataFrameAnalyticsResponse;
 import org.elasticsearch.client.ml.GetDatafeedRequest;
 import org.elasticsearch.client.ml.GetDatafeedResponse;
 import org.elasticsearch.client.ml.GetJobRequest;
 import org.elasticsearch.client.ml.GetJobResponse;
+import org.elasticsearch.client.ml.GetTrainedModelsRequest;
+import org.elasticsearch.client.ml.GetTrainedModelsStatsRequest;
 import org.elasticsearch.client.ml.StopDataFrameAnalyticsRequest;
 import org.elasticsearch.client.ml.StopDatafeedRequest;
 import org.elasticsearch.client.ml.datafeed.DatafeedConfig;
@@ -25,26 +31,77 @@ import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig;
 import org.elasticsearch.client.ml.job.config.Job;
 
 import java.io.IOException;
+import java.io.UncheckedIOException;
+import java.util.Collections;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
 
 /**
  * Cleans up and ML resources created during tests
  */
 public class MlTestStateCleaner {
 
+    private static final Set<String> NOT_DELETED_TRAINED_MODELS = Collections.singleton("lang_ident_model_1");
     private final Logger logger;
     private final MachineLearningClient mlClient;
+    private final RestHighLevelClient client;
 
-    public MlTestStateCleaner(Logger logger, MachineLearningClient mlClient) {
+    public MlTestStateCleaner(Logger logger, RestHighLevelClient client) {
         this.logger = logger;
-        this.mlClient = mlClient;
+        this.mlClient = client.machineLearning();
+        this.client = client;
     }
 
     public void clearMlMetadata() throws IOException {
+        deleteAllTrainedModels();
         deleteAllDatafeeds();
         deleteAllJobs();
         deleteAllDataFrameAnalytics();
     }
 
+    @SuppressWarnings("unchecked")
+    private void deleteAllTrainedModels() throws IOException {
+        Set<String> pipelinesWithModels = mlClient.getTrainedModelsStats(
+            new GetTrainedModelsStatsRequest("_all").setPageParams(new PageParams(0, 10_000)), RequestOptions.DEFAULT
+        ).getTrainedModelStats()
+            .stream()
+            .flatMap(stats -> {
+                Map<String, Object> ingestStats = stats.getIngestStats();
+                if (ingestStats == null || ingestStats.isEmpty()) {
+                    return Stream.empty();
+                }
+                Map<String, Object> pipelines = (Map<String, Object>)ingestStats.get("pipelines");
+                if (pipelines == null || pipelines.isEmpty()) {
+                    return Stream.empty();
+                }
+                return pipelines.keySet().stream();
+            })
+            .collect(Collectors.toSet());
+        for (String pipelineId : pipelinesWithModels) {
+            try {
+                client.ingest().deletePipeline(new DeletePipelineRequest(pipelineId), RequestOptions.DEFAULT);
+            } catch (Exception ex) {
+                logger.warn(() -> new ParameterizedMessage("failed to delete pipeline [{}]", pipelineId), ex);
+            }
+        }
+
+        mlClient.getTrainedModels(
+            GetTrainedModelsRequest.getAllTrainedModelConfigsRequest().setPageParams(new PageParams(0, 10_000)),
+            RequestOptions.DEFAULT)
+            .getTrainedModels()
+            .stream()
+            .filter(trainedModelConfig -> NOT_DELETED_TRAINED_MODELS.contains(trainedModelConfig.getModelId()) == false)
+            .forEach(config -> {
+                try {
+                    mlClient.deleteTrainedModel(new DeleteTrainedModelRequest(config.getModelId()), RequestOptions.DEFAULT);
+                } catch (IOException ex) {
+                    throw new UncheckedIOException(ex);
+                }
+            });
+    }
+
     private void deleteAllDatafeeds() throws IOException {
         stopAllDatafeeds();
 

+ 1 - 1
client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java

@@ -242,7 +242,7 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
 
     @After
     public void cleanUp() throws IOException {
-        new MlTestStateCleaner(logger, highLevelClient().machineLearning()).clearMlMetadata();
+        new MlTestStateCleaner(logger, highLevelClient()).clearMlMetadata();
     }
 
     public void testCreateJob() throws Exception {

+ 1 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlMetadata.java

@@ -310,6 +310,7 @@ public class MlMetadata implements Metadata.Custom {
                 jobs = new TreeMap<>(previous.jobs);
                 datafeeds = new TreeMap<>(previous.datafeeds);
                 upgradeMode = previous.upgradeMode;
+                resetMode = previous.resetMode;
             }
         }
 

+ 20 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/integration/MlRestTestStateCleaner.java

@@ -7,6 +7,7 @@
 package org.elasticsearch.xpack.core.ml.integration;
 
 import org.apache.logging.log4j.Logger;
+import org.apache.logging.log4j.message.ParameterizedMessage;
 import org.elasticsearch.client.Request;
 import org.elasticsearch.client.Response;
 import org.elasticsearch.client.RestClient;
@@ -18,6 +19,8 @@ import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.stream.Collectors;
+
 
 public class MlRestTestStateCleaner {
 
@@ -40,6 +43,23 @@ public class MlRestTestStateCleaner {
 
     @SuppressWarnings("unchecked")
     private void deleteAllTrainedModels() throws IOException {
+        final Request getAllTrainedModelStats = new Request("GET", "/_ml/trained_models/_stats");
+        getAllTrainedModelStats.addParameter("size", "10000");
+        final Response trainedModelsStatsResponse = adminClient.performRequest(getAllTrainedModelStats);
+
+        final List<Map<String, Object>> pipelines = (List<Map<String, Object>>) XContentMapValues.extractValue(
+            "trained_model_stats.ingest.pipelines",
+            ESRestTestCase.entityAsMap(trainedModelsStatsResponse)
+        );
+        Set<String> pipelineIds = pipelines.stream().flatMap(m -> m.keySet().stream()).collect(Collectors.toSet());
+        for (String pipelineId : pipelineIds) {
+            try {
+                adminClient.performRequest(new Request("DELETE", "/_ingest/pipeline/" + pipelineId));
+            } catch (Exception ex) {
+                logger.warn(() -> new ParameterizedMessage("failed to delete pipeline [{}]", pipelineId), ex);
+            }
+        }
+
         final Request getTrainedModels = new Request("GET", "/_ml/trained_models");
         getTrainedModels.addParameter("size", "10000");
         final Response trainedModelsResponse = adminClient.performRequest(getTrainedModels);

+ 17 - 1
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/TestFeatureResetIT.java

@@ -6,6 +6,7 @@
  */
 package org.elasticsearch.xpack.ml.integration;
 
+import org.apache.logging.log4j.message.ParameterizedMessage;
 import org.elasticsearch.action.admin.cluster.snapshots.features.ResetFeatureStateAction;
 import org.elasticsearch.action.admin.cluster.snapshots.features.ResetFeatureStateRequest;
 import org.elasticsearch.action.ingest.DeletePipelineAction;
@@ -28,6 +29,8 @@ import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.DataCounts;
 import org.junit.After;
 
 import java.util.Collections;
+import java.util.HashSet;
+import java.util.Set;
 import java.util.concurrent.TimeUnit;
 
 import static org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor.Factory.countNumberInferenceProcessors;
@@ -44,20 +47,31 @@ import static org.hamcrest.Matchers.is;
 
 public class TestFeatureResetIT extends MlNativeAutodetectIntegTestCase {
 
+    private final Set<String> createdPipelines = new HashSet<>();
+
     @After
     public void cleanup() throws Exception {
         cleanUp();
+        for (String pipeline : createdPipelines) {
+            try {
+                client().execute(DeletePipelineAction.INSTANCE, new DeletePipelineRequest(pipeline)).actionGet();
+            } catch (Exception ex) {
+                logger.warn(() -> new ParameterizedMessage("error cleaning up pipeline [{}]", pipeline), ex);
+            }
+        }
     }
 
-    @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/71072")
     public void testMLFeatureReset() throws Exception {
         startRealtime("feature_reset_anomaly_job");
         startDataFrameJob("feature_reset_data_frame_analytics_job");
         putTrainedModelIngestPipeline("feature_reset_inference_pipeline");
+        createdPipelines.add("feature_reset_inference_pipeline");
         for(int i = 0; i < 100; i ++) {
             indexDocForInference("feature_reset_inference_pipeline");
         }
         client().execute(DeletePipelineAction.INSTANCE, new DeletePipelineRequest("feature_reset_inference_pipeline")).actionGet();
+        createdPipelines.remove("feature_reset_inference_pipeline");
+
         assertBusy(() ->
             assertThat(countNumberInferenceProcessors(client().admin().cluster().prepareState().get().getState()), equalTo(0))
         );
@@ -71,6 +85,7 @@ public class TestFeatureResetIT extends MlNativeAutodetectIntegTestCase {
 
     public void testMLFeatureResetFailureDueToPipelines() throws Exception {
         putTrainedModelIngestPipeline("feature_reset_failure_inference_pipeline");
+        createdPipelines.add("feature_reset_failure_inference_pipeline");
         Exception ex = expectThrows(Exception.class, () -> client().execute(
             ResetFeatureStateAction.INSTANCE,
             new ResetFeatureStateRequest()
@@ -82,6 +97,7 @@ public class TestFeatureResetIT extends MlNativeAutodetectIntegTestCase {
             )
         );
         client().execute(DeletePipelineAction.INSTANCE, new DeletePipelineRequest("feature_reset_failure_inference_pipeline")).actionGet();
+        createdPipelines.remove("feature_reset_failure_inference_pipeline");
         assertThat(isResetMode(), is(false));
     }
 

+ 8 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlMetadataTests.java

@@ -81,6 +81,14 @@ public class MlMetadataTests extends AbstractSerializingTestCase<MlMetadata> {
         return new NamedXContentRegistry(searchModule.getNamedXContents());
     }
 
+    public void testBuilderClone() {
+        for (int i = 0; i < NUMBER_OF_TEST_RUNS; i++) {
+            MlMetadata first = createTestInstance();
+            MlMetadata cloned = MlMetadata.Builder.from(first).build();
+            assertThat(cloned, equalTo(first));
+        }
+    }
+
     public void testPutJob() {
         Job job1 = buildJobBuilder("1").build();
         Job job2 = buildJobBuilder("2").build();