Bläddra i källkod

[ML] integrating feature reset for trained model deployments (#76126)

this integrates removing all model deployments in the ML feature reset action.
Benjamin Trent 4 år sedan
förälder
incheckning
f1d8593b7a
12 ändrade filer med 313 tillägg och 20 borttagningar
  1. 0 6
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java
  2. 3 2
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java
  3. 13 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/annotations/AnnotationIndex.java
  4. 15 0
      x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlNativeIntegTestCase.java
  5. 2 2
      x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java
  6. 108 2
      x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/TestFeatureResetIT.java
  7. 22 5
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java
  8. 44 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationClusterService.java
  9. 5 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeService.java
  10. 1 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java
  11. 52 0
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationClusterServiceTests.java
  12. 48 0
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeServiceTests.java

+ 0 - 6
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java

@@ -126,12 +126,6 @@ public final class MlTasks {
         return tasks == null ? null : tasks.getTask(snapshotUpgradeTaskId(jobId, snapshotId));
     }
 
-    @Nullable
-    public static PersistentTasksCustomMetadata.PersistentTask<?> getTrainedModelDeploymentTask(
-            String modelId, @Nullable PersistentTasksCustomMetadata tasks) {
-        return tasks == null ? null : tasks.getTask(trainedModelDeploymentTaskId(modelId));
-    }
-
     /**
      * Note that the return value of this method does NOT take node relocations into account.
      * Use {@link #getJobStateModifiedForReassignments} to return a value adjusted to the most

+ 3 - 2
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java

@@ -25,7 +25,6 @@ import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.common.xcontent.ToXContentObject;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.tasks.Task;
-import org.elasticsearch.xpack.core.ml.MlTasks;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.IndexLocation;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
@@ -35,6 +34,8 @@ import java.io.IOException;
 import java.util.Objects;
 import java.util.concurrent.TimeUnit;
 
+import static org.elasticsearch.xpack.core.ml.MlTasks.trainedModelDeploymentTaskId;
+
 public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedModelAllocationAction.Response> {
 
     public static final StartTrainedModelDeploymentAction INSTANCE = new StartTrainedModelDeploymentAction();
@@ -237,7 +238,7 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
                 if (Strings.isAllOrWildcard(expectedId)) {
                     return true;
                 }
-                String expectedDescription = MlTasks.TRAINED_MODEL_DEPLOYMENT_TASK_ID_PREFIX + expectedId;
+                String expectedDescription = trainedModelDeploymentTaskId(expectedId);
                 return expectedDescription.equals(task.getDescription());
             }
             return false;

+ 13 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/annotations/AnnotationIndex.java

@@ -6,6 +6,9 @@
  */
 package org.elasticsearch.xpack.core.ml.annotations;
 
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.apache.logging.log4j.message.ParameterizedMessage;
 import org.elasticsearch.ResourceAlreadyExistsException;
 import org.elasticsearch.Version;
 import org.elasticsearch.action.ActionListener;
@@ -34,6 +37,8 @@ import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
 
 public class AnnotationIndex {
 
+    private static final Logger logger = LogManager.getLogger(AnnotationIndex.class);
+
     public static final String READ_ALIAS_NAME = ".ml-annotations-read";
     public static final String WRITE_ALIAS_NAME = ".ml-annotations-write";
     // Exposed for testing, but always use the aliases in non-test code
@@ -100,6 +105,14 @@ public class AnnotationIndex {
 
             // Create the annotations index if it doesn't exist already.
             if (mlLookup.containsKey(INDEX_NAME) == false) {
+                logger.debug(
+                    () -> new ParameterizedMessage(
+                        "Creating [{}] because [{}] exists; trace {}",
+                        INDEX_NAME,
+                        mlLookup.firstKey(),
+                        org.elasticsearch.ExceptionsHelper.formatStackTrace(Thread.currentThread().getStackTrace())
+                    )
+                );
 
                 CreateIndexRequest createIndexRequest =
                     new CreateIndexRequest(INDEX_NAME)

+ 15 - 0
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlNativeIntegTestCase.java

@@ -89,6 +89,7 @@ import org.elasticsearch.xpack.ilm.IndexLifecycle;
 import org.elasticsearch.xpack.ml.LocalStateMachineLearning;
 import org.elasticsearch.xpack.ml.autoscaling.MlScalingReason;
 import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata;
+import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationMetadata;
 import org.elasticsearch.xpack.transform.Transform;
 
 import java.io.IOException;
@@ -280,6 +281,20 @@ abstract class MlNativeIntegTestCase extends ESIntegTestCase {
         if (cluster() != null && cluster().size() > 0) {
             List<NamedWriteableRegistry.Entry> entries = new ArrayList<>(ClusterModule.getNamedWriteables());
             entries.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedWriteables());
+            entries.add(
+                new NamedWriteableRegistry.Entry(
+                    Metadata.Custom.class,
+                    TrainedModelAllocationMetadata.NAME,
+                    TrainedModelAllocationMetadata::new
+                )
+            );
+            entries.add(
+                new NamedWriteableRegistry.Entry(
+                    NamedDiff.class,
+                    TrainedModelAllocationMetadata.NAME,
+                    TrainedModelAllocationMetadata::readDiffFrom
+                )
+            );
             entries.add(new NamedWriteableRegistry.Entry(Metadata.Custom.class, ModelAliasMetadata.NAME, ModelAliasMetadata::new));
             entries.add(new NamedWriteableRegistry.Entry(NamedDiff.class, ModelAliasMetadata.NAME, ModelAliasMetadata::readDiffFrom));
             entries.add(new NamedWriteableRegistry.Entry(Metadata.Custom.class, "ml", MlMetadata::new));

+ 2 - 2
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java

@@ -83,7 +83,7 @@ public class PyTorchModelIT extends ESRestTestCase {
 
     private static final String MODEL_INDEX = "model_store";
     private static final String MODEL_ID ="simple_model_to_evaluate";
-    private static final String BASE_64_ENCODED_MODEL =
+    static final String BASE_64_ENCODED_MODEL =
         "UEsDBAAACAgAAAAAAAAAAAAAAAAAAAAAAAAUAA4Ac2ltcGxlbW9kZWwvZGF0YS5wa2xGQgoAWlpaWlpaWlpaWoACY19fdG9yY2hfXwp" +
             "TdXBlclNpbXBsZQpxACmBfShYCAAAAHRyYWluaW5ncQGIdWJxAi5QSwcIXOpBBDQAAAA0AAAAUEsDBBQACAgIAAAAAAAAAAAAAAAAAA" +
             "AAAAAdAEEAc2ltcGxlbW9kZWwvY29kZS9fX3RvcmNoX18ucHlGQj0AWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaW" +
@@ -106,7 +106,7 @@ public class PyTorchModelIT extends ESRestTestCase {
             "EsBAgAAAAAICAAAAAAAANGeZ1UCAAAAAgAAABMAAAAAAAAAAAAAAAAAFAQAAHNpbXBsZW1vZGVsL3ZlcnNpb25QSwYGLAAAAAAAAAAe" +
             "Ay0AAAAAAAAAAAAFAAAAAAAAAAUAAAAAAAAAagEAAAAAAACSBAAAAAAAAFBLBgcAAAAA/AUAAAAAAAABAAAAUEsFBgAAAAAFAAUAagE" +
             "AAJIEAAAAAA==";
-    private static final int RAW_MODEL_SIZE; // size of the model before base64 encoding
+    static final int RAW_MODEL_SIZE; // size of the model before base64 encoding
     static {
         RAW_MODEL_SIZE = Base64.getDecoder().decode(BASE_64_ENCODED_MODEL).length;
     }

+ 108 - 2
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/TestFeatureResetIT.java

@@ -13,35 +13,49 @@ import org.elasticsearch.action.ingest.DeletePipelineAction;
 import org.elasticsearch.action.ingest.DeletePipelineRequest;
 import org.elasticsearch.action.ingest.PutPipelineAction;
 import org.elasticsearch.action.ingest.PutPipelineRequest;
+import org.elasticsearch.action.support.WriteRequest;
 import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.common.bytes.BytesArray;
 import org.elasticsearch.common.xcontent.XContentType;
+import org.elasticsearch.tasks.TaskInfo;
 import org.elasticsearch.xpack.core.ml.MlMetadata;
 import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction;
+import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
 import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction;
+import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
 import org.elasticsearch.xpack.core.ml.datafeed.DatafeedConfig;
 import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
 import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
 import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.IndexLocation;
 import org.elasticsearch.xpack.core.ml.job.config.Job;
 import org.elasticsearch.xpack.core.ml.job.config.JobState;
 import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.DataCounts;
 import org.junit.After;
 
+import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashSet;
+import java.util.List;
 import java.util.Set;
 import java.util.concurrent.TimeUnit;
+import java.util.stream.Collectors;
 
 import static org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor.Factory.countNumberInferenceProcessors;
 import static org.elasticsearch.xpack.ml.integration.ClassificationIT.KEYWORD_FIELD;
 import static org.elasticsearch.xpack.ml.integration.MlNativeDataFrameAnalyticsIntegTestCase.buildAnalytics;
+import static org.elasticsearch.xpack.ml.integration.PyTorchModelIT.BASE_64_ENCODED_MODEL;
+import static org.elasticsearch.xpack.ml.integration.PyTorchModelIT.RAW_MODEL_SIZE;
 import static org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase.createDatafeed;
 import static org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase.createScheduledJob;
 import static org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase.getDataCounts;
 import static org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase.indexDocs;
 import static org.hamcrest.Matchers.containsString;
-import static org.hamcrest.Matchers.emptyArray;
+import static org.hamcrest.Matchers.empty;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.greaterThan;
 import static org.hamcrest.Matchers.is;
@@ -51,6 +65,7 @@ public class TestFeatureResetIT extends MlNativeAutodetectIntegTestCase {
     private final Set<String> createdPipelines = new HashSet<>();
     private final Set<String> jobIds = new HashSet<>();
     private final Set<String> datafeedIds = new HashSet<>();
+    private static final String TRAINED_MODEL_ID = "trained-model-to-reset";
 
     void cleanupDatafeed(String datafeedId) {
         try {
@@ -122,7 +137,10 @@ public class TestFeatureResetIT extends MlNativeAutodetectIntegTestCase {
             ResetFeatureStateAction.INSTANCE,
             new ResetFeatureStateRequest()
         ).actionGet();
-        assertBusy(() -> assertThat(client().admin().indices().prepareGetIndex().addIndices(".ml*").get().indices(), emptyArray()));
+        assertBusy(() -> {
+            List<String> indices = Arrays.asList(client().admin().indices().prepareGetIndex().addIndices(".ml*").get().indices());
+            assertThat(indices.toString(), indices, is(empty()));
+        });
         assertThat(isResetMode(), is(false));
         // If we have succeeded, clear the jobs and datafeeds so that the delete API doesn't recreate the notifications index
         jobIds.clear();
@@ -147,6 +165,94 @@ public class TestFeatureResetIT extends MlNativeAutodetectIntegTestCase {
         assertThat(isResetMode(), is(false));
     }
 
+    public void testMLFeatureResetWithModelDeployment() throws Exception {
+        createModelDeployment();
+        client().execute(
+            ResetFeatureStateAction.INSTANCE,
+            new ResetFeatureStateRequest()
+        ).actionGet();
+        assertBusy(() -> {
+            List<String> indices = Arrays.asList(client().admin().indices().prepareGetIndex().addIndices(".ml*").get().indices());
+            assertThat(indices.toString(), indices, is(empty()));
+        });
+        assertThat(isResetMode(), is(false));
+        List<String> tasksNames = client().admin()
+            .cluster()
+            .prepareListTasks()
+            .setActions("xpack/ml/*")
+            .get()
+            .getTasks()
+            .stream()
+            .map(TaskInfo::getAction)
+            .collect(Collectors.toList());
+        assertThat(tasksNames, is(empty()));
+    }
+
+    void createModelDeployment() {
+        String indexname = "model_store";
+        client().admin().indices().prepareCreate(indexname).setMapping(
+            "    {\"properties\": {\n" +
+                "        \"doc_type\":    { \"type\": \"keyword\"  },\n" +
+                "        \"model_id\":    { \"type\": \"keyword\"  },\n" +
+                "        \"definition_length\":     { \"type\": \"long\"  },\n" +
+                "        \"total_definition_length\":     { \"type\": \"long\"  },\n" +
+                "        \"compression_version\":     { \"type\": \"long\"  },\n" +
+                "        \"definition\":     { \"type\": \"binary\"  },\n" +
+                "        \"eos\":      { \"type\": \"boolean\" },\n" +
+                "        \"task_type\":      { \"type\": \"keyword\" },\n" +
+                "        \"vocab\":      { \"type\": \"keyword\" },\n" +
+                "        \"with_special_tokens\":      { \"type\": \"boolean\" },\n" +
+                "        \"do_lower_case\":      { \"type\": \"boolean\" }\n" +
+                "      }\n" +
+                "    }}"
+        ).get();
+        client().prepareIndex(indexname)
+            .setId(TRAINED_MODEL_ID + "_task_config")
+            .setSource(
+                "{  " +
+                    "\"task_type\": \"bert_pass_through\",\n" +
+                    "\"with_special_tokens\": false," +
+                    "\"vocab\": [\"these\", \"are\", \"my\", \"words\"]\n" +
+                    "}",
+                XContentType.JSON
+            ).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
+            .get();
+        client().prepareIndex(indexname)
+            .setId("trained_model_definition_doc-" + TRAINED_MODEL_ID + "-0")
+            .setSource(
+                "{  " +
+                    "\"doc_type\": \"trained_model_definition_doc\"," +
+                    "\"model_id\": \"" + TRAINED_MODEL_ID +"\"," +
+                    "\"doc_num\": 0," +
+                    "\"definition_length\":" + RAW_MODEL_SIZE + "," +
+                    "\"total_definition_length\":" + RAW_MODEL_SIZE + "," +
+                    "\"compression_version\": 1," +
+                    "\"definition\": \""  + BASE_64_ENCODED_MODEL + "\"," +
+                    "\"eos\": true" +
+                    "}",
+                XContentType.JSON
+            ).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
+            .get();
+        client()
+            .execute(
+                PutTrainedModelAction.INSTANCE,
+                new PutTrainedModelAction.Request(
+                    TrainedModelConfig.builder()
+                        .setModelType(TrainedModelType.PYTORCH)
+                        .setInferenceConfig(new ClassificationConfig(1))
+                        .setInput(new TrainedModelInput(Arrays.asList("text_field")))
+                        .setLocation(new IndexLocation(TRAINED_MODEL_ID, indexname))
+                        .setModelId(TRAINED_MODEL_ID)
+                        .build()
+                )
+            )
+            .actionGet();
+        client().execute(
+            StartTrainedModelDeploymentAction.INSTANCE,
+            new StartTrainedModelDeploymentAction.Request(TRAINED_MODEL_ID)
+        ).actionGet();
+    }
+
     private boolean isResetMode() {
         ClusterState state = client().admin().cluster().prepareState().get().getState();
         return MlMetadata.getMlMetadata(state).isResetMode();

+ 22 - 5
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java

@@ -567,6 +567,7 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
     private final SetOnce<ModelLoadingService> modelLoadingService = new SetOnce<>();
     private final SetOnce<MlAutoscalingDeciderService> mlAutoscalingDeciderService = new SetOnce<>();
     private final SetOnce<DeploymentManager> deploymentManager = new SetOnce<>();
+    private final SetOnce<TrainedModelAllocationClusterService> trainedModelAllocationClusterServiceSetOnce = new SetOnce<>();
 
     public MachineLearning(Settings settings, Path configPath) {
         this.settings = settings;
@@ -870,11 +871,11 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
             clusterService,
             threadPool
         );
-        final TrainedModelAllocationClusterService trainedModelAllocationClusterService = new TrainedModelAllocationClusterService(
+        trainedModelAllocationClusterServiceSetOnce.set(new TrainedModelAllocationClusterService(
             settings,
             clusterService,
             new NodeLoadDetector(memoryTracker)
-        );
+        ));
 
         mlAutoscalingDeciderService.set(new MlAutoscalingDeciderService(memoryTracker, settings, clusterService));
 
@@ -905,7 +906,7 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
                 modelLoadingService,
                 trainedModelProvider,
                 trainedModelAllocationService,
-                trainedModelAllocationClusterService,
+                trainedModelAllocationClusterServiceSetOnce.get(),
                 deploymentManager.get()
         );
     }
@@ -1375,7 +1376,10 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
 
         ActionListener<ResetFeatureStateResponse.ResetFeatureStateStatus> unsetResetModeListener = ActionListener.wrap(
             success -> client.execute(SetResetModeAction.INSTANCE, SetResetModeActionRequest.disabled(true), ActionListener.wrap(
-                resetSuccess -> finalListener.onResponse(success),
+                resetSuccess -> {
+                    finalListener.onResponse(success);
+                    logger.info("Finished machine learning feature reset");
+                },
                 resetFailure -> {
                     logger.error("failed to disable reset mode after state otherwise successful machine learning reset", resetFailure);
                     finalListener.onFailure(
@@ -1434,6 +1438,7 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
                 client.admin()
                     .cluster()
                     .prepareListTasks()
+                    // This waits for all xpack actions including: allocations, anomaly detections, analytics
                     .setActions("xpack/ml/*")
                     .setWaitForCompletion(true)
                     .execute(ActionListener.wrap(
@@ -1504,7 +1509,7 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
         }, unsetResetModeListener::onFailure);
 
         // Stop data feeds
-        ActionListener<AcknowledgedResponse> pipelineValidation = ActionListener.wrap(
+        ActionListener<AcknowledgedResponse> stopDeploymentsListener = ActionListener.wrap(
             acknowledgedResponse -> {
                 StopDatafeedAction.Request stopDatafeedsReq = new StopDatafeedAction.Request("_all")
                     .setAllowNoMatch(true);
@@ -1519,6 +1524,18 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
             unsetResetModeListener::onFailure
         );
 
+        // Stop all model deployments
+        ActionListener<AcknowledgedResponse> pipelineValidation = ActionListener.wrap(
+            acknowledgedResponse -> {
+                if (trainedModelAllocationClusterServiceSetOnce.get() == null) {
+                    stopDeploymentsListener.onResponse(AcknowledgedResponse.TRUE);
+                    return;
+                }
+                trainedModelAllocationClusterServiceSetOnce.get().removeAllModelAllocations(stopDeploymentsListener);
+            },
+            unsetResetModeListener::onFailure
+        );
+
         // validate no pipelines are using machine learning models
         ActionListener<AcknowledgedResponse> afterResetModeSet = ActionListener.wrap(
             acknowledgedResponse -> {

+ 44 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationClusterService.java

@@ -10,6 +10,7 @@ package org.elasticsearch.xpack.ml.inference.allocation;
 import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 import org.apache.logging.log4j.message.ParameterizedMessage;
+import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.ResourceAlreadyExistsException;
 import org.elasticsearch.ResourceNotFoundException;
 import org.elasticsearch.action.ActionListener;
@@ -27,6 +28,8 @@ import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.gateway.GatewayService;
+import org.elasticsearch.rest.RestStatus;
+import org.elasticsearch.xpack.core.ml.MlMetadata;
 import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
 import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAllocationStateAction;
 import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationState;
@@ -192,6 +195,26 @@ public class TrainedModelAllocationClusterService implements ClusterStateListene
         });
     }
 
+    // Used by the reset action directly
+    public void removeAllModelAllocations(ActionListener<AcknowledgedResponse> listener) {
+        clusterService.submitStateUpdateTask("delete all model allocations", new ClusterStateUpdateTask() {
+            @Override
+            public ClusterState execute(ClusterState currentState) {
+                return removeAllAllocations(currentState);
+            }
+
+            @Override
+            public void onFailure(String source, Exception e) {
+                listener.onFailure(e);
+            }
+
+            @Override
+            public void clusterStateProcessed(String source, ClusterState oldState, ClusterState newState) {
+                listener.onResponse(AcknowledgedResponse.TRUE);
+            }
+        });
+    }
+
     private static ClusterState update(ClusterState currentState, TrainedModelAllocationMetadata.Builder modelAllocations) {
         if (modelAllocations.isChanged()) {
             return ClusterState.builder(currentState)
@@ -205,9 +228,16 @@ public class TrainedModelAllocationClusterService implements ClusterStateListene
     }
 
     ClusterState createModelAllocation(ClusterState currentState, StartTrainedModelDeploymentAction.TaskParams params) {
+        if (MlMetadata.getMlMetadata(currentState).isResetMode()) {
+            throw new ElasticsearchStatusException(
+                "cannot create new allocation for model [{}] while feature reset is in progress.",
+                RestStatus.CONFLICT,
+                params.getModelId()
+            );
+        }
         TrainedModelAllocationMetadata.Builder builder = TrainedModelAllocationMetadata.builder(currentState);
         if (builder.hasModel(params.getModelId())) {
-            throw new ResourceAlreadyExistsException("allocation for model with id [" + params.getModelId() + "] already exist");
+            throw new ResourceAlreadyExistsException("allocation for model with id [{}] already exist", params.getModelId());
         }
 
         Set<String> shuttingDownNodes = nodesShuttingDown(currentState);
@@ -288,6 +318,19 @@ public class TrainedModelAllocationClusterService implements ClusterStateListene
         return update(currentState, builder.removeAllocation(modelId));
     }
 
+    static ClusterState removeAllAllocations(ClusterState currentState) {
+        if (TrainedModelAllocationMetadata.fromState(currentState).modelAllocations().isEmpty()) {
+            return currentState;
+        };
+        return ClusterState.builder(currentState)
+            .metadata(
+                Metadata.builder(currentState.metadata())
+                    .putCustom(TrainedModelAllocationMetadata.NAME, TrainedModelAllocationMetadata.Builder.empty().build())
+                    .build()
+            )
+            .build();
+    }
+
     ClusterState addRemoveAllocationNodes(ClusterState currentState) {
         TrainedModelAllocationMetadata previousState = TrainedModelAllocationMetadata.fromState(currentState);
         TrainedModelAllocationMetadata.Builder builder = TrainedModelAllocationMetadata.builder(currentState);

+ 5 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeService.java

@@ -26,6 +26,7 @@ import org.elasticsearch.tasks.TaskId;
 import org.elasticsearch.tasks.TaskManager;
 import org.elasticsearch.threadpool.Scheduler;
 import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.xpack.core.ml.MlMetadata;
 import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
 import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingState;
 import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingStateAndReason;
@@ -248,6 +249,7 @@ public class TrainedModelAllocationNodeService implements ClusterStateListener {
     @Override
     public void clusterChanged(ClusterChangedEvent event) {
         if (event.metadataChanged()) {
+            final boolean isResetMode = MlMetadata.getMlMetadata(event.state()).isResetMode();
             TrainedModelAllocationMetadata modelAllocationMetadata = TrainedModelAllocationMetadata.fromState(event.state());
             final String currentNode = event.state().nodes().getLocalNodeId();
             for (TrainedModelAllocation trainedModelAllocation : modelAllocationMetadata.modelAllocations().values()) {
@@ -257,7 +259,9 @@ public class TrainedModelAllocationNodeService implements ClusterStateListener {
                     // periodic retries should be handled in a separate thread think
                     && routingStateAndReason.getState().equals(RoutingState.STARTING)
                     // This means we don't already have a task and should attempt creating one and starting the model loading
-                    && modelIdToTask.containsKey(trainedModelAllocation.getTaskParams().getModelId()) == false) {
+                    && modelIdToTask.containsKey(trainedModelAllocation.getTaskParams().getModelId()) == false
+                    // If we are in reset mode, don't start loading a new model on this node.
+                    && isResetMode == false) {
                     prepareModelToLoad(trainedModelAllocation.getTaskParams());
                 }
                 // This mode is not routed to the current node at all

+ 1 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java

@@ -39,7 +39,7 @@ public class TrainedModelDeploymentTask extends CancellableTask implements Start
         TaskParams taskParams,
         TrainedModelAllocationNodeService trainedModelAllocationNodeService
     ) {
-        super(id, type, action, MlTasks.TRAINED_MODEL_DEPLOYMENT_TASK_ID_PREFIX + taskParams.getModelId(), parentTask, headers);
+        super(id, type, action, MlTasks.trainedModelDeploymentTaskId(taskParams.getModelId()), parentTask, headers);
         this.params = taskParams;
         this.trainedModelAllocationNodeService = ExceptionsHelper.requireNonNull(
             trainedModelAllocationNodeService,

+ 52 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationClusterServiceTests.java

@@ -7,6 +7,7 @@
 
 package org.elasticsearch.xpack.ml.inference.allocation;
 
+import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.ResourceAlreadyExistsException;
 import org.elasticsearch.ResourceNotFoundException;
 import org.elasticsearch.Version;
@@ -26,6 +27,7 @@ import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.common.util.set.Sets;
 import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.core.ml.MlMetadata;
 import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
 import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAllocationStateAction;
 import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationState;
@@ -42,6 +44,7 @@ import java.util.Set;
 import java.util.function.Function;
 
 import static org.hamcrest.Matchers.allOf;
+import static org.hamcrest.Matchers.anEmptyMap;
 import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.hasKey;
@@ -174,7 +177,29 @@ public class TrainedModelAllocationClusterServiceTests extends ESTestCase {
 
         ClusterState modified = TrainedModelAllocationClusterService.removeAllocation(clusterStateWithAllocation, modelId);
         assertThat(TrainedModelAllocationMetadata.fromState(modified).getModelAllocation(modelId), is(nullValue()));
+    }
 
+    public void testRemoveAllAllocations() {
+        ClusterState clusterStateWithoutAllocation = ClusterState.builder(new ClusterName("testRemoveAllAllocations"))
+            .metadata(Metadata.builder().build())
+            .build();
+        assertThat(
+            TrainedModelAllocationClusterService.removeAllAllocations(clusterStateWithoutAllocation),
+            equalTo(clusterStateWithoutAllocation)
+        );
+
+        ClusterState clusterStateWithAllocations = ClusterState.builder(new ClusterName("testRemoveAllAllocations"))
+            .metadata(
+                Metadata.builder()
+                    .putCustom(
+                        TrainedModelAllocationMetadata.NAME,
+                        TrainedModelAllocationMetadataTests.randomInstance()
+                    )
+                    .build()
+            )
+            .build();
+        ClusterState modified = TrainedModelAllocationClusterService.removeAllAllocations(clusterStateWithAllocations);
+        assertThat(TrainedModelAllocationMetadata.fromState(modified).modelAllocations(), is(anEmptyMap()));
     }
 
     public void testCreateAllocation() {
@@ -212,6 +237,33 @@ public class TrainedModelAllocationClusterServiceTests extends ESTestCase {
         );
     }
 
+    public void testCreateAllocationWhileResetModeIsTrue() {
+        ClusterState currentState = ClusterState.builder(new ClusterName("testCreateAllocation"))
+            .nodes(
+                DiscoveryNodes.builder()
+                    .add(buildNode("ml-node-with-room", true, ByteSizeValue.ofGb(4).getBytes()))
+                    .build()
+            )
+            .metadata(Metadata.builder().putCustom(MlMetadata.TYPE, new MlMetadata.Builder().isResetMode(true).build()))
+            .build();
+        TrainedModelAllocationClusterService trainedModelAllocationClusterService = createClusterService();
+        expectThrows(
+            ElasticsearchStatusException.class,
+            () -> trainedModelAllocationClusterService.createModelAllocation(currentState, newParams("new-model", 150))
+        );
+
+        ClusterState stateWithoutReset = ClusterState.builder(new ClusterName("testCreateAllocation"))
+            .nodes(
+                DiscoveryNodes.builder()
+                    .add(buildNode("ml-node-with-room", true, ByteSizeValue.ofGb(4).getBytes()))
+                    .build()
+            )
+            .metadata(Metadata.builder().putCustom(MlMetadata.TYPE, new MlMetadata.Builder().isResetMode(false).build()))
+            .build();
+        // Shouldn't throw
+        trainedModelAllocationClusterService.createModelAllocation(stateWithoutReset, newParams("new-model", 150));
+    }
+
     public void testAddRemoveAllocationNodes() {
         ClusterState currentState = ClusterState.builder(new ClusterName("testAddRemoveAllocationNodes"))
             .nodes(

+ 48 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeServiceTests.java

@@ -26,6 +26,7 @@ import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.threadpool.ScalingExecutorBuilder;
 import org.elasticsearch.threadpool.TestThreadPool;
 import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.xpack.core.ml.MlMetadata;
 import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
 import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAllocationStateAction;
 import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingState;
@@ -216,6 +217,53 @@ public class TrainedModelAllocationNodeServiceTests extends ESTestCase {
         verifyNoMoreInteractions(deploymentManager, trainedModelAllocationService);
     }
 
+    public void testClusterChangedWithResetMode() {
+        final TrainedModelAllocationNodeService trainedModelAllocationNodeService = createService();
+        final DiscoveryNodes nodes = DiscoveryNodes.builder()
+            .localNodeId(NODE_ID)
+            .add(
+                new DiscoveryNode(
+                    NODE_ID,
+                    NODE_ID,
+                    buildNewFakeTransportAddress(),
+                    Collections.emptyMap(),
+                    DiscoveryNodeRole.roles(),
+                    Version.CURRENT
+                )
+            )
+            .build();
+        String modelOne = "model-1";
+        String modelTwo = "model-2";
+        String notUsedModel = "model-3";
+        ClusterChangedEvent event = new ClusterChangedEvent(
+            "testClusterChanged",
+            ClusterState.builder(new ClusterName("testClusterChanged"))
+                .nodes(nodes)
+                .metadata(
+                    Metadata.builder()
+                        .putCustom(
+                            TrainedModelAllocationMetadata.NAME,
+                            TrainedModelAllocationMetadata.Builder.empty()
+                                .addNewAllocation(newParams(modelOne))
+                                .addNode(modelOne, NODE_ID)
+                                .addNewAllocation(newParams(modelTwo))
+                                .addNode(modelTwo, NODE_ID)
+                                .addNewAllocation(newParams(notUsedModel))
+                                .addNode(notUsedModel, "some-other-node")
+                                .build()
+                        )
+                        .putCustom(MlMetadata.TYPE, new MlMetadata.Builder().isResetMode(true).build())
+                        .build()
+                )
+                .build(),
+            ClusterState.EMPTY_STATE
+        );
+
+        trainedModelAllocationNodeService.clusterChanged(event);
+        trainedModelAllocationNodeService.loadQueuedModels();
+        verifyNoMoreInteractions(deploymentManager, trainedModelAllocationService);
+    }
+
     public void testClusterChanged() throws Exception {
         final TrainedModelAllocationNodeService trainedModelAllocationNodeService = createService();
         final DiscoveryNodes nodes = DiscoveryNodes.builder()