Browse Source

[ML] Improve resuming a DFA job stopped during inference (#67623)

If a DFA job is stopped while in the inference phase, after
resuming we should start inference immediately. However, this
is currently not the case. Inference is tied in `AnalyticsProcessManager`
and thus we start a process, load data, restore state, etc., until
we get to start inference.

This commit gets rid of this unnecessary delay by factoring inference
out as an independent step and ensuring we can resume straight from
that phase upon restarting a job.
Dimitris Athanasiou 4 years ago
parent
commit
f449b8f19c
17 changed files with 479 additions and 91 deletions
  1. 3 3
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java
  2. 1 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java
  3. 55 4
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java
  4. 4 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTask.java
  5. 5 5
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java
  6. 1 71
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java
  7. 29 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/ProgressTracker.java
  8. 17 2
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolder.java
  9. 1 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/AbstractDataFrameAnalyticsStep.java
  10. 1 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/DataFrameAnalyticsStep.java
  11. 116 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/FinalStep.java
  12. 127 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/InferenceStep.java
  13. 12 0
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTaskTests.java
  14. 1 2
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java
  15. 9 1
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameAnalyticsManagerTests.java
  16. 71 0
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/stats/ProgressTrackerTests.java
  17. 26 0
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolderTests.java

+ 3 - 3
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java

@@ -764,7 +764,6 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
             analyticsProcessFactory,
             dataFrameAnalyticsAuditor,
             trainedModelProvider,
-            modelLoadingService,
             resultsPersisterService,
             EsExecutors.allocatedProcessors(settings));
         MemoryUsageEstimationProcessManager memoryEstimationProcessManager =
@@ -773,8 +772,9 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
         DataFrameAnalyticsConfigProvider dataFrameAnalyticsConfigProvider = new DataFrameAnalyticsConfigProvider(client, xContentRegistry,
             dataFrameAnalyticsAuditor);
         assert client instanceof NodeClient;
-        DataFrameAnalyticsManager dataFrameAnalyticsManager = new DataFrameAnalyticsManager((NodeClient) client, clusterService,
-            dataFrameAnalyticsConfigProvider, analyticsProcessManager, dataFrameAnalyticsAuditor, indexNameExpressionResolver);
+        DataFrameAnalyticsManager dataFrameAnalyticsManager = new DataFrameAnalyticsManager(settings, (NodeClient) client, threadPool,
+            clusterService, dataFrameAnalyticsConfigProvider, analyticsProcessManager, dataFrameAnalyticsAuditor,
+            indexNameExpressionResolver, resultsPersisterService, modelLoadingService);
         this.dataFrameAnalyticsManager.set(dataFrameAnalyticsManager);
 
         // Components shared by anomaly detection and data frame analytics

+ 1 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java

@@ -273,6 +273,7 @@ public class TransportStartDataFrameAnalyticsAction
                         break;
                     case RESUMING_REINDEXING:
                     case RESUMING_ANALYZING:
+                    case RESUMING_INFERENCE:
                         toValidateMappingsListener.onResponse(startContext);
                         break;
                     case FINISHED:

+ 55 - 4
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java

@@ -19,20 +19,30 @@ import org.elasticsearch.cluster.metadata.IndexMetadata;
 import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
 import org.elasticsearch.cluster.metadata.MappingMetadata;
 import org.elasticsearch.cluster.service.ClusterService;
+import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.index.IndexNotFoundException;
+import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.xpack.core.ClientHelper;
 import org.elasticsearch.xpack.core.ml.MlStatsIndex;
 import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
 import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
 import org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+import org.elasticsearch.xpack.ml.dataframe.extractor.ExtractedFieldsDetector;
+import org.elasticsearch.xpack.ml.dataframe.extractor.ExtractedFieldsDetectorFactory;
+import org.elasticsearch.xpack.ml.dataframe.inference.InferenceRunner;
 import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider;
 import org.elasticsearch.xpack.ml.dataframe.process.AnalyticsProcessManager;
 import org.elasticsearch.xpack.ml.dataframe.steps.AnalysisStep;
 import org.elasticsearch.xpack.ml.dataframe.steps.DataFrameAnalyticsStep;
+import org.elasticsearch.xpack.ml.dataframe.steps.FinalStep;
+import org.elasticsearch.xpack.ml.dataframe.steps.InferenceStep;
 import org.elasticsearch.xpack.ml.dataframe.steps.ReindexingStep;
 import org.elasticsearch.xpack.ml.dataframe.steps.StepResponse;
+import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
+import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
 import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
+import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;
 
 import java.util.Objects;
 import java.util.concurrent.atomic.AtomicBoolean;
@@ -43,27 +53,36 @@ public class DataFrameAnalyticsManager {
 
     private static final Logger LOGGER = LogManager.getLogger(DataFrameAnalyticsManager.class);
 
+    private final Settings settings;
     /**
      * We need a {@link NodeClient} to get the reindexing task and be able to report progress
      */
     private final NodeClient client;
+    private final ThreadPool threadPool;
     private final ClusterService clusterService;
     private final DataFrameAnalyticsConfigProvider configProvider;
     private final AnalyticsProcessManager processManager;
     private final DataFrameAnalyticsAuditor auditor;
     private final IndexNameExpressionResolver expressionResolver;
+    private final ResultsPersisterService resultsPersisterService;
+    private final ModelLoadingService modelLoadingService;
     /** Indicates whether the node is shutting down. */
     private final AtomicBoolean nodeShuttingDown = new AtomicBoolean();
 
-    public DataFrameAnalyticsManager(NodeClient client, ClusterService clusterService, DataFrameAnalyticsConfigProvider configProvider,
-                                     AnalyticsProcessManager processManager, DataFrameAnalyticsAuditor auditor,
-                                     IndexNameExpressionResolver expressionResolver) {
+    public DataFrameAnalyticsManager(Settings settings, NodeClient client, ThreadPool threadPool, ClusterService clusterService,
+                                     DataFrameAnalyticsConfigProvider configProvider, AnalyticsProcessManager processManager,
+                                     DataFrameAnalyticsAuditor auditor, IndexNameExpressionResolver expressionResolver,
+                                     ResultsPersisterService resultsPersisterService, ModelLoadingService modelLoadingService) {
+        this.settings = Objects.requireNonNull(settings);
         this.client = Objects.requireNonNull(client);
+        this.threadPool = Objects.requireNonNull(threadPool);
         this.clusterService = Objects.requireNonNull(clusterService);
         this.configProvider = Objects.requireNonNull(configProvider);
         this.processManager = Objects.requireNonNull(processManager);
         this.auditor = Objects.requireNonNull(auditor);
         this.expressionResolver = Objects.requireNonNull(expressionResolver);
+        this.resultsPersisterService = Objects.requireNonNull(resultsPersisterService);
+        this.modelLoadingService = Objects.requireNonNull(modelLoadingService);
     }
 
     public void execute(DataFrameAnalyticsTask task, ClusterState clusterState) {
@@ -141,6 +160,12 @@ public class DataFrameAnalyticsManager {
             case RESUMING_ANALYZING:
                 executeStep(task, config, new AnalysisStep(client, task, auditor, config, processManager));
                 break;
+            case RESUMING_INFERENCE:
+                buildInferenceStep(task, config, ActionListener.wrap(
+                    inferenceStep -> executeStep(task, config, inferenceStep),
+                    task::setFailed
+                ));
+                break;
             case FINISHED:
             default:
                 task.setFailed(ExceptionsHelper.serverError("Unexpected starting state [" + startingState + "]"));
@@ -162,7 +187,15 @@ public class DataFrameAnalyticsManager {
                         executeStep(task, config, new AnalysisStep(client, task, auditor, config, processManager));
                         break;
                     case ANALYSIS:
-                        // This is the last step
+                        buildInferenceStep(task, config, ActionListener.wrap(
+                            inferenceStep -> executeStep(task, config, inferenceStep),
+                            task::setFailed
+                        ));
+                        break;
+                    case INFERENCE:
+                        executeStep(task, config, new FinalStep(client, task, auditor, config));
+                        break;
+                    case FINAL:
                         LOGGER.info("[{}] Marking task completed", config.getId());
                         task.markAsCompleted();
                         break;
@@ -199,6 +232,24 @@ public class DataFrameAnalyticsManager {
             ));
     }
 
+    private void buildInferenceStep(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config, ActionListener<InferenceStep> listener) {
+        ParentTaskAssigningClient parentTaskClient = new ParentTaskAssigningClient(client, task.getParentTaskId());
+
+        ActionListener<ExtractedFieldsDetector> extractedFieldsDetectorListener = ActionListener.wrap(
+            extractedFieldsDetector -> {
+                ExtractedFields extractedFields = extractedFieldsDetector.detect().v1();
+                InferenceRunner inferenceRunner = new InferenceRunner(settings, parentTaskClient, modelLoadingService,
+                    resultsPersisterService, task.getParentTaskId(), config, extractedFields, task.getStatsHolder().getProgressTracker(),
+                    task.getStatsHolder().getDataCountsTracker());
+                InferenceStep inferenceStep = new InferenceStep(client, task, auditor, config, threadPool, inferenceRunner);
+                listener.onResponse(inferenceStep);
+            },
+            listener::onFailure
+        );
+
+        new ExtractedFieldsDetectorFactory(parentTaskClient).createFromDest(config, extractedFieldsDetectorListener);
+    }
+
     public boolean isNodeShuttingDown() {
         return nodeShuttingDown.get();
     }

+ 4 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTask.java

@@ -287,7 +287,7 @@ public class DataFrameAnalyticsTask extends AllocatedPersistentTask implements S
      * {@code FINISHED} means the job had finished.
      */
     public enum StartingState {
-        FIRST_TIME, RESUMING_REINDEXING, RESUMING_ANALYZING, FINISHED
+        FIRST_TIME, RESUMING_REINDEXING, RESUMING_ANALYZING, RESUMING_INFERENCE, FINISHED
     }
 
     public StartingState determineStartingState() {
@@ -313,6 +313,9 @@ public class DataFrameAnalyticsTask extends AllocatedPersistentTask implements S
         if (ProgressTracker.REINDEXING.equals(lastIncompletePhase.getPhase())) {
             return lastIncompletePhase.getProgressPercent() == 0 ? StartingState.FIRST_TIME : StartingState.RESUMING_REINDEXING;
         }
+        if (ProgressTracker.INFERENCE.equals(lastIncompletePhase.getPhase())) {
+            return StartingState.RESUMING_INFERENCE;
+        }
         return StartingState.RESUMING_ANALYZING;
     }
 }

+ 5 - 5
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java

@@ -99,7 +99,7 @@ public class DataFrameDataExtractor {
     }
 
     public void cancel() {
-        LOGGER.debug("[{}] Data extractor was cancelled", context.jobId);
+        LOGGER.debug(() -> new ParameterizedMessage("[{}] Data extractor was cancelled", context.jobId));
         isCancelled = true;
     }
 
@@ -127,7 +127,7 @@ public class DataFrameDataExtractor {
             // We've set allow_partial_search_results to false which means if something
             // goes wrong the request will throw.
             SearchResponse searchResponse = request.get();
-            LOGGER.debug("[{}] Search response was obtained", context.jobId);
+            LOGGER.trace(() -> new ParameterizedMessage("[{}] Search response was obtained", context.jobId));
 
             List<Row> rows = processSearchResponse(searchResponse);
 
@@ -153,7 +153,7 @@ public class DataFrameDataExtractor {
         long from = lastSortKey + 1;
         long to = from + context.scrollSize;
 
-        LOGGER.debug(() -> new ParameterizedMessage(
+        LOGGER.trace(() -> new ParameterizedMessage(
             "[{}] Searching docs with [{}] in [{}, {})", context.jobId, DestinationIndex.INCREMENTAL_ID, from, to));
 
         SearchRequestBuilder searchRequestBuilder = new SearchRequestBuilder(client, SearchAction.INSTANCE)
@@ -283,7 +283,7 @@ public class DataFrameDataExtractor {
         }
         boolean isTraining = trainTestSplitter.get().isTraining(extractedValues);
         Row row = new Row(extractedValues, hit, isTraining);
-        LOGGER.debug(() -> new ParameterizedMessage("[{}] Extracted row: sort key = [{}], is_training = [{}], values = {}",
+        LOGGER.trace(() -> new ParameterizedMessage("[{}] Extracted row: sort key = [{}], is_training = [{}], values = {}",
             context.jobId, row.getSortKey(), isTraining, Arrays.toString(row.values)));
         return row;
     }
@@ -306,7 +306,7 @@ public class DataFrameDataExtractor {
         SearchRequestBuilder searchRequestBuilder = buildDataSummarySearchRequestBuilder();
         SearchResponse searchResponse = executeSearchRequest(searchRequestBuilder);
         long rows = searchResponse.getHits().getTotalHits().value;
-        LOGGER.debug("[{}] Data summary rows [{}]", context.jobId, rows);
+        LOGGER.debug(() -> new ParameterizedMessage("[{}] Data summary rows [{}]", context.jobId, rows));
         return new DataSummary(rows, organicFeatures.length + processedFeatures.length);
     }
 

+ 1 - 71
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java

@@ -10,21 +10,14 @@ import org.apache.logging.log4j.Logger;
 import org.apache.logging.log4j.message.ParameterizedMessage;
 import org.apache.lucene.util.SetOnce;
 import org.elasticsearch.action.ActionListener;
-import org.elasticsearch.action.admin.indices.refresh.RefreshAction;
-import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
 import org.elasticsearch.action.search.SearchResponse;
-import org.elasticsearch.action.support.IndicesOptions;
 import org.elasticsearch.client.Client;
-import org.elasticsearch.client.ParentTaskAssigningClient;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
 import org.elasticsearch.index.query.QueryBuilders;
 import org.elasticsearch.threadpool.ThreadPool;
-import org.elasticsearch.xpack.core.ClientHelper;
-import org.elasticsearch.xpack.core.ml.MlStatsIndex;
 import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
-import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts;
 import org.elasticsearch.xpack.core.ml.job.messages.Messages;
 import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
@@ -32,20 +25,17 @@ import org.elasticsearch.xpack.ml.MachineLearning;
 import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask;
 import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
 import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory;
-import org.elasticsearch.xpack.ml.dataframe.inference.InferenceRunner;
 import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
 import org.elasticsearch.xpack.ml.dataframe.stats.DataCountsTracker;
 import org.elasticsearch.xpack.ml.dataframe.stats.ProgressTracker;
 import org.elasticsearch.xpack.ml.dataframe.stats.StatsPersister;
 import org.elasticsearch.xpack.ml.dataframe.steps.StepResponse;
 import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
-import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
 import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
 import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
 import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;
 
 import java.io.IOException;
-import java.util.Arrays;
 import java.util.List;
 import java.util.Objects;
 import java.util.Optional;
@@ -69,7 +59,6 @@ public class AnalyticsProcessManager {
     private final ConcurrentMap<Long, ProcessContext> processContextByAllocation = new ConcurrentHashMap<>();
     private final DataFrameAnalyticsAuditor auditor;
     private final TrainedModelProvider trainedModelProvider;
-    private final ModelLoadingService modelLoadingService;
     private final ResultsPersisterService resultsPersisterService;
     private final int numAllocatedProcessors;
 
@@ -79,7 +68,6 @@ public class AnalyticsProcessManager {
                                    AnalyticsProcessFactory<AnalyticsResult> analyticsProcessFactory,
                                    DataFrameAnalyticsAuditor auditor,
                                    TrainedModelProvider trainedModelProvider,
-                                   ModelLoadingService modelLoadingService,
                                    ResultsPersisterService resultsPersisterService,
                                    int numAllocatedProcessors) {
         this(
@@ -90,7 +78,6 @@ public class AnalyticsProcessManager {
             analyticsProcessFactory,
             auditor,
             trainedModelProvider,
-            modelLoadingService,
             resultsPersisterService,
             numAllocatedProcessors);
     }
@@ -103,7 +90,6 @@ public class AnalyticsProcessManager {
                                    AnalyticsProcessFactory<AnalyticsResult> analyticsProcessFactory,
                                    DataFrameAnalyticsAuditor auditor,
                                    TrainedModelProvider trainedModelProvider,
-                                   ModelLoadingService modelLoadingService,
                                    ResultsPersisterService resultsPersisterService,
                                    int numAllocatedProcessors) {
         this.settings = Objects.requireNonNull(settings);
@@ -113,7 +99,6 @@ public class AnalyticsProcessManager {
         this.processFactory = Objects.requireNonNull(analyticsProcessFactory);
         this.auditor = Objects.requireNonNull(auditor);
         this.trainedModelProvider = Objects.requireNonNull(trainedModelProvider);
-        this.modelLoadingService = Objects.requireNonNull(modelLoadingService);
         this.resultsPersisterService = Objects.requireNonNull(resultsPersisterService);
         this.numAllocatedProcessors = numAllocatedProcessors;
     }
@@ -183,7 +168,6 @@ public class AnalyticsProcessManager {
         LOGGER.info("[{}] Started loading data", processContext.config.getId());
         auditor.info(processContext.config.getId(), Messages.getMessage(Messages.DATA_FRAME_ANALYTICS_AUDIT_STARTED_LOADING_DATA));
 
-        ParentTaskAssigningClient parentTaskClient = new ParentTaskAssigningClient(client, task.getParentTaskId());
         DataFrameAnalyticsConfig config = processContext.config;
         DataFrameDataExtractor dataExtractor = processContext.dataExtractor.get();
         AnalyticsProcess<AnalyticsResult> process = processContext.process.get();
@@ -203,14 +187,6 @@ public class AnalyticsProcessManager {
             resultProcessor.awaitForCompletion();
             processContext.setFailureReason(resultProcessor.getFailure());
             LOGGER.info("[{}] Result processor has completed", config.getId());
-
-            runInference(parentTaskClient, task, processContext, dataExtractor.getExtractedFields());
-
-            processContext.statsPersister.persistWithRetry(task.getStatsHolder().getDataCountsTracker().report(config.getId()),
-                DataCounts::documentId);
-
-            refreshDest(parentTaskClient, config);
-            refreshIndices(parentTaskClient, config.getId());
         } catch (Exception e) {
             if (task.isStopping()) {
                 // Errors during task stopping are expected but we still want to log them just in case.
@@ -338,43 +314,6 @@ public class AnalyticsProcessManager {
         };
     }
 
-    private void runInference(ParentTaskAssigningClient parentTaskClient, DataFrameAnalyticsTask task, ProcessContext processContext,
-                              ExtractedFields extractedFields) {
-        if (task.isStopping() || processContext.failureReason.get() != null) {
-            // If the task is stopping or there has been an error thus far let's not run inference at all
-            return;
-        }
-
-        if (processContext.config.getAnalysis().supportsInference()) {
-            refreshDest(parentTaskClient, processContext.config);
-            InferenceRunner inferenceRunner = new InferenceRunner(settings, parentTaskClient, modelLoadingService, resultsPersisterService,
-                task.getParentTaskId(), processContext.config, extractedFields, task.getStatsHolder().getProgressTracker(),
-                task.getStatsHolder().getDataCountsTracker());
-            processContext.setInferenceRunner(inferenceRunner);
-            inferenceRunner.run(processContext.resultProcessor.get().getLatestModelId());
-        }
-    }
-
-    private void refreshDest(ParentTaskAssigningClient parentTaskClient, DataFrameAnalyticsConfig config) {
-        ClientHelper.executeWithHeaders(config.getHeaders(), ClientHelper.ML_ORIGIN, parentTaskClient,
-            () -> parentTaskClient.execute(RefreshAction.INSTANCE, new RefreshRequest(config.getDest().getIndex())).actionGet());
-    }
-
-    private void refreshIndices(ParentTaskAssigningClient parentTaskClient, String jobId) {
-        RefreshRequest refreshRequest = new RefreshRequest(
-            AnomalyDetectorsIndex.jobStateIndexPattern(),
-            MlStatsIndex.indexPattern()
-        );
-        refreshRequest.indicesOptions(IndicesOptions.lenientExpandOpen());
-
-        LOGGER.debug(() -> new ParameterizedMessage("[{}] Refreshing indices {}",
-            jobId, Arrays.toString(refreshRequest.indices())));
-
-        try (ThreadContext.StoredContext ignore = parentTaskClient.threadPool().getThreadContext().stashWithOrigin(ML_ORIGIN)) {
-            parentTaskClient.admin().indices().refresh(refreshRequest).actionGet();
-        }
-    }
-
     private void closeProcess(DataFrameAnalyticsTask task) {
         String configId = task.getParams().getId();
         LOGGER.info("[{}] Closing process", configId);
@@ -415,13 +354,10 @@ public class AnalyticsProcessManager {
         private final SetOnce<AnalyticsProcess<AnalyticsResult>> process = new SetOnce<>();
         private final SetOnce<DataFrameDataExtractor> dataExtractor = new SetOnce<>();
         private final SetOnce<AnalyticsResultProcessor> resultProcessor = new SetOnce<>();
-        private final SetOnce<InferenceRunner> inferenceRunner = new SetOnce<>();
         private final SetOnce<String> failureReason = new SetOnce<>();
-        private final StatsPersister statsPersister;
 
         ProcessContext(DataFrameAnalyticsConfig config) {
             this.config = Objects.requireNonNull(config);
-            this.statsPersister = new StatsPersister(config.getId(), resultsPersisterService, auditor);
         }
 
         String getFailureReason() {
@@ -436,10 +372,6 @@ public class AnalyticsProcessManager {
             this.failureReason.trySet(failureReason);
         }
 
-        void setInferenceRunner(InferenceRunner inferenceRunner) {
-            this.inferenceRunner.set(inferenceRunner);
-        }
-
         synchronized void stop() {
             LOGGER.debug("[{}] Stopping process", config.getId());
             if (dataExtractor.get() != null) {
@@ -448,9 +380,6 @@ public class AnalyticsProcessManager {
             if (resultProcessor.get() != null) {
                 resultProcessor.get().cancel();
             }
-            if (inferenceRunner.get() != null) {
-                inferenceRunner.get().cancel();
-            }
             if (process.get() != null) {
                 try {
                     process.get().kill(true);
@@ -507,6 +436,7 @@ public class AnalyticsProcessManager {
             DataFrameRowsJoiner dataFrameRowsJoiner =
                 new DataFrameRowsJoiner(config.getId(), settings, task.getParentTaskId(),
                         dataExtractorFactory.newExtractor(true), resultsPersisterService);
+            StatsPersister statsPersister = new StatsPersister(config.getId(), resultsPersisterService, auditor);
             return new AnalyticsResultProcessor(
                 config, dataFrameRowsJoiner, task.getStatsHolder(), trainedModelProvider, auditor, statsPersister,
                 dataExtractor.get().getExtractedFields());

+ 29 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/ProgressTracker.java

@@ -54,6 +54,9 @@ public class ProgressTracker {
         assert progressPercentPerPhase.containsKey(REINDEXING);
         assert progressPercentPerPhase.containsKey(LOADING_DATA);
         assert progressPercentPerPhase.containsKey(WRITING_RESULTS);
+        // If there is inference it should be the last phase otherwise there
+        // are assumptions that do not hold.
+        assert progressPercentPerPhase.containsKey(INFERENCE) == false || INFERENCE.equals(phasesInOrder[phasesInOrder.length - 1]);
     }
 
     public void updateReindexingProgress(int progressPercent) {
@@ -96,6 +99,32 @@ public class ProgressTracker {
         progressPercentPerPhase.computeIfPresent(phase, (k, v) -> Math.max(v, progress));
     }
 
+    /**
+     * Resets progress to reflect all phases are complete except for inference
+     * which is set to zero.
+     */
+    public void resetForInference() {
+        for (Map.Entry<String, Integer> phaseProgress : progressPercentPerPhase.entrySet()) {
+            if (phaseProgress.getKey().equals(INFERENCE)) {
+                progressPercentPerPhase.put(phaseProgress.getKey(), 0);
+            } else {
+                progressPercentPerPhase.put(phaseProgress.getKey(), 100);
+            }
+        }
+    }
+
+    /**
+     * Returns whether all phases before inference are complete
+     */
+    public boolean areAllPhasesExceptInferenceComplete() {
+        for (Map.Entry<String, Integer> phaseProgress : progressPercentPerPhase.entrySet()) {
+            if (phaseProgress.getKey().equals(INFERENCE) == false && phaseProgress.getValue() < 100) {
+                return false;
+            }
+        }
+        return true;
+    }
+
     public List<PhaseProgress> report() {
         return Arrays.stream(phasesInOrder)
             .map(phase -> new PhaseProgress(phase, progressPercentPerPhase.get(phase)))

+ 17 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolder.java

@@ -34,15 +34,30 @@ public class StatsHolder {
         progressTracker = new ProgressTracker(progress);
     }
 
+    /**
+     * Updates the progress tracker with potentially new in-between phases
+     * that were introduced in a later version while making sure progress indicators
+     * are correct.
+     * @param analysisPhases the full set of phases of the analysis in current version
+     * @param hasInferencePhase whether the analysis supports inference
+     */
     public void adjustProgressTracker(List<String> analysisPhases, boolean hasInferencePhase) {
         int reindexingProgressPercent = progressTracker.getReindexingProgressPercent();
+        boolean areAllPhasesBeforeInferenceComplete = progressTracker.areAllPhasesExceptInferenceComplete();
         progressTracker = ProgressTracker.fromZeroes(analysisPhases, hasInferencePhase);
 
         // If reindexing progress was more than 0 and less than 100 (ie not complete) we reset it to 1
         // as we will have to do reindexing from scratch and at the same time we want
         // to differentiate from a job that has never started before.
-        progressTracker.updateReindexingProgress(
-            (reindexingProgressPercent > 0 && reindexingProgressPercent < 100) ? 1 : reindexingProgressPercent);
+        if (reindexingProgressPercent > 0 && reindexingProgressPercent < 100) {
+            progressTracker.updateReindexingProgress(1);
+        } else {
+            progressTracker.updateReindexingProgress(reindexingProgressPercent);
+        }
+
+        if (hasInferencePhase && areAllPhasesBeforeInferenceComplete) {
+            progressTracker.resetForInference();
+        }
     }
 
     public void resetProgressTracker(List<String> analysisPhases, boolean hasInferencePhase) {

+ 1 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/AbstractDataFrameAnalyticsStep.java

@@ -58,7 +58,7 @@ abstract class AbstractDataFrameAnalyticsStep implements DataFrameAnalyticsStep
     public final void execute(ActionListener<StepResponse> listener) {
         logger.debug(() -> new ParameterizedMessage("[{}] Executing step [{}]", config.getId(), name()));
         if (task.isStopping()) {
-            logger.debug(() -> new ParameterizedMessage("[{}] task is stopping before starting [{}]", config.getId(), name()));
+            logger.debug(() -> new ParameterizedMessage("[{}] task is stopping before starting [{}] step", config.getId(), name()));
             listener.onResponse(new StepResponse(true));
             return;
         }

+ 1 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/DataFrameAnalyticsStep.java

@@ -14,7 +14,7 @@ import java.util.Locale;
 public interface DataFrameAnalyticsStep {
 
     enum Name {
-        REINDEXING, ANALYSIS;
+        REINDEXING, ANALYSIS, INFERENCE, FINAL;
 
         @Override
         public String toString() {

+ 116 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/FinalStep.java

@@ -0,0 +1,116 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+
+package org.elasticsearch.xpack.ml.dataframe.steps;
+
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.apache.logging.log4j.message.ParameterizedMessage;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
+import org.elasticsearch.action.admin.indices.refresh.RefreshResponse;
+import org.elasticsearch.action.index.IndexRequest;
+import org.elasticsearch.action.index.IndexResponse;
+import org.elasticsearch.action.support.IndicesOptions;
+import org.elasticsearch.client.ParentTaskAssigningClient;
+import org.elasticsearch.client.node.NodeClient;
+import org.elasticsearch.common.unit.TimeValue;
+import org.elasticsearch.common.util.concurrent.ThreadContext;
+import org.elasticsearch.common.xcontent.ToXContent;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentFactory;
+import org.elasticsearch.xpack.core.ml.MlStatsIndex;
+import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
+import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts;
+import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
+import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask;
+import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+
+import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
+
+/**
+ * The final step of a data frame analytics job.
+ * Allows the job to perform finalizing tasks like refresh indices,
+ * persist stats, etc.
+ */
+public class FinalStep extends AbstractDataFrameAnalyticsStep {
+
+    private static final Logger LOGGER = LogManager.getLogger(FinalStep.class);
+
+    public FinalStep(NodeClient client, DataFrameAnalyticsTask task, DataFrameAnalyticsAuditor auditor, DataFrameAnalyticsConfig config) {
+        super(client, task, auditor, config);
+    }
+
+    @Override
+    public Name name() {
+        return Name.FINAL;
+    }
+
+    @Override
+    protected void doExecute(ActionListener<StepResponse> listener) {
+
+        ActionListener<RefreshResponse> refreshListener = ActionListener.wrap(
+            refreshResponse -> listener.onResponse(new StepResponse(true)),
+            listener::onFailure
+        );
+
+        ActionListener<IndexResponse> dataCountsIndexedListener = ActionListener.wrap(
+            indexResponse -> refreshIndices(refreshListener),
+            listener::onFailure
+        );
+
+        indexDataCounts(dataCountsIndexedListener);
+    }
+
+    private void indexDataCounts(ActionListener<IndexResponse> listener) {
+        DataCounts dataCounts = task.getStatsHolder().getDataCountsTracker().report(config.getId());
+        try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
+            dataCounts.toXContent(builder, new ToXContent.MapParams(
+                Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")));
+            IndexRequest indexRequest = new IndexRequest(MlStatsIndex.writeAlias())
+                .id(DataCounts.documentId(config.getId()))
+                .setRequireAlias(true)
+                .source(builder);
+            parentTaskClient().index(indexRequest, listener);
+        } catch (IOException e) {
+            listener.onFailure(ExceptionsHelper.serverError("[{}] Error persisting final data counts", e, config.getId()));
+        }
+    }
+
+    private void refreshIndices(ActionListener<RefreshResponse> listener) {
+        RefreshRequest refreshRequest = new RefreshRequest(
+            AnomalyDetectorsIndex.jobStateIndexPattern(),
+            MlStatsIndex.indexPattern(),
+            config.getDest().getIndex()
+        );
+        refreshRequest.indicesOptions(IndicesOptions.lenientExpandOpen());
+
+        LOGGER.debug(() -> new ParameterizedMessage("[{}] Refreshing indices {}", config.getId(),
+            Arrays.toString(refreshRequest.indices())));
+
+        ParentTaskAssigningClient parentTaskClient = parentTaskClient();
+        try (ThreadContext.StoredContext ignore = parentTaskClient.threadPool().getThreadContext().stashWithOrigin(ML_ORIGIN)) {
+            parentTaskClient.admin().indices().refresh(refreshRequest, listener);
+        }
+    }
+
+    @Override
+    public void cancel(String reason, TimeValue timeout) {
+        // Not cancellable
+    }
+
+    @Override
+    public void updateProgress(ActionListener<Void> listener) {
+        // No progress to update
+        listener.onResponse(null);
+    }
+}

+ 127 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/InferenceStep.java

@@ -0,0 +1,127 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+
+package org.elasticsearch.xpack.ml.dataframe.steps;
+
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.apache.logging.log4j.message.ParameterizedMessage;
+import org.elasticsearch.ResourceNotFoundException;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.admin.indices.refresh.RefreshResponse;
+import org.elasticsearch.action.search.SearchAction;
+import org.elasticsearch.action.search.SearchRequest;
+import org.elasticsearch.client.node.NodeClient;
+import org.elasticsearch.common.unit.TimeValue;
+import org.elasticsearch.index.query.QueryBuilders;
+import org.elasticsearch.search.SearchHit;
+import org.elasticsearch.search.builder.SearchSourceBuilder;
+import org.elasticsearch.search.sort.SortOrder;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
+import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
+import org.elasticsearch.xpack.ml.MachineLearning;
+import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask;
+import org.elasticsearch.xpack.ml.dataframe.inference.InferenceRunner;
+import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
+
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
+import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
+
+public class InferenceStep extends AbstractDataFrameAnalyticsStep {
+
+    private static final Logger LOGGER = LogManager.getLogger(InferenceStep.class);
+
+    private final ThreadPool threadPool;
+    private final InferenceRunner inferenceRunner;
+
+    public InferenceStep(NodeClient client, DataFrameAnalyticsTask task, DataFrameAnalyticsAuditor auditor, DataFrameAnalyticsConfig config,
+                         ThreadPool threadPool, InferenceRunner inferenceRunner) {
+        super(client, task, auditor, config);
+        this.threadPool = Objects.requireNonNull(threadPool);
+        this.inferenceRunner = Objects.requireNonNull(inferenceRunner);
+    }
+
+    @Override
+    public Name name() {
+        return Name.INFERENCE;
+    }
+
+    @Override
+    protected void doExecute(ActionListener<StepResponse> listener) {
+        if (config.getAnalysis().supportsInference() == false) {
+            LOGGER.debug(() -> new ParameterizedMessage(
+                "[{}] Inference step completed immediately as analysis does not support inference", config.getId()));
+            listener.onResponse(new StepResponse(false));
+            return;
+        }
+
+        ActionListener<String> modelIdListener = ActionListener.wrap(
+            modelId -> runInference(modelId, listener),
+            listener::onFailure
+        );
+
+        ActionListener<RefreshResponse> refreshDestListener = ActionListener.wrap(
+            refreshResponse -> getModelId(modelIdListener),
+            listener::onFailure
+        );
+
+        refreshDestAsync(refreshDestListener);
+    }
+
+    private void runInference(String modelId, ActionListener<StepResponse> listener) {
+        threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(() -> {
+            try {
+                inferenceRunner.run(modelId);
+                listener.onResponse(new StepResponse(isTaskStopping()));
+            } catch (Exception e) {
+                if (task.isStopping()) {
+                    listener.onResponse(new StepResponse(false));
+                } else {
+                    listener.onFailure(e);
+                }
+            }
+        });
+    }
+
+    private void getModelId(ActionListener<String> listener) {
+        SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
+        searchSourceBuilder.size(1);
+        searchSourceBuilder.fetchSource(false);
+        searchSourceBuilder.query(QueryBuilders.boolQuery()
+            .filter(QueryBuilders.termQuery(TrainedModelConfig.TAGS.getPreferredName(), config.getId()))
+        );
+        searchSourceBuilder.sort(TrainedModelConfig.CREATE_TIME.getPreferredName(), SortOrder.DESC);
+        SearchRequest searchRequest = new SearchRequest(InferenceIndexConstants.INDEX_PATTERN);
+        searchRequest.source(searchSourceBuilder);
+
+        executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, ActionListener.wrap(
+            searchResponse -> {
+                SearchHit[] hits = searchResponse.getHits().getHits();
+                if (hits.length == 0) {
+                    listener.onFailure(new ResourceNotFoundException("No model could be found to perform inference"));
+                } else {
+                    listener.onResponse(hits[0].getId());
+                }
+            },
+            listener::onFailure
+        ));
+    }
+
+    @Override
+    public void cancel(String reason, TimeValue timeout) {
+        inferenceRunner.cancel();
+    }
+
+    @Override
+    public void updateProgress(ActionListener<Void> listener) {
+        // Inference runner updates progress directly
+        listener.onResponse(null);
+    }
+}

+ 12 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTaskTests.java

@@ -118,6 +118,18 @@ public class DataFrameAnalyticsTaskTests extends ESTestCase {
         assertThat(startingState, equalTo(StartingState.RESUMING_ANALYZING));
     }
 
+    public void testDetermineStartingState_GivenInferenceIsIncomplete() {
+        List<PhaseProgress> progress = Arrays.asList(new PhaseProgress("reindexing", 100),
+            new PhaseProgress("loading_data", 100),
+            new PhaseProgress("analyzing", 100),
+            new PhaseProgress("writing_results", 100),
+            new PhaseProgress("inference", 40));
+
+        StartingState startingState = DataFrameAnalyticsTask.determineStartingState("foo", progress);
+
+        assertThat(startingState, equalTo(StartingState.RESUMING_INFERENCE));
+    }
+
     public void testDetermineStartingState_GivenFinished() {
         List<PhaseProgress> progress = Arrays.asList(new PhaseProgress("reindexing", 100),
             new PhaseProgress("loading_data", 100),

+ 1 - 2
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java

@@ -113,9 +113,8 @@ public class AnalyticsProcessManagerTests extends ESTestCase {
         when(dataExtractorFactory.getExtractedFields()).thenReturn(mock(ExtractedFields.class));
 
         resultsPersisterService = mock(ResultsPersisterService.class);
-        modelLoadingService = mock(ModelLoadingService.class);
         processManager = new AnalyticsProcessManager(Settings.EMPTY, client, executorServiceForJob, executorServiceForProcess,
-            processFactory, auditor, trainedModelProvider, modelLoadingService, resultsPersisterService, 1);
+            processFactory, auditor, trainedModelProvider, resultsPersisterService, 1);
     }
 
     public void testRunJob_TaskIsStopping() {

+ 9 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameAnalyticsManagerTests.java

@@ -8,10 +8,14 @@ package org.elasticsearch.xpack.ml.dataframe.process;
 import org.elasticsearch.client.node.NodeClient;
 import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
 import org.elasticsearch.cluster.service.ClusterService;
+import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsManager;
 import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider;
+import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
 import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
+import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;
 
 import static org.hamcrest.Matchers.is;
 import static org.mockito.Mockito.mock;
@@ -21,12 +25,16 @@ public class DataFrameAnalyticsManagerTests extends ESTestCase {
     public void testNodeShuttingDown() {
         DataFrameAnalyticsManager manager =
             new DataFrameAnalyticsManager(
+                Settings.EMPTY,
                 mock(NodeClient.class),
+                mock(ThreadPool.class),
                 mock(ClusterService.class),
                 mock(DataFrameAnalyticsConfigProvider.class),
                 mock(AnalyticsProcessManager.class),
                 mock(DataFrameAnalyticsAuditor.class),
-                mock(IndexNameExpressionResolver.class));
+                mock(IndexNameExpressionResolver.class),
+                mock(ResultsPersisterService.class),
+                mock(ModelLoadingService.class));
         assertThat(manager.isNodeShuttingDown(), is(false));
 
         manager.markNodeAsShuttingDown();

+ 71 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/stats/ProgressTrackerTests.java

@@ -11,7 +11,9 @@ import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
 
 import java.util.Arrays;
 import java.util.Collections;
+import java.util.LinkedHashMap;
 import java.util.List;
+import java.util.Map;
 import java.util.stream.Collectors;
 
 import static org.hamcrest.Matchers.contains;
@@ -148,6 +150,75 @@ public class ProgressTrackerTests extends ESTestCase {
         assertThat(getProgressForPhase(progressTracker, "foo"), equalTo(41));
     }
 
+    public void testResetForInference_GivenInference() {
+        ProgressTracker progressTracker = ProgressTracker.fromZeroes(Arrays.asList("a", "b"), true);
+        progressTracker.updateReindexingProgress(10);
+        progressTracker.updateLoadingDataProgress(20);
+        progressTracker.updatePhase(new PhaseProgress("a", 30));
+        progressTracker.updatePhase(new PhaseProgress("b", 40));
+        progressTracker.updateWritingResultsProgress(50);
+        progressTracker.updateInferenceProgress(60);
+
+        progressTracker.resetForInference();
+
+        List<PhaseProgress> progress = progressTracker.report();
+        assertThat(progress, contains(
+            new PhaseProgress(ProgressTracker.REINDEXING, 100),
+            new PhaseProgress(ProgressTracker.LOADING_DATA, 100),
+            new PhaseProgress("a", 100),
+            new PhaseProgress("b", 100),
+            new PhaseProgress(ProgressTracker.WRITING_RESULTS, 100),
+            new PhaseProgress(ProgressTracker.INFERENCE, 0)
+        ));
+    }
+
+    public void testResetForInference_GivenNoInference() {
+        ProgressTracker progressTracker = ProgressTracker.fromZeroes(Arrays.asList("a", "b"), false);
+        progressTracker.updateReindexingProgress(10);
+        progressTracker.updateLoadingDataProgress(20);
+        progressTracker.updatePhase(new PhaseProgress("a", 30));
+        progressTracker.updatePhase(new PhaseProgress("b", 40));
+        progressTracker.updateWritingResultsProgress(50);
+
+        progressTracker.resetForInference();
+
+        List<PhaseProgress> progress = progressTracker.report();
+        assertThat(progress, contains(
+            new PhaseProgress(ProgressTracker.REINDEXING, 100),
+            new PhaseProgress(ProgressTracker.LOADING_DATA, 100),
+            new PhaseProgress("a", 100),
+            new PhaseProgress("b", 100),
+            new PhaseProgress(ProgressTracker.WRITING_RESULTS, 100)
+        ));
+    }
+
+    public void testAreAllPhasesExceptInferenceComplete_GivenComplete() {
+        ProgressTracker progressTracker = ProgressTracker.fromZeroes(Collections.singletonList("a"), true);
+        progressTracker.updateReindexingProgress(100);
+        progressTracker.updateLoadingDataProgress(100);
+        progressTracker.updatePhase(new PhaseProgress("a", 100));
+        progressTracker.updateWritingResultsProgress(100);
+        progressTracker.updateInferenceProgress(50);
+
+        assertThat(progressTracker.areAllPhasesExceptInferenceComplete(), is(true));
+    }
+
+    public void testAreAllPhasesExceptInferenceComplete_GivenNotComplete() {
+        Map<String, Integer> phasePerProgress = new LinkedHashMap<>();
+        phasePerProgress.put(ProgressTracker.REINDEXING, 100);
+        phasePerProgress.put(ProgressTracker.LOADING_DATA, 100);
+        phasePerProgress.put("a", 100);
+        phasePerProgress.put(ProgressTracker.WRITING_RESULTS, 100);
+        String nonCompletePhase = randomFrom(phasePerProgress.keySet());
+        phasePerProgress.put(ProgressTracker.INFERENCE, 50);
+        phasePerProgress.put(nonCompletePhase, randomIntBetween(0, 99));
+
+        ProgressTracker progressTracker = new ProgressTracker(phasePerProgress.entrySet().stream()
+            .map(entry -> new PhaseProgress(entry.getKey(), entry.getValue())).collect(Collectors.toList()));
+
+        assertThat(progressTracker.areAllPhasesExceptInferenceComplete(), is(false));
+    }
+
     private static int getProgressForPhase(ProgressTracker progressTracker, String phase) {
         return progressTracker.report().stream().filter(p -> p.getPhase().equals(phase)).findFirst().get().getProgressPercent();
     }

+ 26 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolderTests.java

@@ -123,6 +123,32 @@ public class StatsHolderTests extends ESTestCase {
         assertThat(phaseProgresses.get(4).getProgressPercent(), equalTo(0));
     }
 
+    public void testAdjustProgressTracker_GivenAllPhasesCompleteExceptInference() {
+        List<PhaseProgress> phases = Collections.unmodifiableList(
+            Arrays.asList(
+                new PhaseProgress("reindexing", 100),
+                new PhaseProgress("loading_data", 100),
+                new PhaseProgress("a", 100),
+                new PhaseProgress("writing_results", 100),
+                new PhaseProgress("inference", 20)
+            )
+        );
+        StatsHolder statsHolder = new StatsHolder(phases);
+
+        statsHolder.adjustProgressTracker(Arrays.asList("a", "b"), true);
+
+        List<PhaseProgress> phaseProgresses = statsHolder.getProgressTracker().report();
+
+        assertThat(phaseProgresses, contains(
+            new PhaseProgress("reindexing", 100),
+            new PhaseProgress("loading_data", 100),
+            new PhaseProgress("a", 100),
+            new PhaseProgress("b", 100),
+            new PhaseProgress("writing_results", 100),
+            new PhaseProgress("inference", 0)
+        ));
+    }
+
     public void testResetProgressTracker() {
         List<PhaseProgress> phases = Collections.unmodifiableList(
             Arrays.asList(