Browse Source

During ML maintenance, reset jobs in the reset state without a corresponding task. (#106062)

* During ML maintenance, reset jobs in the reset state without a corresponding task.

* Update docs/changelog/106062.yaml

* Fix race condition in MlDailyMaintenanceServiceTests

* Fix log level
Jan Kuipers 1 year ago
parent
commit
24228cd6ea

+ 6 - 0
docs/changelog/106062.yaml

@@ -0,0 +1,6 @@
+pr: 106062
+summary: "During ML maintenance, reset jobs in the reset state without a corresponding\
+  \ task"
+area: Machine Learning
+type: bug
+issues: []

+ 4 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/Job.java

@@ -485,6 +485,10 @@ public class Job implements SimpleDiffable<Job>, Writeable, ToXContentObject {
         return deleting;
     }
 
+    public boolean isResetting() {
+        return blocked != null && Blocked.Reason.RESET.equals(blocked.getReason());
+    }
+
     public boolean allowLazyOpen() {
         return allowLazyOpen;
     }

+ 99 - 44
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MlDailyMaintenanceService.java

@@ -10,9 +10,11 @@ import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 import org.apache.lucene.util.SetOnce;
 import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.ActionType;
 import org.elasticsearch.action.admin.cluster.node.tasks.list.ListTasksRequest;
 import org.elasticsearch.action.admin.cluster.node.tasks.list.ListTasksResponse;
 import org.elasticsearch.action.admin.cluster.node.tasks.list.TransportListTasksAction;
+import org.elasticsearch.action.support.master.AcknowledgedRequest;
 import org.elasticsearch.action.support.master.AcknowledgedResponse;
 import org.elasticsearch.client.internal.Client;
 import org.elasticsearch.cluster.ClusterName;
@@ -27,12 +29,15 @@ import org.elasticsearch.core.Releasable;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.core.Tuple;
 import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
+import org.elasticsearch.tasks.TaskInfo;
 import org.elasticsearch.threadpool.Scheduler;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.xpack.core.ml.MlMetadata;
+import org.elasticsearch.xpack.core.ml.MlTasks;
 import org.elasticsearch.xpack.core.ml.action.DeleteExpiredDataAction;
 import org.elasticsearch.xpack.core.ml.action.DeleteJobAction;
 import org.elasticsearch.xpack.core.ml.action.GetJobsAction;
+import org.elasticsearch.xpack.core.ml.action.ResetJobAction;
 import org.elasticsearch.xpack.core.ml.job.config.Job;
 import org.elasticsearch.xpack.ml.utils.TypedChainTaskExecutor;
 
@@ -42,6 +47,8 @@ import java.util.List;
 import java.util.Objects;
 import java.util.Random;
 import java.util.Set;
+import java.util.function.Function;
+import java.util.function.Predicate;
 import java.util.function.Supplier;
 
 import static java.util.stream.Collectors.toList;
@@ -206,24 +213,34 @@ public class MlDailyMaintenanceService implements Releasable {
     }
 
     private void triggerAnomalyDetectionMaintenance() {
-        // Step 3: Log any error that could have happened
+        // Step 4: Log any error that could have happened
         ActionListener<AcknowledgedResponse> finalListener = ActionListener.wrap(
             unused -> {},
-            e -> logger.error("An error occurred during [ML] maintenance tasks execution", e)
+            e -> logger.warn("An error occurred during [ML] maintenance tasks execution", e)
         );
 
-        // Step 2: Delete expired data
+        // Step 3: Delete expired data
         ActionListener<AcknowledgedResponse> deleteJobsListener = ActionListener.wrap(
             unused -> triggerDeleteExpiredDataTask(finalListener),
             e -> {
-                logger.info("[ML] maintenance task: triggerDeleteJobsInStateDeletingWithoutDeletionTask failed", e);
-                // Note: Steps 1 and 2 are independent of each other and step 2 is executed even if step 1 failed.
+                logger.warn("[ML] maintenance task: triggerResetJobsInStateResetWithoutResetTask failed", e);
+                // Note: Steps 1-3 are independent, so continue upon errors.
                 triggerDeleteExpiredDataTask(finalListener);
             }
         );
 
-        // Step 1: Delete jobs that are in deleting state
-        triggerDeleteJobsInStateDeletingWithoutDeletionTask(deleteJobsListener);
+        // Step 2: Reset jobs that are in resetting state without task
+        ActionListener<AcknowledgedResponse> resetJobsListener = ActionListener.wrap(
+            unused -> triggerResetJobsInStateResetWithoutResetTask(deleteJobsListener),
+            e -> {
+                logger.warn("[ML] maintenance task: triggerDeleteJobsInStateDeletingWithoutDeletionTask failed", e);
+                // Note: Steps 1-3 are independent, so continue upon errors.
+                triggerResetJobsInStateResetWithoutResetTask(deleteJobsListener);
+            }
+        );
+
+        // Step 1: Delete jobs that are in deleting state without task
+        triggerDeleteJobsInStateDeletingWithoutDeletionTask(resetJobsListener);
     }
 
     private void triggerDataFrameAnalyticsMaintenance() {
@@ -257,73 +274,111 @@ public class MlDailyMaintenanceService implements Releasable {
 
     // Visible for testing
     public void triggerDeleteJobsInStateDeletingWithoutDeletionTask(ActionListener<AcknowledgedResponse> finalListener) {
-        SetOnce<Set<String>> jobsInStateDeletingHolder = new SetOnce<>();
-
-        ActionListener<List<Tuple<DeleteJobAction.Request, AcknowledgedResponse>>> deleteJobsActionListener = finalListener
-            .delegateFailureAndWrap((delegate, deleteJobsResponses) -> {
-                List<String> jobIds = deleteJobsResponses.stream()
-                    .filter(t -> t.v2().isAcknowledged() == false)
-                    .map(Tuple::v1)
-                    .map(DeleteJobAction.Request::getJobId)
-                    .collect(toList());
+        triggerJobsInStateWithoutMatchingTask(
+            "triggerDeleteJobsInStateDeletingWithoutDeletionTask",
+            Job::isDeleting,
+            DeleteJobAction.NAME,
+            taskInfo -> stripPrefixOrNull(taskInfo.description(), DeleteJobAction.DELETION_TASK_DESCRIPTION_PREFIX),
+            DeleteJobAction.INSTANCE,
+            DeleteJobAction.Request::new,
+            finalListener
+        );
+    }
+
+    public void triggerResetJobsInStateResetWithoutResetTask(ActionListener<AcknowledgedResponse> finalListener) {
+        triggerJobsInStateWithoutMatchingTask(
+            "triggerResetJobsInStateResetWithoutResetTask",
+            Job::isResetting,
+            ResetJobAction.NAME,
+            taskInfo -> stripPrefixOrNull(taskInfo.description(), MlTasks.JOB_TASK_ID_PREFIX),
+            ResetJobAction.INSTANCE,
+            ResetJobAction.Request::new,
+            finalListener
+        );
+    }
+
+    /**
+     * @return If the string starts with the prefix, this returns the string without the prefix.
+     *         Otherwise, this return null.
+     */
+    private static String stripPrefixOrNull(String str, String prefix) {
+        return str == null || str.startsWith(prefix) == false ? null : str.substring(prefix.length());
+    }
+
+    /**
+     * Executes a request for each job in a state, while missing the corresponding task. This
+     * usually indicates the node originally executing the task has died, so retry the request.
+     *
+     * @param maintenanceTaskName Name of ML maintenance task; used only for logging.
+     * @param jobFilter           Predicate for filtering the jobs.
+     * @param taskActionName      Action name of the tasks corresponding to the jobs.
+     * @param jobIdExtractor      Function to extract the job ID from the task info (in order to match to the job).
+     * @param actionType          Action type of the request that should be (re)executed.
+     * @param requestCreator      Function to create the request from the job ID.
+     * @param finalListener       Listener that captures the final response.
+     */
+    private void triggerJobsInStateWithoutMatchingTask(
+        String maintenanceTaskName,
+        Predicate<Job> jobFilter,
+        String taskActionName,
+        Function<TaskInfo, String> jobIdExtractor,
+        ActionType<AcknowledgedResponse> actionType,
+        Function<String, AcknowledgedRequest<?>> requestCreator,
+        ActionListener<AcknowledgedResponse> finalListener
+    ) {
+        SetOnce<Set<String>> jobsInStateHolder = new SetOnce<>();
+
+        ActionListener<List<Tuple<String, AcknowledgedResponse>>> jobsActionListener = finalListener.delegateFailureAndWrap(
+            (delegate, jobsResponses) -> {
+                List<String> jobIds = jobsResponses.stream().filter(t -> t.v2().isAcknowledged() == false).map(Tuple::v1).collect(toList());
                 if (jobIds.isEmpty()) {
-                    logger.info("Successfully completed [ML] maintenance task: triggerDeleteJobsInStateDeletingWithoutDeletionTask");
+                    logger.info("Successfully completed [ML] maintenance task: {}", maintenanceTaskName);
                 } else {
-                    logger.info("The following ML jobs could not be deleted: [" + String.join(",", jobIds) + "]");
+                    logger.info("[ML] maintenance task {} failed for jobs: {}", maintenanceTaskName, jobIds);
                 }
                 delegate.onResponse(AcknowledgedResponse.TRUE);
-            });
+            }
+        );
 
         ActionListener<ListTasksResponse> listTasksActionListener = ActionListener.wrap(listTasksResponse -> {
-            Set<String> jobsInStateDeleting = jobsInStateDeletingHolder.get();
-            Set<String> jobsWithDeletionTask = listTasksResponse.getTasks()
-                .stream()
-                .filter(t -> t.description() != null)
-                .filter(t -> t.description().startsWith(DeleteJobAction.DELETION_TASK_DESCRIPTION_PREFIX))
-                .map(t -> t.description().substring(DeleteJobAction.DELETION_TASK_DESCRIPTION_PREFIX.length()))
-                .collect(toSet());
-            Set<String> jobsInStateDeletingWithoutDeletionTask = Sets.difference(jobsInStateDeleting, jobsWithDeletionTask);
-            if (jobsInStateDeletingWithoutDeletionTask.isEmpty()) {
+            Set<String> jobsInState = jobsInStateHolder.get();
+            Set<String> jobsWithTask = listTasksResponse.getTasks().stream().map(jobIdExtractor).filter(Objects::nonNull).collect(toSet());
+            Set<String> jobsInStateWithoutTask = Sets.difference(jobsInState, jobsWithTask);
+            if (jobsInStateWithoutTask.isEmpty()) {
                 finalListener.onResponse(AcknowledgedResponse.TRUE);
                 return;
             }
-            TypedChainTaskExecutor<Tuple<DeleteJobAction.Request, AcknowledgedResponse>> chainTaskExecutor = new TypedChainTaskExecutor<>(
+            TypedChainTaskExecutor<Tuple<String, AcknowledgedResponse>> chainTaskExecutor = new TypedChainTaskExecutor<>(
                 EsExecutors.DIRECT_EXECUTOR_SERVICE,
                 Predicates.always(),
                 Predicates.always()
             );
-            for (String jobId : jobsInStateDeletingWithoutDeletionTask) {
-                DeleteJobAction.Request request = new DeleteJobAction.Request(jobId);
+            for (String jobId : jobsInStateWithoutTask) {
                 chainTaskExecutor.add(
                     listener -> executeAsyncWithOrigin(
                         client,
                         ML_ORIGIN,
-                        DeleteJobAction.INSTANCE,
-                        request,
-                        listener.delegateFailureAndWrap((l, response) -> l.onResponse(Tuple.tuple(request, response)))
+                        actionType,
+                        requestCreator.apply(jobId),
+                        listener.delegateFailureAndWrap((l, response) -> l.onResponse(Tuple.tuple(jobId, response)))
                     )
                 );
             }
-            chainTaskExecutor.execute(deleteJobsActionListener);
+            chainTaskExecutor.execute(jobsActionListener);
         }, finalListener::onFailure);
 
         ActionListener<GetJobsAction.Response> getJobsActionListener = ActionListener.wrap(getJobsResponse -> {
-            Set<String> jobsInStateDeleting = getJobsResponse.getResponse()
-                .results()
-                .stream()
-                .filter(Job::isDeleting)
-                .map(Job::getId)
-                .collect(toSet());
-            if (jobsInStateDeleting.isEmpty()) {
+            Set<String> jobsInState = getJobsResponse.getResponse().results().stream().filter(jobFilter).map(Job::getId).collect(toSet());
+            if (jobsInState.isEmpty()) {
                 finalListener.onResponse(AcknowledgedResponse.TRUE);
                 return;
             }
-            jobsInStateDeletingHolder.set(jobsInStateDeleting);
+            jobsInStateHolder.set(jobsInState);
             executeAsyncWithOrigin(
                 client,
                 ML_ORIGIN,
                 TransportListTasksAction.TYPE,
-                new ListTasksRequest().setActions(DeleteJobAction.NAME),
+                new ListTasksRequest().setActions(taskActionName),
                 listTasksActionListener
             );
         }, finalListener::onFailure);

+ 109 - 61
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlDailyMaintenanceServiceTests.java

@@ -31,6 +31,7 @@ import org.elasticsearch.xpack.core.ml.MlMetadata;
 import org.elasticsearch.xpack.core.ml.action.DeleteExpiredDataAction;
 import org.elasticsearch.xpack.core.ml.action.DeleteJobAction;
 import org.elasticsearch.xpack.core.ml.action.GetJobsAction;
+import org.elasticsearch.xpack.core.ml.action.ResetJobAction;
 import org.elasticsearch.xpack.core.ml.job.config.Job;
 import org.junit.After;
 import org.junit.Before;
@@ -38,8 +39,10 @@ import org.mockito.Mockito;
 import org.mockito.stubbing.Answer;
 
 import java.util.Collections;
+import java.util.List;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.TimeUnit;
+import java.util.function.Supplier;
 
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.same;
@@ -79,29 +82,21 @@ public class MlDailyMaintenanceServiceTests extends ESTestCase {
         doAnswer(withResponse(new GetJobsAction.Response(new QueryPage<>(Collections.emptyList(), 0, new ParseField(""))))).when(client)
             .execute(same(GetJobsAction.INSTANCE), any(), any());
 
-        int triggerCount = randomIntBetween(2, 4);
-        CountDownLatch latch = new CountDownLatch(triggerCount);
-        try (MlDailyMaintenanceService service = createService(latch, client)) {
-            service.start();
-            latch.await(5, TimeUnit.SECONDS);
-        }
+        int triggerCount = randomIntBetween(1, 3);
+        executeMaintenanceTriggers(triggerCount);
 
-        verify(client, times(triggerCount - 1)).execute(same(DeleteExpiredDataAction.INSTANCE), any(), any());
-        verify(client, times(triggerCount - 1)).execute(same(GetJobsAction.INSTANCE), any(), any());
-        verify(mlAssignmentNotifier, times(triggerCount - 1)).auditUnassignedMlTasks(any(), any());
+        verify(client, times(triggerCount)).execute(same(DeleteExpiredDataAction.INSTANCE), any(), any());
+        verify(client, times(2 * triggerCount)).execute(same(GetJobsAction.INSTANCE), any(), any());
+        verify(mlAssignmentNotifier, times(triggerCount)).auditUnassignedMlTasks(any(), any());
     }
 
     public void testScheduledTriggeringWhileUpgradeModeIsEnabled() throws InterruptedException {
         when(clusterService.state()).thenReturn(createClusterState(true));
 
-        int triggerCount = randomIntBetween(2, 4);
-        CountDownLatch latch = new CountDownLatch(triggerCount);
-        try (MlDailyMaintenanceService service = createService(latch, client)) {
-            service.start();
-            latch.await(5, TimeUnit.SECONDS);
-        }
+        int triggerCount = randomIntBetween(1, 3);
+        executeMaintenanceTriggers(triggerCount);
 
-        verify(clusterService, times(triggerCount - 1)).state();
+        verify(clusterService, times(triggerCount)).state();
         verifyNoMoreInteractions(client, clusterService, mlAssignmentNotifier);
     }
 
@@ -143,11 +138,7 @@ public class MlDailyMaintenanceServiceTests extends ESTestCase {
     public void testNoAnomalyDetectionTasksWhenDisabled() throws InterruptedException {
         when(clusterService.state()).thenReturn(createClusterState(false));
 
-        CountDownLatch latch = new CountDownLatch(2);
-        try (MlDailyMaintenanceService service = createService(latch, client, false, randomBoolean(), randomBoolean())) {
-            service.start();
-            latch.await(5, TimeUnit.SECONDS);
-        }
+        executeMaintenanceTriggers(1, false, randomBoolean(), randomBoolean());
 
         verify(client, never()).threadPool();
         verify(client, never()).execute(same(DeleteExpiredDataAction.INSTANCE), any(), any());
@@ -160,15 +151,11 @@ public class MlDailyMaintenanceServiceTests extends ESTestCase {
         doAnswer(deleteExpiredDataAnswer).when(client).execute(same(DeleteExpiredDataAction.INSTANCE), any(), any());
         doAnswer(getJobsAnswer).when(client).execute(same(GetJobsAction.INSTANCE), any(), any());
 
-        CountDownLatch latch = new CountDownLatch(2);
-        try (MlDailyMaintenanceService service = createService(latch, client)) {
-            service.start();
-            latch.await(5, TimeUnit.SECONDS);
-        }
+        executeMaintenanceTriggers(1);
 
-        verify(client, Mockito.atLeast(2)).threadPool();
-        verify(client, Mockito.atLeast(1)).execute(same(DeleteExpiredDataAction.INSTANCE), any(), any());
-        verify(client, Mockito.atLeast(1)).execute(same(GetJobsAction.INSTANCE), any(), any());
+        verify(client, times(3)).threadPool();
+        verify(client, times(1)).execute(same(DeleteExpiredDataAction.INSTANCE), any(), any());
+        verify(client, times(2)).execute(same(GetJobsAction.INSTANCE), any(), any());
         verify(mlAssignmentNotifier, Mockito.atLeast(1)).auditUnassignedMlTasks(any(), any());
     }
 
@@ -202,14 +189,10 @@ public class MlDailyMaintenanceServiceTests extends ESTestCase {
             .when(client)
             .execute(same(TransportListTasksAction.TYPE), any(), any());
 
-        CountDownLatch latch = new CountDownLatch(2);
-        try (MlDailyMaintenanceService service = createService(latch, client)) {
-            service.start();
-            latch.await(5, TimeUnit.SECONDS);
-        }
+        executeMaintenanceTriggers(1);
 
-        verify(client, times(3)).threadPool();
-        verify(client).execute(same(GetJobsAction.INSTANCE), any(), any());
+        verify(client, times(4)).threadPool();
+        verify(client, times(2)).execute(same(GetJobsAction.INSTANCE), any(), any());
         verify(client).execute(same(TransportListTasksAction.TYPE), any(), any());
         verify(client).execute(same(DeleteExpiredDataAction.INSTANCE), any(), any());
         verify(mlAssignmentNotifier).auditUnassignedMlTasks(any(), any());
@@ -240,14 +223,10 @@ public class MlDailyMaintenanceServiceTests extends ESTestCase {
         ).execute(same(TransportListTasksAction.TYPE), any(), any());
         doAnswer(withResponse(AcknowledgedResponse.of(deleted))).when(client).execute(same(DeleteJobAction.INSTANCE), any(), any());
 
-        CountDownLatch latch = new CountDownLatch(2);
-        try (MlDailyMaintenanceService service = createService(latch, client)) {
-            service.start();
-            latch.await(5, TimeUnit.SECONDS);
-        }
+        executeMaintenanceTriggers(1);
 
-        verify(client, times(4)).threadPool();
-        verify(client).execute(same(GetJobsAction.INSTANCE), any(), any());
+        verify(client, times(5)).threadPool();
+        verify(client, times(2)).execute(same(GetJobsAction.INSTANCE), any(), any());
         verify(client).execute(same(TransportListTasksAction.TYPE), any(), any());
         verify(client).execute(same(DeleteJobAction.INSTANCE), any(), any());
         verify(client).execute(same(DeleteExpiredDataAction.INSTANCE), any(), any());
@@ -255,29 +234,98 @@ public class MlDailyMaintenanceServiceTests extends ESTestCase {
         verifyNoMoreInteractions(client, mlAssignmentNotifier);
     }
 
-    private MlDailyMaintenanceService createService(CountDownLatch latch, Client client) {
-        return createService(latch, client, true, true, true);
+    public void testJobInResettingState_doesNotHaveResetTask() throws InterruptedException {
+        testJobInResettingState(false);
+    }
+
+    public void testJobInResettingState_hasResetTask() throws InterruptedException {
+        testJobInResettingState(true);
+    }
+
+    private void testJobInResettingState(boolean hasResetTask) throws InterruptedException {
+        String jobId = "job-in-state-resetting";
+        when(clusterService.state()).thenReturn(createClusterState(false));
+        doAnswer(withResponse(new DeleteExpiredDataAction.Response(true))).when(client)
+            .execute(same(DeleteExpiredDataAction.INSTANCE), any(), any());
+        Job job = mock(Job.class);
+        when(job.getId()).thenReturn(jobId);
+        when(job.isDeleting()).thenReturn(false);
+        when(job.isResetting()).thenReturn(true);
+        doAnswer(withResponse(new GetJobsAction.Response(new QueryPage<>(List.of(job), 1, new ParseField(""))))).when(client)
+            .execute(same(GetJobsAction.INSTANCE), any(), any());
+        List<TaskInfo> tasks = hasResetTask
+            ? List.of(
+                new TaskInfo(
+                    new TaskId("test", 123),
+                    "test",
+                    "test",
+                    ResetJobAction.NAME,
+                    "job-" + jobId,
+                    null,
+                    0,
+                    0,
+                    true,
+                    false,
+                    new TaskId("test", 456),
+                    Collections.emptyMap()
+                )
+            )
+            : List.of();
+        doAnswer(withResponse(new ListTasksResponse(tasks, List.of(), List.of()))).when(client)
+            .execute(same(TransportListTasksAction.TYPE), any(), any());
+        doAnswer(withResponse(AcknowledgedResponse.of(true))).when(client).execute(same(ResetJobAction.INSTANCE), any(), any());
+
+        executeMaintenanceTriggers(1);
+
+        verify(client, times(hasResetTask ? 4 : 5)).threadPool();
+        verify(client, times(2)).execute(same(GetJobsAction.INSTANCE), any(), any());
+        verify(client).execute(same(TransportListTasksAction.TYPE), any(), any());
+        if (hasResetTask == false) {
+            verify(client).execute(same(ResetJobAction.INSTANCE), any(), any());
+        }
+        verify(client).execute(same(DeleteExpiredDataAction.INSTANCE), any(), any());
+        verify(mlAssignmentNotifier).auditUnassignedMlTasks(any(), any());
+        verifyNoMoreInteractions(client, mlAssignmentNotifier);
+    }
+
+    private void executeMaintenanceTriggers(int triggerCount) throws InterruptedException {
+        executeMaintenanceTriggers(triggerCount, true, true, true);
     }
 
-    private MlDailyMaintenanceService createService(
-        CountDownLatch latch,
-        Client client,
+    private void executeMaintenanceTriggers(
+        int triggerCount,
         boolean isAnomalyDetectionEnabled,
         boolean isDataFrameAnalyticsEnabled,
         boolean isNlpEnabled
-    ) {
-        return new MlDailyMaintenanceService(Settings.EMPTY, threadPool, client, clusterService, mlAssignmentNotifier, () -> {
-            // We need to be careful that an unexpected iteration doesn't get squeezed in by the maintenance threadpool in
-            // between the latch getting counted down to zero and the main test thread stopping the maintenance service.
-            // This could happen if the main test thread happens to be waiting for a CPU for the whole 100ms after the
-            // latch counts down to zero.
-            if (latch.getCount() > 0) {
-                latch.countDown();
-                return TimeValue.timeValueMillis(100);
-            } else {
-                return TimeValue.timeValueHours(1);
-            }
-        }, isAnomalyDetectionEnabled, isDataFrameAnalyticsEnabled, isNlpEnabled);
+    ) throws InterruptedException {
+        // The scheduleProvider is called upon scheduling. The latch waits for (triggerCount + 1)
+        // schedules to happen, which means that the maintenance task is executed triggerCount
+        // times. The first triggerCount invocations of the scheduleProvider return 100ms, which
+        // is the time between the executed maintenance tasks.
+        // After that, maintenance task (triggerCount + 1) is scheduled after 100sec, the latch is
+        // released, the service is closed, and the method returns. Task (triggerCount + 1) is
+        // therefore never executed.
+        CountDownLatch latch = new CountDownLatch(triggerCount + 1);
+        Supplier<TimeValue> scheduleProvider = () -> {
+            latch.countDown();
+            return TimeValue.timeValueMillis(latch.getCount() > 0 ? 100 : 100_000);
+        };
+        try (
+            MlDailyMaintenanceService service = new MlDailyMaintenanceService(
+                Settings.EMPTY,
+                threadPool,
+                client,
+                clusterService,
+                mlAssignmentNotifier,
+                scheduleProvider,
+                isAnomalyDetectionEnabled,
+                isDataFrameAnalyticsEnabled,
+                isNlpEnabled
+            )
+        ) {
+            service.start();
+            latch.await(5, TimeUnit.SECONDS);
+        }
     }
 
     private static ClusterState createClusterState(boolean isUpgradeMode) {