Prechádzať zdrojové kódy

[ML] Restore data counts on resuming data frame analytics (#67937)

Now that data frame analytics jobs can be resumed straight into
the inference phase, we need to ensure data counts are persisted
at the end of the analysis step and restored when the job is
started again.

This commit removes the need for storing the progress on start
as a task parameter. Instead, when the task gets assigned we now
restore all stats by making a call to the get stats API. Additionally,
we now ensure that an allocated task that hasn't had its `StatsHolder`
restored yet is treated as a stopped task from the get stats API, which
means we will report the stored stats.

Relates #67623
Dimitris Athanasiou 4 rokov pred
rodič
commit
4af3a18873
25 zmenil súbory, kde vykonal 167 pridanie a 98 odobranie
  1. 12 22
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsAction.java
  2. 1 2
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/MlTasksTests.java
  3. 0 9
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsActionTaskParamsTests.java
  4. 1 1
      x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/DataFrameAnalyticsConfigProviderIT.java
  5. 12 4
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java
  6. 10 17
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java
  7. 2 2
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java
  8. 7 2
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTask.java
  9. 21 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/DataCountsTracker.java
  10. 7 8
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolder.java
  11. 5 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/AbstractDataFrameAnalyticsStep.java
  12. 2 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/AnalysisStep.java
  13. 7 2
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/FinalStep.java
  14. 2 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/InferenceStep.java
  15. 5 5
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsActionTests.java
  16. 1 1
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportStopDataFrameAnalyticsActionTests.java
  17. 1 1
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/autoscaling/MlAutoscalingDeciderServiceTests.java
  18. 6 3
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTaskTests.java
  19. 2 1
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/inference/InferenceRunnerTests.java
  20. 11 5
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java
  21. 6 1
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java
  22. 31 0
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/stats/DataCountsTrackerTests.java
  23. 11 6
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolderTests.java
  24. 2 2
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/JobNodeSelectorTests.java
  25. 2 2
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/process/MlMemoryTrackerTests.java

+ 12 - 22
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsAction.java

@@ -9,7 +9,6 @@ import org.elasticsearch.Version;
 import org.elasticsearch.action.ActionRequestValidationException;
 import org.elasticsearch.action.ActionType;
 import org.elasticsearch.action.support.master.MasterNodeRequest;
-import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.ParseField;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.io.stream.StreamInput;
@@ -30,7 +29,6 @@ import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
 
 import java.io.IOException;
 import java.util.Collections;
-import java.util.List;
 import java.util.Objects;
 import java.util.concurrent.TimeUnit;
 
@@ -147,17 +145,13 @@ public class StartDataFrameAnalyticsAction extends ActionType<NodeAcknowledgedRe
         public static final Version VERSION_INTRODUCED = Version.V_7_3_0;
         public static final Version VERSION_DESTINATION_INDEX_MAPPINGS_CHANGED = Version.V_7_10_0;
 
-        private static final ParseField PROGRESS_ON_START = new ParseField("progress_on_start");
-
-        @SuppressWarnings("unchecked")
         public static final ConstructingObjectParser<TaskParams, Void> PARSER = new ConstructingObjectParser<>(
             MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, true,
-            a -> new TaskParams((String) a[0], (String) a[1], (List<PhaseProgress>) a[2], (Boolean) a[3]));
+            a -> new TaskParams((String) a[0], (String) a[1], (Boolean) a[2]));
 
         static {
             PARSER.declareString(ConstructingObjectParser.constructorArg(), DataFrameAnalyticsConfig.ID);
             PARSER.declareString(ConstructingObjectParser.constructorArg(), DataFrameAnalyticsConfig.VERSION);
-            PARSER.declareObjectArray(ConstructingObjectParser.optionalConstructorArg(), PhaseProgress.PARSER, PROGRESS_ON_START);
             PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), DataFrameAnalyticsConfig.ALLOW_LAZY_START);
         }
 
@@ -167,25 +161,24 @@ public class StartDataFrameAnalyticsAction extends ActionType<NodeAcknowledgedRe
 
         private final String id;
         private final Version version;
-        private final List<PhaseProgress> progressOnStart;
         private final boolean allowLazyStart;
 
-        public TaskParams(String id, Version version, List<PhaseProgress> progressOnStart, boolean allowLazyStart) {
+        public TaskParams(String id, Version version, boolean allowLazyStart) {
             this.id = Objects.requireNonNull(id);
             this.version = Objects.requireNonNull(version);
-            this.progressOnStart = Collections.unmodifiableList(progressOnStart);
             this.allowLazyStart = allowLazyStart;
         }
 
-        private TaskParams(String id, String version, @Nullable List<PhaseProgress> progressOnStart, Boolean allowLazyStart) {
-            this(id, Version.fromString(version), progressOnStart == null ? Collections.emptyList() : progressOnStart,
-                allowLazyStart != null && allowLazyStart);
+        private TaskParams(String id, String version, Boolean allowLazyStart) {
+            this(id, Version.fromString(version), allowLazyStart != null && allowLazyStart);
         }
 
         public TaskParams(StreamInput in) throws IOException {
             this.id = in.readString();
             this.version = Version.readVersion(in);
-            this.progressOnStart = in.readList(PhaseProgress::new);
+            if (in.getVersion().before(Version.V_8_0_0)) {
+                in.readList(PhaseProgress::new);
+            }
             this.allowLazyStart = in.readBoolean();
         }
 
@@ -197,10 +190,6 @@ public class StartDataFrameAnalyticsAction extends ActionType<NodeAcknowledgedRe
             return version;
         }
 
-        public List<PhaseProgress> getProgressOnStart() {
-            return progressOnStart;
-        }
-
         public boolean isAllowLazyStart() {
             return allowLazyStart;
         }
@@ -219,7 +208,10 @@ public class StartDataFrameAnalyticsAction extends ActionType<NodeAcknowledgedRe
         public void writeTo(StreamOutput out) throws IOException {
             out.writeString(id);
             Version.writeVersion(version, out);
-            out.writeList(progressOnStart);
+            if (out.getVersion().before(Version.V_8_0_0)) {
+                // Previous versions expect a list of phase progress objects.
+                out.writeList(Collections.emptyList());
+            }
             out.writeBoolean(allowLazyStart);
         }
 
@@ -228,7 +220,6 @@ public class StartDataFrameAnalyticsAction extends ActionType<NodeAcknowledgedRe
             builder.startObject();
             builder.field(DataFrameAnalyticsConfig.ID.getPreferredName(), id);
             builder.field(DataFrameAnalyticsConfig.VERSION.getPreferredName(), version);
-            builder.field(PROGRESS_ON_START.getPreferredName(), progressOnStart);
             builder.field(DataFrameAnalyticsConfig.ALLOW_LAZY_START.getPreferredName(), allowLazyStart);
             builder.endObject();
             return builder;
@@ -236,7 +227,7 @@ public class StartDataFrameAnalyticsAction extends ActionType<NodeAcknowledgedRe
 
         @Override
         public int hashCode() {
-            return Objects.hash(id, version, progressOnStart, allowLazyStart);
+            return Objects.hash(id, version, allowLazyStart);
         }
 
         @Override
@@ -247,7 +238,6 @@ public class StartDataFrameAnalyticsAction extends ActionType<NodeAcknowledgedRe
             TaskParams other = (TaskParams) o;
             return Objects.equals(id, other.id)
                 && Objects.equals(version, other.version)
-                && Objects.equals(progressOnStart, other.progressOnStart)
                 && Objects.equals(allowLazyStart, other.allowLazyStart);
         }
     }

+ 1 - 2
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/MlTasksTests.java

@@ -22,7 +22,6 @@ import org.elasticsearch.xpack.core.ml.job.config.JobState;
 import org.elasticsearch.xpack.core.ml.job.config.JobTaskState;
 
 import java.net.InetAddress;
-import java.util.Collections;
 
 import static org.hamcrest.Matchers.containsInAnyOrder;
 import static org.hamcrest.Matchers.empty;
@@ -248,7 +247,7 @@ public class MlTasksTests extends ESTestCase {
                                                                                                 boolean isStale) {
         PersistentTasksCustomMetadata.Builder builder = PersistentTasksCustomMetadata.builder();
         builder.addTask(MlTasks.dataFrameAnalyticsTaskId(jobId), MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME,
-            new StartDataFrameAnalyticsAction.TaskParams(jobId, Version.CURRENT, Collections.emptyList(), false),
+            new StartDataFrameAnalyticsAction.TaskParams(jobId, Version.CURRENT, false),
             new PersistentTasksCustomMetadata.Assignment(nodeId, "test assignment"));
         if (state != null) {
             builder.updateTaskState(MlTasks.dataFrameAnalyticsTaskId(jobId),

+ 0 - 9
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsActionTaskParamsTests.java

@@ -9,11 +9,8 @@ package org.elasticsearch.xpack.core.ml.action;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.test.AbstractSerializingTestCase;
-import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
 
 import java.io.IOException;
-import java.util.ArrayList;
-import java.util.List;
 
 import static org.elasticsearch.test.VersionUtils.randomVersion;
 
@@ -26,15 +23,9 @@ public class StartDataFrameAnalyticsActionTaskParamsTests extends AbstractSerial
 
     @Override
     protected StartDataFrameAnalyticsAction.TaskParams createTestInstance() {
-        int phaseCount = randomIntBetween(0, 5);
-        List<PhaseProgress> progressOnStart = new ArrayList<>(phaseCount);
-        for (int i = 0; i < phaseCount; i++) {
-            progressOnStart.add(new PhaseProgress(randomAlphaOfLength(10), randomIntBetween(0, 100)));
-        }
         return new StartDataFrameAnalyticsAction.TaskParams(
             randomAlphaOfLength(10),
             randomVersion(random()),
-            progressOnStart,
             randomBoolean());
     }
 

+ 1 - 1
x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/DataFrameAnalyticsConfigProviderIT.java

@@ -354,7 +354,7 @@ public class DataFrameAnalyticsConfigProviderIT extends MlSingleNodeTestCase {
         builder.addTask(
             MlTasks.dataFrameAnalyticsTaskId(analyticsId),
             MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME,
-            new StartDataFrameAnalyticsAction.TaskParams(analyticsId, Version.CURRENT, emptyList(), false),
+            new StartDataFrameAnalyticsAction.TaskParams(analyticsId, Version.CURRENT, false),
             new PersistentTasksCustomMetadata.Assignment("node", "test assignment"));
         builder.updateTaskState(
             MlTasks.dataFrameAnalyticsTaskId(analyticsId),

+ 12 - 4
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java

@@ -55,6 +55,7 @@ import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
 import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask;
 import org.elasticsearch.xpack.ml.dataframe.StoredProgress;
 import org.elasticsearch.xpack.ml.dataframe.stats.ProgressTracker;
+import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder;
 import org.elasticsearch.xpack.ml.utils.persistence.MlParserUtils;
 
 import java.util.ArrayList;
@@ -107,12 +108,19 @@ public class TransportGetDataFrameAnalyticsStatsAction
 
         ActionListener<Void> updateProgressListener = ActionListener.wrap(
             aVoid -> {
+                StatsHolder statsHolder = task.getStatsHolder();
+                if (statsHolder == null) {
+                    // The task has just been assigned and has not been initialized with its stats holder yet.
+                    // We return empty result here so that we treat it as a stopped task and return its stored stats.
+                    listener.onResponse(new QueryPage<>(Collections.emptyList(), 0, GetDataFrameAnalyticsAction.Response.RESULTS_FIELD));
+                    return;
+                }
                 Stats stats = buildStats(
                     task.getParams().getId(),
-                    task.getStatsHolder().getProgressTracker().report(),
-                    task.getStatsHolder().getDataCountsTracker().report(task.getParams().getId()),
-                    task.getStatsHolder().getMemoryUsage(),
-                    task.getStatsHolder().getAnalysisStats()
+                    statsHolder.getProgressTracker().report(),
+                    statsHolder.getDataCountsTracker().report(),
+                    statsHolder.getMemoryUsage(),
+                    statsHolder.getAnalysisStats()
                 );
                 listener.onResponse(new QueryPage<>(Collections.singletonList(stats), 1,
                     GetDataFrameAnalyticsAction.Response.RESULTS_FIELD));

+ 10 - 17
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java

@@ -79,6 +79,7 @@ import org.elasticsearch.xpack.ml.dataframe.StoredProgress;
 import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory;
 import org.elasticsearch.xpack.ml.dataframe.extractor.ExtractedFieldsDetectorFactory;
 import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider;
+import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder;
 import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
 import org.elasticsearch.xpack.ml.job.JobNodeSelector;
 import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
@@ -185,7 +186,6 @@ public class TransportStartDataFrameAnalyticsAction
                     new TaskParams(
                         request.getId(),
                         startContext.config.getVersion(),
-                        startContext.progressOnStart,
                         startContext.config.isAllowLazyStart());
                 persistentTasksService.sendStartRequest(
                     MlTasks.dataFrameAnalyticsTaskId(request.getId()),
@@ -484,13 +484,11 @@ public class TransportStartDataFrameAnalyticsAction
 
     private static class StartContext {
         private final DataFrameAnalyticsConfig config;
-        private final List<PhaseProgress> progressOnStart;
         private final DataFrameAnalyticsTask.StartingState startingState;
         private volatile ExtractedFields extractedFields;
 
         private StartContext(DataFrameAnalyticsConfig config, List<PhaseProgress> progressOnStart) {
             this.config = config;
-            this.progressOnStart = progressOnStart;
             this.startingState = DataFrameAnalyticsTask.determineStartingState(config.getId(), progressOnStart);
         }
     }
@@ -671,26 +669,21 @@ public class TransportStartDataFrameAnalyticsAction
                 return;
             }
 
-            ActionListener<StoredProgress> progressListener = ActionListener.wrap(
-                storedProgress -> {
-                    if (storedProgress != null) {
-                        dfaTask.getStatsHolder().setProgressTracker(storedProgress.get());
-                    }
+            // Execute task
+            ActionListener<GetDataFrameAnalyticsStatsAction.Response> statsListener = ActionListener.wrap(
+                statsResponse -> {
+                    GetDataFrameAnalyticsStatsAction.Response.Stats stats = statsResponse.getResponse().results().get(0);
+                    dfaTask.setStatsHolder(
+                        new StatsHolder(stats.getProgress(), stats.getMemoryUsage(), stats.getAnalysisStats(), stats.getDataCounts()));
                     executeTask(dfaTask);
                 },
                 dfaTask::setFailed
             );
 
+            // Get stats to initialize in memory stats tracking
             ActionListener<Boolean> templateCheckListener = ActionListener.wrap(
-                ok -> {
-                    if (analyticsState != DataFrameAnalyticsState.STOPPED) {
-                        // If the state is not stopped it means the task is reassigning and
-                        // we need to update the progress from the last stored progress doc.
-                        searchProgressFromIndex(params.getId(), progressListener);
-                    } else {
-                        progressListener.onResponse(null);
-                    }
-                },
+                ok -> executeAsyncWithOrigin(client, ML_ORIGIN, GetDataFrameAnalyticsStatsAction.INSTANCE,
+                    new GetDataFrameAnalyticsStatsAction.Request(params.getId()), statsListener),
                 error -> {
                     Throwable cause = ExceptionsHelper.unwrapCause(error);
                     logger.error(

+ 2 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java

@@ -178,8 +178,8 @@ public class DataFrameAnalyticsManager {
         ActionListener<StepResponse> stepListener = ActionListener.wrap(
             stepResponse -> {
                 if (stepResponse.isTaskComplete()) {
-                    LOGGER.info("[{}] Marking task completed", config.getId());
-                    task.markAsCompleted();
+                    // We always want to perform the final step as it tidies things up
+                    executeStep(task, config, new FinalStep(client, task, auditor, config));
                     return;
                 }
                 switch (step.name()) {

+ 7 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTask.java

@@ -18,6 +18,7 @@ import org.elasticsearch.action.search.SearchResponse;
 import org.elasticsearch.action.support.WriteRequest;
 import org.elasticsearch.client.Client;
 import org.elasticsearch.client.ParentTaskAssigningClient;
+import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.unit.TimeValue;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.json.JsonXContent;
@@ -60,7 +61,7 @@ public class DataFrameAnalyticsTask extends AllocatedPersistentTask implements S
     private final StartDataFrameAnalyticsAction.TaskParams taskParams;
     private volatile boolean isStopping;
     private volatile boolean isMarkAsCompletedCalled;
-    private final StatsHolder statsHolder;
+    private volatile StatsHolder statsHolder;
     private volatile DataFrameAnalyticsStep currentStep;
 
     public DataFrameAnalyticsTask(long id, String type, String action, TaskId parentTask, Map<String, String> headers,
@@ -71,7 +72,6 @@ public class DataFrameAnalyticsTask extends AllocatedPersistentTask implements S
         this.analyticsManager = Objects.requireNonNull(analyticsManager);
         this.auditor = Objects.requireNonNull(auditor);
         this.taskParams = Objects.requireNonNull(taskParams);
-        this.statsHolder = new StatsHolder(taskParams.getProgressOnStart());
     }
 
     public void setStep(DataFrameAnalyticsStep step) {
@@ -86,6 +86,11 @@ public class DataFrameAnalyticsTask extends AllocatedPersistentTask implements S
         return isStopping;
     }
 
+    public void setStatsHolder(StatsHolder statsHolder) {
+        this.statsHolder = Objects.requireNonNull(statsHolder);
+    }
+
+    @Nullable
     public StatsHolder getStatsHolder() {
         return statsHolder;
     }

+ 21 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/DataCountsTracker.java

@@ -8,12 +8,22 @@ package org.elasticsearch.xpack.ml.dataframe.stats;
 
 import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts;
 
+import java.util.Objects;
+
 public class DataCountsTracker {
 
+    private final String jobId;
     private volatile long trainingDocsCount;
     private volatile long testDocsCount;
     private volatile long skippedDocsCount;
 
+    public DataCountsTracker(DataCounts dataCounts) {
+        this.jobId = Objects.requireNonNull(dataCounts.getJobId());
+        this.trainingDocsCount = dataCounts.getTrainingDocsCount();
+        this.testDocsCount = dataCounts.getTestDocsCount();
+        this.skippedDocsCount = dataCounts.getSkippedDocsCount();
+    }
+
     public void incrementTrainingDocsCount() {
         trainingDocsCount++;
     }
@@ -26,7 +36,7 @@ public class DataCountsTracker {
         skippedDocsCount++;
     }
 
-    public DataCounts report(String jobId) {
+    public DataCounts report() {
         return new DataCounts(
             jobId,
             trainingDocsCount,
@@ -34,4 +44,14 @@ public class DataCountsTracker {
             skippedDocsCount
         );
     }
+
+    public void reset() {
+        trainingDocsCount = 0;
+        testDocsCount = 0;
+        skippedDocsCount = 0;
+    }
+
+    public void resetTestDocsCount() {
+        testDocsCount = 0;
+    }
 }

+ 7 - 8
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolder.java

@@ -5,7 +5,9 @@
  */
 package org.elasticsearch.xpack.ml.dataframe.stats;
 
+import org.elasticsearch.common.Nullable;
 import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStats;
+import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts;
 import org.elasticsearch.xpack.core.ml.dataframe.stats.common.MemoryUsage;
 import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
 
@@ -23,15 +25,12 @@ public class StatsHolder {
     private final AtomicReference<AnalysisStats> analysisStatsHolder;
     private final DataCountsTracker dataCountsTracker;
 
-    public StatsHolder(List<PhaseProgress> progressOnStart) {
-        progressTracker = new ProgressTracker(progressOnStart);
-        memoryUsageHolder = new AtomicReference<>();
-        analysisStatsHolder = new AtomicReference<>();
-        dataCountsTracker = new DataCountsTracker();
-    }
-
-    public void setProgressTracker(List<PhaseProgress> progress) {
+    public StatsHolder(List<PhaseProgress> progress, @Nullable MemoryUsage memoryUsage, @Nullable AnalysisStats analysisStats,
+                       DataCounts dataCounts) {
         progressTracker = new ProgressTracker(progress);
+        memoryUsageHolder = new AtomicReference<>(memoryUsage);
+        analysisStatsHolder = new AtomicReference<>(analysisStats);
+        dataCountsTracker = new DataCountsTracker(dataCounts);
     }
 
     /**

+ 5 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/AbstractDataFrameAnalyticsStep.java

@@ -57,7 +57,7 @@ abstract class AbstractDataFrameAnalyticsStep implements DataFrameAnalyticsStep
     @Override
     public final void execute(ActionListener<StepResponse> listener) {
         logger.debug(() -> new ParameterizedMessage("[{}] Executing step [{}]", config.getId(), name()));
-        if (task.isStopping()) {
+        if (task.isStopping() && shouldSkipIfTaskIsStopping()) {
             logger.debug(() -> new ParameterizedMessage("[{}] task is stopping before starting [{}] step", config.getId(), name()));
             listener.onResponse(new StepResponse(true));
             return;
@@ -76,4 +76,8 @@ abstract class AbstractDataFrameAnalyticsStep implements DataFrameAnalyticsStep
         executeWithHeadersAsync(config.getHeaders(), ML_ORIGIN, parentTaskClient, RefreshAction.INSTANCE,
             new RefreshRequest(config.getDest().getIndex()), refreshListener);
     }
+
+    protected boolean shouldSkipIfTaskIsStopping() {
+        return true;
+    }
 }

+ 2 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/AnalysisStep.java

@@ -48,6 +48,8 @@ public class AnalysisStep extends AbstractDataFrameAnalyticsStep {
 
     @Override
     protected void doExecute(ActionListener<StepResponse> listener) {
+        task.getStatsHolder().getDataCountsTracker().reset();
+
         final ParentTaskAssigningClient parentTaskClient = parentTaskClient();
         // Update state to ANALYZING and start process
         ActionListener<DataFrameDataExtractorFactory> dataExtractorFactoryListener = ActionListener.wrap(

+ 7 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/FinalStep.java

@@ -60,7 +60,7 @@ public class FinalStep extends AbstractDataFrameAnalyticsStep {
     protected void doExecute(ActionListener<StepResponse> listener) {
 
         ActionListener<RefreshResponse> refreshListener = ActionListener.wrap(
-            refreshResponse -> listener.onResponse(new StepResponse(true)),
+            refreshResponse -> listener.onResponse(new StepResponse(false)),
             listener::onFailure
         );
 
@@ -73,7 +73,7 @@ public class FinalStep extends AbstractDataFrameAnalyticsStep {
     }
 
     private void indexDataCounts(ActionListener<IndexResponse> listener) {
-        DataCounts dataCounts = task.getStatsHolder().getDataCountsTracker().report(config.getId());
+        DataCounts dataCounts = task.getStatsHolder().getDataCountsTracker().report();
         try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
             dataCounts.toXContent(builder, new ToXContent.MapParams(
                 Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")));
@@ -111,4 +111,9 @@ public class FinalStep extends AbstractDataFrameAnalyticsStep {
         // No progress to update
         listener.onResponse(null);
     }
+
+    @Override
+    protected boolean shouldSkipIfTaskIsStopping() {
+        return false;
+    }
 }

+ 2 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/steps/InferenceStep.java

@@ -62,6 +62,8 @@ public class InferenceStep extends AbstractDataFrameAnalyticsStep {
             return;
         }
 
+        task.getStatsHolder().getDataCountsTracker().resetTestDocsCount();
+
         ActionListener<String> modelIdListener = ActionListener.wrap(
             modelId -> runInference(modelId, listener),
             listener::onFailure

+ 5 - 5
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsActionTests.java

@@ -50,7 +50,7 @@ public class TransportStartDataFrameAnalyticsActionTests extends ESTestCase {
     // Cannot assign the node because upgrade mode is enabled
     public void testGetAssignment_UpgradeModeIsEnabled() {
         TaskExecutor executor = createTaskExecutor();
-        TaskParams params = new TaskParams(JOB_ID, Version.CURRENT, Collections.emptyList(), false);
+        TaskParams params = new TaskParams(JOB_ID, Version.CURRENT, false);
         ClusterState clusterState =
             ClusterState.builder(new ClusterName("_name"))
                 .metadata(Metadata.builder().putCustom(MlMetadata.TYPE, new MlMetadata.Builder().isUpgradeMode(true).build()))
@@ -64,7 +64,7 @@ public class TransportStartDataFrameAnalyticsActionTests extends ESTestCase {
     // Cannot assign the node because there are no existing nodes in the cluster state
     public void testGetAssignment_NoNodes() {
         TaskExecutor executor = createTaskExecutor();
-        TaskParams params = new TaskParams(JOB_ID, Version.CURRENT, Collections.emptyList(), false);
+        TaskParams params = new TaskParams(JOB_ID, Version.CURRENT, false);
         ClusterState clusterState =
             ClusterState.builder(new ClusterName("_name"))
                 .metadata(Metadata.builder().putCustom(MlMetadata.TYPE, new MlMetadata.Builder().build()))
@@ -78,7 +78,7 @@ public class TransportStartDataFrameAnalyticsActionTests extends ESTestCase {
     // Cannot assign the node because none of the existing nodes is an ML node
     public void testGetAssignment_NoMlNodes() {
         TaskExecutor executor = createTaskExecutor();
-        TaskParams params = new TaskParams(JOB_ID, Version.CURRENT, Collections.emptyList(), false);
+        TaskParams params = new TaskParams(JOB_ID, Version.CURRENT, false);
         ClusterState clusterState =
             ClusterState.builder(new ClusterName("_name"))
                 .metadata(Metadata.builder().putCustom(MlMetadata.TYPE, new MlMetadata.Builder().build()))
@@ -104,7 +104,7 @@ public class TransportStartDataFrameAnalyticsActionTests extends ESTestCase {
     //  - _node_name2 is too old (version 7.9.2)
     public void testGetAssignment_MlNodesAreTooOld() {
         TaskExecutor executor = createTaskExecutor();
-        TaskParams params = new TaskParams(JOB_ID, Version.CURRENT, Collections.emptyList(), false);
+        TaskParams params = new TaskParams(JOB_ID, Version.CURRENT, false);
         ClusterState clusterState =
             ClusterState.builder(new ClusterName("_name"))
                 .metadata(Metadata.builder().putCustom(MlMetadata.TYPE, new MlMetadata.Builder().build()))
@@ -131,7 +131,7 @@ public class TransportStartDataFrameAnalyticsActionTests extends ESTestCase {
     // In such a case destination index will be created from scratch so that its mappings are up-to-date.
     public void testGetAssignment_MlNodeIsNewerThanTheMlJobButTheAssignmentSuceeds() {
         TaskExecutor executor = createTaskExecutor();
-        TaskParams params = new TaskParams(JOB_ID, Version.V_7_9_0, Collections.emptyList(), false);
+        TaskParams params = new TaskParams(JOB_ID, Version.V_7_9_0, false);
         ClusterState clusterState =
             ClusterState.builder(new ClusterName("_name"))
                 .metadata(Metadata.builder().putCustom(MlMetadata.TYPE, new MlMetadata.Builder().build()))

+ 1 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportStopDataFrameAnalyticsActionTests.java

@@ -62,7 +62,7 @@ public class TransportStopDataFrameAnalyticsActionTests extends ESTestCase {
     private static void addAnalyticsTask(PersistentTasksCustomMetadata.Builder builder, String analyticsId, String nodeId,
                                          DataFrameAnalyticsState state, boolean allowLazyStart) {
         builder.addTask(MlTasks.dataFrameAnalyticsTaskId(analyticsId), MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME,
-            new StartDataFrameAnalyticsAction.TaskParams(analyticsId, Version.CURRENT, Collections.emptyList(), allowLazyStart),
+            new StartDataFrameAnalyticsAction.TaskParams(analyticsId, Version.CURRENT, allowLazyStart),
             new PersistentTasksCustomMetadata.Assignment(nodeId, "test assignment"));
 
         if (state != null) {

+ 1 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/autoscaling/MlAutoscalingDeciderServiceTests.java

@@ -489,7 +489,7 @@ public class MlAutoscalingDeciderServiceTests extends ESTestCase {
         builder.addTask(
             MlTasks.dataFrameAnalyticsTaskId(jobId),
             MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME,
-            new StartDataFrameAnalyticsAction.TaskParams(jobId, Version.CURRENT, Collections.emptyList(), true),
+            new StartDataFrameAnalyticsAction.TaskParams(jobId, Version.CURRENT, true),
             nodeId == null ? AWAITING_LAZY_ASSIGNMENT : new PersistentTasksCustomMetadata.Assignment(nodeId, "test assignment")
         );
         if (jobState != null) {

+ 6 - 3
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTaskTests.java

@@ -31,10 +31,12 @@ import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction;
 import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
 import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState;
+import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts;
 import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
 import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
 import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask.StartingState;
 import org.elasticsearch.xpack.ml.dataframe.stats.ProgressTracker;
+import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder;
 import org.elasticsearch.xpack.ml.dataframe.steps.DataFrameAnalyticsStep;
 import org.elasticsearch.xpack.ml.dataframe.steps.StepResponse;
 import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
@@ -163,7 +165,7 @@ public class DataFrameAnalyticsTaskTests extends ESTestCase {
             new PhaseProgress(ProgressTracker.WRITING_RESULTS, 0));
 
         StartDataFrameAnalyticsAction.TaskParams taskParams = new StartDataFrameAnalyticsAction.TaskParams(
-            "task_id", Version.CURRENT, progress, false);
+            "task_id", Version.CURRENT, false);
 
         SearchResponse searchResponse = mock(SearchResponse.class);
         when(searchResponse.getHits()).thenReturn(searchHits);
@@ -180,6 +182,7 @@ public class DataFrameAnalyticsTaskTests extends ESTestCase {
             new DataFrameAnalyticsTask(
                 123, "type", "action", null, Map.of(), client, analyticsManager, auditor, taskParams);
         task.init(persistentTasksService, taskManager, "task-id", 42);
+        task.setStatsHolder(new StatsHolder(progress, null, null, new DataCounts("test_job")));
 
         task.persistProgress(client, "task_id", runnable);
 
@@ -243,7 +246,6 @@ public class DataFrameAnalyticsTaskTests extends ESTestCase {
             new StartDataFrameAnalyticsAction.TaskParams(
                 "job-id",
                 Version.CURRENT,
-                progress,
                 false);
 
         SearchResponse searchResponse = mock(SearchResponse.class);
@@ -257,6 +259,7 @@ public class DataFrameAnalyticsTaskTests extends ESTestCase {
             new DataFrameAnalyticsTask(
                 123, "type", "action", null, Map.of(), client, analyticsManager, auditor, taskParams);
         task.init(persistentTasksService, taskManager, "task-id", 42);
+        task.setStatsHolder(new StatsHolder(progress, null, null, new DataCounts("test_job")));
         task.setStep(new StubReindexingStep(task.getStatsHolder().getProgressTracker()));
         Exception exception = new Exception("some exception");
 
@@ -301,7 +304,7 @@ public class DataFrameAnalyticsTaskTests extends ESTestCase {
         };
     }
 
-    private class StubReindexingStep implements DataFrameAnalyticsStep {
+    private static class StubReindexingStep implements DataFrameAnalyticsStep {
 
         private final ProgressTracker progressTracker;
 

+ 2 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/inference/InferenceRunnerTests.java

@@ -22,6 +22,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
 import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
 import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
 import org.elasticsearch.xpack.core.ml.dataframe.analyses.RegressionTests;
+import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts;
 import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
@@ -178,6 +179,6 @@ public class InferenceRunnerTests extends ESTestCase {
 
     private InferenceRunner createInferenceRunner(ExtractedFields extractedFields) {
         return new InferenceRunner(Settings.EMPTY, client, modelLoadingService,  resultsPersisterService, parentTaskId, config,
-            extractedFields, progressTracker, new DataCountsTracker());
+            extractedFields, progressTracker, new DataCountsTracker(new DataCounts(config.getId())));
     }
 }

+ 11 - 5
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java

@@ -19,6 +19,7 @@ import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction;
 import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
 import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests;
 import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetectionTests;
+import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts;
 import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask;
 import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
 import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory;
@@ -97,8 +98,7 @@ public class AnalyticsProcessManagerTests extends ESTestCase {
 
         task = mock(DataFrameAnalyticsTask.class);
         when(task.getAllocationId()).thenReturn(TASK_ALLOCATION_ID);
-        when(task.getStatsHolder()).thenReturn(new StatsHolder(
-            ProgressTracker.fromZeroes(Collections.singletonList("analyzing"), false).report()));
+        when(task.getStatsHolder()).thenReturn(newStatsHolder());
         when(task.getParentTaskId()).thenReturn(new TaskId(""));
         dataFrameAnalyticsConfig = DataFrameAnalyticsConfigTests.createRandomBuilder(CONFIG_ID,
             false,
@@ -117,10 +117,16 @@ public class AnalyticsProcessManagerTests extends ESTestCase {
             processFactory, auditor, trainedModelProvider, resultsPersisterService, 1);
     }
 
+    private StatsHolder newStatsHolder() {
+        return new StatsHolder(ProgressTracker.fromZeroes(Collections.singletonList("analyzing"), false).report(),
+            null,
+            null,
+            new DataCounts(CONFIG_ID));
+    }
+
     public void testRunJob_TaskIsStopping() {
         when(task.isStopping()).thenReturn(true);
-        when(task.getParams()).thenReturn(
-            new StartDataFrameAnalyticsAction.TaskParams("data_frame_id", Version.CURRENT, Collections.emptyList(), false));
+        when(task.getParams()).thenReturn(new StartDataFrameAnalyticsAction.TaskParams("data_frame_id", Version.CURRENT, false));
 
         processManager.runJob(task, dataFrameAnalyticsConfig, dataExtractorFactory, ActionListener.wrap(
             stepResponse -> {
@@ -209,7 +215,7 @@ public class AnalyticsProcessManagerTests extends ESTestCase {
     public void testRunJob_ProcessNotAliveAfterStart() {
         when(process.isProcessAlive()).thenReturn(false);
         when(task.getParams()).thenReturn(
-            new StartDataFrameAnalyticsAction.TaskParams("data_frame_id", Version.CURRENT, Collections.emptyList(), false));
+            new StartDataFrameAnalyticsAction.TaskParams("data_frame_id", Version.CURRENT, false));
 
         processManager.runJob(task, dataFrameAnalyticsConfig, dataExtractorFactory, ActionListener.wrap(
             stepResponse -> fail("Expected error but listener got a response instead"),

+ 6 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java

@@ -14,6 +14,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
 import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
 import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStats;
 import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStatsTests;
+import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts;
 import org.elasticsearch.xpack.core.ml.dataframe.stats.common.MemoryUsage;
 import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStats;
 import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStatsTests;
@@ -59,7 +60,11 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
 
     private AnalyticsProcess<AnalyticsResult> process;
     private DataFrameRowsJoiner dataFrameRowsJoiner;
-    private StatsHolder statsHolder = new StatsHolder(ProgressTracker.fromZeroes(Collections.singletonList("analyzing"), false).report());
+    private StatsHolder statsHolder = new StatsHolder(
+        ProgressTracker.fromZeroes(Collections.singletonList("analyzing"), false).report(),
+        null,
+        null,
+        new DataCounts(JOB_ID));
     private TrainedModelProvider trainedModelProvider;
     private DataFrameAnalyticsAuditor auditor;
     private StatsPersister statsPersister;

+ 31 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/stats/DataCountsTrackerTests.java

@@ -0,0 +1,31 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+
+package org.elasticsearch.xpack.ml.dataframe.stats;
+
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts;
+
+import static org.hamcrest.Matchers.equalTo;
+
+public class DataCountsTrackerTests extends ESTestCase {
+
+    private static final String JOB_ID = "test";
+
+    public void testReset() {
+        DataCountsTracker dataCountsTracker = new DataCountsTracker(new DataCounts(JOB_ID, 10, 20, 30));
+        dataCountsTracker.reset();
+        DataCounts resetDataCounts = dataCountsTracker.report();
+        assertThat(resetDataCounts, equalTo(new DataCounts(JOB_ID)));
+    }
+
+    public void testResetTestDocsCount() {
+        DataCountsTracker dataCountsTracker = new DataCountsTracker(new DataCounts(JOB_ID, 10, 20, 30));
+        dataCountsTracker.resetTestDocsCount();
+        DataCounts resetDataCounts = dataCountsTracker.report();
+        assertThat(resetDataCounts, equalTo(new DataCounts(JOB_ID, 10, 0, 30)));
+    }
+}

+ 11 - 6
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolderTests.java

@@ -7,6 +7,7 @@
 package org.elasticsearch.xpack.ml.dataframe.stats;
 
 import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts;
 import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
 
 import java.util.Arrays;
@@ -29,7 +30,7 @@ public class StatsHolderTests extends ESTestCase {
                 new PhaseProgress("writing_results", 0)
             )
         );
-        StatsHolder statsHolder = new StatsHolder(phases);
+        StatsHolder statsHolder = newStatsHolder(phases);
 
         statsHolder.adjustProgressTracker(Arrays.asList("a", "b"), false);
 
@@ -55,7 +56,7 @@ public class StatsHolderTests extends ESTestCase {
                 new PhaseProgress("writing_results", 50)
             )
         );
-        StatsHolder statsHolder = new StatsHolder(phases);
+        StatsHolder statsHolder = newStatsHolder(phases);
 
         statsHolder.adjustProgressTracker(Arrays.asList("a", "b"), false);
 
@@ -81,7 +82,7 @@ public class StatsHolderTests extends ESTestCase {
                 new PhaseProgress("writing_results", 50)
             )
         );
-        StatsHolder statsHolder = new StatsHolder(phases);
+        StatsHolder statsHolder = newStatsHolder(phases);
 
         statsHolder.adjustProgressTracker(Arrays.asList("c", "d"), false);
 
@@ -107,7 +108,7 @@ public class StatsHolderTests extends ESTestCase {
                 new PhaseProgress("writing_results", 50)
             )
         );
-        StatsHolder statsHolder = new StatsHolder(phases);
+        StatsHolder statsHolder = newStatsHolder(phases);
 
         statsHolder.adjustProgressTracker(Arrays.asList("a", "b"), false);
 
@@ -133,7 +134,7 @@ public class StatsHolderTests extends ESTestCase {
                 new PhaseProgress("inference", 20)
             )
         );
-        StatsHolder statsHolder = new StatsHolder(phases);
+        StatsHolder statsHolder = newStatsHolder(phases);
 
         statsHolder.adjustProgressTracker(Arrays.asList("a", "b"), true);
 
@@ -159,7 +160,7 @@ public class StatsHolderTests extends ESTestCase {
                 new PhaseProgress("writing_results", 50)
             )
         );
-        StatsHolder statsHolder = new StatsHolder(phases);
+        StatsHolder statsHolder = newStatsHolder(phases);
 
         statsHolder.resetProgressTracker(Arrays.asList("a", "b"), false);
 
@@ -174,4 +175,8 @@ public class StatsHolderTests extends ESTestCase {
         assertThat(phaseProgresses.get(3).getProgressPercent(), equalTo(0));
         assertThat(phaseProgresses.get(4).getProgressPercent(), equalTo(0));
     }
+
+    private static StatsHolder newStatsHolder(List<PhaseProgress> progress) {
+        return new StatsHolder(progress, null, null, new DataCounts("test_job"));
+    }
 }

+ 2 - 2
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/JobNodeSelectorTests.java

@@ -819,7 +819,7 @@ public class JobNodeSelectorTests extends ESTestCase {
     static void addDataFrameAnalyticsJobTask(String id, String nodeId, DataFrameAnalyticsState state,
                                              PersistentTasksCustomMetadata.Builder builder, boolean isStale, boolean allowLazyStart) {
         builder.addTask(MlTasks.dataFrameAnalyticsTaskId(id), MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME,
-            new StartDataFrameAnalyticsAction.TaskParams(id, Version.CURRENT, Collections.emptyList(), allowLazyStart),
+            new StartDataFrameAnalyticsAction.TaskParams(id, Version.CURRENT, allowLazyStart),
             new PersistentTasksCustomMetadata.Assignment(nodeId, "test assignment"));
         if (state != null) {
             builder.updateTaskState(MlTasks.dataFrameAnalyticsTaskId(id),
@@ -828,6 +828,6 @@ public class JobNodeSelectorTests extends ESTestCase {
     }
 
     private static TaskParams createTaskParams(String id) {
-        return new TaskParams(id, Version.CURRENT, Collections.emptyList(), false);
+        return new TaskParams(id, Version.CURRENT, false);
     }
 }

+ 2 - 2
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/process/MlMemoryTrackerTests.java

@@ -267,7 +267,7 @@ public class MlMemoryTrackerTests extends ESTestCase {
     PersistentTasksCustomMetadata.PersistentTask<StartDataFrameAnalyticsAction.TaskParams>
     makeTestDataFrameAnalyticsTask(String id, boolean allowLazyStart) {
         return new PersistentTasksCustomMetadata.PersistentTask<>(MlTasks.dataFrameAnalyticsTaskId(id),
-            MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, new StartDataFrameAnalyticsAction.TaskParams(id, Version.CURRENT,
-            Collections.emptyList(), allowLazyStart), 0, PersistentTasksCustomMetadata.INITIAL_ASSIGNMENT);
+            MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, new StartDataFrameAnalyticsAction.TaskParams(id, Version.CURRENT, allowLazyStart),
+            0, PersistentTasksCustomMetadata.INITIAL_ASSIGNMENT);
     }
 }