|
@@ -8,16 +8,22 @@
|
|
|
package org.elasticsearch.client;
|
|
|
|
|
|
import org.apache.logging.log4j.Logger;
|
|
|
+import org.apache.logging.log4j.message.ParameterizedMessage;
|
|
|
+import org.elasticsearch.action.ingest.DeletePipelineRequest;
|
|
|
+import org.elasticsearch.client.core.PageParams;
|
|
|
import org.elasticsearch.client.ml.CloseJobRequest;
|
|
|
import org.elasticsearch.client.ml.DeleteDataFrameAnalyticsRequest;
|
|
|
import org.elasticsearch.client.ml.DeleteDatafeedRequest;
|
|
|
import org.elasticsearch.client.ml.DeleteJobRequest;
|
|
|
+import org.elasticsearch.client.ml.DeleteTrainedModelRequest;
|
|
|
import org.elasticsearch.client.ml.GetDataFrameAnalyticsRequest;
|
|
|
import org.elasticsearch.client.ml.GetDataFrameAnalyticsResponse;
|
|
|
import org.elasticsearch.client.ml.GetDatafeedRequest;
|
|
|
import org.elasticsearch.client.ml.GetDatafeedResponse;
|
|
|
import org.elasticsearch.client.ml.GetJobRequest;
|
|
|
import org.elasticsearch.client.ml.GetJobResponse;
|
|
|
+import org.elasticsearch.client.ml.GetTrainedModelsRequest;
|
|
|
+import org.elasticsearch.client.ml.GetTrainedModelsStatsRequest;
|
|
|
import org.elasticsearch.client.ml.StopDataFrameAnalyticsRequest;
|
|
|
import org.elasticsearch.client.ml.StopDatafeedRequest;
|
|
|
import org.elasticsearch.client.ml.datafeed.DatafeedConfig;
|
|
@@ -25,26 +31,77 @@ import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig;
|
|
|
import org.elasticsearch.client.ml.job.config.Job;
|
|
|
|
|
|
import java.io.IOException;
|
|
|
+import java.io.UncheckedIOException;
|
|
|
+import java.util.Collections;
|
|
|
+import java.util.Map;
|
|
|
+import java.util.Set;
|
|
|
+import java.util.stream.Collectors;
|
|
|
+import java.util.stream.Stream;
|
|
|
|
|
|
/**
|
|
|
* Cleans up and ML resources created during tests
|
|
|
*/
|
|
|
public class MlTestStateCleaner {
|
|
|
|
|
|
+ private static final Set<String> NOT_DELETED_TRAINED_MODELS = Collections.singleton("lang_ident_model_1");
|
|
|
private final Logger logger;
|
|
|
private final MachineLearningClient mlClient;
|
|
|
+ private final RestHighLevelClient client;
|
|
|
|
|
|
- public MlTestStateCleaner(Logger logger, MachineLearningClient mlClient) {
|
|
|
+ public MlTestStateCleaner(Logger logger, RestHighLevelClient client) {
|
|
|
this.logger = logger;
|
|
|
- this.mlClient = mlClient;
|
|
|
+ this.mlClient = client.machineLearning();
|
|
|
+ this.client = client;
|
|
|
}
|
|
|
|
|
|
public void clearMlMetadata() throws IOException {
|
|
|
+ deleteAllTrainedModels();
|
|
|
deleteAllDatafeeds();
|
|
|
deleteAllJobs();
|
|
|
deleteAllDataFrameAnalytics();
|
|
|
}
|
|
|
|
|
|
+ @SuppressWarnings("unchecked")
|
|
|
+ private void deleteAllTrainedModels() throws IOException {
|
|
|
+ Set<String> pipelinesWithModels = mlClient.getTrainedModelsStats(
|
|
|
+ new GetTrainedModelsStatsRequest("_all").setPageParams(new PageParams(0, 10_000)), RequestOptions.DEFAULT
|
|
|
+ ).getTrainedModelStats()
|
|
|
+ .stream()
|
|
|
+ .flatMap(stats -> {
|
|
|
+ Map<String, Object> ingestStats = stats.getIngestStats();
|
|
|
+ if (ingestStats == null || ingestStats.isEmpty()) {
|
|
|
+ return Stream.empty();
|
|
|
+ }
|
|
|
+ Map<String, Object> pipelines = (Map<String, Object>)ingestStats.get("pipelines");
|
|
|
+ if (pipelines == null || pipelines.isEmpty()) {
|
|
|
+ return Stream.empty();
|
|
|
+ }
|
|
|
+ return pipelines.keySet().stream();
|
|
|
+ })
|
|
|
+ .collect(Collectors.toSet());
|
|
|
+ for (String pipelineId : pipelinesWithModels) {
|
|
|
+ try {
|
|
|
+ client.ingest().deletePipeline(new DeletePipelineRequest(pipelineId), RequestOptions.DEFAULT);
|
|
|
+ } catch (Exception ex) {
|
|
|
+ logger.warn(() -> new ParameterizedMessage("failed to delete pipeline [{}]", pipelineId), ex);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ mlClient.getTrainedModels(
|
|
|
+ GetTrainedModelsRequest.getAllTrainedModelConfigsRequest().setPageParams(new PageParams(0, 10_000)),
|
|
|
+ RequestOptions.DEFAULT)
|
|
|
+ .getTrainedModels()
|
|
|
+ .stream()
|
|
|
+ .filter(trainedModelConfig -> NOT_DELETED_TRAINED_MODELS.contains(trainedModelConfig.getModelId()) == false)
|
|
|
+ .forEach(config -> {
|
|
|
+ try {
|
|
|
+ mlClient.deleteTrainedModel(new DeleteTrainedModelRequest(config.getModelId()), RequestOptions.DEFAULT);
|
|
|
+ } catch (IOException ex) {
|
|
|
+ throw new UncheckedIOException(ex);
|
|
|
+ }
|
|
|
+ });
|
|
|
+ }
|
|
|
+
|
|
|
private void deleteAllDatafeeds() throws IOException {
|
|
|
stopAllDatafeeds();
|
|
|
|