Pārlūkot izejas kodu

[ML] Integrating ML with the node shutdown API (#75188)

When the node shutdown API is used to indicate that
an ML node is going to be shut down, ML will do the
following:

1. Isolate all datafeeds on the node and unassign
   their persistent tasks.
2. Tell all anomaly detection jobs on the node to go
   through the motions of closing, but not actually
   close and instead unassign their persistent tasks.
3. Report that the node is safe to shut down once all
   persistent tasks associated with anomaly detection
   jobs and model snapshot upgrades have either
   completed or been unassigned.
David Roberts 4 gadi atpakaļ
vecāks
revīzija
bce246dbb7
29 mainītis faili ar 1226 papildinājumiem un 150 dzēšanām
  1. 2 1
      server/src/main/java/org/elasticsearch/persistent/PersistentTasksClusterService.java
  2. 5 1
      server/src/main/java/org/elasticsearch/persistent/PersistentTasksService.java
  3. 63 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java
  4. 2 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StopDatafeedAction.java
  5. 11 1
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/LocalStateCompositeXPackPlugin.java
  6. 61 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/MlTasksTests.java
  7. 2 1
      x-pack/plugin/ml/build.gradle
  8. 224 0
      x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/MlNodeShutdownIT.java
  9. 16 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java
  10. 2 2
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MlConfigMigrator.java
  11. 61 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MlLifeCycleService.java
  12. 56 20
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCloseJobAction.java
  13. 84 23
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDatafeedAction.java
  14. 113 26
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStopDatafeedAction.java
  15. 20 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/DatafeedNodeSelector.java
  16. 67 13
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/DatafeedRunner.java
  17. 7 7
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectCommunicator.java
  18. 68 16
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcessManager.java
  19. 29 8
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/task/JobTask.java
  20. 5 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/task/OpenJobPersistentTasksExecutor.java
  21. 1 1
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlConfigMigratorTests.java
  22. 165 0
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlLifeCycleServiceTests.java
  23. 6 2
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportCloseJobActionTests.java
  24. 3 1
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportStartDatafeedActionTests.java
  25. 60 14
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/DatafeedNodeSelectorTests.java
  26. 21 1
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/DatafeedRunnerTests.java
  27. 1 1
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcessManagerTests.java
  28. 32 0
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/task/JobTaskTests.java
  29. 39 7
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/support/BaseMlIntegTestCase.java

+ 2 - 1
server/src/main/java/org/elasticsearch/persistent/PersistentTasksClusterService.java

@@ -317,7 +317,8 @@ public class PersistentTasksClusterService implements ClusterStateListener, Clos
         Randomness.shuffle(candidateNodes);
 
         final Assignment assignment = persistentTasksExecutor.getAssignment(taskParams, candidateNodes, currentState);
-        assert (assignment == null || isNodeShuttingDown(currentState, assignment.getExecutorNode()) == false) :
+        assert assignment != null : "getAssignment() should always return an Assignment object, containing a node or a reason why not";
+        assert (assignment.getExecutorNode() == null || isNodeShuttingDown(currentState, assignment.getExecutorNode()) == false) :
             "expected task [" + taskName + "] to be assigned to a node that is not marked as shutting down, but " +
                 assignment.getExecutorNode() + " is currently marked as shutting down";
         return assignment;

+ 5 - 1
server/src/main/java/org/elasticsearch/persistent/PersistentTasksService.java

@@ -132,7 +132,11 @@ public class PersistentTasksService {
      * or above.
      */
     public boolean isLocalAbortSupported() {
-        return clusterService.state().nodes().getMinNodeVersion().onOrAfter(LOCAL_ABORT_AVAILABLE_VERSION);
+        return isLocalAbortSupported(clusterService.state());
+    }
+
+    public static boolean isLocalAbortSupported(ClusterState state) {
+        return state.nodes().getMinNodeVersion().onOrAfter(LOCAL_ABORT_AVAILABLE_VERSION);
     }
 
     /**

+ 63 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java

@@ -277,6 +277,69 @@ public final class MlTasks {
         return tasks.findTasks(JOB_TASK_NAME, task -> true);
     }
 
+    public static Collection<PersistentTasksCustomMetadata.PersistentTask<?>> datafeedTasksOnNode(
+        @Nullable PersistentTasksCustomMetadata tasks, String nodeId) {
+        if (tasks == null) {
+            return Collections.emptyList();
+        }
+
+        return tasks.findTasks(DATAFEED_TASK_NAME, task -> nodeId.equals(task.getExecutorNode()));
+    }
+
+    public static Collection<PersistentTasksCustomMetadata.PersistentTask<?>> jobTasksOnNode(
+        @Nullable PersistentTasksCustomMetadata tasks, String nodeId) {
+        if (tasks == null) {
+            return Collections.emptyList();
+        }
+
+        return tasks.findTasks(JOB_TASK_NAME, task -> nodeId.equals(task.getExecutorNode()));
+    }
+
+    public static Collection<PersistentTasksCustomMetadata.PersistentTask<?>> nonFailedJobTasksOnNode(
+        @Nullable PersistentTasksCustomMetadata tasks, String nodeId) {
+        if (tasks == null) {
+            return Collections.emptyList();
+        }
+
+        return tasks.findTasks(JOB_TASK_NAME, task -> {
+            if (nodeId.equals(task.getExecutorNode()) == false) {
+                return false;
+            }
+            JobTaskState state = (JobTaskState) task.getState();
+            if (state == null) {
+                return true;
+            }
+            return state.getState() != JobState.FAILED;
+        });
+    }
+
+    public static Collection<PersistentTasksCustomMetadata.PersistentTask<?>> snapshotUpgradeTasksOnNode(
+        @Nullable PersistentTasksCustomMetadata tasks, String nodeId) {
+        if (tasks == null) {
+            return Collections.emptyList();
+        }
+
+        return tasks.findTasks(JOB_SNAPSHOT_UPGRADE_TASK_NAME, task -> nodeId.equals(task.getExecutorNode()));
+    }
+
+    public static Collection<PersistentTasksCustomMetadata.PersistentTask<?>> nonFailedSnapshotUpgradeTasksOnNode(
+        @Nullable PersistentTasksCustomMetadata tasks, String nodeId) {
+        if (tasks == null) {
+            return Collections.emptyList();
+        }
+
+        return tasks.findTasks(JOB_SNAPSHOT_UPGRADE_TASK_NAME, task -> {
+            if (nodeId.equals(task.getExecutorNode()) == false) {
+                return false;
+            }
+            SnapshotUpgradeTaskState taskState = (SnapshotUpgradeTaskState) task.getState();
+            if (taskState == null) {
+                return true;
+            }
+            SnapshotUpgradeState state = taskState.getState();
+            return state != SnapshotUpgradeState.FAILED;
+        });
+    }
 
     /**
      * Get the job Ids of anomaly detector job tasks that do

+ 2 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StopDatafeedAction.java

@@ -96,7 +96,9 @@ public class StopDatafeedAction extends ActionType<StopDatafeedAction.Response>
             return resolvedStartedDatafeedIds;
         }
 
+        // This is used internally - the transport action sets it, not the user
         public void setResolvedStartedDatafeedIds(String[] resolvedStartedDatafeedIds) {
+            assert resolvedStartedDatafeedIds != null;
             this.resolvedStartedDatafeedIds = resolvedStartedDatafeedIds;
         }
 

+ 11 - 1
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/LocalStateCompositeXPackPlugin.java

@@ -18,6 +18,7 @@ import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.cluster.coordination.ElectionStrategy;
 import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
 import org.elasticsearch.cluster.metadata.IndexTemplateMetadata;
+import org.elasticsearch.cluster.metadata.SingleNodeShutdownMetadata;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.node.DiscoveryNodes;
 import org.elasticsearch.cluster.routing.allocation.ExistingShardsAllocator;
@@ -69,6 +70,7 @@ import org.elasticsearch.plugins.Plugin;
 import org.elasticsearch.plugins.RepositoryPlugin;
 import org.elasticsearch.plugins.ScriptPlugin;
 import org.elasticsearch.plugins.SearchPlugin;
+import org.elasticsearch.plugins.ShutdownAwarePlugin;
 import org.elasticsearch.plugins.SystemIndexPlugin;
 import org.elasticsearch.repositories.RepositoriesService;
 import org.elasticsearch.repositories.Repository;
@@ -109,7 +111,7 @@ import static java.util.stream.Collectors.toList;
 
 public class LocalStateCompositeXPackPlugin extends XPackPlugin implements ScriptPlugin, ActionPlugin, IngestPlugin, NetworkPlugin,
         ClusterPlugin, DiscoveryPlugin, MapperPlugin, AnalysisPlugin, PersistentTaskPlugin, EnginePlugin, IndexStorePlugin,
-        SystemIndexPlugin, SearchPlugin {
+        SystemIndexPlugin, SearchPlugin, ShutdownAwarePlugin {
 
     private XPackLicenseState licenseState;
     private SSLService sslService;
@@ -588,4 +590,12 @@ public class LocalStateCompositeXPackPlugin extends XPackPlugin implements Scrip
         }
 
     }
+
+    public boolean safeToShutdown(String nodeId, SingleNodeShutdownMetadata.Type shutdownType) {
+        return filterPlugins(ShutdownAwarePlugin.class).stream().allMatch(plugin -> plugin.safeToShutdown(nodeId, shutdownType));
+    }
+
+    public void signalShutdown(Collection<String> shutdownNodeIds) {
+        filterPlugins(ShutdownAwarePlugin.class).forEach(plugin -> plugin.signalShutdown(shutdownNodeIds));
+    }
 }

+ 61 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/MlTasksTests.java

@@ -28,9 +28,11 @@ import org.elasticsearch.xpack.core.ml.job.config.JobTaskState;
 import java.net.InetAddress;
 import java.util.Arrays;
 
+import static org.hamcrest.Matchers.contains;
 import static org.hamcrest.Matchers.containsInAnyOrder;
 import static org.hamcrest.Matchers.empty;
 import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.hasProperty;
 
 public class MlTasksTests extends ESTestCase {
 
@@ -173,6 +175,65 @@ public class MlTasksTests extends ESTestCase {
                 containsInAnyOrder("datafeed_without_assignment", "datafeed_without_node"));
     }
 
+    public void testDatafeedTasksOnNode() {
+        PersistentTasksCustomMetadata.Builder tasksBuilder = PersistentTasksCustomMetadata.builder();
+        assertThat(MlTasks.openJobIds(tasksBuilder.build()), empty());
+
+        tasksBuilder.addTask(MlTasks.datafeedTaskId("df1"), MlTasks.DATAFEED_TASK_NAME,
+            new StartDatafeedAction.DatafeedParams("df1", 0L),
+            new PersistentTasksCustomMetadata.Assignment("node-1", "test assignment"));
+        tasksBuilder.addTask(MlTasks.jobTaskId("job-2"), MlTasks.JOB_TASK_NAME, new OpenJobAction.JobParams("foo-2"),
+            new PersistentTasksCustomMetadata.Assignment("node-2", "test assignment"));
+        tasksBuilder.addTask(MlTasks.datafeedTaskId("df2"), MlTasks.DATAFEED_TASK_NAME,
+            new StartDatafeedAction.DatafeedParams("df2", 0L),
+            new PersistentTasksCustomMetadata.Assignment("node-2", "test assignment"));
+
+        assertThat(MlTasks.datafeedTasksOnNode(tasksBuilder.build(), "node-2"), contains(hasProperty("id", equalTo("datafeed-df2"))));
+    }
+
+    public void testJobTasksOnNode() {
+        PersistentTasksCustomMetadata.Builder tasksBuilder = PersistentTasksCustomMetadata.builder();
+        assertThat(MlTasks.openJobIds(tasksBuilder.build()), empty());
+
+        tasksBuilder.addTask(MlTasks.jobTaskId("job-1"), MlTasks.JOB_TASK_NAME, new OpenJobAction.JobParams("foo-1"),
+            new PersistentTasksCustomMetadata.Assignment("node-1", "test assignment"));
+        tasksBuilder.addTask(MlTasks.datafeedTaskId("df1"), MlTasks.DATAFEED_TASK_NAME,
+            new StartDatafeedAction.DatafeedParams("df1", 0L),
+            new PersistentTasksCustomMetadata.Assignment("node-1", "test assignment"));
+        tasksBuilder.addTask(MlTasks.jobTaskId("job-2"), MlTasks.JOB_TASK_NAME, new OpenJobAction.JobParams("foo-2"),
+            new PersistentTasksCustomMetadata.Assignment("node-2", "test assignment"));
+        tasksBuilder.addTask(MlTasks.datafeedTaskId("df2"), MlTasks.DATAFEED_TASK_NAME,
+            new StartDatafeedAction.DatafeedParams("df2", 0L),
+            new PersistentTasksCustomMetadata.Assignment("node-2", "test assignment"));
+        tasksBuilder.addTask(MlTasks.jobTaskId("job-3"), MlTasks.JOB_TASK_NAME, new OpenJobAction.JobParams("foo-3"),
+            new PersistentTasksCustomMetadata.Assignment("node-2", "test assignment"));
+
+        assertThat(MlTasks.jobTasksOnNode(tasksBuilder.build(), "node-2"),
+            containsInAnyOrder(hasProperty("id", equalTo("job-job-2")), hasProperty("id", equalTo("job-job-3"))));
+    }
+
+    public void testNonFailedJobTasksOnNode() {
+        PersistentTasksCustomMetadata.Builder tasksBuilder = PersistentTasksCustomMetadata.builder();
+        assertThat(MlTasks.openJobIds(tasksBuilder.build()), empty());
+
+        tasksBuilder.addTask(MlTasks.jobTaskId("job-1"), MlTasks.JOB_TASK_NAME, new OpenJobAction.JobParams("foo-1"),
+            new PersistentTasksCustomMetadata.Assignment("node-1", "test assignment"));
+        tasksBuilder.updateTaskState(MlTasks.jobTaskId("job-1"), new JobTaskState(JobState.FAILED, 1, "testing"));
+        tasksBuilder.addTask(MlTasks.jobTaskId("job-2"), MlTasks.JOB_TASK_NAME, new OpenJobAction.JobParams("foo-2"),
+            new PersistentTasksCustomMetadata.Assignment("node-1", "test assignment"));
+        if (randomBoolean()) {
+            tasksBuilder.updateTaskState(MlTasks.jobTaskId("job-2"), new JobTaskState(JobState.OPENED, 2, "testing"));
+        }
+        tasksBuilder.addTask(MlTasks.jobTaskId("job-3"), MlTasks.JOB_TASK_NAME, new OpenJobAction.JobParams("foo-3"),
+            new PersistentTasksCustomMetadata.Assignment("node-2", "test assignment"));
+        if (randomBoolean()) {
+            tasksBuilder.updateTaskState(MlTasks.jobTaskId("job-3"), new JobTaskState(JobState.FAILED, 3, "testing"));
+        }
+
+        assertThat(MlTasks.nonFailedJobTasksOnNode(tasksBuilder.build(), "node-1"),
+            contains(hasProperty("id", equalTo("job-job-2"))));
+    }
+
     public void testGetDataFrameAnalyticsState_GivenNullTask() {
         DataFrameAnalyticsState state = MlTasks.getDataFrameAnalyticsState(null);
         assertThat(state, equalTo(DataFrameAnalyticsState.STOPPED));

+ 2 - 1
x-pack/plugin/ml/build.gradle

@@ -57,9 +57,10 @@ tasks.named("bundlePlugin").configure {
 dependencies {
   compileOnly project(':modules:lang-painless:spi')
   compileOnly project(path: xpackModule('core'))
+  compileOnly project(path: xpackModule('autoscaling'))
   testImplementation(testArtifact(project(xpackModule('core'))))
   testImplementation project(path: xpackModule('ilm'))
-  compileOnly project(path: xpackModule('autoscaling'))
+  testImplementation project(path: xpackModule('shutdown'))
   testImplementation project(path: xpackModule('data-streams'))
   testImplementation project(':modules:ingest-common')
   // This should not be here

+ 224 - 0
x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/MlNodeShutdownIT.java

@@ -0,0 +1,224 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.integration;
+
+import org.apache.lucene.util.SetOnce;
+import org.elasticsearch.Build;
+import org.elasticsearch.cluster.metadata.Metadata;
+import org.elasticsearch.cluster.metadata.SingleNodeShutdownMetadata;
+import org.elasticsearch.common.unit.ByteSizeValue;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.xpack.core.action.util.QueryPage;
+import org.elasticsearch.xpack.core.ml.action.CloseJobAction;
+import org.elasticsearch.xpack.core.ml.action.GetJobsStatsAction;
+import org.elasticsearch.xpack.core.ml.action.OpenJobAction;
+import org.elasticsearch.xpack.core.ml.action.PutDatafeedAction;
+import org.elasticsearch.xpack.core.ml.action.PutJobAction;
+import org.elasticsearch.xpack.core.ml.action.StartDatafeedAction;
+import org.elasticsearch.xpack.core.ml.datafeed.DatafeedConfig;
+import org.elasticsearch.xpack.core.ml.job.config.Job;
+import org.elasticsearch.xpack.core.ml.job.config.JobState;
+import org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase;
+import org.elasticsearch.xpack.shutdown.PutShutdownNodeAction;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.concurrent.TimeUnit;
+
+import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.notNullValue;
+
+public class MlNodeShutdownIT extends BaseMlIntegTestCase {
+
+    public void testJobsVacateShuttingDownNode() throws Exception {
+
+        // TODO: delete this condition when the shutdown API is always available
+        assumeTrue("shutdown API is behind a snapshot-only feature flag", Build.CURRENT.isSnapshot());
+
+        internalCluster().ensureAtLeastNumDataNodes(3);
+        ensureStableCluster();
+
+        // Index some source data for the datafeeds.
+        createSourceData();
+
+        // Open 6 jobs.  Since there are 3 nodes in the cluster we should get 2 jobs per node.
+        setupJobAndDatafeed("shutdown-job-1", ByteSizeValue.ofMb(2));
+        setupJobAndDatafeed("shutdown-job-2", ByteSizeValue.ofMb(2));
+        setupJobAndDatafeed("shutdown-job-3", ByteSizeValue.ofMb(2));
+        setupJobAndDatafeed("shutdown-job-4", ByteSizeValue.ofMb(2));
+        setupJobAndDatafeed("shutdown-job-5", ByteSizeValue.ofMb(2));
+        setupJobAndDatafeed("shutdown-job-6", ByteSizeValue.ofMb(2));
+
+        // Choose a node to shut down.  Choose a non-master node most of the time, as ML nodes in Cloud
+        // will never be master, and Cloud is where the node shutdown API will primarily be used.
+        String nodeNameToShutdown = rarely() ? internalCluster().getMasterName() : Arrays.stream(internalCluster().getNodeNames())
+            .filter(nodeName -> internalCluster().getMasterName().equals(nodeName) == false).findFirst().get();
+        SetOnce<String> nodeIdToShutdown = new SetOnce<>();
+
+        // Wait for the desired initial state of 2 jobs running on each node.
+        assertBusy(() -> {
+            GetJobsStatsAction.Response statsResponse =
+                client().execute(GetJobsStatsAction.INSTANCE, new GetJobsStatsAction.Request(Metadata.ALL)).actionGet();
+            QueryPage<GetJobsStatsAction.Response.JobStats> jobStats = statsResponse.getResponse();
+            assertThat(jobStats, notNullValue());
+            long numJobsOnNodeToShutdown = jobStats.results().stream()
+                .filter(stats -> stats.getNode() != null && nodeNameToShutdown.equals(stats.getNode().getName())).count();
+            long numJobsOnOtherNodes = jobStats.results().stream()
+                .filter(stats -> stats.getNode() != null && nodeNameToShutdown.equals(stats.getNode().getName()) == false).count();
+            assertThat(numJobsOnNodeToShutdown, is(2L));
+            assertThat(numJobsOnOtherNodes, is(4L));
+            nodeIdToShutdown.set(jobStats.results().stream()
+                .filter(stats -> stats.getNode() != null && nodeNameToShutdown.equals(stats.getNode().getName()))
+                .map(stats -> stats.getNode().getId()).findFirst().get());
+        });
+
+        // Call the shutdown API for the chosen node.
+        client().execute(PutShutdownNodeAction.INSTANCE,
+            new PutShutdownNodeAction.Request(nodeIdToShutdown.get(), randomFrom(SingleNodeShutdownMetadata.Type.values()), "just testing"))
+            .actionGet();
+
+        // Wait for the desired end state of all 6 jobs running on nodes that are not shutting down.
+        assertBusy(() -> {
+            GetJobsStatsAction.Response statsResponse =
+                client().execute(GetJobsStatsAction.INSTANCE, new GetJobsStatsAction.Request(Metadata.ALL)).actionGet();
+            QueryPage<GetJobsStatsAction.Response.JobStats> jobStats = statsResponse.getResponse();
+            assertThat(jobStats, notNullValue());
+            long numJobsOnNodeToShutdown = jobStats.results().stream()
+                .filter(stats -> stats.getNode() != null && nodeNameToShutdown.equals(stats.getNode().getName())).count();
+            long numJobsOnOtherNodes = jobStats.results().stream()
+                .filter(stats -> stats.getNode() != null && nodeNameToShutdown.equals(stats.getNode().getName()) == false).count();
+            assertThat(numJobsOnNodeToShutdown, is(0L));
+            assertThat(numJobsOnOtherNodes, is(6L));
+        });
+    }
+
+    public void testCloseJobVacatingShuttingDownNode() throws Exception {
+
+        // TODO: delete this condition when the shutdown API is always available
+        assumeTrue("shutdown API is behind a snapshot-only feature flag", Build.CURRENT.isSnapshot());
+
+        internalCluster().ensureAtLeastNumDataNodes(3);
+        ensureStableCluster();
+
+        // Index some source data for the datafeeds.
+        createSourceData();
+
+        // Open 6 jobs.  Since there are 3 nodes in the cluster we should get 2 jobs per node.
+        setupJobAndDatafeed("shutdown-close-job-1", ByteSizeValue.ofMb(2));
+        setupJobAndDatafeed("shutdown-close-job-2", ByteSizeValue.ofMb(2));
+        setupJobAndDatafeed("shutdown-close-job-3", ByteSizeValue.ofMb(2));
+        setupJobAndDatafeed("shutdown-close-job-4", ByteSizeValue.ofMb(2));
+        setupJobAndDatafeed("shutdown-close-job-5", ByteSizeValue.ofMb(2));
+        setupJobAndDatafeed("shutdown-close-job-6", ByteSizeValue.ofMb(2));
+
+        // Choose a node to shut down, and one job on that node to close after the shutdown request has been sent.
+        // Choose a non-master node most of the time, as ML nodes in Cloud will never be master, and Cloud is where
+        // the node shutdown API will primarily be used.
+        String nodeNameToShutdown = rarely() ? internalCluster().getMasterName() : Arrays.stream(internalCluster().getNodeNames())
+            .filter(nodeName -> internalCluster().getMasterName().equals(nodeName) == false).findFirst().get();
+        SetOnce<String> nodeIdToShutdown = new SetOnce<>();
+        SetOnce<String> jobIdToClose = new SetOnce<>();
+
+        // Wait for the desired initial state of 2 jobs running on each node.
+        assertBusy(() -> {
+            GetJobsStatsAction.Response statsResponse =
+                client().execute(GetJobsStatsAction.INSTANCE, new GetJobsStatsAction.Request(Metadata.ALL)).actionGet();
+            QueryPage<GetJobsStatsAction.Response.JobStats> jobStats = statsResponse.getResponse();
+            assertThat(jobStats, notNullValue());
+            long numJobsOnNodeToShutdown = jobStats.results().stream()
+                .filter(stats -> stats.getNode() != null && nodeNameToShutdown.equals(stats.getNode().getName())).count();
+            long numJobsOnOtherNodes = jobStats.results().stream()
+                .filter(stats -> stats.getNode() != null && nodeNameToShutdown.equals(stats.getNode().getName()) == false).count();
+            assertThat(numJobsOnNodeToShutdown, is(2L));
+            assertThat(numJobsOnOtherNodes, is(4L));
+            nodeIdToShutdown.set(jobStats.results().stream()
+                .filter(stats -> stats.getNode() != null && nodeNameToShutdown.equals(stats.getNode().getName()))
+                .map(stats -> stats.getNode().getId()).findFirst().get());
+            jobIdToClose.set(jobStats.results().stream()
+                .filter(stats -> stats.getNode() != null && nodeNameToShutdown.equals(stats.getNode().getName()))
+                .map(GetJobsStatsAction.Response.JobStats::getJobId).findAny().get());
+        });
+
+        // Call the shutdown API for the chosen node.
+        client().execute(PutShutdownNodeAction.INSTANCE,
+            new PutShutdownNodeAction.Request(nodeIdToShutdown.get(), randomFrom(SingleNodeShutdownMetadata.Type.values()), "just testing"))
+            .actionGet();
+
+        if (randomBoolean()) {
+            // This isn't waiting for something to happen - just adding timing variation
+            // to increase the chance of subtle race conditions occurring.
+            Thread.sleep(randomIntBetween(1, 10));
+        }
+
+        // There are several different scenarios for this request:
+        // 1. It might arrive at the original node that is shutting down before the job has transitioned into the
+        //    vacating state.  Then it's just a normal close that node shut down should not interfere with.
+        // 2. It might arrive at the original node that is shutting down while the job is vacating, but early enough
+        //    that the vacate can be promoted to a close (since the early part of the work they do is the same).
+        // 3. It might arrive at the original node that is shutting down while the job is vacating, but too late
+        //    to promote the vacate to a close (since the request to unassign the persistent task has already been
+        //    sent to the master node).  In this case fallback code in the job task should delete the persistent
+        //    task to effectively force-close the job on its new node.
+        // 4. It might arrive after the job has been unassigned from its original node after vacating but before it's
+        //    been assigned to a new node.  In this case the close job action will delete the persistent task.
+        // 5. It might arrive after the job has been assigned to its new node.  In this case it's just a normal close
+        //    on a node that isn't even shutting down.
+        client().execute(CloseJobAction.INSTANCE, new CloseJobAction.Request(jobIdToClose.get())).actionGet();
+
+        // Wait for the desired end state of the 5 jobs that were not closed running on nodes that are not shutting
+        // down, and the closed job not running anywhere.
+        assertBusy(() -> {
+            GetJobsStatsAction.Response statsResponse =
+                client().execute(GetJobsStatsAction.INSTANCE, new GetJobsStatsAction.Request(Metadata.ALL)).actionGet();
+            QueryPage<GetJobsStatsAction.Response.JobStats> jobStats = statsResponse.getResponse();
+            assertThat(jobStats, notNullValue());
+            long numJobsOnNodeToShutdown = jobStats.results().stream()
+                .filter(stats -> stats.getNode() != null && nodeNameToShutdown.equals(stats.getNode().getName())).count();
+            long numJobsOnOtherNodes = jobStats.results().stream()
+                .filter(stats -> stats.getNode() != null && nodeNameToShutdown.equals(stats.getNode().getName()) == false).count();
+            assertThat(numJobsOnNodeToShutdown, is(0L));
+            assertThat(numJobsOnOtherNodes, is(5L)); // 5 rather than 6 because we closed one
+        });
+    }
+
+    private void setupJobAndDatafeed(String jobId, ByteSizeValue modelMemoryLimit) throws Exception {
+        Job.Builder job = createScheduledJob(jobId, modelMemoryLimit);
+        PutJobAction.Request putJobRequest = new PutJobAction.Request(job);
+        client().execute(PutJobAction.INSTANCE, putJobRequest).actionGet();
+
+        String datafeedId = jobId;
+        DatafeedConfig config = createDatafeed(datafeedId, job.getId(), Collections.singletonList("data"));
+        PutDatafeedAction.Request putDatafeedRequest = new PutDatafeedAction.Request(config);
+        client().execute(PutDatafeedAction.INSTANCE, putDatafeedRequest).actionGet();
+
+        client().execute(OpenJobAction.INSTANCE, new OpenJobAction.Request(job.getId()));
+        assertBusy(() -> {
+            GetJobsStatsAction.Response statsResponse =
+                client().execute(GetJobsStatsAction.INSTANCE, new GetJobsStatsAction.Request(job.getId())).actionGet();
+            assertEquals(JobState.OPENED, statsResponse.getResponse().results().get(0).getState());
+        }, 30, TimeUnit.SECONDS);
+
+        StartDatafeedAction.Request startDatafeedRequest = new StartDatafeedAction.Request(config.getId(), 0L);
+        client().execute(StartDatafeedAction.INSTANCE, startDatafeedRequest).get();
+    }
+
+    private void ensureStableCluster() {
+        ensureStableCluster(internalCluster().getNodeNames().length, TimeValue.timeValueSeconds(60));
+    }
+
+    private void createSourceData() {
+        client().admin().indices().prepareCreate("data")
+            .setMapping("time", "type=date")
+            .get();
+        long numDocs = randomIntBetween(50, 100);
+        long now = System.currentTimeMillis();
+        long weekAgo = now - 604800000;
+        long twoWeeksAgo = weekAgo - 604800000;
+        indexDocs(logger, "data", numDocs, twoWeeksAgo, weekAgo);
+    }
+}

+ 16 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java

@@ -25,6 +25,7 @@ import org.elasticsearch.cluster.NamedDiff;
 import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
 import org.elasticsearch.cluster.metadata.IndexTemplateMetadata;
 import org.elasticsearch.cluster.metadata.Metadata;
+import org.elasticsearch.cluster.metadata.SingleNodeShutdownMetadata;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.node.DiscoveryNodeRole;
 import org.elasticsearch.cluster.node.DiscoveryNodes;
@@ -65,6 +66,7 @@ import org.elasticsearch.plugins.IngestPlugin;
 import org.elasticsearch.plugins.PersistentTaskPlugin;
 import org.elasticsearch.plugins.Plugin;
 import org.elasticsearch.plugins.SearchPlugin;
+import org.elasticsearch.plugins.ShutdownAwarePlugin;
 import org.elasticsearch.plugins.SystemIndexPlugin;
 import org.elasticsearch.repositories.RepositoriesService;
 import org.elasticsearch.rest.RestController;
@@ -414,7 +416,8 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
                                                        CircuitBreakerPlugin,
                                                        IngestPlugin,
                                                        PersistentTaskPlugin,
-                                                       SearchPlugin {
+                                                       SearchPlugin,
+                                                       ShutdownAwarePlugin {
     public static final String NAME = "ml";
     public static final String BASE_PATH = "/_ml/";
     public static final String DATAFEED_THREAD_POOL_NAME = NAME + "_datafeed";
@@ -550,6 +553,7 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
     private final SetOnce<DataFrameAnalyticsAuditor> dataFrameAnalyticsAuditor = new SetOnce<>();
     private final SetOnce<MlMemoryTracker> memoryTracker = new SetOnce<>();
     private final SetOnce<ActionFilter> mlUpgradeModeActionFilter = new SetOnce<>();
+    private final SetOnce<MlLifeCycleService> mlLifeCycleService = new SetOnce<>();
     private final SetOnce<CircuitBreaker> inferenceModelBreaker = new SetOnce<>();
     private final SetOnce<ModelLoadingService> modelLoadingService = new SetOnce<>();
     private final SetOnce<MlAutoscalingDeciderService> mlAutoscalingDeciderService = new SetOnce<>();
@@ -836,6 +840,7 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
         MlLifeCycleService mlLifeCycleService =
             new MlLifeCycleService(
                 clusterService, datafeedRunner, mlController, autodetectProcessManager, dataFrameAnalyticsManager, memoryTracker);
+        this.mlLifeCycleService.set(mlLifeCycleService);
         MlAssignmentNotifier mlAssignmentNotifier = new MlAssignmentNotifier(anomalyDetectionAuditor, dataFrameAnalyticsAuditor, threadPool,
             new MlConfigMigrator(settings, client, clusterService, indexNameExpressionResolver), clusterService);
 
@@ -1521,4 +1526,14 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
             return List.of();
         }
     }
+
+    @Override
+    public boolean safeToShutdown(String nodeId, SingleNodeShutdownMetadata.Type shutdownType) {
+        return mlLifeCycleService.get().isNodeSafeToShutdown(nodeId);
+    }
+
+    @Override
+    public void signalShutdown(Collection<String> shutdownNodeIds) {
+        mlLifeCycleService.get().signalGracefulShutdown(shutdownNodeIds);
+    }
 }

+ 2 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MlConfigMigrator.java

@@ -568,7 +568,7 @@ public class MlConfigMigrator {
      * @param clusterState The cluster state
      * @return The closed job configurations
      */
-    public static List<DatafeedConfig> stopppedOrUnallocatedDatafeeds(ClusterState clusterState) {
+    public static List<DatafeedConfig> stoppedOrUnallocatedDatafeeds(ClusterState clusterState) {
         PersistentTasksCustomMetadata persistentTasks = clusterState.metadata().custom(PersistentTasksCustomMetadata.TYPE);
         Set<String> startedDatafeedIds = MlTasks.startedDatafeedIds(persistentTasks);
         startedDatafeedIds.removeAll(MlTasks.unassignedDatafeedIds(persistentTasks, clusterState.nodes()));
@@ -594,7 +594,7 @@ public class MlConfigMigrator {
     }
 
     public static List<JobsAndDatafeeds> splitInBatches(ClusterState clusterState) {
-        Collection<DatafeedConfig> stoppedDatafeeds = stopppedOrUnallocatedDatafeeds(clusterState);
+        Collection<DatafeedConfig> stoppedDatafeeds = stoppedOrUnallocatedDatafeeds(clusterState);
         Map<String, Job> eligibleJobs = nonDeletingJobs(closedOrUnallocatedJobs(clusterState)).stream()
             .map(MlConfigMigrator::updateJobForMigration)
             .collect(Collectors.toMap(Job::getId, Function.identity(), (a, b) -> a));

+ 61 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MlLifeCycleService.java

@@ -6,8 +6,12 @@
  */
 package org.elasticsearch.xpack.ml;
 
+import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.component.LifecycleListener;
+import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
+import org.elasticsearch.persistent.PersistentTasksService;
+import org.elasticsearch.xpack.core.ml.MlTasks;
 import org.elasticsearch.xpack.ml.datafeed.DatafeedRunner;
 import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsManager;
 import org.elasticsearch.xpack.ml.job.process.autodetect.AutodetectProcessManager;
@@ -15,10 +19,12 @@ import org.elasticsearch.xpack.ml.process.MlController;
 import org.elasticsearch.xpack.ml.process.MlMemoryTracker;
 
 import java.io.IOException;
+import java.util.Collection;
 import java.util.Objects;
 
 public class MlLifeCycleService {
 
+    private final ClusterService clusterService;
     private final DatafeedRunner datafeedRunner;
     private final MlController mlController;
     private final AutodetectProcessManager autodetectProcessManager;
@@ -28,6 +34,7 @@ public class MlLifeCycleService {
     MlLifeCycleService(ClusterService clusterService, DatafeedRunner datafeedRunner, MlController mlController,
                        AutodetectProcessManager autodetectProcessManager, DataFrameAnalyticsManager analyticsManager,
                        MlMemoryTracker memoryTracker) {
+        this.clusterService = Objects.requireNonNull(clusterService);
         this.datafeedRunner = Objects.requireNonNull(datafeedRunner);
         this.mlController = Objects.requireNonNull(mlController);
         this.autodetectProcessManager = Objects.requireNonNull(autodetectProcessManager);
@@ -47,7 +54,7 @@ public class MlLifeCycleService {
             analyticsManager.markNodeAsShuttingDown();
             // This prevents datafeeds from sending data to autodetect processes WITHOUT stopping the datafeeds, so they get reassigned.
             // We have to do this first, otherwise the datafeeds could fail if they send data to a dead autodetect process.
-            datafeedRunner.isolateAllDatafeedsOnThisNodeBeforeShutdown();
+            datafeedRunner.prepareForImmediateShutdown();
             // This kills autodetect processes WITHOUT closing the jobs, so they get reassigned.
             autodetectProcessManager.killAllProcessesOnThisNode();
             mlController.stop();
@@ -56,4 +63,57 @@ public class MlLifeCycleService {
         }
         memoryTracker.stop();
     }
+
+    /**
+     * Is it safe to shut down a particular node without any ML rework being required?
+     * @param nodeId ID of the node being shut down.
+     * @return Has all active ML work vacated the specified node?
+     */
+    public boolean isNodeSafeToShutdown(String nodeId) {
+        return isNodeSafeToShutdown(nodeId, clusterService.state());
+    }
+
+    static boolean isNodeSafeToShutdown(String nodeId, ClusterState state) {
+        // If we are in a mixed version cluster that doesn't support locally aborting persistent tasks then
+        // we cannot perform graceful shutdown, so just revert to the behaviour of previous versions where
+        // the node shutdown API didn't exist
+        if (PersistentTasksService.isLocalAbortSupported(state) == false) {
+            return true;
+        }
+        PersistentTasksCustomMetadata tasks = state.metadata().custom(PersistentTasksCustomMetadata.TYPE);
+        // TODO: currently only considering anomaly detection jobs - could extend in the future
+        // Ignore failed jobs - the persistent task still exists to remember the failure (because no
+        // persistent task means closed), but these don't need to be relocated to another node.
+        return MlTasks.nonFailedJobTasksOnNode(tasks, nodeId).isEmpty() &&
+            MlTasks.nonFailedSnapshotUpgradeTasksOnNode(tasks, nodeId).isEmpty();
+    }
+
+    /**
+     * Called when nodes have been marked for shutdown.
+     * This method will only react if the local node is in the collection provided.
+     * (The assumption is that this method will be called on every node, so each node will get to react.)
+     * If the local node is marked for shutdown then ML jobs running on it will be told to gracefully
+     * persist state and then unassigned so that they relocate to a different node.
+     * @param shutdownNodeIds IDs of all nodes being shut down.
+     */
+    public void signalGracefulShutdown(Collection<String> shutdownNodeIds) {
+        signalGracefulShutdown(clusterService.state(), shutdownNodeIds);
+    }
+
+    void signalGracefulShutdown(ClusterState state, Collection<String> shutdownNodeIds) {
+
+        // If we are in a mixed version cluster that doesn't support locally aborting persistent tasks then
+        // we cannot perform graceful shutdown, so just revert to the behaviour of previous versions where
+        // the node shutdown API didn't exist
+        if (PersistentTasksService.isLocalAbortSupported(state) == false) {
+            return;
+        }
+
+        if (shutdownNodeIds.contains(state.nodes().getLocalNodeId())) {
+
+            datafeedRunner.vacateAllDatafeedsOnThisNode(
+                "previously assigned node [" + state.nodes().getLocalNode().getName() + "] is shutting down");
+            autodetectProcessManager.vacateOpenJobsOnThisNode();
+        }
+    }
 }

+ 56 - 20
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCloseJobAction.java

@@ -22,10 +22,10 @@ import org.elasticsearch.client.Client;
 import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.cluster.node.DiscoveryNodes;
 import org.elasticsearch.cluster.service.ClusterService;
-import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.inject.Inject;
 import org.elasticsearch.common.util.concurrent.AbstractRunnable;
 import org.elasticsearch.common.util.concurrent.AtomicArray;
+import org.elasticsearch.common.util.set.Sets;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.discovery.MasterNotDiscoveredException;
 import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
@@ -37,6 +37,7 @@ import org.elasticsearch.xpack.core.ClientHelper;
 import org.elasticsearch.xpack.core.ml.MlTasks;
 import org.elasticsearch.xpack.core.ml.action.CloseJobAction;
 import org.elasticsearch.xpack.core.ml.action.IsolateDatafeedAction;
+import org.elasticsearch.xpack.core.ml.action.OpenJobAction;
 import org.elasticsearch.xpack.core.ml.action.StopDatafeedAction;
 import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState;
 import org.elasticsearch.xpack.core.ml.job.config.JobState;
@@ -53,6 +54,7 @@ import java.util.ArrayList;
 import java.util.Collection;
 import java.util.HashSet;
 import java.util.List;
+import java.util.Objects;
 import java.util.Set;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.stream.Collectors;
@@ -366,14 +368,14 @@ public class TransportCloseJobAction extends TransportTasksAction<JobTask, Close
             PersistentTasksCustomMetadata.PersistentTask<?> jobTask = MlTasks.getJobTask(jobId, tasks);
             if (jobTask != null) {
                 auditor.info(jobId, Messages.JOB_AUDIT_CLOSING);
-                waitForCloseRequest.persistentTaskIds.add(jobTask.getId());
+                waitForCloseRequest.persistentTasks.add(jobTask);
                 waitForCloseRequest.jobsToFinalize.add(jobId);
             }
         }
         for (String jobId : closingJobIds) {
             PersistentTasksCustomMetadata.PersistentTask<?> jobTask = MlTasks.getJobTask(jobId, tasks);
             if (jobTask != null) {
-                waitForCloseRequest.persistentTaskIds.add(jobTask.getId());
+                waitForCloseRequest.persistentTasks.add(jobTask);
             }
         }
 
@@ -389,8 +391,7 @@ public class TransportCloseJobAction extends TransportTasksAction<JobTask, Close
             threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(new AbstractRunnable() {
                 @Override
                 public void onFailure(Exception e) {
-                    if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException
-                        && Strings.isAllOrWildcard(request.getJobId())) {
+                    if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) {
                         logger.trace(
                             () -> new ParameterizedMessage(
                                 "[{}] [{}] failed to close job due to resource not found exception",
@@ -413,8 +414,7 @@ public class TransportCloseJobAction extends TransportTasksAction<JobTask, Close
                 }
             });
         }, e -> {
-            if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException
-                && Strings.isAllOrWildcard(request.getJobId())) {
+            if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) {
                 logger.trace(
                     () -> new ParameterizedMessage(
                         "[{}] [{}] failed to update job to closing due to resource not found exception",
@@ -478,8 +478,7 @@ public class TransportCloseJobAction extends TransportTasksAction<JobTask, Close
                             @Override
                             public void onFailure(Exception e) {
                                 final int slot = counter.incrementAndGet();
-                                if ((ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException &&
-                                    Strings.isAllOrWildcard(new String[]{request.getJobId()})) == false) {
+                                if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException == false) {
                                     failures.set(slot - 1, e);
                                 }
                                 if (slot == numberOfJobs) {
@@ -524,39 +523,76 @@ public class TransportCloseJobAction extends TransportTasksAction<JobTask, Close
             return;
         }
 
+        final Set<String> movedJobs = Sets.newConcurrentHashSet();
+
+        ActionListener<CloseJobAction.Response> intermediateListener = ActionListener.wrap(
+            response -> {
+                for (String jobId : movedJobs) {
+                    PersistentTasksCustomMetadata.PersistentTask<?> jobTask = MlTasks.getJobTask(jobId, tasks);
+                    persistentTasksService.sendRemoveRequest(jobTask.getId(), ActionListener.wrap(
+                        r -> logger.trace("[{}] removed persistent task for relocated job", jobId),
+                        e -> {
+                            if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) {
+                                logger.debug("[{}] relocated job task already removed", jobId);
+                            } else {
+                                logger.error("[" + jobId + "] failed to remove task to stop relocated job", e);
+                            }
+                        })
+                    );
+                }
+                listener.onResponse(response);
+            }, listener::onFailure
+        );
+
         boolean noOpenJobsToClose = openJobIds.isEmpty();
         if (noOpenJobsToClose) {
             // No jobs to close but we still want to wait on closing jobs in the request
-            waitForJobClosed(request, waitForCloseRequest, new CloseJobAction.Response(true), listener);
+            waitForJobClosed(request, waitForCloseRequest, new CloseJobAction.Response(true), intermediateListener, movedJobs);
             return;
         }
 
         ActionListener<CloseJobAction.Response> finalListener =
                 ActionListener.wrap(
                         r -> waitForJobClosed(request, waitForCloseRequest,
-                        r, listener),
+                        r, intermediateListener, movedJobs),
                         listener::onFailure);
         super.doExecute(task, request, finalListener);
     }
 
     static class WaitForCloseRequest {
-        List<String> persistentTaskIds = new ArrayList<>();
+        List<PersistentTasksCustomMetadata.PersistentTask<?>> persistentTasks = new ArrayList<>();
         List<String> jobsToFinalize = new ArrayList<>();
 
         public boolean hasJobsToWaitFor() {
-            return persistentTaskIds.isEmpty() == false;
+            return persistentTasks.isEmpty() == false;
         }
     }
 
-    // Wait for job to be marked as closed in cluster state, which means the job persistent task has been removed
-    // This api returns when job has been closed, but that doesn't mean the persistent task has been removed from cluster state,
-    // so wait for that to happen here.
+    /**
+     * Wait for job to be marked as closed in cluster state, which means the job persistent task has been removed
+     * This api returns when job has been closed, but that doesn't mean the persistent task has been removed from cluster state,
+     * so wait for that to happen here.
+     *
+     * Since the close job action consists of a chain of async callbacks, it's possible that jobs have moved nodes since we decided
+     * what to do with them at the beginning of the chain.  We cannot simply wait for these, as the request to stop them will have
+     * been sent to the wrong node and ignored there, so we'll just spin until the timeout expires.
+     */
     void waitForJobClosed(CloseJobAction.Request request, WaitForCloseRequest waitForCloseRequest, CloseJobAction.Response response,
-                          ActionListener<CloseJobAction.Response> listener) {
+                          ActionListener<CloseJobAction.Response> listener, Set<String> movedJobs) {
         persistentTasksService.waitForPersistentTasksCondition(persistentTasksCustomMetadata -> {
-            for (String persistentTaskId : waitForCloseRequest.persistentTaskIds) {
-                if (persistentTasksCustomMetadata.getTask(persistentTaskId) != null) {
-                    return false;
+            for (PersistentTasksCustomMetadata.PersistentTask<?> originalPersistentTask : waitForCloseRequest.persistentTasks) {
+                String originalPersistentTaskId = originalPersistentTask.getId();
+                PersistentTasksCustomMetadata.PersistentTask<?> currentPersistentTask =
+                    persistentTasksCustomMetadata.getTask(originalPersistentTaskId);
+                if (currentPersistentTask != null) {
+                    if (Objects.equals(originalPersistentTask.getExecutorNode(), currentPersistentTask.getExecutorNode())
+                        && originalPersistentTask.getAllocationId() == currentPersistentTask.getAllocationId()) {
+                        return false;
+                    }
+                    OpenJobAction.JobParams params = (OpenJobAction.JobParams) originalPersistentTask.getParams();
+                    if (movedJobs.add(params.getJobId())) {
+                        logger.info("Job [{}] changed assignment while waiting for it to be closed", params.getJobId());
+                    }
                 }
             }
             return true;

+ 84 - 23
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDatafeedAction.java

@@ -71,6 +71,7 @@ import java.util.Collection;
 import java.util.List;
 import java.util.Locale;
 import java.util.Map;
+import java.util.Objects;
 import java.util.Optional;
 import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.Consumer;
@@ -430,11 +431,8 @@ public class TransportStartDatafeedAction extends TransportMasterNodeAction<Star
         public PersistentTasksCustomMetadata.Assignment getAssignment(StartDatafeedAction.DatafeedParams params,
                                                                       Collection<DiscoveryNode> candidateNodes,
                                                                       ClusterState clusterState) {
-            // 'candidateNodes' is not actually used here because the assignment for the task is
-            // already filtered elsewhere (JobNodeSelector), this is only finding the node a task
-            // has already been assigned to.
             return new DatafeedNodeSelector(clusterState, resolver, params.getDatafeedId(), params.getJobId(),
-                    params.getDatafeedIndices(), params.getIndicesOptions()).selectNode();
+                    params.getDatafeedIndices(), params.getIndicesOptions()).selectNode(candidateNodes);
         }
 
         @Override
@@ -455,21 +453,28 @@ public class TransportStartDatafeedAction extends TransportMasterNodeAction<Star
             DatafeedTask datafeedTask = (DatafeedTask) allocatedPersistentTask;
             DatafeedState datafeedState = (DatafeedState) state;
 
-            // If we are "stopping" there is nothing to do
+            // If we are stopping, stopped or isolated we should not start the runner.  Due to
+            // races in the way messages pass between nodes via cluster state or direct action calls
+            // we need to detect stopped/stopping by both considering the persistent task state in
+            // cluster state and also whether an explicit request to stop has been received on this
+            // node.
             if (DatafeedState.STOPPING.equals(datafeedState)) {
                 logger.info("[{}] datafeed got reassigned while stopping. Marking as completed", params.getDatafeedId());
-                datafeedTask.markAsCompleted();
+                datafeedTask.completeOrFailIfRequired(null);
                 return;
             }
-            datafeedTask.datafeedRunner = datafeedRunner;
-            datafeedRunner.run(datafeedTask,
-                    (error) -> {
-                        if (error != null) {
-                            datafeedTask.markAsFailed(error);
-                        } else {
-                            datafeedTask.markAsCompleted();
-                        }
-                    });
+            switch (datafeedTask.setDatafeedRunner(datafeedRunner)) {
+                case NEITHER:
+                    datafeedRunner.run(datafeedTask, datafeedTask::completeOrFailIfRequired);
+                    break;
+                case ISOLATED:
+                    logger.info("[{}] datafeed isolated immediately after reassignment.", params.getDatafeedId());
+                    break;
+                case STOPPED:
+                    logger.info("[{}] datafeed stopped immediately after reassignment. Marking as completed", params.getDatafeedId());
+                    datafeedTask.completeOrFailIfRequired(null);
+                    break;
+            }
         }
 
         @Override
@@ -483,11 +488,17 @@ public class TransportStartDatafeedAction extends TransportMasterNodeAction<Star
 
     public static class DatafeedTask extends AllocatedPersistentTask implements StartDatafeedAction.DatafeedTaskMatcher {
 
+        public enum StoppedOrIsolatedBeforeRunning { NEITHER, ISOLATED, STOPPED }
+
         private final String datafeedId;
         private final long startTime;
         private final Long endTime;
-        /* only pck protected for testing */
-        volatile DatafeedRunner datafeedRunner;
+        /**
+         * This must always be set within a synchronized block that also checks
+         * the value of the {@code stoppedOrIsolatedBeforeRunning} flag.
+         */
+        private DatafeedRunner datafeedRunner;
+        private StoppedOrIsolatedBeforeRunning stoppedOrIsolatedBeforeRunning = StoppedOrIsolatedBeforeRunning.NEITHER;
 
         DatafeedTask(long id, String type, String action, TaskId parentTaskId, StartDatafeedAction.DatafeedParams params,
                      Map<String, String> headers) {
@@ -514,6 +525,20 @@ public class TransportStartDatafeedAction extends TransportMasterNodeAction<Star
             return endTime != null;
         }
 
+        /**
+         * Set the datafeed runner <em>if</em> the task has not already been told to stop or isolate.
+         * @return A {@link StoppedOrIsolatedBeforeRunning} object that indicates whether the
+         *         datafeed task had previously been told to stop or isolate.  {@code datafeedRunner}
+         *         will only be set to the supplied value if the return value of this method is
+         *         {@link StoppedOrIsolatedBeforeRunning#NEITHER}.
+         */
+        synchronized StoppedOrIsolatedBeforeRunning setDatafeedRunner(DatafeedRunner datafeedRunner) {
+            if (stoppedOrIsolatedBeforeRunning == StoppedOrIsolatedBeforeRunning.NEITHER) {
+                this.datafeedRunner = Objects.requireNonNull(datafeedRunner);
+            }
+            return stoppedOrIsolatedBeforeRunning;
+        }
+
         @Override
         protected void onCancelled() {
             // If the persistent task framework wants us to stop then we should do so immediately and
@@ -530,20 +555,52 @@ public class TransportStartDatafeedAction extends TransportMasterNodeAction<Star
         }
 
         public void stop(String reason, TimeValue timeout) {
-            if (datafeedRunner != null) {
-                datafeedRunner.stopDatafeed(this, reason, timeout);
+            synchronized (this) {
+                if (datafeedRunner == null) {
+                    stoppedOrIsolatedBeforeRunning = StoppedOrIsolatedBeforeRunning.STOPPED;
+                    return;
+                }
             }
+            datafeedRunner.stopDatafeed(this, reason, timeout);
+        }
+
+        public synchronized StoppedOrIsolatedBeforeRunning getStoppedOrIsolatedBeforeRunning() {
+            return stoppedOrIsolatedBeforeRunning;
         }
 
         public void isolate() {
-            if (datafeedRunner != null) {
-                datafeedRunner.isolateDatafeed(getAllocationId());
+            synchronized (this) {
+                if (datafeedRunner == null) {
+                    // Stopped takes precedence over isolated for what we report externally,
+                    // as stopped needs to cause the persistent task to be marked as completed
+                    // (regardless of whether it was isolated) whereas isolated but not stopped
+                    // mustn't do this.
+                    if (stoppedOrIsolatedBeforeRunning == StoppedOrIsolatedBeforeRunning.NEITHER) {
+                        stoppedOrIsolatedBeforeRunning = StoppedOrIsolatedBeforeRunning.ISOLATED;
+                    }
+                    return;
+                }
+            }
+            datafeedRunner.isolateDatafeed(getAllocationId());
+        }
+
+        void completeOrFailIfRequired(Exception error) {
+            // A task can only be completed or failed once - trying multiple times just causes log spam
+            if (isCompleted()) {
+                return;
+            }
+            if (error != null) {
+                markAsFailed(error);
+            } else {
+                markAsCompleted();
             }
         }
 
         public Optional<GetDatafeedRunningStateAction.Response.RunningState> getRunningState() {
-            if (datafeedRunner == null) {
-                return Optional.empty();
+            synchronized (this) {
+                if (datafeedRunner == null) {
+                    return Optional.empty();
+                }
             }
             return Optional.of(new GetDatafeedRunningStateAction.Response.RunningState(
                 this.endTime == null,
@@ -572,6 +629,10 @@ public class TransportStartDatafeedAction extends TransportMasterNodeAction<Star
                 if (assignment.equals(DatafeedNodeSelector.AWAITING_JOB_ASSIGNMENT)) {
                     return true;
                 }
+                // This means the node the job got assigned to was shut down in between starting the job and the datafeed - not an error
+                if (assignment.equals(DatafeedNodeSelector.AWAITING_JOB_RELOCATION)) {
+                    return true;
+                }
                 if (assignment.equals(PersistentTasksCustomMetadata.INITIAL_ASSIGNMENT) == false && assignment.isAssigned() == false) {
                     // Assignment has failed despite passing our "fast fail" validation
                     exception = new ElasticsearchStatusException("Could not start datafeed, allocation explanation [" +

+ 113 - 26
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStopDatafeedAction.java

@@ -25,6 +25,8 @@ import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.inject.Inject;
 import org.elasticsearch.common.util.concurrent.AbstractRunnable;
 import org.elasticsearch.common.util.concurrent.AtomicArray;
+import org.elasticsearch.common.util.set.Sets;
+import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.discovery.MasterNotDiscoveredException;
 import org.elasticsearch.persistent.PersistentTasksClusterService;
 import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
@@ -51,13 +53,14 @@ import java.util.Objects;
 import java.util.Set;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.stream.Collectors;
-import java.util.stream.Stream;
 
 import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
 
 public class TransportStopDatafeedAction extends TransportTasksAction<TransportStartDatafeedAction.DatafeedTask, StopDatafeedAction.Request,
         StopDatafeedAction.Response, StopDatafeedAction.Response> {
 
+    private static final int MAX_ATTEMPTS = 10;
+
     private static final Logger logger = LogManager.getLogger(TransportStopDatafeedAction.class);
 
     private final ThreadPool threadPool;
@@ -128,6 +131,11 @@ public class TransportStopDatafeedAction extends TransportTasksAction<TransportS
 
     @Override
     protected void doExecute(Task task, StopDatafeedAction.Request request, ActionListener<StopDatafeedAction.Response> listener) {
+        doExecute(task, request, listener, 1);
+    }
+
+    private void doExecute(Task task, StopDatafeedAction.Request request, ActionListener<StopDatafeedAction.Response> listener,
+                           int attempt) {
         final ClusterState state = clusterService.state();
         final DiscoveryNodes nodes = state.nodes();
         if (nodes.isLocalNodeElectedMaster() == false) {
@@ -155,12 +163,11 @@ public class TransportStopDatafeedAction extends TransportTasksAction<TransportS
                             listener.onResponse(new StopDatafeedAction.Response(true));
                             return;
                         }
-                        request.setResolvedStartedDatafeedIds(startedDatafeeds.toArray(new String[startedDatafeeds.size()]));
 
                         if (request.isForce()) {
                             forceStopDatafeed(request, listener, tasks, nodes, notStoppedDatafeeds);
                         } else {
-                            normalStopDatafeed(task, request, listener, tasks, nodes, startedDatafeeds, stoppingDatafeeds);
+                            normalStopDatafeed(task, request, listener, tasks, nodes, startedDatafeeds, stoppingDatafeeds, attempt);
                         }
                     },
                     listener::onFailure
@@ -170,9 +177,11 @@ public class TransportStopDatafeedAction extends TransportTasksAction<TransportS
 
     private void normalStopDatafeed(Task task, StopDatafeedAction.Request request, ActionListener<StopDatafeedAction.Response> listener,
                                     PersistentTasksCustomMetadata tasks, DiscoveryNodes nodes,
-                                    List<String> startedDatafeeds, List<String> stoppingDatafeeds) {
+                                    List<String> startedDatafeeds, List<String> stoppingDatafeeds, int attempt) {
         final Set<String> executorNodes = new HashSet<>();
         final List<String> startedDatafeedsJobs = new ArrayList<>();
+        final List<String> resolvedStartedDatafeeds = new ArrayList<>();
+        final List<PersistentTasksCustomMetadata.PersistentTask<?>> allDataFeedsToWaitFor = new ArrayList<>();
         for (String datafeedId : startedDatafeeds) {
             PersistentTasksCustomMetadata.PersistentTask<?> datafeedTask = MlTasks.getDatafeedTask(datafeedId, tasks);
             if (datafeedTask == null) {
@@ -182,7 +191,9 @@ public class TransportStopDatafeedAction extends TransportTasksAction<TransportS
                 logger.error(msg);
             } else if (PersistentTasksClusterService.needsReassignment(datafeedTask.getAssignment(), nodes) == false) {
                 startedDatafeedsJobs.add(((StartDatafeedAction.DatafeedParams) datafeedTask.getParams()).getJobId());
+                resolvedStartedDatafeeds.add(datafeedId);
                 executorNodes.add(datafeedTask.getExecutorNode());
+                allDataFeedsToWaitFor.add(datafeedTask);
             } else {
                 // This is the easy case - the datafeed is not currently assigned to a valid node,
                 // so can be gracefully stopped simply by removing its persistent task.  (Usually
@@ -194,21 +205,37 @@ public class TransportStopDatafeedAction extends TransportTasksAction<TransportS
                     r -> auditDatafeedStopped(datafeedTask),
                     e -> logger.error("[" + datafeedId + "] failed to remove task to stop unassigned datafeed", e))
                 );
+                allDataFeedsToWaitFor.add(datafeedTask);
             }
         }
 
+        for (String datafeedId : stoppingDatafeeds) {
+            PersistentTasksCustomMetadata.PersistentTask<?> datafeedTask = MlTasks.getDatafeedTask(datafeedId, tasks);
+            assert datafeedTask != null : "Requested datafeed [" + datafeedId + "] be stopped, but datafeed's task could not be found.";
+            allDataFeedsToWaitFor.add(datafeedTask);
+        }
+
+        request.setResolvedStartedDatafeedIds(resolvedStartedDatafeeds.toArray(new String[0]));
         request.setNodes(executorNodes.toArray(new String[0]));
 
-        // wait for started and stopping datafeeds
-        // Map datafeedId -> datafeed task Id.
-        List<String> allDataFeedsToWaitFor = Stream.concat(
-                startedDatafeeds.stream().map(MlTasks::datafeedTaskId),
-                stoppingDatafeeds.stream().map(MlTasks::datafeedTaskId))
-                .collect(Collectors.toList());
+        final Set<String> movedDatafeeds = Sets.newConcurrentHashSet();
 
         ActionListener<StopDatafeedAction.Response> finalListener = ActionListener.wrap(
-                r -> waitForDatafeedStopped(allDataFeedsToWaitFor, request, r, ActionListener.wrap(
+                response -> waitForDatafeedStopped(allDataFeedsToWaitFor, request, response, ActionListener.wrap(
                     finished -> {
+                        for (String datafeedId : movedDatafeeds) {
+                            PersistentTasksCustomMetadata.PersistentTask<?> datafeedTask = MlTasks.getDatafeedTask(datafeedId, tasks);
+                            persistentTasksService.sendRemoveRequest(datafeedTask.getId(), ActionListener.wrap(
+                                r -> auditDatafeedStopped(datafeedTask),
+                                e -> {
+                                    if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) {
+                                        logger.debug("[{}] relocated datafeed task already removed", datafeedId);
+                                    } else {
+                                        logger.error("[" + datafeedId + "] failed to remove task to stop relocated datafeed", e);
+                                    }
+                                })
+                            );
+                        }
                         if (startedDatafeedsJobs.isEmpty()) {
                             listener.onResponse(finished);
                             return;
@@ -232,15 +259,40 @@ public class TransportStopDatafeedAction extends TransportTasksAction<TransportS
                             ));
                     },
                     listener::onFailure
-                )),
+                ), movedDatafeeds),
                 e -> {
-                    if (ExceptionsHelper.unwrapCause(e) instanceof FailedNodeException) {
+                    Throwable unwrapped = ExceptionsHelper.unwrapCause(e);
+                    if (unwrapped instanceof FailedNodeException) {
                         // A node has dropped out of the cluster since we started executing the requests.
                         // Since stopping an already stopped datafeed is not an error we can try again.
                         // The datafeeds that were running on the node that dropped out of the cluster
                         // will just have their persistent tasks cancelled.  Datafeeds that were stopped
                         // by the previous attempt will be noops in the subsequent attempt.
-                        doExecute(task, request, listener);
+                        if (attempt <= MAX_ATTEMPTS) {
+                            logger.warn("Node [{}] failed while processing stop datafeed request - retrying",
+                                ((FailedNodeException) unwrapped).nodeId());
+                            doExecute(task, request, listener, attempt + 1);
+                        } else {
+                            listener.onFailure(e);
+                        }
+                    } else if (unwrapped instanceof RetryStopDatafeedException) {
+                        // This is for the case where a local task wasn't yet running at the moment a
+                        // request to stop it arrived at its node.  This can happen when the cluster
+                        // state says a persistent task should be running on a particular node but that
+                        // node hasn't yet had time to start the corresponding local task.
+                        if (attempt <= MAX_ATTEMPTS) {
+                            logger.info("Insufficient responses while processing stop datafeed request [{}] - retrying",
+                                unwrapped.getMessage());
+                            // Unlike the failed node case above, in this case we should wait a little
+                            // before retrying because we need to allow time for the local task to
+                            // start on the node it's supposed to be running on.
+                            threadPool.schedule(() -> doExecute(task, request, listener, attempt + 1),
+                                TimeValue.timeValueMillis(100L * attempt), ThreadPool.Names.SAME);
+                        } else {
+                            listener.onFailure(ExceptionsHelper.serverError("Failed to stop datafeed [" + request.getDatafeedId()
+                                + "] after " + MAX_ATTEMPTS
+                                + " due to inconsistencies between local and persistent tasks within the cluster"));
+                        }
                     } else {
                         listener.onFailure(e);
                     }
@@ -363,16 +415,34 @@ public class TransportStopDatafeedAction extends TransportTasksAction<TransportS
         listener.onFailure(e);
     }
 
-    // Wait for datafeed to be marked as stopped in cluster state, which means the datafeed persistent task has been removed
-    // This api returns when task has been cancelled, but that doesn't mean the persistent task has been removed from cluster state,
-    // so wait for that to happen here.
-    void waitForDatafeedStopped(List<String> datafeedPersistentTaskIds, StopDatafeedAction.Request request,
+    /**
+     * Wait for datafeed to be marked as stopped in cluster state, which means the datafeed persistent task has been removed.
+     * This api returns when task has been cancelled, but that doesn't mean the persistent task has been removed from cluster state,
+     * so wait for that to happen here.
+     *
+     * Since the stop datafeed action consists of a chain of async callbacks, it's possible that datafeeds have moved nodes since we
+     * decided what to do with them at the beginning of the chain.  We cannot simply wait for these, as the request to stop them will
+     * have been sent to the wrong node and ignored there, so we'll just spin until the timeout expires.
+     */
+    void waitForDatafeedStopped(List<PersistentTasksCustomMetadata.PersistentTask<?>> datafeedPersistentTasks,
+                                StopDatafeedAction.Request request,
                                 StopDatafeedAction.Response response,
-                                ActionListener<StopDatafeedAction.Response> listener) {
+                                ActionListener<StopDatafeedAction.Response> listener,
+                                Set<String> movedDatafeeds) {
         persistentTasksService.waitForPersistentTasksCondition(persistentTasksCustomMetadata -> {
-            for (String persistentTaskId: datafeedPersistentTaskIds) {
-                if (persistentTasksCustomMetadata.getTask(persistentTaskId) != null) {
-                    return false;
+            for (PersistentTasksCustomMetadata.PersistentTask<?> originalPersistentTask : datafeedPersistentTasks) {
+                String originalPersistentTaskId = originalPersistentTask.getId();
+                PersistentTasksCustomMetadata.PersistentTask<?> currentPersistentTask =
+                    persistentTasksCustomMetadata.getTask(originalPersistentTaskId);
+                if (currentPersistentTask != null) {
+                    if (Objects.equals(originalPersistentTask.getExecutorNode(), currentPersistentTask.getExecutorNode())
+                        && originalPersistentTask.getAllocationId() == currentPersistentTask.getAllocationId()) {
+                        return false;
+                    }
+                    StartDatafeedAction.DatafeedParams params = (StartDatafeedAction.DatafeedParams) originalPersistentTask.getParams();
+                    if (movedDatafeeds.add(params.getDatafeedId())) {
+                        logger.info("Datafeed [{}] changed assignment while waiting for it to be stopped", params.getDatafeedId());
+                    }
                 }
             }
             return true;
@@ -383,7 +453,7 @@ public class TransportStopDatafeedAction extends TransportTasksAction<TransportS
     protected StopDatafeedAction.Response newResponse(StopDatafeedAction.Request request, List<StopDatafeedAction.Response> tasks,
                                                       List<TaskOperationFailure> taskOperationFailures,
                                                       List<FailedNodeException> failedNodeExceptions) {
-        // number of resolved data feeds should be equal to the number of
+        // number of resolved (i.e. running on a node) started data feeds should be equal to the number of
         // tasks, otherwise something went wrong
         if (request.getResolvedStartedDatafeedIds().length != tasks.size()) {
             if (taskOperationFailures.isEmpty() == false) {
@@ -393,13 +463,30 @@ public class TransportStopDatafeedAction extends TransportTasksAction<TransportS
                 throw org.elasticsearch.ExceptionsHelper
                         .convertToElastic(failedNodeExceptions.get(0));
             } else {
-                // This can happen when the actual task in the node no longer exists,
-                // which means the datafeed(s) have already been stopped.
-                return new StopDatafeedAction.Response(true);
+                // This can happen when the local task in the node no longer exists,
+                // which means the datafeed(s) have already been stopped.  It can
+                // also happen if the local task hadn't yet been created when the
+                // stop request hit the executor node.  In this second case we need
+                // to retry, otherwise the wait for completion will wait until it
+                // times out.  We cannot tell which case it is, but it doesn't hurt
+                // to retry in both cases since stopping a stopped datafeed is a
+                // no-op.
+                throw new RetryStopDatafeedException(request.getResolvedStartedDatafeedIds().length, tasks.size());
             }
         }
 
         return new StopDatafeedAction.Response(tasks.stream().allMatch(StopDatafeedAction.Response::isStopped));
     }
 
+    /**
+     * A special exception to indicate that we should retry stopping the datafeeds.
+     * This exception is not transportable, so should only be thrown in situations
+     * where it will be caught on the same node.
+     */
+    static class RetryStopDatafeedException extends RuntimeException {
+
+        RetryStopDatafeedException(int numResponsesExpected, int numResponsesReceived) {
+            super("expected " + numResponsesExpected + " responses, got " + numResponsesReceived);
+        }
+    }
 }

+ 20 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/DatafeedNodeSelector.java

@@ -13,6 +13,7 @@ import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.action.support.IndicesOptions;
 import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
+import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.routing.IndexRoutingTable;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.common.Strings;
@@ -25,6 +26,7 @@ import org.elasticsearch.xpack.core.ml.job.config.JobState;
 import org.elasticsearch.xpack.core.ml.job.config.JobTaskState;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 
+import java.util.Collection;
 import java.util.List;
 import java.util.Objects;
 
@@ -37,6 +39,8 @@ public class DatafeedNodeSelector {
 
     public static final PersistentTasksCustomMetadata.Assignment AWAITING_JOB_ASSIGNMENT =
         new PersistentTasksCustomMetadata.Assignment(null, "datafeed awaiting job assignment.");
+    public static final PersistentTasksCustomMetadata.Assignment AWAITING_JOB_RELOCATION =
+        new PersistentTasksCustomMetadata.Assignment(null, "datafeed awaiting job relocation.");
 
     private final String datafeedId;
     private final String jobId;
@@ -75,7 +79,15 @@ public class DatafeedNodeSelector {
         }
     }
 
-    public PersistentTasksCustomMetadata.Assignment selectNode() {
+    /**
+     * Select which node to run the datafeed on.  The logic is to always choose the same node that the job
+     * is already running on <em>unless</em> this node is not permitted for some reason or there is some
+     * problem in the cluster that would stop the datafeed working.
+     * @param candidateNodes Only nodes in this collection may be chosen as the executor node.
+     * @return The assignment for the datafeed, containing either an executor node or a reason why an
+     *         executor node was not returned.
+     */
+    public PersistentTasksCustomMetadata.Assignment selectNode(Collection<DiscoveryNode> candidateNodes) {
         if (MlMetadata.getMlMetadata(clusterState).isUpgradeMode()) {
             return AWAITING_UPGRADE;
         }
@@ -89,6 +101,13 @@ public class DatafeedNodeSelector {
             if (jobNode == null) {
                 return AWAITING_JOB_ASSIGNMENT;
             }
+            // During node shutdown the datafeed will have been unassigned but the job will still be gracefully persisting state.
+            // During this time the datafeed will be trying to select the job's node, but we must disallow this.  Instead the
+            // datafeed must remain in limbo until the job has finished persisting state and can move to a different node.
+            // Nodes that are shutting down will have been excluded from the candidate nodes.
+            if (candidateNodes.stream().anyMatch(candidateNode -> candidateNode.getId().equals(jobNode)) == false) {
+                return AWAITING_JOB_RELOCATION;
+            }
             return new PersistentTasksCustomMetadata.Assignment(jobNode, "");
         }
         LOGGER.debug(assignmentFailure.reason);

+ 67 - 13
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/DatafeedRunner.java

@@ -34,12 +34,14 @@ import org.elasticsearch.xpack.core.ml.job.messages.Messages;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.ml.MachineLearning;
 import org.elasticsearch.xpack.ml.action.TransportStartDatafeedAction;
+import org.elasticsearch.xpack.ml.action.TransportStartDatafeedAction.DatafeedTask.StoppedOrIsolatedBeforeRunning;
 import org.elasticsearch.xpack.ml.job.process.autodetect.AutodetectProcessManager;
 import org.elasticsearch.xpack.ml.notifications.AnomalyDetectionAuditor;
 
 import java.util.ArrayList;
 import java.util.Iterator;
 import java.util.List;
+import java.util.Locale;
 import java.util.Objects;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ConcurrentMap;
@@ -85,10 +87,10 @@ public class DatafeedRunner {
 
     public void run(TransportStartDatafeedAction.DatafeedTask task, Consumer<Exception> finishHandler) {
         ActionListener<DatafeedJob> datafeedJobHandler = ActionListener.wrap(
-                datafeedJob -> {
-                    String jobId = datafeedJob.getJobId();
-                    Holder holder = new Holder(task, task.getDatafeedId(), datafeedJob,
-                            new ProblemTracker(auditor, jobId), finishHandler);
+            datafeedJob -> {
+                String jobId = datafeedJob.getJobId();
+                Holder holder = new Holder(task, task.getDatafeedId(), datafeedJob, new ProblemTracker(auditor, jobId), finishHandler);
+                if (task.getStoppedOrIsolatedBeforeRunning() == StoppedOrIsolatedBeforeRunning.NEITHER) {
                     runningDatafeedsOnThisNode.put(task.getAllocationId(), holder);
                     task.updatePersistentTaskState(DatafeedState.STARTED, new ActionListener<PersistentTask<?>>() {
                         @Override
@@ -101,16 +103,31 @@ public class DatafeedRunner {
                             if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) {
                                 // The task was stopped in the meantime, no need to do anything
                                 logger.info("[{}] Aborting as datafeed has been stopped", task.getDatafeedId());
+                                runningDatafeedsOnThisNode.remove(task.getAllocationId());
+                                finishHandler.accept(null);
                             } else {
                                 finishHandler.accept(e);
                             }
                         }
                     });
-                }, finishHandler
+                } else {
+                    logger.info("[{}] Datafeed has been {} before running", task.getDatafeedId(),
+                        task.getStoppedOrIsolatedBeforeRunning().toString().toLowerCase(Locale.ROOT));
+                    finishHandler.accept(null);
+                }
+            }, finishHandler
         );
 
         ActionListener<DatafeedContext> datafeedContextListener = ActionListener.wrap(
-            datafeedContext -> datafeedJobBuilder.build(task, datafeedContext, datafeedJobHandler),
+            datafeedContext -> {
+                if (task.getStoppedOrIsolatedBeforeRunning() == StoppedOrIsolatedBeforeRunning.NEITHER) {
+                    datafeedJobBuilder.build(task, datafeedContext, datafeedJobHandler);
+                } else {
+                    logger.info("[{}] Datafeed has been {} while building context", task.getDatafeedId(),
+                        task.getStoppedOrIsolatedBeforeRunning().toString().toLowerCase(Locale.ROOT));
+                    finishHandler.accept(null);
+                }
+            },
             finishHandler
         );
 
@@ -140,16 +157,16 @@ public class DatafeedRunner {
     }
 
     /**
-     * This is used before the JVM is killed.  It differs from stopAllDatafeedsOnThisNode in that it leaves
-     * the datafeed tasks in the "started" state, so that they get restarted on a different node.
+     * This is used before the JVM is killed.  It differs from {@link #stopAllDatafeedsOnThisNode} in that it
+     * leaves the datafeed tasks in the "started" state, so that they get restarted on a different node.  It
+     * differs from {@link #vacateAllDatafeedsOnThisNode} in that it does not proactively relocate the persistent
+     * tasks.  With this method the assumption is that the JVM is going to be killed almost immediately, whereas
+     * {@link #vacateAllDatafeedsOnThisNode} is used with more graceful shutdowns.
      */
-    public void isolateAllDatafeedsOnThisNodeBeforeShutdown() {
+    public void prepareForImmediateShutdown() {
         Iterator<Holder> iter = runningDatafeedsOnThisNode.values().iterator();
         while (iter.hasNext()) {
-            Holder next = iter.next();
-            next.isolateDatafeed();
-            // TODO: it's not ideal that this "isolate" method does something a bit different to the one below
-            next.setNodeIsShuttingDown();
+            iter.next().setNodeIsShuttingDown();
             iter.remove();
         }
     }
@@ -163,6 +180,25 @@ public class DatafeedRunner {
         }
     }
 
+    /**
+     * Like {@link #prepareForImmediateShutdown} this is used when the node is
+     * going to shut down.  However, the difference is that in this case it's going to be a
+     * graceful shutdown, which could take a lot longer than the second or two expected in the
+     * case where {@link #prepareForImmediateShutdown} is called.  Therefore,
+     * in this case we actively ask for the datafeed persistent tasks to be unassigned, so that
+     * they can restart on a different node as soon as <em>their</em> corresponding job has
+     * persisted its state.  This means the small jobs can potentially restart sooner than if
+     * nothing relocated until <em>all</em> graceful shutdown activities on the node were
+     * complete.
+     */
+    public void vacateAllDatafeedsOnThisNode(String reason) {
+        for (Holder holder : runningDatafeedsOnThisNode.values()) {
+            if (holder.isIsolated() == false) {
+                holder.vacateNode(reason);
+            }
+        }
+    }
+
     public boolean finishedLookBack(TransportStartDatafeedAction.DatafeedTask task) {
         Holder holder = runningDatafeedsOnThisNode.get(task.getAllocationId());
         return holder != null && holder.isLookbackFinished();
@@ -421,10 +457,28 @@ public class DatafeedRunner {
             datafeedJob.isolate();
         }
 
+        /**
+         * This method tells the datafeed to do as little work as possible from now on, but does not
+         * do anything to clean up local or persistent tasks, or other data structures.  The assumption
+         * is that cleanup will be achieved when the JVM stops running, and that is going to happen
+         * very soon.
+         */
         public void setNodeIsShuttingDown() {
+            isolateDatafeed();
             isNodeShuttingDown = true;
         }
 
+        /**
+         * Tell the datafeed to do as little work as possible, and tell the master node to move its
+         * persistent task to a different node in the cluster.  This method should be called when it
+         * is known the node will shut down relatively soon, but all tasks are being gracefully
+         * migrated away first.
+         */
+        public void vacateNode(String reason) {
+            isolateDatafeed();
+            task.markAsLocallyAborted(reason);
+        }
+
         public boolean isLookbackFinished() {
             return lookbackFinished;
         }

+ 7 - 7
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectCommunicator.java

@@ -56,7 +56,7 @@ import java.util.function.BiConsumer;
 
 public class AutodetectCommunicator implements Closeable {
 
-    private static final Logger LOGGER = LogManager.getLogger(AutodetectCommunicator.class);
+    private static final Logger logger = LogManager.getLogger(AutodetectCommunicator.class);
     private static final Duration FLUSH_PROCESS_CHECK_FREQUENCY = Duration.ofSeconds(1);
 
     private final Job job;
@@ -160,7 +160,7 @@ public class AutodetectCommunicator implements Closeable {
             } finally {
                 onFinishHandler.accept(null, true);
             }
-            LOGGER.info("[{}] job closed", job.getId());
+            logger.info("[{}] autodetect connection for job closed", job.getId());
             return null;
         });
         try {
@@ -195,7 +195,7 @@ public class AutodetectCommunicator implements Closeable {
                 try {
                     autodetectResultProcessor.awaitCompletion();
                 } catch (TimeoutException e) {
-                    LOGGER.warn(new ParameterizedMessage("[{}] Timed out waiting for killed job", job.getId()), e);
+                    logger.warn(new ParameterizedMessage("[{}] Timed out waiting for killed job", job.getId()), e);
                 }
             }
         } finally {
@@ -278,7 +278,7 @@ public class AutodetectCommunicator implements Closeable {
 
     @Nullable
     FlushAcknowledgement waitFlushToCompletion(String flushId, boolean waitForNormalization) throws Exception {
-        LOGGER.debug("[{}] waiting for flush", job.getId());
+        logger.debug("[{}] waiting for flush", job.getId());
 
         FlushAcknowledgement flushAcknowledgement;
         try {
@@ -296,11 +296,11 @@ public class AutodetectCommunicator implements Closeable {
             // We also have to wait for the normalizer to become idle so that we block
             // clients from querying results in the middle of normalization.
             if (waitForNormalization) {
-                LOGGER.debug("[{}] Initial flush completed, waiting until renormalizer is idle.", job.getId());
+                logger.debug("[{}] Initial flush completed, waiting until renormalizer is idle.", job.getId());
                 autodetectResultProcessor.waitUntilRenormalizerIsIdle();
             }
 
-            LOGGER.debug("[{}] Flush completed", job.getId());
+            logger.debug("[{}] Flush completed", job.getId());
         }
 
         return flushAcknowledgement;
@@ -360,7 +360,7 @@ public class AutodetectCommunicator implements Closeable {
                     handler.accept(null, ExceptionsHelper.conflictStatusException(
                             "[{}] Could not submit operation to process as it has been killed", job.getId()));
                 } else {
-                    LOGGER.error(new ParameterizedMessage("[{}] Unexpected exception writing to process", job.getId()), e);
+                    logger.error(new ParameterizedMessage("[{}] Unexpected exception writing to process", job.getId()), e);
                     handler.accept(null, e);
                 }
             }

+ 68 - 16
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcessManager.java

@@ -11,6 +11,7 @@ import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 import org.apache.logging.log4j.message.ParameterizedMessage;
 import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.ResourceNotFoundException;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.client.Client;
 import org.elasticsearch.cluster.ClusterChangedEvent;
@@ -157,6 +158,7 @@ public class AutodetectProcessManager implements ClusterStateListener {
         this.maxAllowedRunningJobs = maxAllowedRunningJobs;
     }
 
+    // The primary use of this is for license expiry
     public synchronized void closeAllJobsOnThisNode(String reason) {
         // Note, snapshot upgrader processes could still be running, but those are short lived
         // Leaving them running is OK.
@@ -164,8 +166,10 @@ public class AutodetectProcessManager implements ClusterStateListener {
         if (numJobs != 0) {
             logger.info("Closing [{}] jobs, because [{}]", numJobs, reason);
 
-            for (ProcessContext process : processByAllocation.values()) {
-                closeJob(process.getJobTask(), reason);
+            for (ProcessContext processContext : processByAllocation.values()) {
+                JobTask jobTask = processContext.getJobTask();
+                setJobState(jobTask, JobState.CLOSING, reason);
+                jobTask.closeJob(reason);
             }
         }
     }
@@ -218,6 +222,32 @@ public class AutodetectProcessManager implements ClusterStateListener {
         }
     }
 
+    /**
+     * Makes open jobs on this node go through the motions of closing but
+     * without completing the persistent task and instead telling the
+     * master node to assign the persistent task to a different node.
+     * The intended user of this functionality is the node shutdown API.
+     * Jobs that are already closing continue to close.
+     */
+    public synchronized void vacateOpenJobsOnThisNode() {
+
+        for (ProcessContext processContext : processByAllocation.values()) {
+
+            // We ignore jobs that either don't have a running process yet or already closing.
+            // - The ones that don't yet have a running process will get picked up on a subsequent call to this
+            //   method.  This is simpler than trying to interact with a job before its process is started,
+            //   and importantly, when it eventually does get picked up it will be fast to shut down again
+            //   since it will only just have been started.
+            // - For jobs that are already closing we might as well let them close on the current node
+            //   rather than trying to vacate them to a different node first.
+            if (processContext.getState() == ProcessContext.ProcessStateName.RUNNING && processContext.getJobTask().triggerVacate()) {
+                // We need to fork here, as persisting state is a potentially long-running operation
+                threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(
+                    () -> closeProcessAndTask(processContext, processContext.getJobTask(), "node is shutting down"));
+            }
+        }
+    }
+
     /**
      * Initiate background persistence of the job
      * @param jobTask The job task
@@ -460,6 +490,11 @@ public class AutodetectProcessManager implements ClusterStateListener {
     public void openJob(JobTask jobTask, ClusterState clusterState, TimeValue masterNodeTimeout,
                         BiConsumer<Exception, Boolean> closeHandler) {
         String jobId = jobTask.getJobId();
+        if (jobTask.isClosing()) {
+            logger.info("Aborting opening of job [{}] as it is being closed", jobId);
+            jobTask.markAsCompleted();
+            return;
+        }
         logger.info("Opening job [{}]", jobId);
 
         // Start the process
@@ -553,7 +588,19 @@ public class AutodetectProcessManager implements ClusterStateListener {
                                 return;
                             }
                             processContext.getAutodetectCommunicator().restoreState(params.modelSnapshot());
-                            setJobState(jobTask, JobState.OPENED);
+                            setJobState(jobTask, JobState.OPENED, null, e -> {
+                                if (e != null) {
+                                    logSetJobStateFailure(JobState.OPENED, job.getId(), e);
+                                    if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) {
+                                        // Don't leave a process with no persistent task hanging around
+                                        processContext.newKillBuilder()
+                                            .setAwaitCompletion(false)
+                                            .setFinish(false)
+                                            .kill();
+                                        processByAllocation.remove(jobTask.getAllocationId());
+                                    }
+                                }
+                            });
                         }
                     } catch (Exception e1) {
                         // No need to log here as the persistent task framework will log it
@@ -743,7 +790,7 @@ public class AutodetectProcessManager implements ClusterStateListener {
         processContext.tryLock();
         try {
             if (processContext.setDying() == false) {
-                logger.debug("Cannot close job [{}] as it has been marked as dying", jobId);
+                logger.debug("Cannot {} job [{}] as it has been marked as dying", jobTask.isVacating() ? "vacate" : "close", jobId);
                 // The only way we can get here is if 2 close requests are made very close together.
                 // The other close has done the work so it's safe to return here without doing anything.
                 return;
@@ -755,15 +802,17 @@ public class AutodetectProcessManager implements ClusterStateListener {
             if (jobKilled) {
                 logger.debug("[{}] Cleaning up job opened after kill", jobId);
             } else if (reason == null) {
-                logger.info("Closing job [{}]", jobId);
+                logger.info("{} job [{}]", jobTask.isVacating() ? "Vacating" : "Closing", jobId);
             } else {
-                logger.info("Closing job [{}], because [{}]", jobId, reason);
+                logger.info("{} job [{}], because [{}]", jobTask.isVacating() ? "Vacating" : "Closing", jobId, reason);
             }
 
             AutodetectCommunicator communicator = processContext.getAutodetectCommunicator();
             if (communicator == null) {
                 assert jobKilled == false
                     : "Job " + jobId + " killed before process started yet still had no communicator during cleanup after process started";
+                assert jobTask.isVacating() == false
+                    : "Job " + jobId + " was vacated before it had a communicator - should not be possible";
                 logger.debug("Job [{}] is being closed before its process is started", jobId);
                 jobTask.markAsCompleted();
                 processByAllocation.remove(allocationId);
@@ -782,11 +831,12 @@ public class AutodetectProcessManager implements ClusterStateListener {
             // If the close failed because the process has explicitly been killed by us then just pass on that exception.
             // (Note that jobKilled may be false in this case, if the kill is executed while communicator.close() is running.)
             if (e instanceof ElasticsearchStatusException && ((ElasticsearchStatusException) e).status() == RestStatus.CONFLICT) {
-                logger.trace("[{}] Conflict between kill and close during autodetect process cleanup - job {} before cleanup started",
-                    jobId, jobKilled ? "killed" : "not killed");
+                logger.trace("[{}] Conflict between kill and {} during autodetect process cleanup - job {} before cleanup started",
+                    jobId, jobTask.isVacating() ? "vacate" : "close", jobKilled ? "killed" : "not killed");
                 throw (ElasticsearchStatusException) e;
             }
-            String msg = jobKilled ? "Exception cleaning up autodetect process started after kill" : "Exception closing autodetect process";
+            String msg = jobKilled ? "Exception cleaning up autodetect process started after kill"
+                : "Exception " + (jobTask.isVacating() ? "vacating" : "closing") + " autodetect process";
             logger.warn("[" + jobId + "] " + msg, e);
             setJobState(jobTask, JobState.FAILED, e.getMessage());
             throw ExceptionsHelper.serverError(msg, e);
@@ -805,8 +855,8 @@ public class AutodetectProcessManager implements ClusterStateListener {
     }
 
     /**
-     * Stop the running job and mark it as finished.
-     *
+     * Stop the running job and mark it as finished.  For consistency with the job task,
+     * other than for testing this method should only be called via {@link JobTask#closeJob}.
      * @param jobTask The job to stop
      * @param reason  The reason for closing the job
      */
@@ -866,14 +916,16 @@ public class AutodetectProcessManager implements ClusterStateListener {
         JobTaskState jobTaskState = new JobTaskState(state, jobTask.getAllocationId(), reason);
         jobTask.updatePersistentTaskState(jobTaskState, ActionListener.wrap(
             persistentTask -> logger.info("Successfully set job state to [{}] for job [{}]", state, jobTask.getJobId()),
-            e -> logger.error(
-                () -> new ParameterizedMessage("Could not set job state to [{}] for job [{}]", state, jobTask.getJobId()),
-                e)
+            e -> logSetJobStateFailure(state, jobTask.getJobId(), e)
         ));
     }
 
-    void setJobState(JobTask jobTask, JobState state) {
-        setJobState(jobTask, state, null);
+    private void logSetJobStateFailure(JobState state, String jobId, Exception e) {
+        if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) {
+            logger.debug("Could not set job state to [{}] for job [{}] as it has been closed", state, jobId);
+        } else {
+            logger.error(() -> new ParameterizedMessage("Could not set job state to [{}] for job [{}]", state, jobId), e);
+        }
     }
 
     void setJobState(JobTask jobTask, JobState state, String reason, CheckedConsumer<Exception, IOException> handler) {

+ 29 - 8
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/task/JobTask.java

@@ -16,14 +16,22 @@ import org.elasticsearch.xpack.core.ml.action.OpenJobAction;
 import org.elasticsearch.xpack.ml.job.process.autodetect.AutodetectProcessManager;
 
 import java.util.Map;
+import java.util.concurrent.atomic.AtomicReference;
 
 public class JobTask extends AllocatedPersistentTask implements OpenJobAction.JobTaskMatcher {
 
-    private static final Logger LOGGER = LogManager.getLogger(JobTask.class);
+    /**
+     * We should only progress forwards through these states: close takes precedence over vacate
+     */
+    enum ClosingOrVacating {
+        NEITHER, VACATING, CLOSING
+    }
+
+    private static final Logger logger = LogManager.getLogger(JobTask.class);
 
     private final String jobId;
+    private final AtomicReference<ClosingOrVacating> closingOrVacating = new AtomicReference<>(ClosingOrVacating.NEITHER);
     private volatile AutodetectProcessManager autodetectProcessManager;
-    private volatile boolean isClosing = false;
 
     JobTask(String jobId, long id, String type, String action, TaskId parentTask, Map<String, String> headers) {
         super(id, type, action, "job-" + jobId, parentTask, headers);
@@ -37,27 +45,40 @@ public class JobTask extends AllocatedPersistentTask implements OpenJobAction.Jo
     @Override
     protected void onCancelled() {
         String reason = getReasonCancelled();
-        LOGGER.trace(() -> new ParameterizedMessage("[{}] Cancelling job task because: {}", jobId, reason));
-        isClosing = true;
+        logger.trace(() -> new ParameterizedMessage("[{}] Cancelling job task because: {}", jobId, reason));
+        closingOrVacating.set(ClosingOrVacating.CLOSING);
         autodetectProcessManager.killProcess(this, false, reason);
     }
 
     public boolean isClosing() {
-        return isClosing;
+        return closingOrVacating.get() == ClosingOrVacating.CLOSING;
+    }
+
+    public boolean triggerVacate() {
+        return closingOrVacating.compareAndSet(ClosingOrVacating.NEITHER, ClosingOrVacating.VACATING);
+    }
+
+    public boolean isVacating() {
+        return closingOrVacating.get() == ClosingOrVacating.VACATING;
     }
 
     public void closeJob(String reason) {
-        isClosing = true;
+        // If a job is vacating the node when a close request arrives, convert that vacate to a close.
+        // This may be too late, if the vacate operation has already gone past the point of unassigning
+        // the persistent task instead of completing it.  But in general a close should take precedence
+        // over a vacate.
+        if (closingOrVacating.getAndSet(ClosingOrVacating.CLOSING) == ClosingOrVacating.VACATING) {
+            logger.info("[{}] Close request for job while it was vacating the node", jobId);
+        }
         autodetectProcessManager.closeJob(this, reason);
     }
 
     public void killJob(String reason) {
-        isClosing = true;
+        closingOrVacating.set(ClosingOrVacating.CLOSING);
         autodetectProcessManager.killProcess(this, true, reason);
     }
 
     void setAutodetectProcessManager(AutodetectProcessManager autodetectProcessManager) {
         this.autodetectProcessManager = autodetectProcessManager;
     }
-
 }

+ 5 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/task/OpenJobPersistentTasksExecutor.java

@@ -326,7 +326,11 @@ public class OpenJobPersistentTasksExecutor extends AbstractJobPersistentTasksEx
         String jobId = jobTask.getJobId();
         autodetectProcessManager.openJob(jobTask, clusterState, PERSISTENT_TASK_MASTER_NODE_TIMEOUT, (e2, shouldFinalizeJob) -> {
             if (e2 == null) {
-                if (shouldFinalizeJob) {
+                // Beyond this point it's too late to change our minds about whether we're closing or vacating
+                if (jobTask.isVacating()) {
+                    jobTask.markAsLocallyAborted(
+                        "previously assigned node [" + clusterState.nodes().getLocalNode().getName() + "] is shutting down");
+                } else if (shouldFinalizeJob) {
                     FinalizeJobExecutionAction.Request finalizeRequest = new FinalizeJobExecutionAction.Request(new String[]{jobId});
                     finalizeRequest.masterNodeTimeout(PERSISTENT_TASK_MASTER_NODE_TIMEOUT);
                     executeAsyncWithOrigin(client, ML_ORIGIN, FinalizeJobExecutionAction.INSTANCE, finalizeRequest,

+ 1 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlConfigMigratorTests.java

@@ -140,7 +140,7 @@ public class MlConfigMigratorTests extends ESTestCase {
                 .nodes(nodes)
                 .build();
 
-        assertThat(MlConfigMigrator.stopppedOrUnallocatedDatafeeds(clusterState),
+        assertThat(MlConfigMigrator.stoppedOrUnallocatedDatafeeds(clusterState),
                 containsInAnyOrder(stopppedDatafeed, datafeedWithoutAllocation));
     }
 

+ 165 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlLifeCycleServiceTests.java

@@ -0,0 +1,165 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml;
+
+import org.elasticsearch.Version;
+import org.elasticsearch.cluster.ClusterState;
+import org.elasticsearch.cluster.metadata.Metadata;
+import org.elasticsearch.cluster.node.DiscoveryNode;
+import org.elasticsearch.cluster.node.DiscoveryNodeRole;
+import org.elasticsearch.cluster.node.DiscoveryNodes;
+import org.elasticsearch.cluster.service.ClusterService;
+import org.elasticsearch.common.transport.TransportAddress;
+import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.core.ml.MlTasks;
+import org.elasticsearch.xpack.core.ml.action.OpenJobAction;
+import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction;
+import org.elasticsearch.xpack.core.ml.action.StartDatafeedAction;
+import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
+import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState;
+import org.elasticsearch.xpack.core.ml.job.config.JobState;
+import org.elasticsearch.xpack.core.ml.job.config.JobTaskState;
+import org.elasticsearch.xpack.core.ml.job.snapshot.upgrade.SnapshotUpgradeState;
+import org.elasticsearch.xpack.core.ml.job.snapshot.upgrade.SnapshotUpgradeTaskState;
+import org.elasticsearch.xpack.ml.datafeed.DatafeedRunner;
+import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsManager;
+import org.elasticsearch.xpack.ml.job.process.autodetect.AutodetectProcessManager;
+import org.elasticsearch.xpack.ml.job.snapshot.upgrader.SnapshotUpgradeTaskParams;
+import org.elasticsearch.xpack.ml.process.MlController;
+import org.elasticsearch.xpack.ml.process.MlMemoryTracker;
+import org.junit.Before;
+
+import java.net.InetAddress;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+
+import static org.hamcrest.Matchers.is;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoMoreInteractions;
+
+public class MlLifeCycleServiceTests extends ESTestCase {
+
+    private ClusterService clusterService;
+    private DatafeedRunner datafeedRunner;
+    private MlController mlController;
+    private AutodetectProcessManager autodetectProcessManager;
+    private DataFrameAnalyticsManager analyticsManager;
+    private MlMemoryTracker memoryTracker;
+
+    @Before
+    public void setupMocks() {
+        clusterService = mock(ClusterService.class);
+        datafeedRunner = mock(DatafeedRunner.class);
+        mlController = mock(MlController.class);
+        autodetectProcessManager = mock(AutodetectProcessManager.class);
+        analyticsManager = mock(DataFrameAnalyticsManager.class);
+        memoryTracker = mock(MlMemoryTracker.class);
+    }
+
+    public void testIsNodeSafeToShutdown() {
+        PersistentTasksCustomMetadata.Builder tasksBuilder = PersistentTasksCustomMetadata.builder();
+
+        tasksBuilder.addTask(MlTasks.jobTaskId("job-1"), MlTasks.JOB_TASK_NAME, new OpenJobAction.JobParams("job-1"),
+            new PersistentTasksCustomMetadata.Assignment("node-1", "test assignment"));
+        tasksBuilder.addTask(MlTasks.datafeedTaskId("df1"), MlTasks.DATAFEED_TASK_NAME,
+            new StartDatafeedAction.DatafeedParams("df1", 0L),
+            new PersistentTasksCustomMetadata.Assignment("node-1", "test assignment"));
+        tasksBuilder.addTask(MlTasks.dataFrameAnalyticsTaskId("job-2"), MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME,
+            new StartDataFrameAnalyticsAction.TaskParams("foo-2", Version.CURRENT, true),
+            new PersistentTasksCustomMetadata.Assignment("node-2", "test assignment"));
+        tasksBuilder.addTask(MlTasks.snapshotUpgradeTaskId("job-3", "snapshot-3"), MlTasks.JOB_SNAPSHOT_UPGRADE_TASK_NAME,
+            new SnapshotUpgradeTaskParams("job-3", "snapshot-3"),
+            new PersistentTasksCustomMetadata.Assignment("node-3", "test assignment"));
+
+        Metadata metadata = Metadata.builder().putCustom(PersistentTasksCustomMetadata.TYPE, tasksBuilder.build()).build();
+        ClusterState clusterState = ClusterState.builder(ClusterState.EMPTY_STATE).metadata(metadata).build();
+
+        assertThat(MlLifeCycleService.isNodeSafeToShutdown("node-1", clusterState), is(false)); // has AD job
+        assertThat(MlLifeCycleService.isNodeSafeToShutdown("node-2", clusterState), is(true)); // has DFA job
+        assertThat(MlLifeCycleService.isNodeSafeToShutdown("node-3", clusterState), is(false)); // has snapshot upgrade
+        assertThat(MlLifeCycleService.isNodeSafeToShutdown("node-4", clusterState), is(true)); // has no ML tasks
+    }
+
+    public void testIsNodeSafeToShutdownGivenFailedTasks() {
+        PersistentTasksCustomMetadata.Builder tasksBuilder = PersistentTasksCustomMetadata.builder();
+
+        tasksBuilder.addTask(MlTasks.jobTaskId("job-1"), MlTasks.JOB_TASK_NAME, new OpenJobAction.JobParams("job-1"),
+            new PersistentTasksCustomMetadata.Assignment("node-1", "test assignment"));
+        tasksBuilder.updateTaskState(MlTasks.jobTaskId("job-1"), new JobTaskState(JobState.FAILED, 1, "testing"));
+        tasksBuilder.addTask(MlTasks.dataFrameAnalyticsTaskId("job-2"), MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME,
+            new StartDataFrameAnalyticsAction.TaskParams("foo-2", Version.CURRENT, true),
+            new PersistentTasksCustomMetadata.Assignment("node-2", "test assignment"));
+        tasksBuilder.updateTaskState(MlTasks.dataFrameAnalyticsTaskId("job-2"),
+            new DataFrameAnalyticsTaskState(DataFrameAnalyticsState.FAILED, 2, "testing"));
+        tasksBuilder.addTask(MlTasks.snapshotUpgradeTaskId("job-3", "snapshot-3"), MlTasks.JOB_SNAPSHOT_UPGRADE_TASK_NAME,
+            new SnapshotUpgradeTaskParams("job-3", "snapshot-3"),
+            new PersistentTasksCustomMetadata.Assignment("node-3", "test assignment"));
+        tasksBuilder.updateTaskState(MlTasks.snapshotUpgradeTaskId("job-3", "snapshot-3"),
+            new SnapshotUpgradeTaskState(SnapshotUpgradeState.FAILED, 3, "testing"));
+
+        Metadata metadata = Metadata.builder().putCustom(PersistentTasksCustomMetadata.TYPE, tasksBuilder.build()).build();
+        ClusterState clusterState = ClusterState.builder(ClusterState.EMPTY_STATE).metadata(metadata).build();
+
+        assertThat(MlLifeCycleService.isNodeSafeToShutdown("node-1", clusterState), is(true)); // has failed AD job
+        assertThat(MlLifeCycleService.isNodeSafeToShutdown("node-2", clusterState), is(true)); // has failed DFA job
+        assertThat(MlLifeCycleService.isNodeSafeToShutdown("node-3", clusterState), is(true)); // has failed snapshot upgrade
+        assertThat(MlLifeCycleService.isNodeSafeToShutdown("node-4", clusterState), is(true)); // has no ML tasks
+    }
+
+    public void testSignalGracefulShutdownIncludingLocalNode() {
+
+        MlLifeCycleService mlLifeCycleService = new MlLifeCycleService(clusterService, datafeedRunner, mlController,
+            autodetectProcessManager, analyticsManager, memoryTracker);
+
+        DiscoveryNodes.Builder nodesBuilder = DiscoveryNodes.builder()
+            .add(new DiscoveryNode("node-1-name", "node-1", new TransportAddress(InetAddress.getLoopbackAddress(), 9300),
+                Collections.emptyMap(), DiscoveryNodeRole.roles(), Version.CURRENT))
+            .add(new DiscoveryNode("node-2-name", "node-2", new TransportAddress(InetAddress.getLoopbackAddress(), 9301),
+                Collections.emptyMap(), DiscoveryNodeRole.roles(), Version.CURRENT))
+            .add(new DiscoveryNode("node-3-name", "node-3", new TransportAddress(InetAddress.getLoopbackAddress(), 9302),
+                Collections.emptyMap(), DiscoveryNodeRole.roles(), Version.CURRENT))
+            .masterNodeId("node-1")
+            .localNodeId("node-2");
+        ClusterState clusterState = ClusterState.builder(ClusterState.EMPTY_STATE).nodes(nodesBuilder).build();
+
+        Collection<String> shutdownNodeIds =
+            randomBoolean() ? Collections.singleton("node-2") : Arrays.asList("node-1", "node-2", "node-3");
+
+        mlLifeCycleService.signalGracefulShutdown(clusterState, shutdownNodeIds);
+
+        verify(datafeedRunner).vacateAllDatafeedsOnThisNode("previously assigned node [node-2-name] is shutting down");
+        verify(autodetectProcessManager).vacateOpenJobsOnThisNode();
+        verifyNoMoreInteractions(datafeedRunner, mlController, autodetectProcessManager, analyticsManager, memoryTracker);
+    }
+
+    public void testSignalGracefulShutdownExcludingLocalNode() {
+
+        MlLifeCycleService mlLifeCycleService = new MlLifeCycleService(clusterService, datafeedRunner, mlController,
+            autodetectProcessManager, analyticsManager, memoryTracker);
+        DiscoveryNodes.Builder nodesBuilder = DiscoveryNodes.builder()
+            .add(new DiscoveryNode("node-1-name", "node-1", new TransportAddress(InetAddress.getLoopbackAddress(), 9300),
+                Collections.emptyMap(), DiscoveryNodeRole.roles(), Version.CURRENT))
+            .add(new DiscoveryNode("node-2-name", "node-2", new TransportAddress(InetAddress.getLoopbackAddress(), 9301),
+                Collections.emptyMap(), DiscoveryNodeRole.roles(), Version.CURRENT))
+            .add(new DiscoveryNode("node-3-name", "node-3", new TransportAddress(InetAddress.getLoopbackAddress(), 9302),
+                Collections.emptyMap(), DiscoveryNodeRole.roles(), Version.CURRENT))
+            .masterNodeId("node-1")
+            .localNodeId("node-2");
+        ClusterState clusterState = ClusterState.builder(ClusterState.EMPTY_STATE).nodes(nodesBuilder).build();
+
+        Collection<String> shutdownNodeIds =
+            randomBoolean() ? Collections.singleton("node-1") : Arrays.asList("node-1", "node-3");
+
+        mlLifeCycleService.signalGracefulShutdown(clusterState, shutdownNodeIds);
+
+        verifyNoMoreInteractions(datafeedRunner, mlController, autodetectProcessManager, analyticsManager, memoryTracker);
+    }
+}

+ 6 - 2
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportCloseJobActionTests.java

@@ -55,6 +55,8 @@ import static org.elasticsearch.xpack.ml.job.task.OpenJobPersistentTasksExecutor
 import static org.hamcrest.Matchers.contains;
 import static org.hamcrest.Matchers.containsInAnyOrder;
 import static org.hamcrest.Matchers.empty;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.hasProperty;
 import static org.hamcrest.Matchers.instanceOf;
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.anyBoolean;
@@ -307,8 +309,10 @@ public class TransportCloseJobActionTests extends ESTestCase {
             TransportCloseJobAction.buildWaitForCloseRequest(
                 openJobIds, closingJobIds, tasksBuilder.build(), mock(AnomalyDetectionAuditor.class));
         assertEquals(waitForCloseRequest.jobsToFinalize, Arrays.asList("openjob1", "openjob2"));
-        assertEquals(waitForCloseRequest.persistentTaskIds,
-                Arrays.asList("job-openjob1", "job-openjob2", "job-closingjob1"));
+        assertThat(waitForCloseRequest.persistentTasks, containsInAnyOrder(
+            hasProperty("id", equalTo("job-openjob1")),
+            hasProperty("id", equalTo("job-openjob2")),
+            hasProperty("id", equalTo("job-closingjob1"))));
         assertTrue(waitForCloseRequest.hasJobsToWaitFor());
 
         waitForCloseRequest = TransportCloseJobAction.buildWaitForCloseRequest(Collections.emptyList(), Collections.emptyList(),

+ 3 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportStartDatafeedActionTests.java

@@ -34,6 +34,7 @@ import static org.elasticsearch.persistent.PersistentTasksCustomMetadata.INITIAL
 import static org.elasticsearch.xpack.ml.job.task.OpenJobPersistentTasksExecutorTests.addJobTask;
 import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.is;
 import static org.mockito.Matchers.any;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.mock;
@@ -163,7 +164,8 @@ public class TransportStartDatafeedActionTests extends ESTestCase {
                                                                                DatafeedRunner datafeedRunner) {
         TransportStartDatafeedAction.DatafeedTask task = new TransportStartDatafeedAction.DatafeedTask(id, type, action, parentTaskId,
                 params, Collections.emptyMap());
-        task.datafeedRunner = datafeedRunner;
+        assertThat(task.setDatafeedRunner(datafeedRunner),
+            is(TransportStartDatafeedAction.DatafeedTask.StoppedOrIsolatedBeforeRunning.NEITHER));
         return task;
     }
 }

+ 60 - 14
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/DatafeedNodeSelectorTests.java

@@ -16,6 +16,7 @@ import org.elasticsearch.cluster.metadata.IndexMetadata;
 import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
 import org.elasticsearch.cluster.metadata.Metadata;
 import org.elasticsearch.cluster.node.DiscoveryNode;
+import org.elasticsearch.cluster.node.DiscoveryNodeRole;
 import org.elasticsearch.cluster.node.DiscoveryNodes;
 import org.elasticsearch.cluster.routing.IndexRoutingTable;
 import org.elasticsearch.cluster.routing.IndexShardRoutingTable;
@@ -43,6 +44,7 @@ import org.junit.Before;
 import java.net.InetAddress;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.Date;
 import java.util.List;
@@ -89,7 +91,7 @@ public class DatafeedNodeSelectorTests extends ESTestCase {
             df.getId(),
             df.getJobId(),
             df.getIndices(),
-            SearchRequest.DEFAULT_INDICES_OPTIONS).selectNode();
+            SearchRequest.DEFAULT_INDICES_OPTIONS).selectNode(makeCandidateNodes("node_id", "other_node_id"));
         assertEquals("node_id", result.getExecutorNode());
         new DatafeedNodeSelector(clusterState,
             resolver,
@@ -117,7 +119,7 @@ public class DatafeedNodeSelectorTests extends ESTestCase {
             df.getId(),
             df.getJobId(),
             df.getIndices(),
-            SearchRequest.DEFAULT_INDICES_OPTIONS).selectNode();
+            SearchRequest.DEFAULT_INDICES_OPTIONS).selectNode(makeCandidateNodes("node_id", "other_node_id"));
         assertEquals("node_id", result.getExecutorNode());
         new DatafeedNodeSelector(clusterState,
             resolver,
@@ -142,7 +144,7 @@ public class DatafeedNodeSelectorTests extends ESTestCase {
             df.getId(),
             df.getJobId(),
             df.getIndices(),
-            SearchRequest.DEFAULT_INDICES_OPTIONS).selectNode();
+            SearchRequest.DEFAULT_INDICES_OPTIONS).selectNode(makeCandidateNodes("node_id", "other_node_id"));
         assertEquals("node_id", result.getExecutorNode());
         new DatafeedNodeSelector(clusterState,
             resolver,
@@ -167,7 +169,7 @@ public class DatafeedNodeSelectorTests extends ESTestCase {
             df.getId(),
             df.getJobId(),
             df.getIndices(),
-            SearchRequest.DEFAULT_INDICES_OPTIONS).selectNode();
+            SearchRequest.DEFAULT_INDICES_OPTIONS).selectNode(makeCandidateNodes("node_id", "other_node_id"));
         assertNull(result.getExecutorNode());
         assertThat(result.getExplanation(), equalTo("cannot start datafeed [datafeed_id], because the job's [job_id] state is " +
                 "[closed] while state [opened] is required"));
@@ -199,7 +201,7 @@ public class DatafeedNodeSelectorTests extends ESTestCase {
             df.getId(),
             df.getJobId(),
             df.getIndices(),
-            SearchRequest.DEFAULT_INDICES_OPTIONS).selectNode();
+            SearchRequest.DEFAULT_INDICES_OPTIONS).selectNode(makeCandidateNodes("node_id", "other_node_id"));
         assertNull(result.getExecutorNode());
         assertEquals("cannot start datafeed [datafeed_id], because the job's [job_id] state is [" + jobState +
                 "] while state [opened] is required", result.getExplanation());
@@ -236,7 +238,7 @@ public class DatafeedNodeSelectorTests extends ESTestCase {
             df.getId(),
             df.getJobId(),
             df.getIndices(),
-            SearchRequest.DEFAULT_INDICES_OPTIONS).selectNode();
+            SearchRequest.DEFAULT_INDICES_OPTIONS).selectNode(makeCandidateNodes("node_id", "other_node_id"));
         assertNull(result.getExecutorNode());
         assertThat(result.getExplanation(), equalTo("cannot start datafeed [datafeed_id] because index [foo] " +
                 "does not have all primary shards active yet."));
@@ -270,7 +272,7 @@ public class DatafeedNodeSelectorTests extends ESTestCase {
             df.getId(),
             df.getJobId(),
             df.getIndices(),
-            SearchRequest.DEFAULT_INDICES_OPTIONS).selectNode();
+            SearchRequest.DEFAULT_INDICES_OPTIONS).selectNode(makeCandidateNodes("node_id", "other_node_id"));
         assertNull(result.getExecutorNode());
         assertThat(result.getExplanation(), equalTo("cannot start datafeed [datafeed_id] because index [foo] " +
                 "does not have all primary shards active yet."));
@@ -298,7 +300,7 @@ public class DatafeedNodeSelectorTests extends ESTestCase {
             df.getId(),
             df.getJobId(),
             df.getIndices(),
-            SearchRequest.DEFAULT_INDICES_OPTIONS).selectNode();
+            SearchRequest.DEFAULT_INDICES_OPTIONS).selectNode(makeCandidateNodes("node_id", "other_node_id"));
         assertNull(result.getExecutorNode());
         assertThat(result.getExplanation(),
             equalTo("cannot start datafeed [datafeed_id] because it failed resolving indices given [not_foo] and " +
@@ -337,7 +339,7 @@ public class DatafeedNodeSelectorTests extends ESTestCase {
             df.getId(),
             df.getJobId(),
             df.getIndices(),
-            SearchRequest.DEFAULT_INDICES_OPTIONS).selectNode();
+            SearchRequest.DEFAULT_INDICES_OPTIONS).selectNode(makeCandidateNodes("node_id", "other_node_id"));
         assertEquals("node_id", result.getExecutorNode());
         new DatafeedNodeSelector(clusterState,
             resolver,
@@ -362,7 +364,7 @@ public class DatafeedNodeSelectorTests extends ESTestCase {
             df.getId(),
             df.getJobId(),
             df.getIndices(),
-            SearchRequest.DEFAULT_INDICES_OPTIONS).selectNode();
+            SearchRequest.DEFAULT_INDICES_OPTIONS).selectNode(makeCandidateNodes("node_id", "other_node_id"));
         assertNotNull(result.getExecutorNode());
     }
 
@@ -379,12 +381,14 @@ public class DatafeedNodeSelectorTests extends ESTestCase {
 
         givenClusterState("foo", 1, 0);
 
+        Collection<DiscoveryNode> candidateNodes = makeCandidateNodes("node_id1", "node_id2", "node_id3");
+
         PersistentTasksCustomMetadata.Assignment result = new DatafeedNodeSelector(clusterState,
             resolver,
             df.getId(),
             df.getJobId(),
             df.getIndices(),
-            SearchRequest.DEFAULT_INDICES_OPTIONS).selectNode();
+            SearchRequest.DEFAULT_INDICES_OPTIONS).selectNode(candidateNodes);
         assertNull(result.getExecutorNode());
         assertEquals("cannot start datafeed [datafeed_id], because the job's [job_id] state is stale",
                 result.getExplanation());
@@ -408,7 +412,7 @@ public class DatafeedNodeSelectorTests extends ESTestCase {
             df.getId(),
             df.getJobId(),
             df.getIndices(),
-            SearchRequest.DEFAULT_INDICES_OPTIONS).selectNode();
+            SearchRequest.DEFAULT_INDICES_OPTIONS).selectNode(candidateNodes);
         assertEquals("node_id1", result.getExecutorNode());
         new DatafeedNodeSelector(clusterState,
             resolver,
@@ -462,7 +466,7 @@ public class DatafeedNodeSelectorTests extends ESTestCase {
             df.getId(),
             df.getJobId(),
             df.getIndices(),
-            SearchRequest.DEFAULT_INDICES_OPTIONS).selectNode();
+            SearchRequest.DEFAULT_INDICES_OPTIONS).selectNode(makeCandidateNodes("node_id", "other_node_id"));
         assertThat(result, equalTo(MlTasks.AWAITING_UPGRADE));
     }
 
@@ -482,7 +486,7 @@ public class DatafeedNodeSelectorTests extends ESTestCase {
             df.getId(),
             df.getJobId(),
             df.getIndices(),
-            SearchRequest.DEFAULT_INDICES_OPTIONS).selectNode();
+            SearchRequest.DEFAULT_INDICES_OPTIONS).selectNode(makeCandidateNodes("node_id", "other_node_id"));
         assertThat(result, equalTo(MlTasks.RESET_IN_PROGRESS));
     }
 
@@ -507,6 +511,38 @@ public class DatafeedNodeSelectorTests extends ESTestCase {
         assertThat(e.getMessage(), equalTo("Could not start datafeed [datafeed_id] as indices are being upgraded"));
     }
 
+    public void testSelectNode_GivenJobIsOpenedAndNodeIsShuttingDown() {
+        Job job = createScheduledJob("job_id").build(new Date());
+        DatafeedConfig df = createDatafeed("datafeed_id", job.getId(), Collections.singletonList("foo"));
+
+        PersistentTasksCustomMetadata.Builder tasksBuilder =  PersistentTasksCustomMetadata.builder();
+        addJobTask(job.getId(), "node_id", JobState.OPENED, tasksBuilder);
+        tasks = tasksBuilder.build();
+
+        givenClusterState("foo", 1, 0);
+
+        PersistentTasksCustomMetadata.Assignment result = new DatafeedNodeSelector(clusterState,
+            resolver,
+            df.getId(),
+            df.getJobId(),
+            df.getIndices(),
+            SearchRequest.DEFAULT_INDICES_OPTIONS).selectNode(makeCandidateNodes("other_node_id"));
+        assertNull(result.getExecutorNode());
+        assertEquals("datafeed awaiting job relocation.", result.getExplanation());
+
+        // This is different to the pattern of the other tests - we allow the datafeed task to be
+        // created even though it cannot be assigned.  The reason is that it would be perverse for
+        // start datafeed to throw an error just because a user was unlucky and opened a job just
+        // before a node got shut down, such that their subsequent call to start its datafeed arrived
+        // after that node was shutting down.
+        new DatafeedNodeSelector(clusterState,
+            resolver,
+            df.getId(),
+            df.getJobId(),
+            df.getIndices(),
+            SearchRequest.DEFAULT_INDICES_OPTIONS).checkDatafeedTaskCanBeCreated();
+    }
+
     private void givenClusterState(String index, int numberOfShards, int numberOfReplicas) {
         List<Tuple<Integer, ShardRoutingState>> states = new ArrayList<>(1);
         states.add(new Tuple<>(0, ShardRoutingState.STARTED));
@@ -585,4 +621,14 @@ public class DatafeedNodeSelectorTests extends ESTestCase {
 
         return new RoutingTable.Builder().add(rtBuilder).build();
     }
+
+    Collection<DiscoveryNode> makeCandidateNodes(String... nodeIds) {
+        List<DiscoveryNode> candidateNodes = new ArrayList<>();
+        int port = 9300;
+        for (String nodeId : nodeIds) {
+            candidateNodes.add(new DiscoveryNode(nodeId + "-name", nodeId, new TransportAddress(InetAddress.getLoopbackAddress(), port++),
+                Collections.emptyMap(), DiscoveryNodeRole.roles(), Version.CURRENT));
+        }
+        return candidateNodes;
+    }
 }

+ 21 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/DatafeedRunnerTests.java

@@ -77,6 +77,7 @@ public class DatafeedRunnerTests extends ESTestCase {
     private DatafeedJob datafeedJob;
     private DatafeedRunner datafeedRunner;
     private DatafeedContextProvider datafeedContextProvider;
+    private DatafeedJobBuilder datafeedJobBuilder;
     private long currentTime = 120000;
     private AnomalyDetectionAuditor auditor;
     private ArgumentCaptor<ClusterStateListener> capturedClusterStateListener = ArgumentCaptor.forClass(ClusterStateListener.class);
@@ -122,7 +123,7 @@ public class DatafeedRunnerTests extends ESTestCase {
         when(datafeedJob.stop()).thenReturn(true);
         when(datafeedJob.getJobId()).thenReturn(job.getId());
         when(datafeedJob.getMaxEmptySearches()).thenReturn(null);
-        DatafeedJobBuilder datafeedJobBuilder = mock(DatafeedJobBuilder.class);
+        datafeedJobBuilder = mock(DatafeedJobBuilder.class);
         doAnswer(invocationOnMock -> {
             @SuppressWarnings("rawtypes")
             ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2];
@@ -423,6 +424,24 @@ public class DatafeedRunnerTests extends ESTestCase {
         verify(threadPool, never()).executor(MachineLearning.DATAFEED_THREAD_POOL_NAME);
     }
 
+    public void testDatafeedGetsStoppedWhileStarting() {
+        PersistentTasksCustomMetadata.Builder tasksBuilder = PersistentTasksCustomMetadata.builder();
+        addJobTask(JOB_ID, "node_id", JobState.OPENED, tasksBuilder);
+        ClusterState cs = ClusterState.builder(clusterService.state())
+            .metadata(new Metadata.Builder().putCustom(PersistentTasksCustomMetadata.TYPE, tasksBuilder.build())).build();
+        when(clusterService.state()).thenReturn(cs);
+
+        Consumer<Exception> handler = mockConsumer();
+        DatafeedTask task = createDatafeedTask(DATAFEED_ID, 0L, 60000L);
+        when(task.getStoppedOrIsolatedBeforeRunning()).thenReturn(DatafeedTask.StoppedOrIsolatedBeforeRunning.STOPPED);
+        datafeedRunner.run(task, handler);
+
+        // Verify datafeed aborted after creating context but before doing anything else
+        verify(datafeedContextProvider).buildDatafeedContext(eq(DATAFEED_ID), any());
+        verify(datafeedJobBuilder, never()).build(any(), any(), any());
+        verify(threadPool, never()).executor(MachineLearning.DATAFEED_THREAD_POOL_NAME);
+    }
+
     public static DatafeedConfig.Builder createDatafeedConfig(String datafeedId, String jobId) {
         DatafeedConfig.Builder datafeedConfig = new DatafeedConfig.Builder(datafeedId, jobId);
         datafeedConfig.setIndices(Collections.singletonList("myIndex"));
@@ -453,6 +472,7 @@ public class DatafeedRunnerTests extends ESTestCase {
             listener.onResponse(mock(PersistentTask.class));
             return null;
         }).when(task).updatePersistentTaskState(any(), any());
+        when(task.getStoppedOrIsolatedBeforeRunning()).thenReturn(DatafeedTask.StoppedOrIsolatedBeforeRunning.NEITHER);
         return task;
     }
 

+ 1 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcessManagerTests.java

@@ -515,7 +515,7 @@ public class AutodetectProcessManagerTests extends ESTestCase {
         manager.processData(jobTask, analysisRegistry, createInputStream(""), randomFrom(XContentType.values()), mock(DataLoadParams.class),
                 (dataCounts1, e) -> {
                 });
-        verify(manager).setJobState(any(), eq(JobState.OPENED));
+        verify(manager).setJobState(any(), eq(JobState.OPENED), any(), any());
         // job is created
         assertEquals(1, manager.numberOfOpenJobs());
         expectThrows(ElasticsearchException.class, () -> manager.closeJob(jobTask, null));

+ 32 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/task/JobTaskTests.java

@@ -44,4 +44,36 @@ public class JobTaskTests extends ESTestCase {
         assertThat(jobTask.isClosing(), is(true));
         verify(processManager).killProcess(jobTask, true, "test");
     }
+
+    public void testCloseOrVacateTransitions() {
+
+        JobTask jobTask = new JobTask("transition-test-task", 0, "persistent", "", null, null);
+
+        assertThat(jobTask.isClosing(), is(false));
+        assertThat(jobTask.isVacating(), is(false));
+
+        AutodetectProcessManager processManager = mock(AutodetectProcessManager.class);
+        jobTask.setAutodetectProcessManager(processManager);
+
+        assertThat(jobTask.isClosing(), is(false));
+        assertThat(jobTask.isVacating(), is(false));
+
+        // we can transition from neither closing nor vacating to vacating
+        assertThat(jobTask.triggerVacate(), is(true));
+
+        assertThat(jobTask.isClosing(), is(false));
+        assertThat(jobTask.isVacating(), is(true));
+
+        jobTask.closeJob("just testing");
+        verify(processManager).closeJob(jobTask, "just testing");
+
+        assertThat(jobTask.isClosing(), is(true));
+        assertThat(jobTask.isVacating(), is(false));
+
+        // we cannot transition from closing back to vacating
+        assertThat(jobTask.triggerVacate(), is(false));
+
+        assertThat(jobTask.isClosing(), is(true));
+        assertThat(jobTask.isVacating(), is(false));
+    }
 }

+ 39 - 7
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/support/BaseMlIntegTestCase.java

@@ -7,8 +7,12 @@
 package org.elasticsearch.xpack.ml.support;
 
 import org.apache.logging.log4j.Logger;
+import org.elasticsearch.Build;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.DocWriteRequest;
+import org.elasticsearch.action.admin.cluster.node.tasks.list.ListTasksAction;
+import org.elasticsearch.action.admin.cluster.node.tasks.list.ListTasksRequest;
+import org.elasticsearch.action.admin.cluster.node.tasks.list.ListTasksResponse;
 import org.elasticsearch.action.admin.indices.recovery.RecoveryResponse;
 import org.elasticsearch.action.admin.indices.settings.put.UpdateSettingsRequest;
 import org.elasticsearch.action.bulk.BulkItemResponse;
@@ -39,6 +43,7 @@ import org.elasticsearch.script.MockScriptPlugin;
 import org.elasticsearch.script.ScoreScript;
 import org.elasticsearch.script.ScriptContext;
 import org.elasticsearch.script.ScriptEngine;
+import org.elasticsearch.tasks.TaskInfo;
 import org.elasticsearch.test.ESIntegTestCase;
 import org.elasticsearch.test.MockHttpTransport;
 import org.elasticsearch.threadpool.ThreadPool;
@@ -75,6 +80,7 @@ import org.elasticsearch.xpack.ml.LocalStateMachineLearning;
 import org.elasticsearch.xpack.ml.MachineLearning;
 import org.elasticsearch.xpack.ml.MlSingleNodeTestCase;
 import org.elasticsearch.xpack.monitoring.MonitoringService;
+import org.elasticsearch.xpack.shutdown.ShutdownPlugin;
 import org.junit.After;
 import org.junit.Before;
 
@@ -90,7 +96,9 @@ import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.Consumer;
 import java.util.function.Function;
+import java.util.stream.Collectors;
 
+import static org.hamcrest.Matchers.empty;
 import static org.hamcrest.Matchers.equalTo;
 import static org.mockito.Matchers.any;
 import static org.mockito.Mockito.doAnswer;
@@ -123,6 +131,10 @@ public abstract class BaseMlIntegTestCase extends ESIntegTestCase {
         settings.put(MonitoringService.ENABLED.getKey(), false);
         settings.put(MonitoringService.ELASTICSEARCH_COLLECTION_ENABLED.getKey(), false);
         settings.put(LifecycleSettings.LIFECYCLE_HISTORY_INDEX_ENABLED_SETTING.getKey(), false);
+        // TODO: put this setting unconditionally once the shutdown API is not protected by a feature flag
+        if (Build.CURRENT.isSnapshot()) {
+            settings.put(ShutdownPlugin.SHUTDOWN_FEATURE_ENABLED_FLAG, true);
+        }
         return settings.build();
     }
 
@@ -133,6 +145,7 @@ public abstract class BaseMlIntegTestCase extends ESIntegTestCase {
             CommonAnalysisPlugin.class,
             IngestCommonPlugin.class,
             ReindexPlugin.class,
+            ShutdownPlugin.class,
             // To remove warnings about painless not being supported
             MockPainlessScriptEngine.TestPlugin.class,
             // ILM is required for .ml-state template index settings
@@ -206,6 +219,10 @@ public abstract class BaseMlIntegTestCase extends ESIntegTestCase {
     }
 
     public static Job.Builder createScheduledJob(String jobId) {
+        return createScheduledJob(jobId, null);
+    }
+
+    public static Job.Builder createScheduledJob(String jobId, ByteSizeValue modelMemoryLimit) {
         DataDescription.Builder dataDescription = new DataDescription.Builder();
         dataDescription.setTimeFormat("yyyy-MM-dd HH:mm:ss");
 
@@ -215,7 +232,9 @@ public abstract class BaseMlIntegTestCase extends ESIntegTestCase {
 
         Job.Builder builder = new Job.Builder();
         builder.setId(jobId);
-
+        if (modelMemoryLimit != null) {
+            builder.setAnalysisLimits(new AnalysisLimits(modelMemoryLimit.getMb(), null));
+        }
         builder.setAnalysisConfig(analysisConfig);
         builder.setDataDescription(dataDescription);
         return builder;
@@ -242,11 +261,12 @@ public abstract class BaseMlIntegTestCase extends ESIntegTestCase {
     }
 
     @After
-    public void cleanupWorkaround() throws Exception {
+    public void cleanup() throws Exception {
         logger.info("[{}#{}]: Cleaning up datafeeds and jobs after test", getTestClass().getSimpleName(), getTestName());
         deleteAllDatafeeds(logger, client());
         deleteAllJobs(logger, client());
         deleteAllDataFrameAnalytics(client());
+        waitForPendingTasks(client());
         assertBusy(() -> {
             RecoveryResponse recoveryResponse = client().admin().indices().prepareRecoveries()
                     .setActiveOnly(true)
@@ -333,7 +353,7 @@ public abstract class BaseMlIntegTestCase extends ESIntegTestCase {
         final QueryPage<DatafeedConfig> datafeeds =
             client.execute(GetDatafeedsAction.INSTANCE, new GetDatafeedsAction.Request(GetDatafeedsAction.ALL)).actionGet().getResponse();
         try {
-            logger.info("Closing all datafeeds (using _all)");
+            logger.info("Stopping all datafeeds (using _all)");
             StopDatafeedAction.Response stopResponse = client
                     .execute(StopDatafeedAction.INSTANCE, new StopDatafeedAction.Request("_all"))
                     .get();
@@ -348,8 +368,7 @@ public abstract class BaseMlIntegTestCase extends ESIntegTestCase {
             } catch (ExecutionException e2) {
                 logger.warn("Force-stopping datafeed with _all failed.", e2);
             }
-            throw new RuntimeException(
-                    "Had to resort to force-stopping datafeed, something went wrong?", e1);
+            throw new RuntimeException("Had to resort to force-stopping datafeed, something went wrong?", e1);
         }
 
         for (final DatafeedConfig datafeed : datafeeds.results()) {
@@ -397,7 +416,7 @@ public abstract class BaseMlIntegTestCase extends ESIntegTestCase {
         for (final Job job : jobs.results()) {
             assertBusy(() -> {
                 GetJobsStatsAction.Response statsResponse =
-                        client().execute(GetJobsStatsAction.INSTANCE, new GetJobsStatsAction.Request(job.getId())).actionGet();
+                        client.execute(GetJobsStatsAction.INSTANCE, new GetJobsStatsAction.Request(job.getId())).actionGet();
                 assertEquals(JobState.CLOSED, statsResponse.getResponse().results().get(0).getState());
             });
             AcknowledgedResponse response =
@@ -413,7 +432,7 @@ public abstract class BaseMlIntegTestCase extends ESIntegTestCase {
 
         assertBusy(() -> {
             GetDataFrameAnalyticsStatsAction.Response statsResponse =
-                client().execute(GetDataFrameAnalyticsStatsAction.INSTANCE, new GetDataFrameAnalyticsStatsAction.Request("_all")).get();
+                client.execute(GetDataFrameAnalyticsStatsAction.INSTANCE, new GetDataFrameAnalyticsStatsAction.Request("_all")).get();
             assertTrue(statsResponse.getResponse().results().stream().allMatch(s -> s.getState().equals(DataFrameAnalyticsState.STOPPED)));
         });
         for (final DataFrameAnalyticsConfig config : analytics.results()) {
@@ -421,6 +440,19 @@ public abstract class BaseMlIntegTestCase extends ESIntegTestCase {
         }
     }
 
+    public static void waitForPendingTasks(Client client) throws Exception {
+        ListTasksRequest request = new ListTasksRequest().setDetailed(true);
+
+        assertBusy(() -> {
+            ListTasksResponse response = client.execute(ListTasksAction.INSTANCE, request).get();
+            List<String> activeTasks = response.getTasks().stream()
+                .filter(t -> t.getAction().startsWith(ListTasksAction.NAME) == false)
+                .map(TaskInfo::toString)
+                .collect(Collectors.toList());
+            assertThat(activeTasks, empty());
+        });
+    }
+
     protected static <T> void blockingCall(Consumer<ActionListener<T>> function,
                                            AtomicReference<T> response,
                                            AtomicReference<Exception> error) throws InterruptedException {