Browse Source

[ML] Restore analytics state if available (#47128)

This commit restores the model state if available in data
frame analytics jobs.

In addition, this changes the start API so that a stopped job
can be restarted. As we now store the progress in the state index
when the task is stopped, we can use it to determine what state
the job was in when it got stopped.

Note that in order to be able to distinguish between a job
that runs for the first time and another that is restarting,
we ensure reindexing progress is reported to be at least 1
for a running task.
Dimitris Athanasiou 6 years ago
parent
commit
f47da1d3fe
31 changed files with 709 additions and 153 deletions
  1. 30 6
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsAction.java
  2. 5 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java
  3. 5 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java
  4. 5 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java
  5. 1 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java
  6. 9 1
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsActionTaskParamsTests.java
  7. 6 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetectionTests.java
  8. 8 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java
  9. 2 1
      x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java
  10. 82 0
      x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java
  11. 67 0
      x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java
  12. 7 6
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java
  13. 115 50
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java
  14. 55 38
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java
  15. 56 11
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTask.java
  16. 7 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcess.java
  17. 4 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfig.java
  18. 5 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessFactory.java
  19. 74 11
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java
  20. 1 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/MemoryUsageEstimationProcessManager.java
  21. 10 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeAnalyticsProcess.java
  22. 7 4
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeAnalyticsProcessFactory.java
  23. 6 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeMemoryUsageEstimationProcess.java
  24. 3 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeMemoryUsageEstimationProcessFactory.java
  25. 2 19
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/persistence/StateStreamer.java
  26. 41 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/StateToProcessWriterHelper.java
  27. 2 1
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportStopDataFrameAnalyticsActionTests.java
  28. 90 0
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTaskTests.java
  29. 1 1
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/MemoryUsageEstimationProcessManagerTests.java
  30. 1 1
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/JobNodeSelectorTests.java
  31. 2 2
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/process/MlMemoryTrackerTests.java

+ 30 - 6
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsAction.java

@@ -13,6 +13,7 @@ import org.elasticsearch.action.support.master.AcknowledgedResponse;
 import org.elasticsearch.action.support.master.MasterNodeRequest;
 import org.elasticsearch.client.ElasticsearchClient;
 import org.elasticsearch.cluster.metadata.MetaData;
+import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.ParseField;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.io.stream.StreamInput;
@@ -29,8 +30,11 @@ import org.elasticsearch.xpack.core.ml.MlTasks;
 import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
 import org.elasticsearch.xpack.core.ml.job.messages.Messages;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
 
 import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
 import java.util.Objects;
 
 public class StartDataFrameAnalyticsAction extends ActionType<AcknowledgedResponse> {
@@ -150,12 +154,15 @@ public class StartDataFrameAnalyticsAction extends ActionType<AcknowledgedRespon
 
         public static final Version VERSION_INTRODUCED = Version.V_7_3_0;
 
+        private static final ParseField PROGRESS_ON_START = new ParseField("progress_on_start");
+
         public static ConstructingObjectParser<TaskParams, Void> PARSER = new ConstructingObjectParser<>(
-            MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, true, a -> new TaskParams((String) a[0], (String) a[1]));
+            MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, true, a -> new TaskParams((String) a[0], (String) a[1], (List<PhaseProgress>) a[2]));
 
         static {
             PARSER.declareString(ConstructingObjectParser.constructorArg(), DataFrameAnalyticsConfig.ID);
             PARSER.declareString(ConstructingObjectParser.constructorArg(), DataFrameAnalyticsConfig.VERSION);
+            PARSER.declareObjectArray(ConstructingObjectParser.optionalConstructorArg(), PhaseProgress.PARSER, PROGRESS_ON_START);
         }
 
         public static TaskParams fromXContent(XContentParser parser) {
@@ -164,25 +171,36 @@ public class StartDataFrameAnalyticsAction extends ActionType<AcknowledgedRespon
 
         private final String id;
         private final Version version;
+        private final List<PhaseProgress> progressOnStart;
 
-        public TaskParams(String id, Version version) {
+        public TaskParams(String id, Version version, List<PhaseProgress> progressOnStart) {
             this.id = Objects.requireNonNull(id);
             this.version = Objects.requireNonNull(version);
+            this.progressOnStart = Collections.unmodifiableList(progressOnStart);
         }
 
-        private TaskParams(String id, String version) {
-            this(id, Version.fromString(version));
+        private TaskParams(String id, String version, @Nullable List<PhaseProgress> progressOnStart) {
+            this(id, Version.fromString(version), progressOnStart == null ? Collections.emptyList() : progressOnStart);
         }
 
         public TaskParams(StreamInput in) throws IOException {
             this.id = in.readString();
             this.version = Version.readVersion(in);
+            if (in.getVersion().onOrAfter(Version.V_7_5_0)) {
+                progressOnStart = in.readList(PhaseProgress::new);
+            } else {
+                progressOnStart = Collections.emptyList();
+            }
         }
 
         public String getId() {
             return id;
         }
 
+        public List<PhaseProgress> getProgressOnStart() {
+            return progressOnStart;
+        }
+
         @Override
         public String getWriteableName() {
             return MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME;
@@ -197,6 +215,9 @@ public class StartDataFrameAnalyticsAction extends ActionType<AcknowledgedRespon
         public void writeTo(StreamOutput out) throws IOException {
             out.writeString(id);
             Version.writeVersion(version, out);
+            if (out.getVersion().onOrAfter(Version.V_7_5_0)) {
+                out.writeList(progressOnStart);
+            }
         }
 
         @Override
@@ -204,13 +225,14 @@ public class StartDataFrameAnalyticsAction extends ActionType<AcknowledgedRespon
             builder.startObject();
             builder.field(DataFrameAnalyticsConfig.ID.getPreferredName(), id);
             builder.field(DataFrameAnalyticsConfig.VERSION.getPreferredName(), version);
+            builder.field(PROGRESS_ON_START.getPreferredName(), progressOnStart);
             builder.endObject();
             return builder;
         }
 
         @Override
         public int hashCode() {
-            return Objects.hash(id, version);
+            return Objects.hash(id, version, progressOnStart);
         }
 
         @Override
@@ -219,7 +241,9 @@ public class StartDataFrameAnalyticsAction extends ActionType<AcknowledgedRespon
             if (o == null || getClass() != o.getClass()) return false;
 
             TaskParams other = (TaskParams) o;
-            return Objects.equals(id, other.id) && Objects.equals(version, other.version);
+            return Objects.equals(id, other.id)
+                && Objects.equals(version, other.version)
+                && Objects.equals(progressOnStart, other.progressOnStart);
         }
     }
 

+ 5 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java

@@ -37,4 +37,9 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable {
      * @return {@code true} if this analysis persists state that can later be used to restore from a given point
      */
     boolean persistsState();
+
+    /**
+     * Returns the document id for the analysis state
+     */
+    String getStateDocId(String jobId);
 }

+ 5 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java

@@ -174,6 +174,11 @@ public class OutlierDetection implements DataFrameAnalysis {
         return false;
     }
 
+    @Override
+    public String getStateDocId(String jobId) {
+        throw new UnsupportedOperationException("Outlier detection does not support state");
+    }
+
     public enum Method {
         LOF, LDOF, DISTANCE_KTH_NN, DISTANCE_KNN;
 

+ 5 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java

@@ -215,6 +215,11 @@ public class Regression implements DataFrameAnalysis {
         return true;
     }
 
+    @Override
+    public String getStateDocId(String jobId) {
+        return jobId + "_regression_state#1";
+    }
+
     @Override
     public int hashCode() {
         return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,

+ 1 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java

@@ -66,6 +66,7 @@ public final class Messages {
     public static final String DATA_FRAME_ANALYTICS_AUDIT_REUSING_DEST_INDEX = "Using existing destination index [{0}]";
     public static final String DATA_FRAME_ANALYTICS_AUDIT_FINISHED_REINDEXING = "Finished reindexing to destination index [{0}]";
     public static final String DATA_FRAME_ANALYTICS_AUDIT_FINISHED_ANALYSIS = "Finished analysis";
+    public static final String DATA_FRAME_ANALYTICS_AUDIT_RESTORING_STATE = "Restoring from previous model state";
 
     public static final String FILTER_CANNOT_DELETE = "Cannot delete filter [{0}] currently used by jobs {1}";
     public static final String FILTER_CONTAINS_TOO_MANY_ITEMS = "Filter [{0}] contains too many items; up to [{1}] items are allowed";

+ 9 - 1
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsActionTaskParamsTests.java

@@ -10,8 +10,11 @@ import org.elasticsearch.Version;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.test.AbstractSerializingTestCase;
+import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
 
 import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
 
 public class StartDataFrameAnalyticsActionTaskParamsTests extends AbstractSerializingTestCase<StartDataFrameAnalyticsAction.TaskParams> {
 
@@ -22,7 +25,12 @@ public class StartDataFrameAnalyticsActionTaskParamsTests extends AbstractSerial
 
     @Override
     protected StartDataFrameAnalyticsAction.TaskParams createTestInstance() {
-        return new StartDataFrameAnalyticsAction.TaskParams(randomAlphaOfLength(10), Version.CURRENT);
+        int phaseCount = randomIntBetween(0, 5);
+        List<PhaseProgress> progressOnStart = new ArrayList<>(phaseCount);
+        for (int i = 0; i < phaseCount; i++) {
+            progressOnStart.add(new PhaseProgress(randomAlphaOfLength(10), randomIntBetween(0, 100)));
+        }
+        return new StartDataFrameAnalyticsAction.TaskParams(randomAlphaOfLength(10), Version.CURRENT, progressOnStart);
     }
 
     @Override

+ 6 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetectionTests.java

@@ -56,4 +56,10 @@ public class OutlierDetectionTests extends AbstractSerializingTestCase<OutlierDe
         assertThat((Double) params.get(OutlierDetection.FEATURE_INFLUENCE_THRESHOLD.getPreferredName()),
             is(closeTo(0.42, 1E-9)));
     }
+
+    public void testGetStateDocId() {
+        OutlierDetection outlierDetection = createRandom();
+        assertThat(outlierDetection.persistsState(), is(false));
+        expectThrows(UnsupportedOperationException.class, () -> outlierDetection.getStateDocId("foo"));
+    }
 }

+ 8 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java

@@ -13,6 +13,7 @@ import org.elasticsearch.test.AbstractSerializingTestCase;
 import java.io.IOException;
 
 import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.is;
 
 public class RegressionTests extends AbstractSerializingTestCase<Regression> {
 
@@ -124,4 +125,11 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
 
         assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
     }
+
+    public void testGetStateDocId() {
+        Regression regression = createRandom();
+        assertThat(regression.persistsState(), is(true));
+        String randomId = randomAlphaOfLength(10);
+        assertThat(regression.getStateDocId(randomId), equalTo(randomId + "_regression_state#1"));
+    }
 }

+ 2 - 1
x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java

@@ -199,7 +199,8 @@ abstract class MlNativeDataFrameAnalyticsIntegTestCase extends MlNativeIntegTest
         assertBusy(() -> assertTrue(indexExists(AuditorField.NOTIFICATIONS_INDEX)));
         assertBusy(() -> {
             String[] actualAuditMessages = fetchAllAuditMessages(configId);
-            assertThat(actualAuditMessages.length, equalTo(expectedAuditMessagePrefixes.length));
+            assertThat("Messages: " + Arrays.toString(actualAuditMessages), actualAuditMessages.length,
+                equalTo(expectedAuditMessagePrefixes.length));
             for (int i = 0; i < actualAuditMessages.length; i++) {
                 assertThat(actualAuditMessages[i], startsWith(expectedAuditMessagePrefixes[i]));
             }

+ 82 - 0
x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java

@@ -23,6 +23,7 @@ import java.util.Arrays;
 import java.util.List;
 import java.util.Map;
 
+import static org.hamcrest.Matchers.anyOf;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.greaterThan;
 import static org.hamcrest.Matchers.is;
@@ -258,6 +259,87 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
         assertModelStatePersisted(jobId);
     }
 
+    public void testStopAndRestart() throws Exception {
+        String jobId = "regression_stop_and_restart";
+        String sourceIndex = jobId + "_source_index";
+
+        BulkRequestBuilder bulkRequestBuilder = client().prepareBulk();
+        bulkRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
+
+        List<Double> featureValues = Arrays.asList(1.0, 2.0, 3.0);
+        List<Double> dependentVariableValues = Arrays.asList(10.0, 20.0, 30.0);
+
+        for (int i = 0; i < 350; i++) {
+            Double field = featureValues.get(i % 3);
+            Double value = dependentVariableValues.get(i % 3);
+
+            IndexRequest indexRequest = new IndexRequest(sourceIndex);
+            indexRequest.source("feature", field, "variable", value);
+            bulkRequestBuilder.add(indexRequest);
+        }
+        BulkResponse bulkResponse = bulkRequestBuilder.get();
+        if (bulkResponse.hasFailures()) {
+            fail("Failed to index data: " + bulkResponse.buildFailureMessage());
+        }
+
+        String destIndex = sourceIndex + "_results";
+        DataFrameAnalyticsConfig config = buildRegressionAnalytics(jobId, new String[] {sourceIndex}, destIndex, null,
+            new Regression("variable"));
+        registerAnalytics(config);
+        putAnalytics(config);
+
+        assertState(jobId, DataFrameAnalyticsState.STOPPED);
+        assertProgress(jobId, 0, 0, 0, 0);
+
+        startAnalytics(jobId);
+
+        // Wait until state is one of REINDEXING or ANALYZING, or until it is STOPPED.
+        assertBusy(() -> {
+            DataFrameAnalyticsState state = getAnalyticsStats(jobId).get(0).getState();
+            assertThat(state, is(anyOf(equalTo(DataFrameAnalyticsState.REINDEXING), equalTo(DataFrameAnalyticsState.ANALYZING),
+                equalTo(DataFrameAnalyticsState.STOPPED))));
+        });
+        stopAnalytics(jobId);
+        waitUntilAnalyticsIsStopped(jobId);
+
+        // Now let's start it again
+        try {
+            startAnalytics(jobId);
+        } catch (Exception e) {
+            if (e.getMessage().equals("Cannot start because the job has already finished")) {
+                // That means the job had managed to complete
+            } else {
+                throw e;
+            }
+        }
+
+        waitUntilAnalyticsIsStopped(jobId);
+
+        SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
+        for (SearchHit hit : sourceData.getHits()) {
+            GetResponse destDocGetResponse = client().prepareGet().setIndex(config.getDest().getIndex()).setId(hit.getId()).get();
+            assertThat(destDocGetResponse.isExists(), is(true));
+            Map<String, Object> sourceDoc = hit.getSourceAsMap();
+            Map<String, Object> destDoc = destDocGetResponse.getSource();
+            for (String field : sourceDoc.keySet()) {
+                assertThat(destDoc.containsKey(field), is(true));
+                assertThat(destDoc.get(field), equalTo(sourceDoc.get(field)));
+            }
+            assertThat(destDoc.containsKey("ml"), is(true));
+
+            @SuppressWarnings("unchecked")
+            Map<String, Object> resultsObject = (Map<String, Object>) destDoc.get("ml");
+
+            assertThat(resultsObject.containsKey("variable_prediction"), is(true));
+            assertThat(resultsObject.containsKey("is_training"), is(true));
+            assertThat(resultsObject.get("is_training"), is(true));
+        }
+
+        assertProgress(jobId, 100, 100, 100, 100);
+        assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
+        assertModelStatePersisted(jobId);
+    }
+
     private void assertModelStatePersisted(String jobId) {
         String docId = jobId + "_regression_state#1";
         SearchResponse searchResponse = client().prepareSearch(AnomalyDetectorsIndex.jobStateIndexPattern())

+ 67 - 0
x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java

@@ -31,6 +31,7 @@ import org.junit.After;
 import java.util.Map;
 
 import static org.hamcrest.Matchers.allOf;
+import static org.hamcrest.Matchers.anyOf;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.greaterThan;
 import static org.hamcrest.Matchers.greaterThanOrEqualTo;
@@ -476,4 +477,70 @@ public class RunDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTest
             "Created analytics with analysis type [outlier_detection]",
             "Estimated memory usage for this analytics to be");
     }
+
+    public void testOutlierDetectionStopAndRestart() throws Exception {
+        String sourceIndex = "test-outlier-detection-stop-and-restart";
+
+        client().admin().indices().prepareCreate(sourceIndex)
+            .addMapping("_doc", "numeric_1", "type=double", "numeric_2", "type=float", "categorical_1", "type=keyword")
+            .get();
+
+        BulkRequestBuilder bulkRequestBuilder = client().prepareBulk();
+        bulkRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
+
+        int docCount = randomIntBetween(1024, 2048);
+        for (int i = 0; i < docCount; i++) {
+            IndexRequest indexRequest = new IndexRequest(sourceIndex);
+            indexRequest.source("numeric_1", randomDouble(), "numeric_2", randomFloat(), "categorical_1", randomAlphaOfLength(10));
+            bulkRequestBuilder.add(indexRequest);
+        }
+        BulkResponse bulkResponse = bulkRequestBuilder.get();
+        if (bulkResponse.hasFailures()) {
+            fail("Failed to index data: " + bulkResponse.buildFailureMessage());
+        }
+
+        String id = "test_outlier_detection_stop_and_restart";
+        DataFrameAnalyticsConfig config = buildOutlierDetectionAnalytics(
+            id, new String[] {sourceIndex}, sourceIndex + "-results", "custom_ml");
+        registerAnalytics(config);
+        putAnalytics(config);
+
+        assertState(id, DataFrameAnalyticsState.STOPPED);
+        startAnalytics(id);
+
+        // Wait until state is one of REINDEXING or ANALYZING, or until it is STOPPED.
+        assertBusy(() -> {
+            DataFrameAnalyticsState state = getAnalyticsStats(id).get(0).getState();
+            assertThat(state, is(anyOf(equalTo(DataFrameAnalyticsState.REINDEXING), equalTo(DataFrameAnalyticsState.ANALYZING),
+                equalTo(DataFrameAnalyticsState.STOPPED))));
+        });
+        stopAnalytics(id);
+        waitUntilAnalyticsIsStopped(id);
+
+        // Now let's start it again
+        try {
+            startAnalytics(id);
+        } catch (Exception e) {
+            if (e.getMessage().equals("Cannot start because the job has already finished")) {
+                // That means the job had managed to complete
+            } else {
+                throw e;
+            }
+        }
+
+        waitUntilAnalyticsIsStopped(id);
+
+        // Check we've got all docs
+        SearchResponse searchResponse = client().prepareSearch(config.getDest().getIndex()).setTrackTotalHits(true).get();
+        assertThat(searchResponse.getHits().getTotalHits().value, equalTo((long) docCount));
+
+        // Check they all have an outlier_score
+        searchResponse = client().prepareSearch(config.getDest().getIndex())
+            .setTrackTotalHits(true)
+            .setQuery(QueryBuilders.existsQuery("custom_ml.outlier_score")).get();
+        assertThat(searchResponse.getHits().getTotalHits().value, equalTo((long) docCount));
+
+        assertProgress(id, 100, 100, 100, 100);
+        assertThat(searchStoredProgress(id).getHits().getTotalHits().value, equalTo(1L));
+    }
 }

+ 7 - 6
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java

@@ -195,11 +195,11 @@ import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsManager;
 import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider;
 import org.elasticsearch.xpack.ml.dataframe.process.AnalyticsProcessFactory;
 import org.elasticsearch.xpack.ml.dataframe.process.AnalyticsProcessManager;
-import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
 import org.elasticsearch.xpack.ml.dataframe.process.MemoryUsageEstimationProcessManager;
-import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult;
-import org.elasticsearch.xpack.ml.dataframe.process.NativeMemoryUsageEstimationProcessFactory;
 import org.elasticsearch.xpack.ml.dataframe.process.NativeAnalyticsProcessFactory;
+import org.elasticsearch.xpack.ml.dataframe.process.NativeMemoryUsageEstimationProcessFactory;
+import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
+import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult;
 import org.elasticsearch.xpack.ml.inference.persistence.InferenceInternalIndex;
 import org.elasticsearch.xpack.ml.job.JobManager;
 import org.elasticsearch.xpack.ml.job.JobManagerHolder;
@@ -535,8 +535,8 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu
                     new BlackHoleAutodetectProcess(job.getId());
             // factor of 1.0 makes renormalization a no-op
             normalizerProcessFactory = (jobId, quantilesState, bucketSpan, executorService) -> new MultiplyingNormalizerProcess(1.0);
-            analyticsProcessFactory = (jobId, analyticsProcessConfig, executorService, onProcessCrash) -> null;
-            memoryEstimationProcessFactory = (jobId, analyticsProcessConfig, executorService, onProcessCrash) -> null;
+            analyticsProcessFactory = (jobId, analyticsProcessConfig, state, executorService, onProcessCrash) -> null;
+            memoryEstimationProcessFactory = (jobId, analyticsProcessConfig, state, executorService, onProcessCrash) -> null;
         }
         NormalizerFactory normalizerFactory = new NormalizerFactory(normalizerProcessFactory,
                 threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME));
@@ -561,7 +561,8 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu
         this.datafeedManager.set(datafeedManager);
 
         // Data frame analytics components
-        AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager(client, threadPool, analyticsProcessFactory);
+        AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager(client, threadPool, analyticsProcessFactory,
+            dataFrameAnalyticsAuditor);
         MemoryUsageEstimationProcessManager memoryEstimationProcessManager =
             new MemoryUsageEstimationProcessManager(
                 threadPool.generic(), threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME), memoryEstimationProcessFactory);

+ 115 - 50
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java

@@ -49,6 +49,7 @@ import org.elasticsearch.xpack.core.XPackField;
 import org.elasticsearch.xpack.core.ml.MlMetadata;
 import org.elasticsearch.xpack.core.ml.MlTasks;
 import org.elasticsearch.xpack.core.ml.action.EstimateMemoryUsageAction;
+import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction;
 import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction;
 import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction;
 import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
@@ -57,6 +58,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState;
 import org.elasticsearch.xpack.core.ml.job.messages.Messages;
 import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
 import org.elasticsearch.xpack.ml.MachineLearning;
 import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsManager;
 import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask;
@@ -73,9 +75,10 @@ import java.util.ArrayList;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
-import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.Predicate;
 
+import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
+import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
 import static org.elasticsearch.xpack.core.ml.MlTasks.AWAITING_UPGRADE;
 import static org.elasticsearch.xpack.ml.MachineLearning.MAX_OPEN_JOBS_PER_NODE;
 
@@ -159,70 +162,75 @@ public class TransportStartDataFrameAnalyticsAction
                 }
             };
 
-        AtomicReference<DataFrameAnalyticsConfig> configHolder = new AtomicReference<>();
-
         // Start persistent task
-        ActionListener<Void> memoryRequirementRefreshListener = ActionListener.wrap(
-            aVoid -> {
+        ActionListener<StartContext> memoryUsageHandledListener = ActionListener.wrap(
+            startContext -> {
                 StartDataFrameAnalyticsAction.TaskParams taskParams = new StartDataFrameAnalyticsAction.TaskParams(
-                    request.getId(), configHolder.get().getVersion());
+                    request.getId(), startContext.config.getVersion(), startContext.progressOnStart);
                 persistentTasksService.sendStartRequest(MlTasks.dataFrameAnalyticsTaskId(request.getId()),
                     MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, taskParams, waitForAnalyticsToStart);
             },
             listener::onFailure
         );
 
+        // Perform memory usage estimation for this config
+        ActionListener<StartContext> startContextListener = ActionListener.wrap(
+            startContext -> {
+                estimateMemoryUsageAndUpdateMemoryTracker(startContext, memoryUsageHandledListener);
+            },
+            listener::onFailure
+        );
+
+        // Get start context
+        getStartContext(request.getId(), startContextListener);
+    }
+
+    private void estimateMemoryUsageAndUpdateMemoryTracker(StartContext startContext, ActionListener<StartContext> listener) {
+        final String jobId = startContext.config.getId();
+
         // Tell the job tracker to refresh the memory requirement for this job and all other jobs that have persistent tasks
         ActionListener<EstimateMemoryUsageAction.Response> estimateMemoryUsageListener = ActionListener.wrap(
             estimateMemoryUsageResponse -> {
                 auditor.info(
-                    request.getId(),
+                    jobId,
                     Messages.getMessage(
                         Messages.DATA_FRAME_ANALYTICS_AUDIT_ESTIMATED_MEMORY_USAGE,
                         estimateMemoryUsageResponse.getExpectedMemoryWithoutDisk()));
                 // Validate that model memory limit is sufficient to run the analysis
-                if (configHolder.get().getModelMemoryLimit()
+                if (startContext.config.getModelMemoryLimit()
                     .compareTo(estimateMemoryUsageResponse.getExpectedMemoryWithoutDisk()) < 0) {
                     ElasticsearchStatusException e =
                         ExceptionsHelper.badRequestException(
                             "Cannot start because the configured model memory limit [{}] is lower than the expected memory usage [{}]",
-                            configHolder.get().getModelMemoryLimit(), estimateMemoryUsageResponse.getExpectedMemoryWithoutDisk());
+                            startContext.config.getModelMemoryLimit(), estimateMemoryUsageResponse.getExpectedMemoryWithoutDisk());
                     listener.onFailure(e);
                     return;
                 }
                 // Refresh memory requirement for jobs
                 memoryTracker.addDataFrameAnalyticsJobMemoryAndRefreshAllOthers(
-                    request.getId(), configHolder.get().getModelMemoryLimit().getBytes(), memoryRequirementRefreshListener);
+                    jobId, startContext.config.getModelMemoryLimit().getBytes(), ActionListener.wrap(
+                        aVoid -> listener.onResponse(startContext), listener::onFailure));
             },
             listener::onFailure
         );
 
-        // Perform memory usage estimation for this config
-        ActionListener<DataFrameAnalyticsConfig> configListener = ActionListener.wrap(
-            config -> {
-                configHolder.set(config);
-                PutDataFrameAnalyticsAction.Request estimateMemoryUsageRequest = new PutDataFrameAnalyticsAction.Request(config);
-                ClientHelper.executeAsyncWithOrigin(
-                    client,
-                    ClientHelper.ML_ORIGIN,
-                    EstimateMemoryUsageAction.INSTANCE,
-                    estimateMemoryUsageRequest,
-                    estimateMemoryUsageListener);
-            },
-            listener::onFailure
-        );
+        PutDataFrameAnalyticsAction.Request estimateMemoryUsageRequest = new PutDataFrameAnalyticsAction.Request(startContext.config);
+        ClientHelper.executeAsyncWithOrigin(
+            client,
+            ClientHelper.ML_ORIGIN,
+            EstimateMemoryUsageAction.INSTANCE,
+            estimateMemoryUsageRequest,
+            estimateMemoryUsageListener);
 
-        // Get config
-        getConfigAndValidate(request.getId(), configListener);
     }
 
-    private void getConfigAndValidate(String id, ActionListener<DataFrameAnalyticsConfig> finalListener) {
+    private void getStartContext(String id, ActionListener<StartContext> finalListener) {
 
-        // Step 5. Validate that there are analyzable data in the source index
-        ActionListener<DataFrameAnalyticsConfig> validateMappingsMergeListener = ActionListener.wrap(
-            config -> DataFrameDataExtractorFactory.createForSourceIndices(client,
+        // Step 6. Validate that there are analyzable data in the source index
+        ActionListener<StartContext> validateMappingsMergeListener = ActionListener.wrap(
+            startContext -> DataFrameDataExtractorFactory.createForSourceIndices(client,
                 "validate_source_index_has_rows-" + id,
-                config,
+                startContext.config,
                 ActionListener.wrap(
                     dataFrameDataExtractorFactory ->
                         dataFrameDataExtractorFactory
@@ -234,10 +242,10 @@ public class TransportStartDataFrameAnalyticsAction
                                             "Unable to start {} as there are no analyzable data in source indices [{}].",
                                             RestStatus.BAD_REQUEST,
                                             id,
-                                            Strings.arrayToCommaDelimitedString(config.getSource().getIndex())
+                                            Strings.arrayToCommaDelimitedString(startContext.config.getSource().getIndex())
                                         ));
                                     } else {
-                                        finalListener.onResponse(config);
+                                        finalListener.onResponse(startContext);
                                     }
                                 },
                                 finalListener::onFailure
@@ -248,49 +256,94 @@ public class TransportStartDataFrameAnalyticsAction
             finalListener::onFailure
         );
 
-        // Step 4. Validate mappings can be merged
-        ActionListener<DataFrameAnalyticsConfig> toValidateMappingsListener = ActionListener.wrap(
-            config -> MappingsMerger.mergeMappings(client, config.getHeaders(), config.getSource().getIndex(), ActionListener.wrap(
-                mappings -> validateMappingsMergeListener.onResponse(config), finalListener::onFailure)),
+        // Step 5. Validate mappings can be merged
+        ActionListener<StartContext> toValidateMappingsListener = ActionListener.wrap(
+            startContext -> MappingsMerger.mergeMappings(client, startContext.config.getHeaders(),
+                startContext.config.getSource().getIndex(), ActionListener.wrap(
+                mappings -> validateMappingsMergeListener.onResponse(startContext), finalListener::onFailure)),
             finalListener::onFailure
         );
 
-        // Step 3. Validate dest index is empty
-        ActionListener<DataFrameAnalyticsConfig> toValidateDestEmptyListener = ActionListener.wrap(
-            config -> checkDestIndexIsEmptyIfExists(config, toValidateMappingsListener),
+        // Step 4. Validate dest index is empty if task is starting for first time
+        ActionListener<StartContext> toValidateDestEmptyListener = ActionListener.wrap(
+            startContext -> {
+                DataFrameAnalyticsTask.StartingState startingState = DataFrameAnalyticsTask.determineStartingState(
+                    startContext.config.getId(), startContext.progressOnStart);
+                switch (startingState) {
+                    case FIRST_TIME:
+                        checkDestIndexIsEmptyIfExists(startContext, toValidateMappingsListener);
+                        break;
+                    case RESUMING_REINDEXING:
+                    case RESUMING_ANALYZING:
+                        toValidateMappingsListener.onResponse(startContext);
+                        break;
+                    case FINISHED:
+                        LOGGER.info("[{}] Job has already finished", startContext.config.getId());
+                        finalListener.onFailure(ExceptionsHelper.badRequestException(
+                            "Cannot start because the job has already finished"));
+                        break;
+                    default:
+                        finalListener.onFailure(ExceptionsHelper.serverError("Unexpected starting state " + startingState));
+                        break;
+                }
+            },
             finalListener::onFailure
         );
 
-        // Step 2. Validate source and dest; check data extraction is possible
-        ActionListener<DataFrameAnalyticsConfig> getConfigListener = ActionListener.wrap(
-            config -> {
-                new SourceDestValidator(clusterService.state(), indexNameExpressionResolver).check(config);
-                DataFrameDataExtractorFactory.validateConfigAndSourceIndex(client, config, toValidateDestEmptyListener);
+        // Step 3. Validate source and dest; check data extraction is possible
+        ActionListener<StartContext> startContextListener = ActionListener.wrap(
+            startContext -> {
+                new SourceDestValidator(clusterService.state(), indexNameExpressionResolver).check(startContext.config);
+                DataFrameDataExtractorFactory.validateConfigAndSourceIndex(client, startContext.config, ActionListener.wrap(
+                    config -> toValidateDestEmptyListener.onResponse(startContext), finalListener::onFailure));
             },
             finalListener::onFailure
         );
 
+        // Step 2. Get stats to recover progress
+        ActionListener<DataFrameAnalyticsConfig> getConfigListener = ActionListener.wrap(
+            config -> getProgress(config, ActionListener.wrap(
+                progress -> startContextListener.onResponse(new StartContext(config, progress)), finalListener::onFailure)),
+            finalListener::onFailure
+        );
+
         // Step 1. Get the config
         configProvider.get(id, getConfigListener);
     }
 
-    private void checkDestIndexIsEmptyIfExists(DataFrameAnalyticsConfig config, ActionListener<DataFrameAnalyticsConfig> listener) {
-        String destIndex = config.getDest().getIndex();
+    private void getProgress(DataFrameAnalyticsConfig config, ActionListener<List<PhaseProgress>> listener) {
+        GetDataFrameAnalyticsStatsAction.Request getStatsRequest = new GetDataFrameAnalyticsStatsAction.Request(config.getId());
+        executeAsyncWithOrigin(client, ML_ORIGIN, GetDataFrameAnalyticsStatsAction.INSTANCE, getStatsRequest, ActionListener.wrap(
+            statsResponse -> {
+                List<GetDataFrameAnalyticsStatsAction.Response.Stats> stats = statsResponse.getResponse().results();
+                if (stats.isEmpty()) {
+                    // The job has been deleted in between
+                    listener.onFailure(ExceptionsHelper.missingDataFrameAnalytics(config.getId()));
+                } else {
+                    listener.onResponse(stats.get(0).getProgress());
+                }
+            },
+            listener::onFailure
+        ));
+    }
+
+    private void checkDestIndexIsEmptyIfExists(StartContext startContext, ActionListener<StartContext> listener) {
+        String destIndex = startContext.config.getDest().getIndex();
         SearchRequest destEmptySearch = new SearchRequest(destIndex);
         destEmptySearch.source().size(0);
         destEmptySearch.allowPartialSearchResults(false);
-        ClientHelper.executeWithHeadersAsync(config.getHeaders(), ClientHelper.ML_ORIGIN, client, SearchAction.INSTANCE,
+        ClientHelper.executeWithHeadersAsync(startContext.config.getHeaders(), ClientHelper.ML_ORIGIN, client, SearchAction.INSTANCE,
             destEmptySearch, ActionListener.wrap(
                 searchResponse -> {
                     if (searchResponse.getHits().getTotalHits().value > 0) {
                         listener.onFailure(ExceptionsHelper.badRequestException("dest index [{}] must be empty", destIndex));
                     } else {
-                        listener.onResponse(config);
+                        listener.onResponse(startContext);
                     }
                 },
                 e -> {
                     if (e instanceof IndexNotFoundException) {
-                        listener.onResponse(config);
+                        listener.onResponse(startContext);
                     } else {
                         listener.onFailure(e);
                     }
@@ -331,6 +384,16 @@ public class TransportStartDataFrameAnalyticsAction
         });
     }
 
+    private static class StartContext {
+        private final DataFrameAnalyticsConfig config;
+        private final List<PhaseProgress> progressOnStart;
+
+        private StartContext(DataFrameAnalyticsConfig config, List<PhaseProgress> progressOnStart) {
+            this.config = config;
+            this.progressOnStart = progressOnStart;
+        }
+    }
+
     /**
      * Important: the methods of this class must NOT throw exceptions.  If they did then the callers
      * of endpoints waiting for a condition tested by this predicate would never get a response.
@@ -539,4 +602,6 @@ public class TransportStartDataFrameAnalyticsAction
             this.maxOpenJobs = maxOpenJobs;
         }
     }
+
+
 }

+ 55 - 38
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java

@@ -34,7 +34,6 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
 import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState;
 import org.elasticsearch.xpack.core.ml.job.messages.Messages;
 import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
-import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory;
 import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider;
 import org.elasticsearch.xpack.ml.dataframe.process.AnalyticsProcessManager;
@@ -67,61 +66,33 @@ public class DataFrameAnalyticsManager {
     }
 
     public void execute(DataFrameAnalyticsTask task, DataFrameAnalyticsState currentState, ClusterState clusterState) {
-        ActionListener<DataFrameAnalyticsConfig> reindexingStateListener = ActionListener.wrap(
-            config -> reindexDataframeAndStartAnalysis(task, config),
-            error -> task.updateState(DataFrameAnalyticsState.FAILED, error.getMessage())
-        );
-
         // With config in hand, determine action to take
         ActionListener<DataFrameAnalyticsConfig> configListener = ActionListener.wrap(
             config -> {
-                DataFrameAnalyticsTaskState reindexingState = new DataFrameAnalyticsTaskState(DataFrameAnalyticsState.REINDEXING,
-                    task.getAllocationId(), null);
                 switch(currentState) {
-                    // If we are STARTED, we are right at the beginning of our task, we should indicate that we are entering the
-                    // REINDEX state and start reindexing.
+                    // If we are STARTED, it means the job was started because the start API was called.
+                    // We should determine the job's starting state based on its previous progress.
                     case STARTED:
-                        task.updatePersistentTaskState(reindexingState, ActionListener.wrap(
-                            updatedTask -> reindexingStateListener.onResponse(config),
-                            error -> {
-                                if (error instanceof ResourceNotFoundException) {
-                                    // The task has been stopped
-                                } else {
-                                    reindexingStateListener.onFailure(error);
-                                }
-                            }));
+                        executeStartingJob(task, config);
                         break;
                     // The task has fully reindexed the documents and we should continue on with our analyses
                     case ANALYZING:
+                        LOGGER.debug("[{}] Reassigning job that was analyzing", config.getId());
                         startAnalytics(task, config, true);
                         break;
                     // If we are already at REINDEXING, we are not 100% sure if we reindexed ALL the docs.
                     // We will delete the destination index, recreate, reindex
                     case REINDEXING:
-                        ClientHelper.executeAsyncWithOrigin(client,
-                            ML_ORIGIN,
-                            DeleteIndexAction.INSTANCE,
-                            new DeleteIndexRequest(config.getDest().getIndex()),
-                            ActionListener.wrap(
-                                r-> reindexingStateListener.onResponse(config),
-                                e -> {
-                                    if (e instanceof IndexNotFoundException) {
-                                        reindexingStateListener.onResponse(config);
-                                    } else {
-                                        reindexingStateListener.onFailure(e);
-                                    }
-                                }
-                            ));
+                        LOGGER.debug("[{}] Reassigning job that was reindexing", config.getId());
+                        executeJobInMiddleOfReindexing(task, config);
                         break;
                     default:
-                        reindexingStateListener.onFailure(
-                            ExceptionsHelper.conflictStatusException(
-                                "Cannot execute analytics task [{}] as it is currently in state [{}]. " +
-                                "Must be one of [STARTED, REINDEXING, ANALYZING]", config.getId(), currentState));
+                        task.updateState(DataFrameAnalyticsState.FAILED, "Cannot execute analytics task [" + config.getId() +
+                            "] as it is in unknown state [" + currentState + "]. Must be one of [STARTED, REINDEXING, ANALYZING]");
                 }
 
             },
-            reindexingStateListener::onFailure
+            error -> task.updateState(DataFrameAnalyticsState.FAILED, error.getMessage())
         );
 
         // Retrieve configuration
@@ -134,6 +105,52 @@ public class DataFrameAnalyticsManager {
         AnomalyDetectorsIndex.createStateIndexAndAliasIfNecessary(client, clusterState, stateAliasListener);
     }
 
+    private void executeStartingJob(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config) {
+        DataFrameAnalyticsTaskState reindexingState = new DataFrameAnalyticsTaskState(DataFrameAnalyticsState.REINDEXING,
+            task.getAllocationId(), null);
+        DataFrameAnalyticsTask.StartingState startingState = DataFrameAnalyticsTask.determineStartingState(
+            config.getId(), task.getParams().getProgressOnStart());
+
+        LOGGER.debug("[{}] Starting job from state [{}]", config.getId(), startingState);
+        switch (startingState) {
+            case FIRST_TIME:
+                task.updatePersistentTaskState(reindexingState, ActionListener.wrap(
+                    updatedTask -> reindexDataframeAndStartAnalysis(task, config),
+                    error -> task.updateState(DataFrameAnalyticsState.FAILED, error.getMessage())
+                ));
+                break;
+            case RESUMING_REINDEXING:
+                task.updatePersistentTaskState(reindexingState, ActionListener.wrap(
+                    updatedTask -> executeJobInMiddleOfReindexing(task, config),
+                    error -> task.updateState(DataFrameAnalyticsState.FAILED, error.getMessage())
+                ));
+                break;
+            case RESUMING_ANALYZING:
+                startAnalytics(task, config, true);
+                break;
+            case FINISHED:
+            default:
+                task.updateState(DataFrameAnalyticsState.FAILED, "Unexpected starting state [" + startingState + "]");
+        }
+    }
+
+    private void executeJobInMiddleOfReindexing(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config) {
+        ClientHelper.executeAsyncWithOrigin(client,
+            ML_ORIGIN,
+            DeleteIndexAction.INSTANCE,
+            new DeleteIndexRequest(config.getDest().getIndex()),
+            ActionListener.wrap(
+                r-> reindexDataframeAndStartAnalysis(task, config),
+                e -> {
+                    if (e instanceof IndexNotFoundException) {
+                        reindexDataframeAndStartAnalysis(task, config);
+                    } else {
+                        task.updateState(DataFrameAnalyticsState.FAILED, e.getMessage());
+                    }
+                }
+            ));
+    }
+
     private void reindexDataframeAndStartAnalysis(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config) {
         if (task.isStopping()) {
             // The task was requested to stop before we started reindexing

+ 56 - 11
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTask.java

@@ -185,14 +185,23 @@ public class DataFrameAnalyticsTask extends AllocatedPersistentTask implements S
     }
 
     public void updateReindexTaskProgress(ActionListener<Void> listener) {
+        getReindexTaskProgress(ActionListener.wrap(
+            // We set reindexing progress at least to 1 for a running process to be able to
+            // distinguish a job that is running for the first time against a job that is restarting.
+            reindexTaskProgress -> {
+                progressTracker.reindexingPercent.set(Math.max(1, reindexTaskProgress));
+                listener.onResponse(null);
+            },
+            listener::onFailure
+        ));
+    }
+
+    private void getReindexTaskProgress(ActionListener<Integer> listener) {
         TaskId reindexTaskId = getReindexTaskId();
         if (reindexTaskId == null) {
             // The task is not present which means either it has not started yet or it finished.
             // We keep track of whether the task has finished so we can use that to tell whether the progress 100.
-            if (isReindexingFinished) {
-                progressTracker.reindexingPercent.set(100);
-            }
-            listener.onResponse(null);
+            listener.onResponse(isReindexingFinished ? 100 : 0);
             return;
         }
 
@@ -202,18 +211,14 @@ public class DataFrameAnalyticsTask extends AllocatedPersistentTask implements S
             taskResponse -> {
                 TaskResult taskResult = taskResponse.getTask();
                 BulkByScrollTask.Status taskStatus = (BulkByScrollTask.Status) taskResult.getTask().getStatus();
-                int progress = taskStatus.getTotal() == 0 ? 0 : (int) (taskStatus.getCreated() * 100.0 / taskStatus.getTotal());
-                progressTracker.reindexingPercent.set(progress);
-                listener.onResponse(null);
+                int progress = (int) (taskStatus.getCreated() * 100.0 / taskStatus.getTotal());
+                listener.onResponse(progress);
             },
             error -> {
                 if (error instanceof ResourceNotFoundException) {
                     // The task is not present which means either it has not started yet or it finished.
                     // We keep track of whether the task has finished so we can use that to tell whether the progress 100.
-                    if (isReindexingFinished) {
-                        progressTracker.reindexingPercent.set(100);
-                    }
-                    listener.onResponse(null);
+                    listener.onResponse(isReindexingFinished ? 100 : 0);
                 } else {
                     listener.onFailure(error);
                 }
@@ -264,6 +269,46 @@ public class DataFrameAnalyticsTask extends AllocatedPersistentTask implements S
         ));
     }
 
+    /**
+     * This captures the possible states a job can be when it starts.
+     * {@code FIRST_TIME} means the job has never been started before.
+     * {@code RESUMING_REINDEXING} means the job was stopped while it was reindexing.
+     * {@code RESUMING_ANALYZING} means the job was stopped while it was analyzing.
+     * {@code FINISHED} means the job had finished.
+     */
+    public enum StartingState {
+        FIRST_TIME, RESUMING_REINDEXING, RESUMING_ANALYZING, FINISHED
+    }
+
+    public static StartingState determineStartingState(String jobId, List<PhaseProgress> progressOnStart) {
+        PhaseProgress lastIncompletePhase = null;
+        for (PhaseProgress phaseProgress : progressOnStart) {
+            if (phaseProgress.getProgressPercent() < 100) {
+                lastIncompletePhase = phaseProgress;
+                break;
+            }
+        }
+
+        if (lastIncompletePhase == null) {
+            return StartingState.FINISHED;
+        }
+
+        LOGGER.debug("[{}] Last incomplete progress [{}, {}]", jobId, lastIncompletePhase.getPhase(),
+            lastIncompletePhase.getProgressPercent());
+
+        switch (lastIncompletePhase.getPhase()) {
+            case ProgressTracker.REINDEXING:
+                return lastIncompletePhase.getProgressPercent() == 0 ? StartingState.FIRST_TIME : StartingState.RESUMING_REINDEXING;
+            case ProgressTracker.LOADING_DATA:
+            case ProgressTracker.ANALYZING:
+            case ProgressTracker.WRITING_RESULTS:
+                return StartingState.RESUMING_ANALYZING;
+            default:
+                LOGGER.warn("[{}] Unexpected progress phase [{}]", jobId, lastIncompletePhase.getPhase());
+                return StartingState.FIRST_TIME;
+        }
+    }
+
     public static String progressDocId(String id) {
         return "data_frame_analytics-" + id + "-progress";
     }

+ 7 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcess.java

@@ -5,6 +5,7 @@
  */
 package org.elasticsearch.xpack.ml.dataframe.process;
 
+import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.xpack.ml.process.NativeProcess;
 
 import java.io.IOException;
@@ -37,4 +38,10 @@ public interface AnalyticsProcess<ProcessResult> extends NativeProcess {
      * @return the process config
      */
     AnalyticsProcessConfig getConfig();
+
+    /**
+     * Restores the model state from a previously persisted one
+     * @param state the state to restore
+     */
+    void restoreState(BytesReference state) throws IOException;
 }

+ 4 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfig.java

@@ -46,6 +46,10 @@ public class AnalyticsProcessConfig implements ToXContentObject {
         this.analysis = Objects.requireNonNull(analysis);
     }
 
+    public String jobId() {
+        return jobId;
+    }
+
     public long rows() {
         return rows;
     }

+ 5 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessFactory.java

@@ -5,6 +5,8 @@
  */
 package org.elasticsearch.xpack.ml.dataframe.process;
 
+import org.elasticsearch.common.Nullable;
+import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
 
 import java.util.concurrent.ExecutorService;
@@ -17,10 +19,12 @@ public interface AnalyticsProcessFactory<ProcessResult> {
      *
      * @param config                 The data frame analytics config
      * @param analyticsProcessConfig The process configuration
+     * @param state                  The state document to restore from if there is one available
      * @param executorService        Executor service used to start the async tasks a job needs to operate the analytical process
      * @param onProcessCrash         Callback to execute if the process stops unexpectedly
      * @return The process
      */
     AnalyticsProcess<ProcessResult> createAnalyticsProcess(DataFrameAnalyticsConfig config, AnalyticsProcessConfig analyticsProcessConfig,
-                                                           ExecutorService executorService, Consumer<String> onProcessCrash);
+                                                           @Nullable BytesReference state, ExecutorService executorService,
+                                                           Consumer<String> onProcessCrash);
 }

+ 74 - 11
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java

@@ -10,13 +10,22 @@ import org.apache.logging.log4j.Logger;
 import org.apache.logging.log4j.message.ParameterizedMessage;
 import org.elasticsearch.action.admin.indices.refresh.RefreshAction;
 import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
+import org.elasticsearch.action.search.SearchRequest;
+import org.elasticsearch.action.search.SearchResponse;
 import org.elasticsearch.client.Client;
+import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.bytes.BytesReference;
+import org.elasticsearch.common.util.concurrent.ThreadContext;
+import org.elasticsearch.index.query.QueryBuilders;
+import org.elasticsearch.search.SearchHit;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.xpack.core.ClientHelper;
 import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
 import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
 import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
+import org.elasticsearch.xpack.core.ml.job.messages.Messages;
+import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.ml.MachineLearning;
 import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask;
@@ -25,6 +34,7 @@ import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFact
 import org.elasticsearch.xpack.ml.dataframe.process.customprocessing.CustomProcessor;
 import org.elasticsearch.xpack.ml.dataframe.process.customprocessing.CustomProcessorFactory;
 import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
+import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
 
 import java.io.IOException;
 import java.util.List;
@@ -36,6 +46,8 @@ import java.util.concurrent.ConcurrentMap;
 import java.util.concurrent.ExecutorService;
 import java.util.function.Consumer;
 
+import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
+
 public class AnalyticsProcessManager {
 
     private static final Logger LOGGER = LogManager.getLogger(AnalyticsProcessManager.class);
@@ -44,17 +56,20 @@ public class AnalyticsProcessManager {
     private final ThreadPool threadPool;
     private final AnalyticsProcessFactory<AnalyticsResult> processFactory;
     private final ConcurrentMap<Long, ProcessContext> processContextByAllocation = new ConcurrentHashMap<>();
+    private final DataFrameAnalyticsAuditor auditor;
 
     public AnalyticsProcessManager(Client client,
                                    ThreadPool threadPool,
-                                   AnalyticsProcessFactory<AnalyticsResult> analyticsProcessFactory) {
+                                   AnalyticsProcessFactory<AnalyticsResult> analyticsProcessFactory,
+                                   DataFrameAnalyticsAuditor auditor) {
         this.client = Objects.requireNonNull(client);
         this.threadPool = Objects.requireNonNull(threadPool);
         this.processFactory = Objects.requireNonNull(analyticsProcessFactory);
+        this.auditor = Objects.requireNonNull(auditor);
     }
 
-    public void runJob(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config,
-                       DataFrameDataExtractorFactory dataExtractorFactory, Consumer<Exception> finishHandler) {
+    public void runJob(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config, DataFrameDataExtractorFactory dataExtractorFactory,
+                       Consumer<Exception> finishHandler) {
         threadPool.generic().execute(() -> {
             if (task.isStopping()) {
                 // The task was requested to stop before we created the process context
@@ -68,17 +83,38 @@ public class AnalyticsProcessManager {
                     + "] Could not create process as one already exists"));
                 return;
             }
-            if (processContext.startProcess(dataExtractorFactory, config, task)) {
+
+            BytesReference state = getModelState(config);
+
+            if (processContext.startProcess(dataExtractorFactory, config, task, state)) {
                 ExecutorService executorService = threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME);
                 executorService.execute(() -> processResults(processContext));
                 executorService.execute(() -> processData(task, config, processContext.dataExtractor,
-                    processContext.process, processContext.resultProcessor, finishHandler));
+                    processContext.process, processContext.resultProcessor, finishHandler, state));
             } else {
                 finishHandler.accept(null);
             }
         });
     }
 
+    @Nullable
+    private BytesReference getModelState(DataFrameAnalyticsConfig config) {
+        if (config.getAnalysis().persistsState() == false) {
+            return null;
+        }
+
+        try (ThreadContext.StoredContext ignore = client.threadPool().getThreadContext().stashWithOrigin(ML_ORIGIN)) {
+            SearchRequest searchRequest = new SearchRequest(AnomalyDetectorsIndex.jobStateIndexPattern());
+            searchRequest.source().size(1).query(QueryBuilders.idsQuery().addIds(config.getAnalysis().getStateDocId(config.getId())));
+            SearchResponse searchResponse = client.prepareSearch(AnomalyDetectorsIndex.jobStateIndexPattern())
+                .setSize(1)
+                .setQuery(QueryBuilders.idsQuery().addIds(config.getAnalysis().getStateDocId(config.getId())))
+                .get();
+            SearchHit[] hits = searchResponse.getHits().getHits();
+            return hits.length == 0 ? null : hits[0].getSourceRef();
+        }
+    }
+
     private void processResults(ProcessContext processContext) {
         try {
             processContext.resultProcessor.process(processContext.process);
@@ -89,7 +125,7 @@ public class AnalyticsProcessManager {
 
     private void processData(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config, DataFrameDataExtractor dataExtractor,
                              AnalyticsProcess<AnalyticsResult> process, AnalyticsResultProcessor resultProcessor,
-                             Consumer<Exception> finishHandler) {
+                             Consumer<Exception> finishHandler, BytesReference state) {
 
         try {
             ProcessContext processContext = processContextByAllocation.get(task.getAllocationId());
@@ -98,6 +134,8 @@ public class AnalyticsProcessManager {
             process.writeEndOfDataMessage();
             process.flushStream();
 
+            restoreState(config, state, process, finishHandler);
+
             LOGGER.info("[{}] Waiting for result processor to complete", config.getId());
             resultProcessor.awaitForCompletion();
             processContext.setFailureReason(resultProcessor.getFailure());
@@ -105,7 +143,9 @@ public class AnalyticsProcessManager {
             refreshDest(config);
             LOGGER.info("[{}] Result processor has completed", config.getId());
         } catch (Exception e) {
-            String errorMsg = new ParameterizedMessage("[{}] Error while processing data", config.getId()).getFormattedMessage();
+            String errorMsg = new ParameterizedMessage("[{}] Error while processing data [{}]", config.getId(), e.getMessage())
+                .getFormattedMessage();
+            LOGGER.error(errorMsg, e);
             processContextByAllocation.get(task.getAllocationId()).setFailureReason(errorMsg);
         } finally {
             closeProcess(task);
@@ -172,10 +212,33 @@ public class AnalyticsProcessManager {
         process.writeRecord(headerRecord);
     }
 
+    private void restoreState(DataFrameAnalyticsConfig config, @Nullable BytesReference state, AnalyticsProcess<AnalyticsResult> process,
+                              Consumer<Exception> failureHandler) {
+        if (config.getAnalysis().persistsState() == false) {
+            LOGGER.debug("[{}] Analysis does not support state", config.getId());
+            return;
+        }
+
+        if (state == null) {
+            LOGGER.debug("[{}] No model state available to restore", config.getId());
+            return;
+        }
+
+        LOGGER.debug("[{}] Restoring from previous model state", config.getId());
+        auditor.info(config.getId(), Messages.DATA_FRAME_ANALYTICS_AUDIT_RESTORING_STATE);
+
+        try (ThreadContext.StoredContext ignore = client.threadPool().getThreadContext().stashWithOrigin(ML_ORIGIN)) {
+            process.restoreState(state);
+        } catch (Exception e) {
+            LOGGER.error(new ParameterizedMessage("[{}] Failed to restore state", process.getConfig().jobId()), e);
+            failureHandler.accept(ExceptionsHelper.serverError("Failed to restore state", e));
+        }
+    }
+
     private AnalyticsProcess<AnalyticsResult> createProcess(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config,
-                                                            AnalyticsProcessConfig analyticsProcessConfig) {
+                                                            AnalyticsProcessConfig analyticsProcessConfig, @Nullable BytesReference state) {
         ExecutorService executorService = threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME);
-        AnalyticsProcess<AnalyticsResult> process = processFactory.createAnalyticsProcess(config, analyticsProcessConfig,
+        AnalyticsProcess<AnalyticsResult> process = processFactory.createAnalyticsProcess(config, analyticsProcessConfig, state,
             executorService, onProcessCrash(task));
         if (process.isProcessAlive() == false) {
             throw ExceptionsHelper.serverError("Failed to start data frame analytics process");
@@ -275,7 +338,7 @@ public class AnalyticsProcessManager {
          * @return {@code true} if the process was started or {@code false} if it was not because it was stopped in the meantime
          */
         private synchronized boolean startProcess(DataFrameDataExtractorFactory dataExtractorFactory, DataFrameAnalyticsConfig config,
-                                                  DataFrameAnalyticsTask task) {
+                                                  DataFrameAnalyticsTask task, @Nullable BytesReference state) {
             if (processKilled) {
                 // The job was stopped before we started the process so no need to start it
                 return false;
@@ -290,7 +353,7 @@ public class AnalyticsProcessManager {
                 LOGGER.info("[{}] no data found to analyze. Will not start analytics native process.", config.getId());
                 return false;
             }
-            process = createProcess(task, config, analyticsProcessConfig);
+            process = createProcess(task, config, analyticsProcessConfig, state);
             DataFrameRowsJoiner dataFrameRowsJoiner = new DataFrameRowsJoiner(config.getId(), client,
                 dataExtractorFactory.newExtractor(true));
             resultProcessor = new AnalyticsResultProcessor(id, dataFrameRowsJoiner, this::isProcessKilled, task.getProgressTracker());

+ 1 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/MemoryUsageEstimationProcessManager.java

@@ -77,6 +77,7 @@ public class MemoryUsageEstimationProcessManager {
             processFactory.createAnalyticsProcess(
                 config,
                 processConfig,
+                null,
                 executorServiceForProcess,
                 // The handler passed here will never be called as AbstractNativeProcess.detectCrash method returns early when
                 // (processInStream == null) which is the case for MemoryUsageEstimationProcess.

+ 10 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeAnalyticsProcess.java

@@ -5,9 +5,11 @@
  */
 package org.elasticsearch.xpack.ml.dataframe.process;
 
+import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
 import org.elasticsearch.xpack.ml.process.NativeController;
 import org.elasticsearch.xpack.ml.process.ProcessResultsParser;
+import org.elasticsearch.xpack.ml.process.StateToProcessWriterHelper;
 
 import java.io.IOException;
 import java.io.InputStream;
@@ -57,4 +59,12 @@ public class NativeAnalyticsProcess extends AbstractNativeAnalyticsProcess<Analy
     public AnalyticsProcessConfig getConfig() {
         return config;
     }
+
+    @Override
+    public void restoreState(BytesReference state) throws IOException {
+        Objects.requireNonNull(state);
+        try (OutputStream restoreStream = processRestoreStream()) {
+            StateToProcessWriterHelper.writeStateToStream(state, restoreStream);
+        }
+    }
 }

+ 7 - 4
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeAnalyticsProcessFactory.java

@@ -9,6 +9,8 @@ import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 import org.elasticsearch.client.Client;
 import org.elasticsearch.cluster.service.ClusterService;
+import org.elasticsearch.common.Nullable;
+import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.common.unit.TimeValue;
 import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
 import org.elasticsearch.core.internal.io.IOUtils;
@@ -57,11 +59,12 @@ public class NativeAnalyticsProcessFactory implements AnalyticsProcessFactory<An
 
     @Override
     public NativeAnalyticsProcess createAnalyticsProcess(DataFrameAnalyticsConfig config, AnalyticsProcessConfig analyticsProcessConfig,
-                                                         ExecutorService executorService, Consumer<String> onProcessCrash) {
+                                                         @Nullable BytesReference state, ExecutorService executorService,
+                                                         Consumer<String> onProcessCrash) {
         String jobId = config.getId();
         List<Path> filesToDelete = new ArrayList<>();
         ProcessPipes processPipes = new ProcessPipes(env, NAMED_PIPE_HELPER, AnalyticsBuilder.ANALYTICS, jobId,
-                true, false, true, true, false, config.getAnalysis().persistsState());
+                true, false, true, true, state != null, config.getAnalysis().persistsState());
 
         // The extra 2 are for the checksum and the control field
         int numberOfFields = analyticsProcessConfig.cols() + 2;
@@ -69,8 +72,8 @@ public class NativeAnalyticsProcessFactory implements AnalyticsProcessFactory<An
         createNativeProcess(jobId, analyticsProcessConfig, filesToDelete, processPipes);
 
         NativeAnalyticsProcess analyticsProcess = new NativeAnalyticsProcess(jobId, nativeController, processPipes.getLogStream().get(),
-                processPipes.getProcessInStream().get(), processPipes.getProcessOutStream().get(), null, numberOfFields,
-                filesToDelete, onProcessCrash, analyticsProcessConfig);
+                processPipes.getProcessInStream().get(), processPipes.getProcessOutStream().get(),
+                processPipes.getRestoreStream().orElse(null), numberOfFields, filesToDelete, onProcessCrash, analyticsProcessConfig);
 
         try {
             startProcess(config, executorService, processPipes, analyticsProcess);

+ 6 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeMemoryUsageEstimationProcess.java

@@ -5,6 +5,7 @@
  */
 package org.elasticsearch.xpack.ml.dataframe.process;
 
+import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult;
 import org.elasticsearch.xpack.ml.process.NativeController;
 
@@ -30,4 +31,9 @@ public class NativeMemoryUsageEstimationProcess extends AbstractNativeAnalyticsP
     public AnalyticsProcessConfig getConfig() {
         throw new UnsupportedOperationException();
     }
+
+    @Override
+    public void restoreState(BytesReference state) {
+        throw new UnsupportedOperationException();
+    }
 }

+ 3 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeMemoryUsageEstimationProcessFactory.java

@@ -8,6 +8,8 @@ package org.elasticsearch.xpack.ml.dataframe.process;
 import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 import org.elasticsearch.cluster.service.ClusterService;
+import org.elasticsearch.common.Nullable;
+import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.common.unit.TimeValue;
 import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
 import org.elasticsearch.core.internal.io.IOUtils;
@@ -55,6 +57,7 @@ public class NativeMemoryUsageEstimationProcessFactory implements AnalyticsProce
     public NativeMemoryUsageEstimationProcess createAnalyticsProcess(
             DataFrameAnalyticsConfig config,
             AnalyticsProcessConfig analyticsProcessConfig,
+            @Nullable BytesReference state,
             ExecutorService executorService,
             Consumer<String> onProcessCrash) {
         List<Path> filesToDelete = new ArrayList<>();

+ 2 - 19
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/persistence/StateStreamer.java

@@ -7,8 +7,6 @@ package org.elasticsearch.xpack.ml.job.persistence;
 
 import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
-import org.apache.lucene.util.BytesRef;
-import org.apache.lucene.util.BytesRefIterator;
 import org.elasticsearch.action.search.SearchResponse;
 import org.elasticsearch.client.Client;
 import org.elasticsearch.common.bytes.BytesReference;
@@ -17,6 +15,7 @@ import org.elasticsearch.index.query.QueryBuilders;
 import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
 import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.CategorizerState;
 import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.ModelSnapshot;
+import org.elasticsearch.xpack.ml.process.StateToProcessWriterHelper;
 
 import java.io.IOException;
 import java.io.OutputStream;
@@ -115,22 +114,6 @@ public class StateStreamer {
             return;
         }
 
-        // The source bytes are already UTF-8.  The C++ process wants UTF-8, so we
-        // can avoid converting to a Java String only to convert back again.
-        BytesRefIterator iterator = source.iterator();
-        for (BytesRef ref = iterator.next(); ref != null; ref = iterator.next()) {
-            // There's a complication that the source can already have trailing 0 bytes
-            int length = ref.bytes.length;
-            while (length > 0 && ref.bytes[length - 1] == 0) {
-                --length;
-            }
-            if (length > 0) {
-                stream.write(ref.bytes, 0, length);
-            }
-        }
-        // This is dictated by RapidJSON on the C++ side; it treats a '\0' as end-of-file
-        // even when it's not really end-of-file, and this is what we need because we're
-        // sending multiple JSON documents via the same named pipe.
-        stream.write(0);
+        StateToProcessWriterHelper.writeStateToStream(source, stream);
     }
 }

+ 41 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/StateToProcessWriterHelper.java

@@ -0,0 +1,41 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.ml.process;
+
+import org.apache.lucene.util.BytesRef;
+import org.apache.lucene.util.BytesRefIterator;
+import org.elasticsearch.common.bytes.BytesReference;
+
+import java.io.IOException;
+import java.io.OutputStream;
+
+/**
+ * A helper class for writing state to a native process
+ */
+public final class StateToProcessWriterHelper {
+
+    private StateToProcessWriterHelper() {}
+
+    public static void writeStateToStream(BytesReference source, OutputStream stream) throws IOException {
+        // The source bytes are already UTF-8.  The C++ process wants UTF-8, so we
+        // can avoid converting to a Java String only to convert back again.
+        BytesRefIterator iterator = source.iterator();
+        for (BytesRef ref = iterator.next(); ref != null; ref = iterator.next()) {
+            // There's a complication that the source can already have trailing 0 bytes
+            int length = ref.bytes.length;
+            while (length > 0 && ref.bytes[length - 1] == 0) {
+                --length;
+            }
+            if (length > 0) {
+                stream.write(ref.bytes, 0, length);
+            }
+        }
+        // This is dictated by RapidJSON on the C++ side; it treats a '\0' as end-of-file
+        // even when it's not really end-of-file, and this is what we need because we're
+        // sending multiple JSON documents via the same named pipe.
+        stream.write(0);
+    }
+}

+ 2 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportStopDataFrameAnalyticsActionTests.java

@@ -16,6 +16,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
 import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState;
 
 import java.util.Arrays;
+import java.util.Collections;
 import java.util.HashSet;
 import java.util.Set;
 
@@ -75,7 +76,7 @@ public class TransportStopDataFrameAnalyticsActionTests extends ESTestCase {
     private static void addAnalyticsTask(PersistentTasksCustomMetaData.Builder builder, String analyticsId, String nodeId,
                                          DataFrameAnalyticsState state) {
         builder.addTask(MlTasks.dataFrameAnalyticsTaskId(analyticsId), MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME,
-            new StartDataFrameAnalyticsAction.TaskParams(analyticsId, Version.CURRENT),
+            new StartDataFrameAnalyticsAction.TaskParams(analyticsId, Version.CURRENT, Collections.emptyList()),
             new PersistentTasksCustomMetaData.Assignment(nodeId, "test assignment"));
 
         builder.updateTaskState(MlTasks.dataFrameAnalyticsTaskId(analyticsId),

+ 90 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTaskTests.java

@@ -0,0 +1,90 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.ml.dataframe;
+
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
+import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask.StartingState;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+import static org.hamcrest.Matchers.equalTo;
+
+public class DataFrameAnalyticsTaskTests extends ESTestCase {
+
+    public void testDetermineStartingState_GivenZeroProgress() {
+        List<PhaseProgress> progress = Arrays.asList(new PhaseProgress("reindexing", 0),
+            new PhaseProgress("loading_data", 0),
+            new PhaseProgress("analyzing", 0),
+            new PhaseProgress("writing_results", 0));
+
+        StartingState startingState = DataFrameAnalyticsTask.determineStartingState("foo", progress);
+
+        assertThat(startingState, equalTo(StartingState.FIRST_TIME));
+    }
+
+    public void testDetermineStartingState_GivenReindexingIsIncomplete() {
+        List<PhaseProgress> progress = Arrays.asList(new PhaseProgress("reindexing", 99),
+            new PhaseProgress("loading_data", 0),
+            new PhaseProgress("analyzing", 0),
+            new PhaseProgress("writing_results", 0));
+
+        StartingState startingState = DataFrameAnalyticsTask.determineStartingState("foo", progress);
+
+        assertThat(startingState, equalTo(StartingState.RESUMING_REINDEXING));
+    }
+
+    public void testDetermineStartingState_GivenLoadingDataIsIncomplete() {
+        List<PhaseProgress> progress = Arrays.asList(new PhaseProgress("reindexing", 100),
+            new PhaseProgress("loading_data", 1),
+            new PhaseProgress("analyzing", 0),
+            new PhaseProgress("writing_results", 0));
+
+        StartingState startingState = DataFrameAnalyticsTask.determineStartingState("foo", progress);
+
+        assertThat(startingState, equalTo(StartingState.RESUMING_ANALYZING));
+    }
+
+    public void testDetermineStartingState_GivenAnalyzingIsIncomplete() {
+        List<PhaseProgress> progress = Arrays.asList(new PhaseProgress("reindexing", 100),
+            new PhaseProgress("loading_data", 100),
+            new PhaseProgress("analyzing", 99),
+            new PhaseProgress("writing_results", 0));
+
+        StartingState startingState = DataFrameAnalyticsTask.determineStartingState("foo", progress);
+
+        assertThat(startingState, equalTo(StartingState.RESUMING_ANALYZING));
+    }
+
+    public void testDetermineStartingState_GivenWritingResultsIsIncomplete() {
+        List<PhaseProgress> progress = Arrays.asList(new PhaseProgress("reindexing", 100),
+            new PhaseProgress("loading_data", 100),
+            new PhaseProgress("analyzing", 100),
+            new PhaseProgress("writing_results", 1));
+
+        StartingState startingState = DataFrameAnalyticsTask.determineStartingState("foo", progress);
+
+        assertThat(startingState, equalTo(StartingState.RESUMING_ANALYZING));
+    }
+
+    public void testDetermineStartingState_GivenFinished() {
+        List<PhaseProgress> progress = Arrays.asList(new PhaseProgress("reindexing", 100),
+            new PhaseProgress("loading_data", 100),
+            new PhaseProgress("analyzing", 100),
+            new PhaseProgress("writing_results", 100));
+
+        StartingState startingState = DataFrameAnalyticsTask.determineStartingState("foo", progress);
+
+        assertThat(startingState, equalTo(StartingState.FINISHED));
+    }
+
+    public void testDetermineStartingState_GivenEmptyProgress() {
+        StartingState startingState = DataFrameAnalyticsTask.determineStartingState("foo", Collections.emptyList());
+        assertThat(startingState, equalTo(StartingState.FINISHED));
+    }
+}

+ 1 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/MemoryUsageEstimationProcessManagerTests.java

@@ -67,7 +67,7 @@ public class MemoryUsageEstimationProcessManagerTests extends ESTestCase {
         process = mock(AnalyticsProcess.class);
         when(process.readAnalyticsResults()).thenReturn(List.of(PROCESS_RESULT).iterator());
         processFactory = mock(AnalyticsProcessFactory.class);
-        when(processFactory.createAnalyticsProcess(any(), any(), any(), any())).thenReturn(process);
+        when(processFactory.createAnalyticsProcess(any(), any(), any(), any(), any())).thenReturn(process);
         dataExtractor = mock(DataFrameDataExtractor.class);
         when(dataExtractor.collectDataSummary()).thenReturn(new DataFrameDataExtractor.DataSummary(NUM_ROWS, NUM_COLS));
         dataExtractorFactory = mock(DataFrameDataExtractorFactory.class);

+ 1 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/JobNodeSelectorTests.java

@@ -562,7 +562,7 @@ public class JobNodeSelectorTests extends ESTestCase {
     static void addDataFrameAnalyticsJobTask(String id, String nodeId, DataFrameAnalyticsState state,
                                              PersistentTasksCustomMetaData.Builder builder, boolean isStale) {
         builder.addTask(MlTasks.dataFrameAnalyticsTaskId(id), MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME,
-            new StartDataFrameAnalyticsAction.TaskParams(id, Version.CURRENT),
+            new StartDataFrameAnalyticsAction.TaskParams(id, Version.CURRENT, Collections.emptyList()),
             new PersistentTasksCustomMetaData.Assignment(nodeId, "test assignment"));
         if (state != null) {
             builder.updateTaskState(MlTasks.dataFrameAnalyticsTaskId(id),

+ 2 - 2
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/process/MlMemoryTrackerTests.java

@@ -263,7 +263,7 @@ public class MlMemoryTrackerTests extends ESTestCase {
     private
     PersistentTasksCustomMetaData.PersistentTask<StartDataFrameAnalyticsAction.TaskParams> makeTestDataFrameAnalyticsTask(String id) {
         return new PersistentTasksCustomMetaData.PersistentTask<>(MlTasks.dataFrameAnalyticsTaskId(id),
-            MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, new StartDataFrameAnalyticsAction.TaskParams(id, Version.CURRENT), 0,
-            PersistentTasksCustomMetaData.INITIAL_ASSIGNMENT);
+            MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, new StartDataFrameAnalyticsAction.TaskParams(id, Version.CURRENT,
+            Collections.emptyList()), 0, PersistentTasksCustomMetaData.INITIAL_ASSIGNMENT);
     }
 }