|
@@ -13,35 +13,49 @@ import org.elasticsearch.action.ingest.DeletePipelineAction;
|
|
|
import org.elasticsearch.action.ingest.DeletePipelineRequest;
|
|
|
import org.elasticsearch.action.ingest.PutPipelineAction;
|
|
|
import org.elasticsearch.action.ingest.PutPipelineRequest;
|
|
|
+import org.elasticsearch.action.support.WriteRequest;
|
|
|
import org.elasticsearch.cluster.ClusterState;
|
|
|
import org.elasticsearch.common.bytes.BytesArray;
|
|
|
import org.elasticsearch.common.xcontent.XContentType;
|
|
|
+import org.elasticsearch.tasks.TaskInfo;
|
|
|
import org.elasticsearch.xpack.core.ml.MlMetadata;
|
|
|
import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction;
|
|
|
+import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
|
|
|
import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction;
|
|
|
+import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
|
|
|
import org.elasticsearch.xpack.core.ml.datafeed.DatafeedConfig;
|
|
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
|
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
|
|
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
|
|
|
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
|
|
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
|
|
|
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
|
|
|
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
|
|
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.IndexLocation;
|
|
|
import org.elasticsearch.xpack.core.ml.job.config.Job;
|
|
|
import org.elasticsearch.xpack.core.ml.job.config.JobState;
|
|
|
import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.DataCounts;
|
|
|
import org.junit.After;
|
|
|
|
|
|
+import java.util.Arrays;
|
|
|
import java.util.Collections;
|
|
|
import java.util.HashSet;
|
|
|
+import java.util.List;
|
|
|
import java.util.Set;
|
|
|
import java.util.concurrent.TimeUnit;
|
|
|
+import java.util.stream.Collectors;
|
|
|
|
|
|
import static org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor.Factory.countNumberInferenceProcessors;
|
|
|
import static org.elasticsearch.xpack.ml.integration.ClassificationIT.KEYWORD_FIELD;
|
|
|
import static org.elasticsearch.xpack.ml.integration.MlNativeDataFrameAnalyticsIntegTestCase.buildAnalytics;
|
|
|
+import static org.elasticsearch.xpack.ml.integration.PyTorchModelIT.BASE_64_ENCODED_MODEL;
|
|
|
+import static org.elasticsearch.xpack.ml.integration.PyTorchModelIT.RAW_MODEL_SIZE;
|
|
|
import static org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase.createDatafeed;
|
|
|
import static org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase.createScheduledJob;
|
|
|
import static org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase.getDataCounts;
|
|
|
import static org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase.indexDocs;
|
|
|
import static org.hamcrest.Matchers.containsString;
|
|
|
-import static org.hamcrest.Matchers.emptyArray;
|
|
|
+import static org.hamcrest.Matchers.empty;
|
|
|
import static org.hamcrest.Matchers.equalTo;
|
|
|
import static org.hamcrest.Matchers.greaterThan;
|
|
|
import static org.hamcrest.Matchers.is;
|
|
@@ -51,6 +65,7 @@ public class TestFeatureResetIT extends MlNativeAutodetectIntegTestCase {
|
|
|
private final Set<String> createdPipelines = new HashSet<>();
|
|
|
private final Set<String> jobIds = new HashSet<>();
|
|
|
private final Set<String> datafeedIds = new HashSet<>();
|
|
|
+ private static final String TRAINED_MODEL_ID = "trained-model-to-reset";
|
|
|
|
|
|
void cleanupDatafeed(String datafeedId) {
|
|
|
try {
|
|
@@ -122,7 +137,10 @@ public class TestFeatureResetIT extends MlNativeAutodetectIntegTestCase {
|
|
|
ResetFeatureStateAction.INSTANCE,
|
|
|
new ResetFeatureStateRequest()
|
|
|
).actionGet();
|
|
|
- assertBusy(() -> assertThat(client().admin().indices().prepareGetIndex().addIndices(".ml*").get().indices(), emptyArray()));
|
|
|
+ assertBusy(() -> {
|
|
|
+ List<String> indices = Arrays.asList(client().admin().indices().prepareGetIndex().addIndices(".ml*").get().indices());
|
|
|
+ assertThat(indices.toString(), indices, is(empty()));
|
|
|
+ });
|
|
|
assertThat(isResetMode(), is(false));
|
|
|
// If we have succeeded, clear the jobs and datafeeds so that the delete API doesn't recreate the notifications index
|
|
|
jobIds.clear();
|
|
@@ -147,6 +165,94 @@ public class TestFeatureResetIT extends MlNativeAutodetectIntegTestCase {
|
|
|
assertThat(isResetMode(), is(false));
|
|
|
}
|
|
|
|
|
|
+ public void testMLFeatureResetWithModelDeployment() throws Exception {
|
|
|
+ createModelDeployment();
|
|
|
+ client().execute(
|
|
|
+ ResetFeatureStateAction.INSTANCE,
|
|
|
+ new ResetFeatureStateRequest()
|
|
|
+ ).actionGet();
|
|
|
+ assertBusy(() -> {
|
|
|
+ List<String> indices = Arrays.asList(client().admin().indices().prepareGetIndex().addIndices(".ml*").get().indices());
|
|
|
+ assertThat(indices.toString(), indices, is(empty()));
|
|
|
+ });
|
|
|
+ assertThat(isResetMode(), is(false));
|
|
|
+ List<String> tasksNames = client().admin()
|
|
|
+ .cluster()
|
|
|
+ .prepareListTasks()
|
|
|
+ .setActions("xpack/ml/*")
|
|
|
+ .get()
|
|
|
+ .getTasks()
|
|
|
+ .stream()
|
|
|
+ .map(TaskInfo::getAction)
|
|
|
+ .collect(Collectors.toList());
|
|
|
+ assertThat(tasksNames, is(empty()));
|
|
|
+ }
|
|
|
+
|
|
|
+ void createModelDeployment() {
|
|
|
+ String indexname = "model_store";
|
|
|
+ client().admin().indices().prepareCreate(indexname).setMapping(
|
|
|
+ " {\"properties\": {\n" +
|
|
|
+ " \"doc_type\": { \"type\": \"keyword\" },\n" +
|
|
|
+ " \"model_id\": { \"type\": \"keyword\" },\n" +
|
|
|
+ " \"definition_length\": { \"type\": \"long\" },\n" +
|
|
|
+ " \"total_definition_length\": { \"type\": \"long\" },\n" +
|
|
|
+ " \"compression_version\": { \"type\": \"long\" },\n" +
|
|
|
+ " \"definition\": { \"type\": \"binary\" },\n" +
|
|
|
+ " \"eos\": { \"type\": \"boolean\" },\n" +
|
|
|
+ " \"task_type\": { \"type\": \"keyword\" },\n" +
|
|
|
+ " \"vocab\": { \"type\": \"keyword\" },\n" +
|
|
|
+ " \"with_special_tokens\": { \"type\": \"boolean\" },\n" +
|
|
|
+ " \"do_lower_case\": { \"type\": \"boolean\" }\n" +
|
|
|
+ " }\n" +
|
|
|
+ " }}"
|
|
|
+ ).get();
|
|
|
+ client().prepareIndex(indexname)
|
|
|
+ .setId(TRAINED_MODEL_ID + "_task_config")
|
|
|
+ .setSource(
|
|
|
+ "{ " +
|
|
|
+ "\"task_type\": \"bert_pass_through\",\n" +
|
|
|
+ "\"with_special_tokens\": false," +
|
|
|
+ "\"vocab\": [\"these\", \"are\", \"my\", \"words\"]\n" +
|
|
|
+ "}",
|
|
|
+ XContentType.JSON
|
|
|
+ ).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
|
|
+ .get();
|
|
|
+ client().prepareIndex(indexname)
|
|
|
+ .setId("trained_model_definition_doc-" + TRAINED_MODEL_ID + "-0")
|
|
|
+ .setSource(
|
|
|
+ "{ " +
|
|
|
+ "\"doc_type\": \"trained_model_definition_doc\"," +
|
|
|
+ "\"model_id\": \"" + TRAINED_MODEL_ID +"\"," +
|
|
|
+ "\"doc_num\": 0," +
|
|
|
+ "\"definition_length\":" + RAW_MODEL_SIZE + "," +
|
|
|
+ "\"total_definition_length\":" + RAW_MODEL_SIZE + "," +
|
|
|
+ "\"compression_version\": 1," +
|
|
|
+ "\"definition\": \"" + BASE_64_ENCODED_MODEL + "\"," +
|
|
|
+ "\"eos\": true" +
|
|
|
+ "}",
|
|
|
+ XContentType.JSON
|
|
|
+ ).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
|
|
+ .get();
|
|
|
+ client()
|
|
|
+ .execute(
|
|
|
+ PutTrainedModelAction.INSTANCE,
|
|
|
+ new PutTrainedModelAction.Request(
|
|
|
+ TrainedModelConfig.builder()
|
|
|
+ .setModelType(TrainedModelType.PYTORCH)
|
|
|
+ .setInferenceConfig(new ClassificationConfig(1))
|
|
|
+ .setInput(new TrainedModelInput(Arrays.asList("text_field")))
|
|
|
+ .setLocation(new IndexLocation(TRAINED_MODEL_ID, indexname))
|
|
|
+ .setModelId(TRAINED_MODEL_ID)
|
|
|
+ .build()
|
|
|
+ )
|
|
|
+ )
|
|
|
+ .actionGet();
|
|
|
+ client().execute(
|
|
|
+ StartTrainedModelDeploymentAction.INSTANCE,
|
|
|
+ new StartTrainedModelDeploymentAction.Request(TRAINED_MODEL_ID)
|
|
|
+ ).actionGet();
|
|
|
+ }
|
|
|
+
|
|
|
private boolean isResetMode() {
|
|
|
ClusterState state = client().admin().cluster().prepareState().get().getState();
|
|
|
return MlMetadata.getMlMetadata(state).isResetMode();
|