Browse Source

[ML] remove thread sleep from results persister (#65904)

* [ML] remove thread sleep from results persister
Having a thread sleep in a recurring action may cause issues on node shutdown.
What if the thread is sleeping while a nice shutdown is occurring? Since these retry timeouts
can extend to a larger period of time, we should instead use scheduled tasks + the threadpool.
This allows the retries to be effectively canceled instead of waiting for a thread to wake back up.

closes https://github.com/elastic/elasticsearch/issues/65890
Benjamin Trent 4 years ago
parent
commit
eb91e35b1c
13 changed files with 407 additions and 144 deletions
  1. 10 2
      server/src/main/java/org/elasticsearch/action/support/RetryableAction.java
  2. 2 0
      x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/BulkFailureRetryIT.java
  3. 7 3
      x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/AutodetectResultProcessorIT.java
  4. 2 3
      x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/EstablishedMemUsageIT.java
  5. 2 3
      x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/JobResultsProviderIT.java
  6. 2 3
      x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/JobStorageDeletionTaskIT.java
  7. 6 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java
  8. 237 115
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/utils/persistence/ResultsPersisterService.java
  9. 76 0
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlSingleNodeTestCase.java
  10. 1 1
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/annotations/AnnotationPersisterTests.java
  11. 15 1
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/persistence/JobResultsPersisterTests.java
  12. 23 0
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/support/BaseMlIntegTestCase.java
  13. 24 12
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/persistence/ResultsPersisterServiceTests.java

+ 10 - 2
server/src/main/java/org/elasticsearch/action/support/RetryableAction.java

@@ -114,6 +114,14 @@ public abstract class RetryableAction<Response> {
 
     public abstract boolean shouldRetry(Exception e);
 
+    protected long calculateDelay(long previousDelay) {
+        return Math.min(previousDelay * 2, Integer.MAX_VALUE);
+    }
+
+    protected long minimumDelayMillis() {
+        return 1L;
+    }
+
     public void onFinished() {
     }
 
@@ -148,10 +156,10 @@ public abstract class RetryableAction<Response> {
                 } else {
                     addException(e);
 
-                    final long nextDelayMillisBound = Math.min(delayMillisBound * 2, Integer.MAX_VALUE);
+                    final long nextDelayMillisBound = calculateDelay(delayMillisBound);
                     final RetryingListener retryingListener = new RetryingListener(nextDelayMillisBound, caughtExceptions);
                     final Runnable runnable = createRunnable(retryingListener);
-                    final long delayMillis = Randomness.get().nextInt(Math.toIntExact(delayMillisBound)) + 1;
+                    final long delayMillis = Randomness.get().nextInt(Math.toIntExact(delayMillisBound)) + minimumDelayMillis();
                     if (isDone.get() == false) {
                         final TimeValue delay = TimeValue.timeValueMillis(delayMillis);
                         logger.debug(() -> new ParameterizedMessage("retrying action that failed in {}", delay), e);

+ 2 - 0
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/BulkFailureRetryIT.java

@@ -65,6 +65,7 @@ public class BulkFailureRetryIT extends MlNativeAutodetectIntegTestCase {
                 .putNull("logger.org.elasticsearch.xpack.ml.datafeed.DatafeedJob")
                 .putNull("logger.org.elasticsearch.xpack.ml.job.persistence.JobResultsPersister")
                 .putNull("logger.org.elasticsearch.xpack.ml.job.process.autodetect.output")
+                .putNull("logger.org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService")
                 .build()).get();
         cleanUp();
     }
@@ -121,6 +122,7 @@ public class BulkFailureRetryIT extends MlNativeAutodetectIntegTestCase {
             .setTransientSettings(Settings.builder()
                 .put("logger.org.elasticsearch.xpack.ml.datafeed.DatafeedJob", "TRACE")
                 .put("logger.org.elasticsearch.xpack.ml.job.persistence.JobResultsPersister", "TRACE")
+                .put("logger.org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService", "TRACE")
                 .put("logger.org.elasticsearch.xpack.ml.job.process.autodetect.output", "TRACE")
                 .put("xpack.ml.persist_results_max_retries", "15")
                 .build()).get();

+ 7 - 3
x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/AutodetectResultProcessorIT.java

@@ -28,6 +28,7 @@ import org.elasticsearch.common.util.concurrent.ThreadContext;
 import org.elasticsearch.common.xcontent.ToXContent;
 import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.index.reindex.ReindexPlugin;
+import org.elasticsearch.ingest.common.IngestCommonPlugin;
 import org.elasticsearch.plugins.Plugin;
 import org.elasticsearch.search.SearchHit;
 import org.elasticsearch.threadpool.ThreadPool;
@@ -56,6 +57,7 @@ import org.elasticsearch.xpack.core.ml.job.results.Bucket;
 import org.elasticsearch.xpack.core.ml.job.results.CategoryDefinition;
 import org.elasticsearch.xpack.core.ml.job.results.Influencer;
 import org.elasticsearch.xpack.core.ml.job.results.ModelPlot;
+import org.elasticsearch.xpack.datastreams.DataStreamsPlugin;
 import org.elasticsearch.xpack.ilm.IndexLifecycle;
 import org.elasticsearch.xpack.ml.LocalStateMachineLearning;
 import org.elasticsearch.xpack.ml.MlSingleNodeTestCase;
@@ -126,7 +128,10 @@ public class AutodetectResultProcessorIT extends MlSingleNodeTestCase {
     protected Collection<Class<? extends Plugin>> getPlugins() {
         return pluginList(
             LocalStateMachineLearning.class,
+            DataStreamsPlugin.class,
+            IngestCommonPlugin.class,
             ReindexPlugin.class,
+            MockPainlessScriptEngine.TestPlugin.class,
             // ILM is required for .ml-state template index settings
             IndexLifecycle.class);
     }
@@ -141,7 +146,7 @@ public class AutodetectResultProcessorIT extends MlSingleNodeTestCase {
         renormalizer = mock(Renormalizer.class);
         process = mock(AutodetectProcess.class);
         capturedUpdateModelSnapshotOnJobRequests = new ArrayList<>();
-        ThreadPool tp = mock(ThreadPool.class);
+        ThreadPool tp = mockThreadPool();
         Settings settings = Settings.builder().put("node.name", "InferenceProcessorFactoryTests_node").build();
         ClusterSettings clusterSettings = new ClusterSettings(settings,
             new HashSet<>(Arrays.asList(InferenceProcessor.MAX_INFERENCE_PROCESSORS,
@@ -151,9 +156,8 @@ public class AutodetectResultProcessorIT extends MlSingleNodeTestCase {
                 ResultsPersisterService.PERSIST_RESULTS_MAX_RETRIES,
                 ClusterApplierService.CLUSTER_SERVICE_SLOW_TASK_LOGGING_THRESHOLD_SETTING)));
         ClusterService clusterService = new ClusterService(settings, clusterSettings, tp);
-
         OriginSettingClient originSettingClient = new OriginSettingClient(client(), ClientHelper.ML_ORIGIN);
-        resultsPersisterService = new ResultsPersisterService(originSettingClient, clusterService, settings);
+        resultsPersisterService = new ResultsPersisterService(tp, originSettingClient, clusterService, settings);
         resultProcessor = new AutodetectResultProcessor(
                 client(),
                 auditor,

+ 2 - 3
x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/EstablishedMemUsageIT.java

@@ -36,7 +36,6 @@ import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.atomic.AtomicReference;
 
 import static org.hamcrest.CoreMatchers.equalTo;
-import static org.mockito.Mockito.mock;
 
 public class EstablishedMemUsageIT extends BaseMlIntegTestCase {
 
@@ -48,7 +47,7 @@ public class EstablishedMemUsageIT extends BaseMlIntegTestCase {
     @Before
     public void createComponents() {
         Settings settings = nodeSettings(0);
-        ThreadPool tp = mock(ThreadPool.class);
+        ThreadPool tp = mockThreadPool();
         ClusterSettings clusterSettings = new ClusterSettings(settings,
             new HashSet<>(Arrays.asList(InferenceProcessor.MAX_INFERENCE_PROCESSORS,
                 MasterService.MASTER_SERVICE_SLOW_TASK_LOGGING_THRESHOLD_SETTING,
@@ -59,7 +58,7 @@ public class EstablishedMemUsageIT extends BaseMlIntegTestCase {
         ClusterService clusterService = new ClusterService(settings, clusterSettings, tp);
 
         OriginSettingClient originSettingClient = new OriginSettingClient(client(), ClientHelper.ML_ORIGIN);
-        ResultsPersisterService resultsPersisterService = new ResultsPersisterService(originSettingClient, clusterService, settings);
+        ResultsPersisterService resultsPersisterService = new ResultsPersisterService(tp, originSettingClient, clusterService, settings);
         jobResultsProvider = new JobResultsProvider(client(), settings, new IndexNameExpressionResolver(new ThreadContext(Settings.EMPTY)));
         jobResultsPersister = new JobResultsPersister(
             originSettingClient, resultsPersisterService, new AnomalyDetectionAuditor(client(), clusterService));

+ 2 - 3
x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/JobResultsProviderIT.java

@@ -100,7 +100,6 @@ import static org.hamcrest.Matchers.in;
 import static org.hamcrest.Matchers.not;
 import static org.hamcrest.collection.IsEmptyCollection.empty;
 import static org.hamcrest.core.Is.is;
-import static org.mockito.Mockito.mock;
 
 
 public class JobResultsProviderIT extends MlSingleNodeTestCase {
@@ -114,7 +113,7 @@ public class JobResultsProviderIT extends MlSingleNodeTestCase {
         Settings.Builder builder = Settings.builder()
                 .put(UnassignedInfo.INDEX_DELAYED_NODE_LEFT_TIMEOUT_SETTING.getKey(), TimeValue.timeValueSeconds(1));
         jobProvider = new JobResultsProvider(client(), builder.build(), new IndexNameExpressionResolver(new ThreadContext(Settings.EMPTY)));
-        ThreadPool tp = mock(ThreadPool.class);
+        ThreadPool tp = mockThreadPool();
         ClusterSettings clusterSettings = new ClusterSettings(builder.build(),
             new HashSet<>(Arrays.asList(InferenceProcessor.MAX_INFERENCE_PROCESSORS,
                 MasterService.MASTER_SERVICE_SLOW_TASK_LOGGING_THRESHOLD_SETTING,
@@ -125,7 +124,7 @@ public class JobResultsProviderIT extends MlSingleNodeTestCase {
         ClusterService clusterService = new ClusterService(builder.build(), clusterSettings, tp);
 
         OriginSettingClient originSettingClient = new OriginSettingClient(client(), ClientHelper.ML_ORIGIN);
-        resultsPersisterService = new ResultsPersisterService(originSettingClient, clusterService, builder.build());
+        resultsPersisterService = new ResultsPersisterService(tp, originSettingClient, clusterService, builder.build());
         auditor = new AnomalyDetectionAuditor(client(), clusterService);
         waitForMlTemplates();
     }

+ 2 - 3
x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/JobStorageDeletionTaskIT.java

@@ -47,7 +47,6 @@ import java.util.concurrent.atomic.AtomicReference;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.is;
 import static org.hamcrest.Matchers.nullValue;
-import static org.mockito.Mockito.mock;
 
 /**
  * Test that ML does not touch unnecessary indices when removing job index aliases
@@ -63,7 +62,7 @@ public class JobStorageDeletionTaskIT extends BaseMlIntegTestCase {
     @Before
     public void createComponents() {
         Settings settings = nodeSettings(0);
-        ThreadPool tp = mock(ThreadPool.class);
+        ThreadPool tp = mockThreadPool();
         ClusterSettings clusterSettings = new ClusterSettings(settings,
             new HashSet<>(Arrays.asList(InferenceProcessor.MAX_INFERENCE_PROCESSORS,
                 MasterService.MASTER_SERVICE_SLOW_TASK_LOGGING_THRESHOLD_SETTING,
@@ -73,7 +72,7 @@ public class JobStorageDeletionTaskIT extends BaseMlIntegTestCase {
                 ClusterApplierService.CLUSTER_SERVICE_SLOW_TASK_LOGGING_THRESHOLD_SETTING)));
         ClusterService clusterService = new ClusterService(settings, clusterSettings, tp);
         OriginSettingClient originSettingClient = new OriginSettingClient(client(), ClientHelper.ML_ORIGIN);
-        ResultsPersisterService resultsPersisterService = new ResultsPersisterService(originSettingClient, clusterService, settings);
+        ResultsPersisterService resultsPersisterService = new ResultsPersisterService(tp, originSettingClient, clusterService, settings);
         jobResultsProvider = new JobResultsProvider(client(), settings, new IndexNameExpressionResolver(new ThreadContext(Settings.EMPTY)));
         jobResultsPersister = new JobResultsPersister(
             originSettingClient, resultsPersisterService, new AnomalyDetectionAuditor(client(), clusterService));

+ 6 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java

@@ -635,7 +635,12 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
         InferenceAuditor inferenceAuditor = new InferenceAuditor(client, clusterService);
         this.dataFrameAnalyticsAuditor.set(dataFrameAnalyticsAuditor);
         OriginSettingClient originSettingClient = new OriginSettingClient(client, ClientHelper.ML_ORIGIN);
-        ResultsPersisterService resultsPersisterService = new ResultsPersisterService(originSettingClient, clusterService, settings);
+        ResultsPersisterService resultsPersisterService = new ResultsPersisterService(
+            threadPool,
+            originSettingClient,
+            clusterService,
+            settings
+        );
         AnnotationPersister anomalyDetectionAnnotationPersister = new AnnotationPersister(resultsPersisterService);
         JobResultsProvider jobResultsProvider = new JobResultsProvider(client, settings, indexNameExpressionResolver);
         JobResultsPersister jobResultsPersister =

+ 237 - 115
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/utils/persistence/ResultsPersisterService.java

@@ -8,9 +8,9 @@ package org.elasticsearch.xpack.ml.utils.persistence;
 import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 import org.apache.logging.log4j.message.ParameterizedMessage;
-import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.ExceptionsHelper;
+import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.bulk.BulkAction;
 import org.elasticsearch.action.bulk.BulkItemResponse;
 import org.elasticsearch.action.bulk.BulkRequest;
@@ -18,11 +18,11 @@ import org.elasticsearch.action.bulk.BulkResponse;
 import org.elasticsearch.action.index.IndexRequest;
 import org.elasticsearch.action.search.SearchRequest;
 import org.elasticsearch.action.search.SearchResponse;
+import org.elasticsearch.action.support.PlainActionFuture;
+import org.elasticsearch.action.support.RetryableAction;
 import org.elasticsearch.action.support.WriteRequest;
 import org.elasticsearch.client.OriginSettingClient;
 import org.elasticsearch.cluster.service.ClusterService;
-import org.elasticsearch.common.CheckedConsumer;
-import org.elasticsearch.common.Randomness;
 import org.elasticsearch.common.settings.Setting;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.unit.TimeValue;
@@ -30,6 +30,7 @@ import org.elasticsearch.common.xcontent.ToXContent;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentFactory;
 import org.elasticsearch.rest.RestStatus;
+import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.xpack.core.ClientHelper;
 
 import java.io.IOException;
@@ -38,14 +39,14 @@ import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashSet;
 import java.util.Map;
-import java.util.Random;
 import java.util.Set;
+import java.util.function.BiConsumer;
 import java.util.function.Consumer;
-import java.util.function.Function;
 import java.util.function.Supplier;
 import java.util.stream.Collectors;
 
 import static org.elasticsearch.ExceptionsHelper.status;
+import static org.elasticsearch.xpack.ml.MachineLearning.UTILITY_THREAD_POOL_NAME;
 
 public class ResultsPersisterService {
     /**
@@ -79,20 +80,16 @@ public class ResultsPersisterService {
     // Having an exponent higher than this causes integer overflow
     private static final int MAX_RETRY_EXPONENT = 24;
 
-    private final CheckedConsumer<Integer, InterruptedException> sleeper;
+    private final ThreadPool threadPool;
     private final OriginSettingClient client;
     private volatile int maxFailureRetries;
 
-    public ResultsPersisterService(OriginSettingClient client, ClusterService clusterService, Settings settings) {
-        this(Thread::sleep, client, clusterService, settings);
-    }
-
     // Visible for testing
-    ResultsPersisterService(CheckedConsumer<Integer, InterruptedException> sleeper,
-                            OriginSettingClient client,
-                            ClusterService clusterService,
-                            Settings settings) {
-        this.sleeper = sleeper;
+    public ResultsPersisterService(ThreadPool threadPool,
+                                   OriginSettingClient client,
+                                   ClusterService clusterService,
+                                   Settings settings) {
+        this.threadPool = threadPool;
         this.client = client;
         this.maxFailureRetries = PERSIST_RESULTS_MAX_RETRIES.get(settings);
         clusterService.getClusterSettings()
@@ -123,8 +120,11 @@ public class ResultsPersisterService {
                                            String jobId,
                                            Supplier<Boolean> shouldRetry,
                                            Consumer<String> retryMsgHandler) {
-        return bulkIndexWithRetry(bulkRequest, jobId, shouldRetry, retryMsgHandler,
-            providedBulkRequest -> client.bulk(providedBulkRequest).actionGet());
+        return bulkIndexWithRetry(bulkRequest,
+            jobId,
+            shouldRetry,
+            retryMsgHandler,
+            client::bulk);
     }
 
     public BulkResponse bulkIndexWithHeadersWithRetry(Map<String, String> headers,
@@ -132,73 +132,58 @@ public class ResultsPersisterService {
                                                       String jobId,
                                                       Supplier<Boolean> shouldRetry,
                                                       Consumer<String> retryMsgHandler) {
-        return bulkIndexWithRetry(bulkRequest, jobId, shouldRetry, retryMsgHandler,
-            providedBulkRequest -> ClientHelper.executeWithHeaders(
-                headers, ClientHelper.ML_ORIGIN, client, () -> client.execute(BulkAction.INSTANCE, bulkRequest).actionGet()));
+        return bulkIndexWithRetry(bulkRequest,
+            jobId,
+            shouldRetry,
+            retryMsgHandler,
+            (providedBulkRequest, listener) -> ClientHelper.executeWithHeadersAsync(
+                headers,
+                ClientHelper.ML_ORIGIN,
+                client,
+                BulkAction.INSTANCE,
+                providedBulkRequest,
+                listener));
     }
 
     private BulkResponse bulkIndexWithRetry(BulkRequest bulkRequest,
                                             String jobId,
                                             Supplier<Boolean> shouldRetry,
                                             Consumer<String> retryMsgHandler,
-                                            Function<BulkRequest, BulkResponse> actionExecutor) {
-        RetryContext retryContext = new RetryContext(jobId, shouldRetry, retryMsgHandler);
-        while (true) {
-            BulkResponse bulkResponse = actionExecutor.apply(bulkRequest);
-            if (bulkResponse.hasFailures() == false) {
-                return bulkResponse;
-            }
-            for (BulkItemResponse itemResponse : bulkResponse.getItems()) {
-                if (itemResponse.isFailed() && isIrrecoverable(itemResponse.getFailure().getCause())) {
-                    Throwable unwrappedParticular = ExceptionsHelper.unwrapCause(itemResponse.getFailure().getCause());
-                    LOGGER.warn(new ParameterizedMessage(
-                        "[{}] experienced failure that cannot be automatically retried. Bulk failure message [{}]",
-                            jobId,
-                            bulkResponse.buildFailureMessage()),
-                        unwrappedParticular);
-                    throw new ElasticsearchStatusException(
-                        "{} experienced failure that cannot be automatically retried. See logs for bulk failures",
-                        status(unwrappedParticular),
-                        unwrappedParticular,
-                        jobId);
-                }
-            }
-            retryContext.nextIteration("index", bulkResponse.buildFailureMessage());
-            // We should only retry the docs that failed.
-            bulkRequest = buildNewRequestFromFailures(bulkRequest, bulkResponse);
-        }
+                                            BiConsumer<BulkRequest, ActionListener<BulkResponse>> actionExecutor) {
+        PlainActionFuture<BulkResponse> getResponse = PlainActionFuture.newFuture();
+        BulkRetryableAction bulkRetryableAction = new BulkRetryableAction(
+            jobId,
+            new BulkRequestRewriter(bulkRequest),
+            shouldRetry,
+            retryMsgHandler,
+            actionExecutor,
+            getResponse
+        );
+        bulkRetryableAction.run();
+        return getResponse.actionGet();
     }
 
     public SearchResponse searchWithRetry(SearchRequest searchRequest,
                                           String jobId,
                                           Supplier<Boolean> shouldRetry,
                                           Consumer<String> retryMsgHandler) {
-        RetryContext retryContext = new RetryContext(jobId, shouldRetry, retryMsgHandler);
-        while (true) {
-            String failureMessage;
-            try {
-                SearchResponse searchResponse = client.search(searchRequest).actionGet();
-                if (RestStatus.OK.equals(searchResponse.status())) {
-                    return searchResponse;
-                }
-                failureMessage = searchResponse.status().toString();
-            } catch (ElasticsearchException e) {
-                LOGGER.warn("[" + jobId + "] Exception while executing search action", e);
-                failureMessage = e.getDetailedMessage();
-                if (isIrrecoverable(e)) {
-                    LOGGER.warn(new ParameterizedMessage("[{}] experienced failure that cannot be automatically retried", jobId), e);
-                    throw new ElasticsearchStatusException(
-                        "{} experienced failure that cannot be automatically retried",
-                        status(e),
-                        e,
-                        jobId);
-                }
-            }
+        PlainActionFuture<SearchResponse> getResponse = PlainActionFuture.newFuture();
+        SearchRetryableAction mlRetryableAction = new SearchRetryableAction(
+            jobId,
+            searchRequest,
+            shouldRetry,
+            retryMsgHandler,
+            getResponse);
+        mlRetryableAction.run();
+        return getResponse.actionGet();
+    }
 
-            retryContext.nextIteration("search", failureMessage);
+    static class RecoverableException extends Exception { }
+    static class IrrecoverableException extends ElasticsearchStatusException {
+        IrrecoverableException(String msg, RestStatus status, Throwable cause, Object... args) {
+            super(msg, status, cause, args);
         }
     }
-
     /**
      * @param ex The exception to check
      * @return true when the failure will persist no matter how many times we retry.
@@ -208,48 +193,194 @@ public class ResultsPersisterService {
         return IRRECOVERABLE_REST_STATUSES.contains(status(t));
     }
 
-    /**
-     * {@link RetryContext} object handles logic that is executed between consecutive retries of an action.
-     *
-     * Note that it does not execute the action itself.
-     */
-    private class RetryContext {
+    @SuppressWarnings("NonAtomicOperationOnVolatileField")
+    private static class BulkRequestRewriter {
+        private volatile BulkRequest bulkRequest;
+        BulkRequestRewriter(BulkRequest initialRequest) {
+            this.bulkRequest = initialRequest;
+        }
+
+        void rewriteRequest(BulkResponse bulkResponse) {
+            if (bulkResponse.hasFailures() == false) {
+                return;
+            }
+            bulkRequest = buildNewRequestFromFailures(bulkRequest, bulkResponse);
+        }
+
+        BulkRequest getBulkRequest() {
+            return bulkRequest;
+        }
+
+    }
+
+    private class BulkRetryableAction extends MlRetryableAction<BulkRequest, BulkResponse> {
+        private final BulkRequestRewriter bulkRequestRewriter;
+        BulkRetryableAction(String jobId,
+                            BulkRequestRewriter bulkRequestRewriter,
+                            Supplier<Boolean> shouldRetry,
+                            Consumer<String> msgHandler,
+                            BiConsumer<BulkRequest, ActionListener<BulkResponse>> actionExecutor,
+                            ActionListener<BulkResponse> listener) {
+            super(jobId,
+                shouldRetry,
+                msgHandler,
+                (request, retryableListener) -> actionExecutor.accept(request, ActionListener.wrap(
+                    bulkResponse -> {
+                        if (bulkResponse.hasFailures() == false) {
+                            retryableListener.onResponse(bulkResponse);
+                            return;
+                        }
+                        for (BulkItemResponse itemResponse : bulkResponse.getItems()) {
+                            if (itemResponse.isFailed() && isIrrecoverable(itemResponse.getFailure().getCause())) {
+                                Throwable unwrappedParticular = ExceptionsHelper.unwrapCause(itemResponse.getFailure().getCause());
+                                LOGGER.warn(new ParameterizedMessage(
+                                        "[{}] experienced failure that cannot be automatically retried. Bulk failure message [{}]",
+                                        jobId,
+                                        bulkResponse.buildFailureMessage()),
+                                    unwrappedParticular);
+                                retryableListener.onFailure(new IrrecoverableException(
+                                    "{} experienced failure that cannot be automatically retried. See logs for bulk failures",
+                                    status(unwrappedParticular),
+                                    unwrappedParticular,
+                                    jobId));
+                                return;
+                            }
+                        }
+                        bulkRequestRewriter.rewriteRequest(bulkResponse);
+                        // Let the listener attempt again with the new bulk request
+                        retryableListener.onFailure(new RecoverableException());
+                    },
+                    retryableListener::onFailure
+                )),
+                listener);
+            this.bulkRequestRewriter = bulkRequestRewriter;
+        }
+
+        @Override
+        public BulkRequest buildRequest() {
+            return bulkRequestRewriter.getBulkRequest();
+        }
+
+        @Override
+        public String getName() {
+            return "index";
+        }
+
+    }
+
+    private class SearchRetryableAction extends MlRetryableAction<SearchRequest, SearchResponse> {
+
+        private final SearchRequest searchRequest;
+        SearchRetryableAction(String jobId,
+                              SearchRequest searchRequest,
+                              Supplier<Boolean> shouldRetry,
+                              Consumer<String> msgHandler,
+                              ActionListener<SearchResponse> listener) {
+            super(jobId,
+                shouldRetry,
+                msgHandler,
+                (request, retryableListener) -> client.search(request, ActionListener.wrap(
+                    searchResponse -> {
+                        if (RestStatus.OK.equals(searchResponse.status())) {
+                            retryableListener.onResponse(searchResponse);
+                            return;
+                        }
+                        retryableListener.onFailure(
+                            new ElasticsearchStatusException(
+                                "search failed with status {}",
+                                searchResponse.status(),
+                                searchResponse.status())
+                        );
+                    },
+                    retryableListener::onFailure
+                )),
+                listener);
+            this.searchRequest = searchRequest;
+        }
+
+        @Override
+        public SearchRequest buildRequest() {
+            return searchRequest;
+        }
+
+        @Override
+        public String getName() {
+            return "search";
+        }
+    }
+
+    // This encapsulates a retryable action that implements our custom backoff retry logic
+    private abstract class MlRetryableAction<Request, Response> extends RetryableAction<Response> {
 
         final String jobId;
         final Supplier<Boolean> shouldRetry;
         final Consumer<String> msgHandler;
-        final Random random = Randomness.get();
-
-        int currentAttempt = 0;
-        int currentMin = MIN_RETRY_SLEEP_MILLIS;
-        int currentMax = MIN_RETRY_SLEEP_MILLIS;
+        final BiConsumer<Request, ActionListener<Response>> action;
+        volatile int currentAttempt = 0;
+        volatile long currentMin = MIN_RETRY_SLEEP_MILLIS;
+        volatile long currentMax = MIN_RETRY_SLEEP_MILLIS;
 
-        RetryContext(String jobId, Supplier<Boolean> shouldRetry, Consumer<String> msgHandler) {
+        MlRetryableAction(String jobId,
+                          Supplier<Boolean> shouldRetry,
+                          Consumer<String> msgHandler,
+                          BiConsumer<Request, ActionListener<Response>> action,
+                          ActionListener<Response> listener) {
+            super(
+                LOGGER,
+                threadPool,
+                TimeValue.timeValueMillis(MIN_RETRY_SLEEP_MILLIS),
+                TimeValue.MAX_VALUE,
+                listener,
+                UTILITY_THREAD_POOL_NAME);
             this.jobId = jobId;
             this.shouldRetry = shouldRetry;
             this.msgHandler = msgHandler;
+            this.action = action;
         }
 
-        void nextIteration(String actionName, String failureMessage) {
+        public abstract Request buildRequest();
+
+        public abstract String getName();
+
+        @Override
+        public void tryAction(ActionListener<Response> listener) {
             currentAttempt++;
+            action.accept(buildRequest(), listener);
+        }
+
+        @Override
+        public boolean shouldRetry(Exception e) {
+            if (isIrrecoverable(e)) {
+                LOGGER.warn(new ParameterizedMessage("[{}] experienced failure that cannot be automatically retried", jobId), e);
+                return false;
+            }
 
             // If the outside conditions have changed and retries are no longer needed, do not retry.
             if (shouldRetry.get() == false) {
-                String msg = new ParameterizedMessage(
-                    "[{}] should not retry {} after [{}] attempts. {}", jobId, actionName, currentAttempt, failureMessage)
-                    .getFormattedMessage();
-                LOGGER.info(msg);
-                throw new ElasticsearchException(msg);
+                LOGGER.info(() -> new ParameterizedMessage(
+                    "[{}] should not retry {} after [{}] attempts",
+                    jobId,
+                    getName(),
+                    currentAttempt
+                ), e);
+                return false;
             }
 
             // If the configured maximum number of retries has been reached, do not retry.
             if (currentAttempt > maxFailureRetries) {
-                String msg = new ParameterizedMessage(
-                    "[{}] failed to {} after [{}] attempts. {}", jobId, actionName, currentAttempt, failureMessage).getFormattedMessage();
-                LOGGER.warn(msg);
-                throw new ElasticsearchException(msg);
+                LOGGER.warn(() -> new ParameterizedMessage(
+                    "[{}] failed to {} after [{}] attempts.",
+                    jobId,
+                    getName(),
+                    currentAttempt
+                ), e);
+                return false;
             }
+            return true;
+        }
 
+        @Override
+        protected long calculateDelay(long previousDelay) {
             // Since we exponentially increase, we don't want force randomness to have an excessively long sleep
             if (currentMax < MAX_RETRY_SLEEP_MILLIS) {
                 currentMin = currentMax;
@@ -259,33 +390,24 @@ public class ResultsPersisterService {
             currentMax = Math.min(uncappedBackoff, MAX_RETRY_SLEEP_MILLIS);
             // Its good to have a random window along the exponentially increasing curve
             // so that not all bulk requests rest for the same amount of time
-            int randBound = 1 + (currentMax - currentMin);
-            int randSleep = currentMin + random.nextInt(randBound);
-            {
-                String msg = new ParameterizedMessage(
-                    "failed to {} after [{}] attempts. Will attempt again in [{}].",
-                    actionName,
-                    currentAttempt,
-                    TimeValue.timeValueMillis(randSleep).getStringRep())
-                    .getFormattedMessage();
-                LOGGER.warn(() -> new ParameterizedMessage("[{}] {}", jobId, msg));
-                msgHandler.accept(msg);
-            }
-            try {
-                sleeper.accept(randSleep);
-            } catch (InterruptedException interruptedException) {
-                LOGGER.warn(
-                    new ParameterizedMessage("[{}] failed to {} after [{}] attempts due to interruption",
-                        jobId,
-                        actionName,
-                        currentAttempt),
-                    interruptedException);
-                Thread.currentThread().interrupt();
-            }
+            int randBound = (int)(1 + (currentMax - currentMin));
+            String msg = new ParameterizedMessage(
+                "failed to {} after [{}] attempts. Will attempt again.",
+                getName(),
+                currentAttempt)
+                .getFormattedMessage();
+            LOGGER.warn(() -> new ParameterizedMessage("[{}] {}", jobId, msg));
+            msgHandler.accept(msg);
+            return randBound;
+        }
+
+        @Override
+        protected long minimumDelayMillis() {
+            return currentMin;
         }
     }
 
-    private BulkRequest buildNewRequestFromFailures(BulkRequest bulkRequest, BulkResponse bulkResponse) {
+    private static BulkRequest buildNewRequestFromFailures(BulkRequest bulkRequest, BulkResponse bulkResponse) {
         // If we failed, lets set the bulkRequest to be a collection of the failed requests
         BulkRequest bulkRequestOfFailures = new BulkRequest();
         Set<String> failedDocIds = Arrays.stream(bulkResponse.getItems())

+ 76 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlSingleNodeTestCase.java

@@ -9,24 +9,42 @@ import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.unit.ByteSizeValue;
+import org.elasticsearch.common.unit.TimeValue;
 import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.ingest.common.IngestCommonPlugin;
 import org.elasticsearch.license.LicenseService;
 import org.elasticsearch.plugins.Plugin;
+import org.elasticsearch.script.IngestScript;
+import org.elasticsearch.script.MockDeterministicScript;
+import org.elasticsearch.script.MockScriptEngine;
+import org.elasticsearch.script.MockScriptPlugin;
+import org.elasticsearch.script.ScoreScript;
+import org.elasticsearch.script.ScriptContext;
+import org.elasticsearch.script.ScriptEngine;
 import org.elasticsearch.search.SearchModule;
 import org.elasticsearch.test.ESSingleNodeTestCase;
+import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.xpack.core.XPackSettings;
 import org.elasticsearch.xpack.core.ilm.LifecycleSettings;
 import org.elasticsearch.xpack.core.ml.MachineLearningField;
+import org.elasticsearch.xpack.datastreams.DataStreamsPlugin;
 import org.elasticsearch.xpack.ilm.IndexLifecycle;
 
 import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
+import java.util.Map;
 import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutorService;
 import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.Consumer;
+import java.util.function.Function;
 
 import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
+import static org.mockito.Matchers.any;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
 
 /**
  * An extension to {@link ESSingleNodeTestCase} that adds node settings specifically needed for ML test cases.
@@ -60,6 +78,9 @@ public abstract class MlSingleNodeTestCase extends ESSingleNodeTestCase {
     protected Collection<Class<? extends Plugin>> getPlugins() {
         return pluginList(
             LocalStateMachineLearning.class,
+            DataStreamsPlugin.class,
+            IngestCommonPlugin.class,
+            MockPainlessScriptEngine.TestPlugin.class,
             // ILM is required for .ml-state template index settings
             IndexLifecycle.class);
     }
@@ -128,6 +149,24 @@ public abstract class MlSingleNodeTestCase extends ESSingleNodeTestCase {
         return responseHolder.get();
     }
 
+    protected static ThreadPool mockThreadPool() {
+        ThreadPool tp = mock(ThreadPool.class);
+        ExecutorService executor = mock(ExecutorService.class);
+        doAnswer(invocationOnMock -> {
+            ((Runnable) invocationOnMock.getArguments()[0]).run();
+            return null;
+        }).when(executor).execute(any(Runnable.class));
+        when(tp.executor(any(String.class))).thenReturn(executor);
+        doAnswer(invocationOnMock -> {
+            ((Runnable) invocationOnMock.getArguments()[0]).run();
+            return null;
+        }).when(tp).schedule(
+            any(Runnable.class), any(TimeValue.class), any(String.class)
+        );
+        return tp;
+    }
+
+
     public static void assertNoException(AtomicReference<Exception> error) throws Exception {
         if (error.get() == null) {
             return;
@@ -135,4 +174,41 @@ public abstract class MlSingleNodeTestCase extends ESSingleNodeTestCase {
         throw error.get();
     }
 
+    public static class MockPainlessScriptEngine extends MockScriptEngine {
+
+        public static final String NAME = "painless";
+
+        public static class TestPlugin extends MockScriptPlugin {
+            @Override
+            public ScriptEngine getScriptEngine(Settings settings, Collection<ScriptContext<?>> contexts) {
+                return new MockPainlessScriptEngine();
+            }
+
+            @Override
+            protected Map<String, Function<Map<String, Object>, Object>> pluginScripts() {
+                return Collections.emptyMap();
+            }
+        }
+
+        @Override
+        public String getType() {
+            return NAME;
+        }
+
+        @Override
+        public <T> T compile(String name, String script, ScriptContext<T> context, Map<String, String> options) {
+            if (context.instanceClazz.equals(ScoreScript.class)) {
+                return context.factoryClazz.cast(new MockScoreScript(MockDeterministicScript.asDeterministic(p -> 0.0)));
+            }
+            if (context.name.equals("ingest")) {
+                IngestScript.Factory factory = vars -> new IngestScript(vars) {
+                    @Override
+                    public void execute(Map<String, Object> ctx) {
+                    }
+                };
+                return context.factoryClazz.cast(factory);
+            }
+            throw new IllegalArgumentException("mock painless does not know how to handle context [" + context.name + "]");
+        }
+    }
 }

+ 1 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/annotations/AnnotationPersisterTests.java

@@ -189,7 +189,7 @@ public class AnnotationPersisterTests extends ESTestCase {
             .persistAnnotation("1", AnnotationTests.randomAnnotation(JOB_ID))
             .persistAnnotation("2", AnnotationTests.randomAnnotation(JOB_ID));
         ElasticsearchException e = expectThrows(ElasticsearchException.class, persisterBuilder::executeRequest);
-        assertThat(e.getMessage(), containsString("failed to index after"));
+        assertThat(e.getMessage(), containsString("Failed execution"));
 
         verify(client, atLeastOnce()).execute(eq(BulkAction.INSTANCE), bulkRequestCaptor.capture(), any());
 

+ 15 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/persistence/JobResultsPersisterTests.java

@@ -23,6 +23,7 @@ import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.cluster.service.MasterService;
 import org.elasticsearch.common.settings.ClusterSettings;
 import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.unit.TimeValue;
 import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.search.SearchHit;
 import org.elasticsearch.search.SearchHits;
@@ -55,6 +56,7 @@ import java.util.Date;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
+import java.util.concurrent.ExecutorService;
 
 import static org.hamcrest.Matchers.equalTo;
 import static org.mockito.Matchers.any;
@@ -392,8 +394,20 @@ public class JobResultsPersisterTests extends ESTestCase {
                 ClusterService.USER_DEFINED_METADATA,
                 ClusterApplierService.CLUSTER_SERVICE_SLOW_TASK_LOGGING_THRESHOLD_SETTING)));
         ClusterService clusterService = new ClusterService(Settings.EMPTY, clusterSettings, tp);
+        ExecutorService executor = mock(ExecutorService.class);
+        doAnswer(invocationOnMock -> {
+            ((Runnable) invocationOnMock.getArguments()[0]).run();
+            return null;
+        }).when(executor).execute(any(Runnable.class));
+        when(tp.executor(any(String.class))).thenReturn(executor);
+        doAnswer(invocationOnMock -> {
+            ((Runnable) invocationOnMock.getArguments()[0]).run();
+            return null;
+        }).when(tp).schedule(
+            any(Runnable.class), any(TimeValue.class), any(String.class)
+        );
 
-        return new ResultsPersisterService(client, clusterService, Settings.EMPTY);
+        return new ResultsPersisterService(tp, client, clusterService, Settings.EMPTY);
     }
 
     private AnomalyDetectionAuditor makeAuditor() {

+ 23 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/support/BaseMlIntegTestCase.java

@@ -32,6 +32,7 @@ import org.elasticsearch.persistent.PersistentTasksClusterService;
 import org.elasticsearch.plugins.Plugin;
 import org.elasticsearch.test.ESIntegTestCase;
 import org.elasticsearch.test.MockHttpTransport;
+import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.xpack.core.XPackSettings;
 import org.elasticsearch.xpack.core.action.util.QueryPage;
 import org.elasticsearch.xpack.core.ilm.LifecycleSettings;
@@ -72,11 +73,16 @@ import java.util.Collections;
 import java.util.List;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.Consumer;
 
 import static org.hamcrest.Matchers.equalTo;
+import static org.mockito.Matchers.any;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
 
 /**
  * A base class for testing datafeed and job lifecycle specifics.
@@ -239,6 +245,23 @@ public abstract class BaseMlIntegTestCase extends ESIntegTestCase {
         });
     }
 
+    protected static ThreadPool mockThreadPool() {
+        ThreadPool tp = mock(ThreadPool.class);
+        ExecutorService executor = mock(ExecutorService.class);
+        doAnswer(invocationOnMock -> {
+            ((Runnable) invocationOnMock.getArguments()[0]).run();
+            return null;
+        }).when(executor).execute(any(Runnable.class));
+        when(tp.executor(any(String.class))).thenReturn(executor);
+        doAnswer(invocationOnMock -> {
+            ((Runnable) invocationOnMock.getArguments()[0]).run();
+            return null;
+        }).when(tp).schedule(
+            any(Runnable.class), any(TimeValue.class), any(String.class)
+        );
+        return tp;
+    }
+
     public static void indexDocs(Logger logger, String index, long numDocs, long start, long end) {
         int maxDelta = (int) (end - start - 1);
         BulkRequestBuilder bulkRequestBuilder = client().prepareBulk();

+ 24 - 12
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/persistence/ResultsPersisterServiceTests.java

@@ -25,9 +25,9 @@ import org.elasticsearch.cluster.routing.OperationRouting;
 import org.elasticsearch.cluster.service.ClusterApplierService;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.cluster.service.MasterService;
-import org.elasticsearch.common.CheckedConsumer;
 import org.elasticsearch.common.settings.ClusterSettings;
 import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.unit.TimeValue;
 import org.elasticsearch.index.Index;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.indices.IndexPrimaryShardNotAllocatedException;
@@ -48,6 +48,7 @@ import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashSet;
 import java.util.List;
+import java.util.concurrent.ExecutorService;
 import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.Supplier;
 
@@ -63,6 +64,7 @@ import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
 
 public class ResultsPersisterServiceTests extends ESTestCase {
 
@@ -156,7 +158,7 @@ public class ResultsPersisterServiceTests extends ESTestCase {
             expectThrows(
                 ElasticsearchException.class,
                 () -> resultsPersisterService.searchWithRetry(SEARCH_REQUEST, JOB_ID, () -> true, messages::add));
-        assertThat(e.getMessage(), containsString("failed to search after [" + (maxFailureRetries + 1) + "] attempts."));
+        assertThat(e.getMessage(), containsString("search failed with status"));
         assertThat(messages, hasSize(maxFailureRetries));
 
         verify(client, times(maxFailureRetries + 1)).execute(eq(SearchAction.INSTANCE), eq(SEARCH_REQUEST), any());
@@ -183,7 +185,7 @@ public class ResultsPersisterServiceTests extends ESTestCase {
             expectThrows(
                 ElasticsearchException.class,
                 () -> resultsPersisterService.searchWithRetry(SEARCH_REQUEST, JOB_ID, () -> false, messages::add));
-        assertThat(e.getMessage(), containsString("should not retry search after [1] attempts. SERVICE_UNAVAILABLE"));
+        assertThat(e.getMessage(), containsString("search failed with status SERVICE_UNAVAILABLE"));
         assertThat(messages, empty());
 
         verify(client, times(1)).execute(eq(SearchAction.INSTANCE), eq(SEARCH_REQUEST), any());
@@ -203,7 +205,7 @@ public class ResultsPersisterServiceTests extends ESTestCase {
                 ElasticsearchException.class,
                 () -> resultsPersisterService.searchWithRetry(SEARCH_REQUEST, JOB_ID, shouldRetryUntil(maxRetries), messages::add));
         assertThat(
-            e.getMessage(), containsString("should not retry search after [" + (maxRetries + 1) + "] attempts. SERVICE_UNAVAILABLE"));
+            e.getMessage(), containsString("search failed with status SERVICE_UNAVAILABLE"));
         assertThat(messages, hasSize(maxRetries));
 
         verify(client, times(maxRetries + 1)).execute(eq(SearchAction.INSTANCE), eq(SEARCH_REQUEST), any());
@@ -219,9 +221,9 @@ public class ResultsPersisterServiceTests extends ESTestCase {
             expectThrows(
                 ElasticsearchException.class,
                 () -> resultsPersisterService.searchWithRetry(SEARCH_REQUEST, JOB_ID, () -> true, (s) -> {}));
-        assertThat(e.getMessage(), containsString("experienced failure that cannot be automatically retried"));
+        assertThat(e.getMessage(), containsString("bad search request"));
 
-        verify(client).execute(eq(SearchAction.INSTANCE), eq(SEARCH_REQUEST), any());
+        verify(client, times(1)).execute(eq(SearchAction.INSTANCE), eq(SEARCH_REQUEST), any());
     }
 
     private static Supplier<Boolean> shouldRetryUntil(int maxRetries) {
@@ -255,7 +257,7 @@ public class ResultsPersisterServiceTests extends ESTestCase {
 
         assertThat(requests.get(0).numberOfActions(), equalTo(2));
         assertThat(requests.get(1).numberOfActions(), equalTo(1));
-        assertThat(lastMessage.get(), containsString("failed to index after [1] attempts. Will attempt again in"));
+        assertThat(lastMessage.get(), containsString("failed to index after [1] attempts. Will attempt again"));
     }
 
     public void testBulkRequestChangeOnIrrecoverableFailures() {
@@ -316,7 +318,7 @@ public class ResultsPersisterServiceTests extends ESTestCase {
             () -> resultsPersisterService.bulkIndexWithRetry(bulkRequest, JOB_ID, () -> true, lastMessage::set));
         verify(client, times(maxFailureRetries + 1)).execute(eq(BulkAction.INSTANCE), any(), any());
 
-        assertThat(lastMessage.get(), containsString("failed to index after [10] attempts. Will attempt again in"));
+        assertThat(lastMessage.get(), containsString("failed to index after [10] attempts. Will attempt again"));
     }
 
     public void testBulkRequestRetriesMsgHandlerIsCalled() {
@@ -340,7 +342,7 @@ public class ResultsPersisterServiceTests extends ESTestCase {
 
         assertThat(requests.get(0).numberOfActions(), equalTo(2));
         assertThat(requests.get(1).numberOfActions(), equalTo(1));
-        assertThat(lastMessage.get(), containsString("failed to index after [1] attempts. Will attempt again in"));
+        assertThat(lastMessage.get(), containsString("failed to index after [1] attempts. Will attempt again"));
     }
 
     private static <Response> Stubber doAnswerWithResponses(Response response1, Response response2) {
@@ -366,7 +368,6 @@ public class ResultsPersisterServiceTests extends ESTestCase {
     }
 
     public static ResultsPersisterService buildResultsPersisterService(OriginSettingClient client) {
-        CheckedConsumer<Integer, InterruptedException> sleeper = millis -> {};
         ThreadPool tp = mock(ThreadPool.class);
         ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY,
             new HashSet<>(Arrays.asList(InferenceProcessor.MAX_INFERENCE_PROCESSORS,
@@ -376,7 +377,18 @@ public class ResultsPersisterServiceTests extends ESTestCase {
                 ResultsPersisterService.PERSIST_RESULTS_MAX_RETRIES,
                 ClusterApplierService.CLUSTER_SERVICE_SLOW_TASK_LOGGING_THRESHOLD_SETTING)));
         ClusterService clusterService = new ClusterService(Settings.EMPTY, clusterSettings, tp);
-
-        return new ResultsPersisterService(sleeper, client, clusterService, Settings.EMPTY);
+        ExecutorService executor = mock(ExecutorService.class);
+        doAnswer(invocationOnMock -> {
+            ((Runnable) invocationOnMock.getArguments()[0]).run();
+            return null;
+        }).when(executor).execute(any(Runnable.class));
+        when(tp.executor(any(String.class))).thenReturn(executor);
+        doAnswer(invocationOnMock -> {
+            ((Runnable) invocationOnMock.getArguments()[0]).run();
+            return null;
+        }).when(tp).schedule(
+            any(Runnable.class), any(TimeValue.class), any(String.class)
+        );
+        return new ResultsPersisterService(tp, client, clusterService, Settings.EMPTY);
     }
 }