Browse Source

Use Search After job iterators (#57875)

Search after is a better choice for the delete expired data iterators
where processing takes a long time as unlike scroll a context does not
have to be kept alive. Also changes the delete expired data endpoint to
404 if the job is unknown
David Kyle 5 years ago
parent
commit
96a6de22d9
17 changed files with 796 additions and 333 deletions
  1. 66 25
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteExpiredDataAction.java
  2. 30 10
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/persistence/SearchAfterJobsIterator.java
  3. 5 62
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/retention/AbstractExpiredJobDataRemover.java
  4. 2 2
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/retention/ExpiredModelSnapshotsRemover.java
  5. 3 2
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/retention/ExpiredResultsRemover.java
  6. 3 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/utils/persistence/BatchedDocumentsIterator.java
  7. 30 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/utils/persistence/BatchedIterator.java
  8. 180 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/utils/persistence/SearchAfterDocumentsIterator.java
  9. 61 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/utils/persistence/WrappedBatchedJobsIterator.java
  10. 2 1
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportDeleteExpiredDataActionTests.java
  11. 17 75
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/retention/AbstractExpiredJobDataRemoverTests.java
  12. 28 32
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/retention/ExpiredModelSnapshotsRemoverTests.java
  13. 33 55
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/retention/ExpiredResultsRemoverTests.java
  14. 122 67
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/persistence/BatchedDocumentsIteratorTests.java
  15. 131 0
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/persistence/SearchAfterDocumentsIteratorTests.java
  16. 77 0
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/persistence/WrappedBatchedJobsIteratorTests.java
  17. 6 1
      x-pack/plugin/src/test/resources/rest-api-spec/test/ml/delete_expired_data.yml

+ 66 - 25
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteExpiredDataAction.java

@@ -14,6 +14,7 @@ import org.elasticsearch.action.support.ThreadedActionListener;
 import org.elasticsearch.client.Client;
 import org.elasticsearch.client.OriginSettingClient;
 import org.elasticsearch.cluster.service.ClusterService;
+import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.inject.Inject;
 import org.elasticsearch.index.reindex.AbstractBulkByScrollRequest;
 import org.elasticsearch.tasks.Task;
@@ -21,7 +22,10 @@ import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.TransportService;
 import org.elasticsearch.xpack.core.ClientHelper;
 import org.elasticsearch.xpack.core.ml.action.DeleteExpiredDataAction;
+import org.elasticsearch.xpack.core.ml.job.config.Job;
 import org.elasticsearch.xpack.ml.MachineLearning;
+import org.elasticsearch.xpack.ml.job.persistence.JobConfigProvider;
+import org.elasticsearch.xpack.ml.job.persistence.SearchAfterJobsIterator;
 import org.elasticsearch.xpack.ml.job.retention.EmptyStateIndexRemover;
 import org.elasticsearch.xpack.ml.job.retention.ExpiredForecastsRemover;
 import org.elasticsearch.xpack.ml.job.retention.ExpiredModelSnapshotsRemover;
@@ -30,6 +34,7 @@ import org.elasticsearch.xpack.ml.job.retention.MlDataRemover;
 import org.elasticsearch.xpack.ml.job.retention.UnusedStateRemover;
 import org.elasticsearch.xpack.ml.notifications.AnomalyDetectionAuditor;
 import org.elasticsearch.xpack.ml.utils.VolatileCursorIterator;
+import org.elasticsearch.xpack.ml.utils.persistence.WrappedBatchedJobsIterator;
 
 import java.time.Clock;
 import java.time.Duration;
@@ -38,9 +43,10 @@ import java.util.Arrays;
 import java.util.Iterator;
 import java.util.List;
 import java.util.function.Supplier;
+import java.util.stream.Collectors;
 
 public class TransportDeleteExpiredDataAction extends HandledTransportAction<DeleteExpiredDataAction.Request,
-        DeleteExpiredDataAction.Response> {
+    DeleteExpiredDataAction.Response> {
 
     private static final Logger logger = LogManager.getLogger(TransportDeleteExpiredDataAction.class);
 
@@ -51,22 +57,26 @@ public class TransportDeleteExpiredDataAction extends HandledTransportAction<Del
     private final OriginSettingClient client;
     private final ClusterService clusterService;
     private final Clock clock;
+    private final JobConfigProvider jobConfigProvider;
 
     @Inject
     public TransportDeleteExpiredDataAction(ThreadPool threadPool, TransportService transportService,
-                                            ActionFilters actionFilters, Client client, ClusterService clusterService) {
+                                            ActionFilters actionFilters, Client client, ClusterService clusterService,
+                                            JobConfigProvider jobConfigProvider) {
         this(threadPool, MachineLearning.UTILITY_THREAD_POOL_NAME, transportService, actionFilters, client, clusterService,
-            Clock.systemUTC());
+            jobConfigProvider, Clock.systemUTC());
     }
 
     TransportDeleteExpiredDataAction(ThreadPool threadPool, String executor, TransportService transportService,
-                                     ActionFilters actionFilters, Client client, ClusterService clusterService, Clock clock) {
+                                     ActionFilters actionFilters, Client client, ClusterService clusterService,
+                                     JobConfigProvider jobConfigProvider, Clock clock) {
         super(DeleteExpiredDataAction.NAME, transportService, actionFilters, DeleteExpiredDataAction.Request::new, executor);
         this.threadPool = threadPool;
         this.executor = executor;
         this.client = new OriginSettingClient(client, ClientHelper.ML_ORIGIN);
         this.clusterService = clusterService;
         this.clock = clock;
+        this.jobConfigProvider = jobConfigProvider;
     }
 
     @Override
@@ -78,22 +88,34 @@ public class TransportDeleteExpiredDataAction extends HandledTransportAction<Del
         );
 
         Supplier<Boolean> isTimedOutSupplier = () -> Instant.now(clock).isAfter(timeoutTime);
-        threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(
-            () -> deleteExpiredData(request, listener, isTimedOutSupplier)
-        );
+        AnomalyDetectionAuditor auditor = new AnomalyDetectionAuditor(client, clusterService.getNodeName());
+
+        if (Strings.isNullOrEmpty(request.getJobId()) || Strings.isAllOrWildcard(new String[]{request.getJobId()})) {
+            List<MlDataRemover> dataRemovers = createDataRemovers(client, auditor);
+            threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(
+                () -> deleteExpiredData(request, dataRemovers, listener, isTimedOutSupplier)
+            );
+        } else {
+            jobConfigProvider.expandJobs(request.getJobId(), false, true, ActionListener.wrap(
+                jobBuilders -> {
+                    threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(() -> {
+                            List<Job> jobs = jobBuilders.stream().map(Job.Builder::build).collect(Collectors.toList());
+                            List<MlDataRemover> dataRemovers = createDataRemovers(jobs, auditor);
+                            deleteExpiredData(request, dataRemovers, listener, isTimedOutSupplier);
+                        }
+                    );
+                },
+                listener::onFailure
+            ));
+        }
     }
 
     private void deleteExpiredData(DeleteExpiredDataAction.Request request,
+                                   List<MlDataRemover> dataRemovers,
                                    ActionListener<DeleteExpiredDataAction.Response> listener,
                                    Supplier<Boolean> isTimedOutSupplier) {
-        AnomalyDetectionAuditor auditor = new AnomalyDetectionAuditor(client, clusterService.getNodeName());
-        List<MlDataRemover> dataRemovers = Arrays.asList(
-                new ExpiredResultsRemover(client, request.getJobId(), auditor, threadPool),
-                new ExpiredForecastsRemover(client, threadPool),
-                new ExpiredModelSnapshotsRemover(client, request.getJobId(), threadPool),
-                new UnusedStateRemover(client, clusterService),
-                new EmptyStateIndexRemover(client)
-        );
+
+
         Iterator<MlDataRemover> dataRemoversIterator = new VolatileCursorIterator<>(dataRemovers);
         // If there is no throttle provided, default to none
         float requestsPerSec = request.getRequestsPerSecond() == null ? Float.POSITIVE_INFINITY : request.getRequestsPerSecond();
@@ -103,7 +125,7 @@ public class TransportDeleteExpiredDataAction extends HandledTransportAction<Del
             //   1 million documents over 5000 seconds ~= 83 minutes.
             // If we have > 5 data nodes, we don't set our throttling.
             requestsPerSec = numberOfDatanodes < 5 ?
-                (float)(AbstractBulkByScrollRequest.DEFAULT_SCROLL_SIZE / 5) * numberOfDatanodes :
+                (float) (AbstractBulkByScrollRequest.DEFAULT_SCROLL_SIZE / 5) * numberOfDatanodes :
                 Float.POSITIVE_INFINITY;
         }
         deleteExpiredData(dataRemoversIterator, requestsPerSec, listener, isTimedOutSupplier, true);
@@ -117,15 +139,15 @@ public class TransportDeleteExpiredDataAction extends HandledTransportAction<Del
         if (haveAllPreviousDeletionsCompleted && mlDataRemoversIterator.hasNext()) {
             MlDataRemover remover = mlDataRemoversIterator.next();
             ActionListener<Boolean> nextListener = ActionListener.wrap(
-                    booleanResponse ->
-                        deleteExpiredData(
-                            mlDataRemoversIterator,
-                            requestsPerSecond,
-                            listener,
-                            isTimedOutSupplier,
-                            booleanResponse
-                        ),
-                    listener::onFailure);
+                booleanResponse ->
+                    deleteExpiredData(
+                        mlDataRemoversIterator,
+                        requestsPerSecond,
+                        listener,
+                        isTimedOutSupplier,
+                        booleanResponse
+                    ),
+                listener::onFailure);
             // Removing expired ML data and artifacts requires multiple operations.
             // These are queued up and executed sequentially in the action listener,
             // the chained calls must all run the ML utility thread pool NOT the thread
@@ -142,4 +164,23 @@ public class TransportDeleteExpiredDataAction extends HandledTransportAction<Del
             listener.onResponse(new DeleteExpiredDataAction.Response(haveAllPreviousDeletionsCompleted));
         }
     }
+
+    private List<MlDataRemover> createDataRemovers(OriginSettingClient client, AnomalyDetectionAuditor auditor) {
+        return Arrays.asList(
+            new ExpiredResultsRemover(client, new WrappedBatchedJobsIterator(new SearchAfterJobsIterator(client)), auditor, threadPool),
+            new ExpiredForecastsRemover(client, threadPool),
+            new ExpiredModelSnapshotsRemover(client, new WrappedBatchedJobsIterator(new SearchAfterJobsIterator(client)), threadPool),
+            new UnusedStateRemover(client, clusterService),
+            new EmptyStateIndexRemover(client));
+    }
+
+    private List<MlDataRemover> createDataRemovers(List<Job> jobs, AnomalyDetectionAuditor auditor) {
+        return Arrays.asList(
+            new ExpiredResultsRemover(client, new VolatileCursorIterator<>(jobs), auditor, threadPool),
+            new ExpiredForecastsRemover(client, threadPool),
+            new ExpiredModelSnapshotsRemover(client, new VolatileCursorIterator<>(jobs), threadPool),
+            new UnusedStateRemover(client, clusterService),
+            new EmptyStateIndexRemover(client));
+    }
+
 }

+ 30 - 10
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/persistence/BatchedJobsIterator.java → x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/persistence/SearchAfterJobsIterator.java

@@ -3,44 +3,64 @@
  * or more contributor license agreements. Licensed under the Elastic License;
  * you may not use this file except in compliance with the Elastic License.
  */
+
 package org.elasticsearch.xpack.ml.job.persistence;
 
 import org.elasticsearch.ElasticsearchParseException;
 import org.elasticsearch.client.OriginSettingClient;
-import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
 import org.elasticsearch.common.xcontent.NamedXContentRegistry;
 import org.elasticsearch.common.xcontent.XContentFactory;
 import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.common.xcontent.XContentType;
 import org.elasticsearch.index.query.QueryBuilder;
+import org.elasticsearch.index.query.TermQueryBuilder;
 import org.elasticsearch.search.SearchHit;
+import org.elasticsearch.search.sort.FieldSortBuilder;
 import org.elasticsearch.xpack.core.ml.job.config.Job;
-import org.elasticsearch.xpack.ml.utils.persistence.BatchedDocumentsIterator;
+import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
+import org.elasticsearch.xpack.ml.utils.persistence.SearchAfterDocumentsIterator;
 
 import java.io.IOException;
 import java.io.InputStream;
 
-public class BatchedJobsIterator extends BatchedDocumentsIterator<Job.Builder> {
+public class SearchAfterJobsIterator extends SearchAfterDocumentsIterator<Job.Builder> {
 
-    private final String jobIdExpression;
+    private String lastJobId;
 
-    public BatchedJobsIterator(OriginSettingClient client, String index, String jobIdExpression) {
-        super(client, index);
-        this.jobIdExpression = jobIdExpression;
+    public SearchAfterJobsIterator(OriginSettingClient client) {
+        super(client, AnomalyDetectorsIndex.configIndexName());
     }
 
     @Override
     protected QueryBuilder getQuery() {
-        String [] tokens = Strings.tokenizeToStringArray(jobIdExpression, ",");
-        return JobConfigProvider.buildJobWildcardQuery(tokens, true);
+        return new TermQueryBuilder(Job.JOB_TYPE.getPreferredName(), Job.ANOMALY_DETECTOR_JOB_TYPE);
+    }
+
+    @Override
+    protected FieldSortBuilder sortField() {
+        return new FieldSortBuilder(Job.ID.getPreferredName());
+    }
+
+    @Override
+    protected Object[] searchAfterFields() {
+        if (lastJobId == null) {
+            return null;
+        } else {
+            return new Object[] {lastJobId};
+        }
+    }
+
+    @Override
+    protected void extractSearchAfterFields(SearchHit lastSearchHit) {
+        lastJobId = Job.extractJobIdFromDocumentId(lastSearchHit.getId());
     }
 
     @Override
     protected Job.Builder map(SearchHit hit) {
         try (InputStream stream = hit.getSourceRef().streamInput();
              XContentParser parser = XContentFactory.xContent(XContentType.JSON)
-                     .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, stream)) {
+                 .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, stream)) {
             return Job.LENIENT_PARSER.apply(parser, null);
         } catch (IOException e) {
             throw new ElasticsearchParseException("failed to parse job document [" + hit.getId() + "]", e);

+ 5 - 62
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/retention/AbstractExpiredJobDataRemover.java

@@ -10,17 +10,11 @@ import org.elasticsearch.client.OriginSettingClient;
 import org.elasticsearch.index.query.BoolQueryBuilder;
 import org.elasticsearch.index.query.QueryBuilders;
 import org.elasticsearch.xpack.core.ml.job.config.Job;
-import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
 import org.elasticsearch.xpack.core.ml.job.results.Result;
-import org.elasticsearch.xpack.ml.job.persistence.BatchedJobsIterator;
-import org.elasticsearch.xpack.ml.utils.VolatileCursorIterator;
 
-import java.util.Deque;
 import java.util.Iterator;
-import java.util.List;
 import java.util.Objects;
 import java.util.function.Supplier;
-import java.util.stream.Collectors;
 
 /**
  * Removes job data that expired with respect to their retention period.
@@ -31,22 +25,22 @@ import java.util.stream.Collectors;
  */
 abstract class AbstractExpiredJobDataRemover implements MlDataRemover {
 
-    private final String jobIdExpression;
     protected final OriginSettingClient client;
+    private final Iterator<Job> jobIterator;
 
-    AbstractExpiredJobDataRemover(String jobIdExpression, OriginSettingClient client) {
-        this.jobIdExpression = jobIdExpression;
+    AbstractExpiredJobDataRemover(OriginSettingClient client, Iterator<Job> jobIterator) {
         this.client = client;
+        this.jobIterator = jobIterator;
     }
 
     @Override
     public void remove(float requestsPerSecond,
                        ActionListener<Boolean> listener,
                        Supplier<Boolean> isTimedOutSupplier) {
-        removeData(newJobIterator(), requestsPerSecond, listener, isTimedOutSupplier);
+        removeData(jobIterator, requestsPerSecond, listener, isTimedOutSupplier);
     }
 
-    private void removeData(WrappedBatchedJobsIterator jobIterator,
+    private void removeData(Iterator<Job> jobIterator,
                             float requestsPerSecond,
                             ActionListener<Boolean> listener,
                             Supplier<Boolean> isTimedOutSupplier) {
@@ -86,11 +80,6 @@ abstract class AbstractExpiredJobDataRemover implements MlDataRemover {
         ));
     }
 
-    private WrappedBatchedJobsIterator newJobIterator() {
-        BatchedJobsIterator jobsIterator = new BatchedJobsIterator(client, AnomalyDetectorsIndex.configIndexName(), jobIdExpression);
-        return new WrappedBatchedJobsIterator(jobsIterator);
-    }
-
     abstract void calcCutoffEpochMs(String jobId, long retentionDays, ActionListener<CutoffDetails> listener);
 
     abstract Long getRetentionDays(Job job);
@@ -147,50 +136,4 @@ abstract class AbstractExpiredJobDataRemover implements MlDataRemover {
                 this.cutoffEpochMs == that.cutoffEpochMs;
         }
     }
-
-    /**
-     * A wrapper around {@link BatchedJobsIterator} that allows iterating jobs one
-     * at a time from the batches returned by {@code BatchedJobsIterator}
-     *
-     * This class abstracts away the logic of pulling one job at a time from
-     * multiple batches.
-     */
-    private static class WrappedBatchedJobsIterator implements Iterator<Job> {
-        private final BatchedJobsIterator batchedIterator;
-        private VolatileCursorIterator<Job> currentBatch;
-
-        WrappedBatchedJobsIterator(BatchedJobsIterator batchedIterator) {
-            this.batchedIterator = batchedIterator;
-        }
-
-        @Override
-        public boolean hasNext() {
-            return (currentBatch != null && currentBatch.hasNext()) || batchedIterator.hasNext();
-        }
-
-        /**
-         * Before BatchedJobsIterator has run a search it reports hasNext == true
-         * but the first search may return no results. In that case null is return
-         * and clients have to handle null.
-         */
-        @Override
-        public Job next() {
-            if (currentBatch != null && currentBatch.hasNext()) {
-                return currentBatch.next();
-            }
-
-            // currentBatch is either null or all its elements have been iterated.
-            // get the next currentBatch
-            currentBatch = createBatchIteratorFromBatch(batchedIterator.next());
-
-            // BatchedJobsIterator.hasNext maybe true if searching the first time
-            // but no results are returned.
-            return currentBatch.hasNext() ? currentBatch.next() : null;
-        }
-
-        private VolatileCursorIterator<Job> createBatchIteratorFromBatch(Deque<Job.Builder> builders) {
-            List<Job> jobs = builders.stream().map(Job.Builder::build).collect(Collectors.toList());
-            return new VolatileCursorIterator<>(jobs);
-        }
-    }
 }

+ 2 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/retention/ExpiredModelSnapshotsRemover.java

@@ -65,8 +65,8 @@ public class ExpiredModelSnapshotsRemover extends AbstractExpiredJobDataRemover
 
     private final ThreadPool threadPool;
 
-    public ExpiredModelSnapshotsRemover(OriginSettingClient client, String jobIdExpression, ThreadPool threadPool) {
-        super(jobIdExpression, client);
+    public ExpiredModelSnapshotsRemover(OriginSettingClient client, Iterator<Job> jobIterator, ThreadPool threadPool) {
+        super(client, jobIterator);
         this.threadPool = Objects.requireNonNull(threadPool);
     }
 

+ 3 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/retention/ExpiredResultsRemover.java

@@ -50,6 +50,7 @@ import java.time.Instant;
 import java.time.ZoneOffset;
 import java.time.ZonedDateTime;
 import java.time.format.DateTimeFormatter;
+import java.util.Iterator;
 import java.util.Objects;
 import java.util.concurrent.TimeUnit;
 
@@ -70,9 +71,9 @@ public class ExpiredResultsRemover extends AbstractExpiredJobDataRemover {
     private final AnomalyDetectionAuditor auditor;
     private final ThreadPool threadPool;
 
-    public ExpiredResultsRemover(OriginSettingClient client, String jobIdExpression,
+    public ExpiredResultsRemover(OriginSettingClient client, Iterator<Job> jobIterator,
                                  AnomalyDetectionAuditor auditor, ThreadPool threadPool) {
-        super(jobIdExpression, client);
+        super(client, jobIterator);
         this.auditor = Objects.requireNonNull(auditor);
         this.threadPool = Objects.requireNonNull(threadPool);
     }

+ 3 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/utils/persistence/BatchedDocumentsIterator.java

@@ -28,7 +28,7 @@ import java.util.Objects;
  * An iterator useful to fetch a big number of documents of type T
  * and iterate through them in batches.
  */
-public abstract class BatchedDocumentsIterator<T>  {
+public abstract class BatchedDocumentsIterator<T> implements BatchedIterator<T>  {
     private static final Logger LOGGER = LogManager.getLogger(BatchedDocumentsIterator.class);
 
     private static final String CONTEXT_ALIVE_DURATION = "5m";
@@ -56,6 +56,7 @@ public abstract class BatchedDocumentsIterator<T>  {
      *
      * @return {@code true} if the iteration has more elements
      */
+    @Override
     public boolean hasNext() {
         return !isScrollInitialised || count != totalHits;
     }
@@ -70,6 +71,7 @@ public abstract class BatchedDocumentsIterator<T>  {
      * @return a {@code Deque} with the next batch of documents
      * @throws NoSuchElementException if the iteration has no more elements
      */
+    @Override
     public Deque<T> next() {
         if (!hasNext()) {
             throw new NoSuchElementException();

+ 30 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/utils/persistence/BatchedIterator.java

@@ -0,0 +1,30 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+
+package org.elasticsearch.xpack.ml.utils.persistence;
+
+import java.util.Deque;
+
+/**
+ * An iterator of batches of objects
+ */
+public interface BatchedIterator<T> {
+
+    /**
+     * Returns {@code true} if the iteration has more elements.
+     * (In other words, returns {@code true} if {@link #next} would
+     * return an element rather than throwing an exception.)
+     *
+     * @return {@code true} if the iteration has more elements
+     */
+    boolean hasNext();
+
+    /**
+     * Get the next batch or throw.
+     * @return The next batch
+     */
+    Deque<T> next();
+}

+ 180 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/utils/persistence/SearchAfterDocumentsIterator.java

@@ -0,0 +1,180 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+
+package org.elasticsearch.xpack.ml.utils.persistence;
+
+import org.elasticsearch.action.search.SearchRequest;
+import org.elasticsearch.action.search.SearchResponse;
+import org.elasticsearch.client.OriginSettingClient;
+import org.elasticsearch.index.query.QueryBuilder;
+import org.elasticsearch.search.SearchHit;
+import org.elasticsearch.search.builder.SearchSourceBuilder;
+import org.elasticsearch.search.sort.FieldSortBuilder;
+import org.elasticsearch.xpack.ml.utils.MlIndicesUtils;
+
+import java.util.ArrayDeque;
+import java.util.Deque;
+import java.util.NoSuchElementException;
+import java.util.Objects;
+import java.util.concurrent.atomic.AtomicBoolean;
+
+/**
+ * An iterator useful to fetch a large number of documents of type T
+ * and iterate through them in batches of 10,000.
+ *
+ * In terms of functionality this is very similar to {@link BatchedDocumentsIterator}
+ * the difference being that this uses search after rather than scroll.
+ *
+ * Search after has the advantage that the scroll context does not have to be kept
+ * alive so if processing each batch takes a long time search after should be
+ * preferred to scroll.
+ *
+ * Documents in the index may be deleted or updated between search after calls
+ * so it is possible that the total hits can change. For this reason the hit
+ * count isn't a reliable indicator of progress and the iterator will judge that
+ * it has reached the end of the search only when less than {@value #BATCH_SIZE}
+ * hits are returned.
+ */
+public abstract class SearchAfterDocumentsIterator<T> implements BatchedIterator<T> {
+
+    private static final int BATCH_SIZE = 10_000;
+
+    private final OriginSettingClient client;
+    private final String index;
+    private final AtomicBoolean lastSearchReturnedResults;
+    private int batchSize = BATCH_SIZE;
+
+    protected SearchAfterDocumentsIterator(OriginSettingClient client, String index) {
+        this.client = Objects.requireNonNull(client);
+        this.index = Objects.requireNonNull(index);
+        this.lastSearchReturnedResults = new AtomicBoolean(true);
+    }
+
+    /**
+     * Returns {@code true} if the iteration has more elements or
+     * no searches have been been run and it is unknown if there is a next.
+     *
+     * Because the index may change between search after calls it is not possible
+     * to know how many results will be returned until all have been seen.
+     * For this reason is it possible {@code hasNext} will return true even
+     * if the next search returns 0 search hits. In that case {@link #next()}
+     * will return an empty collection.
+     *
+     * @return {@code true} if the iteration has more elements or the first
+     * search has not been run
+     */
+    @Override
+    public boolean hasNext() {
+        return lastSearchReturnedResults.get();
+    }
+
+    /**
+     * The first time next() is called, the search will be performed and the first
+     * batch will be returned. Subsequent calls will return the following batches.
+     *
+     * Note it is possible that when there are no results at all, the first time
+     * this method is called an empty {@code Deque} is returned.
+     *
+     * @return a {@code Deque} with the next batch of documents
+     * @throws NoSuchElementException if the iteration has no more elements
+     */
+    @Override
+    public Deque<T> next() {
+        if (!hasNext()) {
+            throw new NoSuchElementException();
+        }
+
+        SearchResponse searchResponse = doSearch(searchAfterFields());
+        return mapHits(searchResponse);
+    }
+
+    private SearchResponse doSearch(Object [] searchAfterValues) {
+        SearchRequest searchRequest = new SearchRequest(index);
+        searchRequest.indicesOptions(MlIndicesUtils.addIgnoreUnavailable(SearchRequest.DEFAULT_INDICES_OPTIONS));
+        SearchSourceBuilder sourceBuilder = (new SearchSourceBuilder()
+            .size(batchSize)
+            .query(getQuery())
+            .fetchSource(shouldFetchSource())
+            .sort(sortField()));
+
+        if (searchAfterValues != null) {
+            sourceBuilder.searchAfter(searchAfterValues);
+        }
+
+        searchRequest.source(sourceBuilder);
+        return client.search(searchRequest).actionGet();
+    }
+
+    private Deque<T> mapHits(SearchResponse searchResponse) {
+        Deque<T> results = new ArrayDeque<>();
+
+        SearchHit[] hits = searchResponse.getHits().getHits();
+        for (SearchHit hit : hits) {
+            T mapped = map(hit);
+            if (mapped != null) {
+                results.add(mapped);
+            }
+        }
+
+        // fewer hits than we requested, this is the end of the search
+        if (hits.length < batchSize) {
+            lastSearchReturnedResults.set(false);
+        }
+
+        if (hits.length > 0) {
+            extractSearchAfterFields(hits[hits.length - 1]);
+        }
+
+        return results;
+    }
+
+    /**
+     * Should fetch source? Defaults to {@code true}
+     * @return whether the source should be fetched
+     */
+    protected boolean shouldFetchSource() {
+        return true;
+    }
+
+    /**
+     * Get the query to use for the search
+     * @return the search query
+     */
+    protected abstract QueryBuilder getQuery();
+
+    /**
+     * The field to sort results on. This should have a unique value per document
+     * for search after.
+     * @return The sort field
+     */
+    protected abstract FieldSortBuilder sortField();
+
+    /**
+     * Maps the search hit to the document type
+     * @param hit
+     *            the search hit
+     * @return The mapped document or {@code null} if the mapping failed
+     */
+    protected abstract T map(SearchHit hit);
+
+    /**
+     * The field to be used in the next search
+      * @return The search after fields
+     */
+    protected abstract Object[] searchAfterFields();
+
+    /**
+     * Extract the fields used in search after from the search hit.
+     * The values are stashed and later returned by {@link #searchAfterFields()}
+     * @param lastSearchHit The last search hit in the previous search response
+     */
+    protected abstract void extractSearchAfterFields(SearchHit lastSearchHit);
+
+    // for testing
+    void setBatchSize(int batchSize) {
+        this.batchSize = batchSize;
+    }
+}

+ 61 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/utils/persistence/WrappedBatchedJobsIterator.java

@@ -0,0 +1,61 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+
+package org.elasticsearch.xpack.ml.utils.persistence;
+
+import org.elasticsearch.xpack.core.ml.job.config.Job;
+import org.elasticsearch.xpack.ml.utils.VolatileCursorIterator;
+
+import java.util.Deque;
+import java.util.Iterator;
+import java.util.List;
+import java.util.stream.Collectors;
+
+/**
+ * A wrapper around {@link BatchedIterator} that allows iterating jobs one
+ * at a time from the batches returned by {@code BatchedIterator}
+ *
+ * This class abstracts away the logic of pulling one job at a time from
+ * multiple batches.
+ */
+public class WrappedBatchedJobsIterator implements Iterator<Job> {
+    private final BatchedIterator<Job.Builder> batchedIterator;
+    private VolatileCursorIterator<Job> currentBatch;
+
+    public WrappedBatchedJobsIterator(BatchedIterator<Job.Builder> batchedIterator) {
+        this.batchedIterator = batchedIterator;
+    }
+
+    @Override
+    public boolean hasNext() {
+        return (currentBatch != null && currentBatch.hasNext()) || batchedIterator.hasNext();
+    }
+
+    /**
+     * Before BatchedIterator has run a search it reports hasNext == true
+     * but the first search may return no results. In that case null is return
+     * and clients have to handle null.
+     */
+    @Override
+    public Job next() {
+        if (currentBatch != null && currentBatch.hasNext()) {
+            return currentBatch.next();
+        }
+
+        // currentBatch is either null or all its elements have been iterated.
+        // get the next currentBatch
+        currentBatch = createBatchIteratorFromBatch(batchedIterator.next());
+
+        // BatchedJobsIterator.hasNext maybe true if searching the first time
+        // but no results are returned.
+        return currentBatch.hasNext() ? currentBatch.next() : null;
+    }
+
+    private VolatileCursorIterator<Job> createBatchIteratorFromBatch(Deque<Job.Builder> builders) {
+        List<Job> jobs = builders.stream().map(Job.Builder::build).collect(Collectors.toList());
+        return new VolatileCursorIterator<>(jobs);
+    }
+}

+ 2 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportDeleteExpiredDataActionTests.java

@@ -14,6 +14,7 @@ import org.elasticsearch.threadpool.TestThreadPool;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.TransportService;
 import org.elasticsearch.xpack.core.ml.action.DeleteExpiredDataAction;
+import org.elasticsearch.xpack.ml.job.persistence.JobConfigProvider;
 import org.elasticsearch.xpack.ml.job.retention.MlDataRemover;
 import org.junit.After;
 import org.junit.Before;
@@ -55,7 +56,7 @@ public class TransportDeleteExpiredDataActionTests extends ESTestCase {
         Client client = mock(Client.class);
         ClusterService clusterService = mock(ClusterService.class);
         transportDeleteExpiredDataAction = new TransportDeleteExpiredDataAction(threadPool, ThreadPool.Names.SAME, transportService,
-            new ActionFilters(Collections.emptySet()), client, clusterService, Clock.systemUTC());
+            new ActionFilters(Collections.emptySet()), client, clusterService, mock(JobConfigProvider.class), Clock.systemUTC());
     }
 
     @After

+ 17 - 75
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/retention/AbstractExpiredJobDataRemoverTests.java

@@ -7,7 +7,6 @@ package org.elasticsearch.xpack.ml.job.retention;
 
 import org.apache.lucene.search.TotalHits;
 import org.elasticsearch.action.ActionListener;
-import org.elasticsearch.action.search.SearchAction;
 import org.elasticsearch.action.search.SearchResponse;
 import org.elasticsearch.client.Client;
 import org.elasticsearch.client.OriginSettingClient;
@@ -28,18 +27,15 @@ import org.junit.Before;
 import java.io.IOException;
 import java.time.Clock;
 import java.time.Instant;
-import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
+import java.util.Iterator;
 import java.util.List;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicInteger;
 
 import static org.hamcrest.Matchers.is;
-import static org.mockito.Matchers.any;
-import static org.mockito.Matchers.eq;
-import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
 
@@ -51,8 +47,8 @@ public class AbstractExpiredJobDataRemoverTests extends ESTestCase {
 
         private int getRetentionDaysCallCount = 0;
 
-        ConcreteExpiredJobDataRemover(String jobId, OriginSettingClient client) {
-            super(jobId, client);
+        ConcreteExpiredJobDataRemover(OriginSettingClient client, Iterator<Job> jobIterator) {
+            super(client, jobIterator);
         }
 
         @Override
@@ -81,11 +77,10 @@ public class AbstractExpiredJobDataRemoverTests extends ESTestCase {
     }
 
     private OriginSettingClient originSettingClient;
-    private Client client;
 
     @Before
     public void setUpTests() {
-        client = mock(Client.class);
+        Client client = mock(Client.class);
         originSettingClient = MockOriginSettingClient.mockOriginSettingClient(client, ClientHelper.ML_ORIGIN);
     }
 
@@ -94,7 +89,7 @@ public class AbstractExpiredJobDataRemoverTests extends ESTestCase {
     }
 
     static SearchResponse createSearchResponseFromHits(List<SearchHit> hits) {
-        SearchHits searchHits = new SearchHits(hits.toArray(new SearchHit[] {}),
+        SearchHits searchHits = new SearchHits(hits.toArray(new SearchHit[]{}),
             new TotalHits(hits.size(), TotalHits.Relation.EQUAL_TO), 1.0f);
         SearchResponse searchResponse = mock(SearchResponse.class);
         when(searchResponse.getHits()).thenReturn(searchHits);
@@ -115,91 +110,38 @@ public class AbstractExpiredJobDataRemoverTests extends ESTestCase {
         return searchResponse;
     }
 
-    public void testRemoveGivenNoJobs() throws IOException {
-        SearchResponse response = createSearchResponse(Collections.emptyList());
-        mockSearchResponse(response);
-
+    public void testRemoveGivenNoJobs() {
         TestListener listener = new TestListener();
-        ConcreteExpiredJobDataRemover remover = new ConcreteExpiredJobDataRemover("*", originSettingClient);
-        remover.remove(1.0f,listener, () -> false);
+        Iterator<Job> jobIterator = Collections.emptyIterator();
+        ConcreteExpiredJobDataRemover remover = new ConcreteExpiredJobDataRemover(originSettingClient, jobIterator);
+        remover.remove(1.0f, listener, () -> false);
 
         listener.waitToCompletion();
         assertThat(listener.success, is(true));
         assertEquals(0, remover.getRetentionDaysCallCount);
     }
 
-    @SuppressWarnings("unchecked")
-    public void testRemoveGivenMultipleBatches() throws IOException {
-        // This is testing AbstractExpiredJobDataRemover.WrappedBatchedJobsIterator
-        int totalHits = 7;
-        List<SearchResponse> responses = new ArrayList<>();
-        responses.add(createSearchResponse(Arrays.asList(
-                JobTests.buildJobBuilder("job1").build(),
-                JobTests.buildJobBuilder("job2").build(),
-                JobTests.buildJobBuilder("job3").build()
-        ), totalHits));
-
-        responses.add(createSearchResponse(Arrays.asList(
-                JobTests.buildJobBuilder("job4").build(),
-                JobTests.buildJobBuilder("job5").build(),
-                JobTests.buildJobBuilder("job6").build()
-        ), totalHits));
-
-        responses.add(createSearchResponse(Collections.singletonList(
-                JobTests.buildJobBuilder("job7").build()
-        ), totalHits));
-
-
-        AtomicInteger searchCount = new AtomicInteger(0);
-
-        doAnswer(invocationOnMock -> {
-            ActionListener<SearchResponse> listener = (ActionListener<SearchResponse>) invocationOnMock.getArguments()[2];
-            listener.onResponse(responses.get(searchCount.getAndIncrement()));
-            return null;
-        }).when(client).execute(eq(SearchAction.INSTANCE), any(), any());
-
-        TestListener listener = new TestListener();
-        ConcreteExpiredJobDataRemover remover = new ConcreteExpiredJobDataRemover("*", originSettingClient);
-        remover.remove(1.0f,listener, () -> false);
-
-        listener.waitToCompletion();
-        assertThat(listener.success, is(true));
-        assertEquals(3, searchCount.get());
-        assertEquals(7, remover.getRetentionDaysCallCount);
-    }
-
-    public void testRemoveGivenTimeOut() throws IOException {
+    public void testRemoveGivenTimeOut() {
 
         int totalHits = 3;
-        SearchResponse response = createSearchResponse(Arrays.asList(
-                JobTests.buildJobBuilder("job1").build(),
-                JobTests.buildJobBuilder("job2").build(),
-                JobTests.buildJobBuilder("job3").build()
-            ), totalHits);
+        List<Job> jobs = Arrays.asList(
+            JobTests.buildJobBuilder("job1").build(),
+            JobTests.buildJobBuilder("job2").build(),
+            JobTests.buildJobBuilder("job3").build()
+        );
 
         final int timeoutAfter = randomIntBetween(0, totalHits - 1);
         AtomicInteger attemptsLeft = new AtomicInteger(timeoutAfter);
 
-        mockSearchResponse(response);
-
         TestListener listener = new TestListener();
-        ConcreteExpiredJobDataRemover remover = new ConcreteExpiredJobDataRemover("*", originSettingClient);
-        remover.remove(1.0f,listener, () -> attemptsLeft.getAndDecrement() <= 0);
+        ConcreteExpiredJobDataRemover remover = new ConcreteExpiredJobDataRemover(originSettingClient, jobs.iterator());
+        remover.remove(1.0f, listener, () -> attemptsLeft.getAndDecrement() <= 0);
 
         listener.waitToCompletion();
         assertThat(listener.success, is(false));
         assertEquals(timeoutAfter, remover.getRetentionDaysCallCount);
     }
 
-    @SuppressWarnings("unchecked")
-    private void mockSearchResponse(SearchResponse searchResponse) {
-        doAnswer(invocationOnMock -> {
-            ActionListener<SearchResponse> listener = (ActionListener<SearchResponse>) invocationOnMock.getArguments()[2];
-            listener.onResponse(searchResponse);
-            return null;
-        }).when(client).execute(eq(SearchAction.INSTANCE), any(), any());
-    }
-
     static class TestListener implements ActionListener<Boolean> {
 
         boolean success;

+ 28 - 32
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/retention/ExpiredModelSnapshotsRemoverTests.java

@@ -35,6 +35,7 @@ import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.Date;
+import java.util.Iterator;
 import java.util.List;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.atomic.AtomicInteger;
@@ -71,27 +72,25 @@ public class ExpiredModelSnapshotsRemoverTests extends ESTestCase {
     }
 
     public void testRemove_GivenJobWithoutActiveSnapshot() throws IOException {
-        List<SearchResponse> responses = Arrays.asList(
-                AbstractExpiredJobDataRemoverTests.createSearchResponse(Collections.singletonList(JobTests.buildJobBuilder("foo")
-                        .setModelSnapshotRetentionDays(7L).build())),
-                AbstractExpiredJobDataRemoverTests.createSearchResponse(Collections.emptyList()));
+        List<Job> jobs = Collections.singletonList(JobTests.buildJobBuilder("foo").setModelSnapshotRetentionDays(7L).build());
 
+        List<SearchResponse> responses = Collections.singletonList(
+                AbstractExpiredJobDataRemoverTests.createSearchResponse(Collections.emptyList()));
         givenClientRequestsSucceed(responses);
 
-        createExpiredModelSnapshotsRemover().remove(1.0f, listener, () -> false);
+        createExpiredModelSnapshotsRemover(jobs.iterator()).remove(1.0f, listener, () -> false);
 
         listener.waitToCompletion();
         assertThat(listener.success, is(true));
-        verify(client, times(2)).execute(eq(SearchAction.INSTANCE), any(), any());
+        verify(client, times(1)).execute(eq(SearchAction.INSTANCE), any(), any());
     }
 
-    public void testRemove_GivenJobsWithMixedRetentionPolicies() throws IOException {
+    public void testRemove_GivenJobsWithMixedRetentionPolicies() {
         List<SearchResponse> searchResponses = new ArrayList<>();
-        searchResponses.add(
-                AbstractExpiredJobDataRemoverTests.createSearchResponse(Arrays.asList(
+        List<Job> jobs = Arrays.asList(
                         JobTests.buildJobBuilder("job-1").setModelSnapshotRetentionDays(7L).setModelSnapshotId("active").build(),
                         JobTests.buildJobBuilder("job-2").setModelSnapshotRetentionDays(17L).setModelSnapshotId("active").build()
-        )));
+        );
 
         Date now = new Date();
         Date oneDayAgo = new Date(now.getTime() - TimeValue.timeValueDays(1).getMillis());
@@ -111,12 +110,12 @@ public class ExpiredModelSnapshotsRemoverTests extends ESTestCase {
         searchResponses.add(AbstractExpiredJobDataRemoverTests.createSearchResponseFromHits(Collections.emptyList()));
 
         givenClientRequestsSucceed(searchResponses);
-        createExpiredModelSnapshotsRemover().remove(1.0f, listener, () -> false);
+        createExpiredModelSnapshotsRemover(jobs.iterator()).remove(1.0f, listener, () -> false);
 
         listener.waitToCompletion();
         assertThat(listener.success, is(true));
 
-        assertThat(capturedSearchRequests.size(), equalTo(5));
+        assertThat(capturedSearchRequests.size(), equalTo(4));
         SearchRequest searchRequest = capturedSearchRequests.get(1);
         assertThat(searchRequest.indices(), equalTo(new String[] {AnomalyDetectorsIndex.jobResultsAliasedName("job-1")}));
         searchRequest = capturedSearchRequests.get(3);
@@ -130,11 +129,10 @@ public class ExpiredModelSnapshotsRemoverTests extends ESTestCase {
 
     public void testRemove_GivenTimeout() throws IOException {
         List<SearchResponse> searchResponses = new ArrayList<>();
-        searchResponses.add(
-                AbstractExpiredJobDataRemoverTests.createSearchResponse(Arrays.asList(
+        List<Job> jobs = Arrays.asList(
             JobTests.buildJobBuilder("snapshots-1").setModelSnapshotRetentionDays(7L).setModelSnapshotId("active").build(),
             JobTests.buildJobBuilder("snapshots-2").setModelSnapshotRetentionDays(17L).setModelSnapshotId("active").build()
-        )));
+        );
 
         Date now = new Date();
         List<ModelSnapshot> snapshots1JobSnapshots = Arrays.asList(createModelSnapshot("snapshots-1", "snapshots-1_1", now),
@@ -148,40 +146,38 @@ public class ExpiredModelSnapshotsRemoverTests extends ESTestCase {
         final int timeoutAfter = randomIntBetween(0, 1);
         AtomicInteger attemptsLeft = new AtomicInteger(timeoutAfter);
 
-        createExpiredModelSnapshotsRemover().remove(1.0f, listener, () -> (attemptsLeft.getAndDecrement() <= 0));
+        createExpiredModelSnapshotsRemover(jobs.iterator()).remove(1.0f, listener, () -> (attemptsLeft.getAndDecrement() <= 0));
 
         listener.waitToCompletion();
         assertThat(listener.success, is(false));
     }
 
-    public void testRemove_GivenClientSearchRequestsFail() throws IOException {
+    public void testRemove_GivenClientSearchRequestsFail() {
         List<SearchResponse> searchResponses = new ArrayList<>();
-        searchResponses.add(
-                AbstractExpiredJobDataRemoverTests.createSearchResponse(Arrays.asList(
+        List<Job> jobs = Arrays.asList(
                 JobTests.buildJobBuilder("snapshots-1").setModelSnapshotRetentionDays(7L).setModelSnapshotId("active").build(),
                 JobTests.buildJobBuilder("snapshots-2").setModelSnapshotRetentionDays(17L).setModelSnapshotId("active").build()
-        )));
+        );
 
         givenClientSearchRequestsFail(searchResponses);
-        createExpiredModelSnapshotsRemover().remove(1.0f, listener, () -> false);
+        createExpiredModelSnapshotsRemover(jobs.iterator()).remove(1.0f, listener, () -> false);
 
         listener.waitToCompletion();
         assertThat(listener.success, is(false));
 
-        assertThat(capturedSearchRequests.size(), equalTo(2));
-        SearchRequest searchRequest = capturedSearchRequests.get(1);
+        assertThat(capturedSearchRequests.size(), equalTo(1));
+        SearchRequest searchRequest = capturedSearchRequests.get(0);
         assertThat(searchRequest.indices(), equalTo(new String[] {AnomalyDetectorsIndex.jobResultsAliasedName("snapshots-1")}));
 
         assertThat(capturedDeleteModelSnapshotRequests.size(), equalTo(0));
     }
 
-    public void testRemove_GivenClientDeleteSnapshotRequestsFail() throws IOException {
+    public void testRemove_GivenClientDeleteSnapshotRequestsFail() {
         List<SearchResponse> searchResponses = new ArrayList<>();
-        searchResponses.add(
-                AbstractExpiredJobDataRemoverTests.createSearchResponse(Arrays.asList(
+        List<Job> jobs = Arrays.asList(
                 JobTests.buildJobBuilder("snapshots-1").setModelSnapshotRetentionDays(7L).setModelSnapshotId("active").build(),
                 JobTests.buildJobBuilder("snapshots-2").setModelSnapshotRetentionDays(17L).setModelSnapshotId("active").build()
-        )));
+        );
 
         Date now = new Date();
         Date oneDayAgo = new Date(new Date().getTime() - TimeValue.timeValueDays(1).getMillis());
@@ -199,12 +195,12 @@ public class ExpiredModelSnapshotsRemoverTests extends ESTestCase {
         searchResponses.add(AbstractExpiredJobDataRemoverTests.createSearchResponseFromHits(Collections.singletonList(snapshot2_2)));
 
         givenClientDeleteModelSnapshotRequestsFail(searchResponses);
-        createExpiredModelSnapshotsRemover().remove(1.0f, listener, () -> false);
+        createExpiredModelSnapshotsRemover(jobs.iterator()).remove(1.0f, listener, () -> false);
 
         listener.waitToCompletion();
         assertThat(listener.success, is(false));
 
-        assertThat(capturedSearchRequests.size(), equalTo(3));
+        assertThat(capturedSearchRequests.size(), equalTo(2));
         SearchRequest searchRequest = capturedSearchRequests.get(1);
         assertThat(searchRequest.indices(), equalTo(new String[] {AnomalyDetectorsIndex.jobResultsAliasedName("snapshots-1")}));
 
@@ -226,7 +222,7 @@ public class ExpiredModelSnapshotsRemoverTests extends ESTestCase {
 
         long retentionDays = 3L;
         ActionListener<AbstractExpiredJobDataRemover.CutoffDetails> cutoffListener = mock(ActionListener.class);
-        createExpiredModelSnapshotsRemover().calcCutoffEpochMs("job-1", retentionDays, cutoffListener);
+        createExpiredModelSnapshotsRemover(Collections.emptyIterator()).calcCutoffEpochMs("job-1", retentionDays, cutoffListener);
 
         long dayInMills = 60 * 60 * 24 * 1000;
         long expectedCutoffTime = oneDayAgo.getTime() - (dayInMills * retentionDays);
@@ -244,7 +240,7 @@ public class ExpiredModelSnapshotsRemoverTests extends ESTestCase {
         assertTrue(id.hasNullValue());
     }
 
-    private ExpiredModelSnapshotsRemover createExpiredModelSnapshotsRemover() {
+    private ExpiredModelSnapshotsRemover createExpiredModelSnapshotsRemover(Iterator<Job> jobIterator) {
         ThreadPool threadPool = mock(ThreadPool.class);
         ExecutorService executor = mock(ExecutorService.class);
 
@@ -256,7 +252,7 @@ public class ExpiredModelSnapshotsRemoverTests extends ESTestCase {
                     return null;
                 }
         ).when(executor).execute(any());
-        return new ExpiredModelSnapshotsRemover(originSettingClient, "*", threadPool);
+        return new ExpiredModelSnapshotsRemover(originSettingClient, jobIterator, threadPool);
     }
 
     private static ModelSnapshot createModelSnapshot(String jobId, String snapshotId, Date date) {

+ 33 - 55
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/retention/ExpiredResultsRemoverTests.java

@@ -7,7 +7,6 @@ package org.elasticsearch.xpack.ml.job.retention;
 
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.search.SearchAction;
-import org.elasticsearch.action.search.SearchRequest;
 import org.elasticsearch.action.search.SearchResponse;
 import org.elasticsearch.client.Client;
 import org.elasticsearch.client.OriginSettingClient;
@@ -26,11 +25,11 @@ import org.elasticsearch.xpack.ml.notifications.AnomalyDetectionAuditor;
 import org.elasticsearch.xpack.ml.test.MockOriginSettingClient;
 import org.junit.Before;
 
-import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.Date;
+import java.util.Iterator;
 import java.util.List;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.atomic.AtomicInteger;
@@ -61,40 +60,36 @@ public class ExpiredResultsRemoverTests extends ESTestCase {
         listener = mock(ActionListener.class);
     }
 
-    public void testRemove_GivenNoJobs() throws IOException {
+    public void testRemove_GivenNoJobs() {
         givenDBQRequestsSucceed();
-        givenJobs(client, Collections.emptyList());
 
-        createExpiredResultsRemover().remove(1.0f, listener, () -> false);
+        createExpiredResultsRemover(Collections.emptyIterator()).remove(1.0f, listener, () -> false);
 
-        verify(client).execute(eq(SearchAction.INSTANCE), any(), any());
         verify(listener).onResponse(true);
     }
 
-    public void testRemove_GivenJobsWithoutRetentionPolicy() throws IOException {
+    public void testRemove_GivenJobsWithoutRetentionPolicy() {
         givenDBQRequestsSucceed();
-        givenJobs(client,
-                Arrays.asList(
+        List<Job> jobs = Arrays.asList(
                 JobTests.buildJobBuilder("foo").build(),
                 JobTests.buildJobBuilder("bar").build()
-        ));
+        );
 
-        createExpiredResultsRemover().remove(1.0f, listener, () -> false);
+        createExpiredResultsRemover(jobs.iterator()).remove(1.0f, listener, () -> false);
 
         verify(listener).onResponse(true);
-        verify(client).execute(eq(SearchAction.INSTANCE), any(), any());
     }
 
     public void testRemove_GivenJobsWithAndWithoutRetentionPolicy() {
         givenDBQRequestsSucceed();
+        givenBucket(new Bucket("id_not_important", new Date(), 60));
 
-        givenSearchResponses(Arrays.asList(
-                JobTests.buildJobBuilder("none").build(),
-                JobTests.buildJobBuilder("results-1").setResultsRetentionDays(10L).build(),
-                JobTests.buildJobBuilder("results-2").setResultsRetentionDays(20L).build()),
-                new Bucket("id_not_important", new Date(), 60));
+        List<Job> jobs = Arrays.asList(
+            JobTests.buildJobBuilder("none").build(),
+            JobTests.buildJobBuilder("results-1").setResultsRetentionDays(10L).build(),
+            JobTests.buildJobBuilder("results-2").setResultsRetentionDays(20L).build());
 
-        createExpiredResultsRemover().remove(1.0f, listener, () -> false);
+        createExpiredResultsRemover(jobs.iterator()).remove(1.0f, listener, () -> false);
 
         assertThat(capturedDeleteByQueryRequests.size(), equalTo(2));
         DeleteByQueryRequest dbqRequest = capturedDeleteByQueryRequests.get(0);
@@ -106,15 +101,17 @@ public class ExpiredResultsRemoverTests extends ESTestCase {
 
     public void testRemove_GivenTimeout() {
         givenDBQRequestsSucceed();
-        givenSearchResponses(Arrays.asList(
-                JobTests.buildJobBuilder("results-1").setResultsRetentionDays(10L).build(),
-                JobTests.buildJobBuilder("results-2").setResultsRetentionDays(20L).build()
-        ), new Bucket("id_not_important", new Date(), 60));
+        givenBucket(new Bucket("id_not_important", new Date(), 60));
+
+        List<Job> jobs = Arrays.asList(
+            JobTests.buildJobBuilder("results-1").setResultsRetentionDays(10L).build(),
+            JobTests.buildJobBuilder("results-2").setResultsRetentionDays(20L).build()
+        );
 
         final int timeoutAfter = randomIntBetween(0, 1);
         AtomicInteger attemptsLeft = new AtomicInteger(timeoutAfter);
 
-        createExpiredResultsRemover().remove(1.0f, listener, () -> (attemptsLeft.getAndDecrement() <= 0));
+        createExpiredResultsRemover(jobs.iterator()).remove(1.0f, listener, () -> (attemptsLeft.getAndDecrement() <= 0));
 
         assertThat(capturedDeleteByQueryRequests.size(), equalTo(timeoutAfter));
         verify(listener).onResponse(false);
@@ -122,14 +119,13 @@ public class ExpiredResultsRemoverTests extends ESTestCase {
 
     public void testRemove_GivenClientRequestsFailed() {
         givenDBQRequestsFailed();
-        givenSearchResponses(
-                Arrays.asList(
-                        JobTests.buildJobBuilder("none").build(),
-                        JobTests.buildJobBuilder("results-1").setResultsRetentionDays(10L).build(),
-                        JobTests.buildJobBuilder("results-2").setResultsRetentionDays(20L).build()),
-                new Bucket("id_not_important", new Date(), 60));
+        givenBucket(new Bucket("id_not_important", new Date(), 60));
 
-        createExpiredResultsRemover().remove(1.0f, listener, () -> false);
+        List<Job> jobs = Arrays.asList(
+            JobTests.buildJobBuilder("none").build(),
+            JobTests.buildJobBuilder("results-1").setResultsRetentionDays(10L).build(),
+            JobTests.buildJobBuilder("results-2").setResultsRetentionDays(20L).build());
+        createExpiredResultsRemover(jobs.iterator()).remove(1.0f, listener, () -> false);
 
         assertThat(capturedDeleteByQueryRequests.size(), equalTo(1));
         DeleteByQueryRequest dbqRequest = capturedDeleteByQueryRequests.get(0);
@@ -142,28 +138,17 @@ public class ExpiredResultsRemoverTests extends ESTestCase {
         String jobId = "calc-cutoff";
         Date latest = new Date();
 
-        givenSearchResponses(Collections.singletonList(JobTests.buildJobBuilder(jobId).setResultsRetentionDays(1L).build()),
-                new Bucket(jobId, latest, 60));
+        givenBucket(new Bucket(jobId, latest, 60));
+        List<Job> jobs = Collections.singletonList(JobTests.buildJobBuilder(jobId).setResultsRetentionDays(1L).build());
 
         ActionListener<AbstractExpiredJobDataRemover.CutoffDetails> cutoffListener = mock(ActionListener.class);
-        createExpiredResultsRemover().calcCutoffEpochMs(jobId, 1L, cutoffListener);
+        createExpiredResultsRemover(jobs.iterator()).calcCutoffEpochMs(jobId, 1L, cutoffListener);
 
         long dayInMills = 60 * 60 * 24 * 1000;
         long expectedCutoffTime = latest.getTime() - dayInMills;
         verify(cutoffListener).onResponse(eq(new AbstractExpiredJobDataRemover.CutoffDetails(latest.getTime(), expectedCutoffTime)));
     }
 
-    @SuppressWarnings("unchecked")
-    static void givenJobs(Client client, List<Job> jobs) throws IOException {
-        SearchResponse response = AbstractExpiredJobDataRemoverTests.createSearchResponse(jobs);
-
-        doAnswer(invocationOnMock -> {
-            ActionListener<SearchResponse> listener = (ActionListener<SearchResponse>) invocationOnMock.getArguments()[2];
-            listener.onResponse(response);
-            return null;
-        }).when(client).execute(eq(SearchAction.INSTANCE), any(), any());
-    }
-
     private void givenDBQRequestsSucceed() {
         givenDBQRequest(true);
     }
@@ -191,22 +176,15 @@ public class ExpiredResultsRemoverTests extends ESTestCase {
     }
 
     @SuppressWarnings("unchecked")
-    private void givenSearchResponses(List<Job> jobs, Bucket bucket) {
+    private void givenBucket(Bucket bucket) {
         doAnswer(invocationOnMock -> {
-            SearchRequest request = (SearchRequest) invocationOnMock.getArguments()[1];
             ActionListener<SearchResponse> listener = (ActionListener<SearchResponse>) invocationOnMock.getArguments()[2];
-
-            if (request.indices()[0].startsWith(AnomalyDetectorsIndex.jobResultsIndexPrefix())) {
-                // asking for the bucket result
-                listener.onResponse(AbstractExpiredJobDataRemoverTests.createSearchResponse(Collections.singletonList(bucket)));
-            } else {
-                listener.onResponse(AbstractExpiredJobDataRemoverTests.createSearchResponse(jobs));
-            }
+            listener.onResponse(AbstractExpiredJobDataRemoverTests.createSearchResponse(Collections.singletonList(bucket)));
             return null;
         }).when(client).execute(eq(SearchAction.INSTANCE), any(), any());
     }
 
-    private ExpiredResultsRemover createExpiredResultsRemover() {
+    private ExpiredResultsRemover createExpiredResultsRemover(Iterator<Job> jobIterator) {
         ThreadPool threadPool = mock(ThreadPool.class);
         ExecutorService executor = mock(ExecutorService.class);
 
@@ -219,6 +197,6 @@ public class ExpiredResultsRemoverTests extends ESTestCase {
             }
         ).when(executor).execute(any());
 
-        return new ExpiredResultsRemover(originSettingClient, "*", mock(AnomalyDetectionAuditor.class), threadPool);
+        return new ExpiredResultsRemover(originSettingClient, jobIterator, mock(AnomalyDetectionAuditor.class), threadPool);
     }
 }

+ 122 - 67
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/persistence/BatchedDocumentsIteratorTests.java

@@ -46,42 +46,41 @@ import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
 
+
 public class BatchedDocumentsIteratorTests extends ESTestCase {
 
     private static final String INDEX_NAME = ".ml-anomalies-foo";
     private static final String SCROLL_ID = "someScrollId";
 
     private Client client;
-    private OriginSettingClient originSettingClient;
     private boolean wasScrollCleared;
 
     private TestIterator testIterator;
 
-    private ArgumentCaptor<SearchRequest> searchRequestCaptor = ArgumentCaptor.forClass(SearchRequest.class);
-    private ArgumentCaptor<SearchScrollRequest> searchScrollRequestCaptor = ArgumentCaptor.forClass(SearchScrollRequest.class);
-
     @Before
     public void setUpMocks() {
         client = Mockito.mock(Client.class);
-        originSettingClient = MockOriginSettingClient.mockOriginSettingClient(client, ClientHelper.ML_ORIGIN);
+        OriginSettingClient originSettingClient = MockOriginSettingClient.mockOriginSettingClient(client, ClientHelper.ML_ORIGIN);
         wasScrollCleared = false;
         testIterator = new TestIterator(originSettingClient, INDEX_NAME);
         givenClearScrollRequest();
     }
 
     public void testQueryReturnsNoResults() {
-        new ScrollResponsesMocker().finishMock();
+        ResponsesMocker scrollResponsesMocker = new ScrollResponsesMocker(client).finishMock();
 
         assertTrue(testIterator.hasNext());
         assertTrue(testIterator.next().isEmpty());
         assertFalse(testIterator.hasNext());
         assertTrue(wasScrollCleared);
-        assertSearchRequest();
-        assertSearchScrollRequests(0);
+        scrollResponsesMocker.assertSearchRequest(INDEX_NAME);
+        scrollResponsesMocker.assertSearchScrollRequests(0, SCROLL_ID);
     }
 
     public void testCallingNextWhenHasNextIsFalseThrows() {
-        new ScrollResponsesMocker().addBatch(createJsonDoc("a"), createJsonDoc("b"), createJsonDoc("c")).finishMock();
+        new ScrollResponsesMocker(client)
+            .addBatch(createJsonDoc("a"), createJsonDoc("b"), createJsonDoc("c"))
+            .finishMock();
         testIterator.next();
         assertFalse(testIterator.hasNext());
 
@@ -89,7 +88,9 @@ public class BatchedDocumentsIteratorTests extends ESTestCase {
     }
 
     public void testQueryReturnsSingleBatch() {
-        new ScrollResponsesMocker().addBatch(createJsonDoc("a"), createJsonDoc("b"), createJsonDoc("c")).finishMock();
+        ResponsesMocker scrollResponsesMocker = new ScrollResponsesMocker(client)
+            .addBatch(createJsonDoc("a"), createJsonDoc("b"), createJsonDoc("c"))
+            .finishMock();
 
         assertTrue(testIterator.hasNext());
         Deque<String> batch = testIterator.next();
@@ -98,16 +99,16 @@ public class BatchedDocumentsIteratorTests extends ESTestCase {
         assertFalse(testIterator.hasNext());
         assertTrue(wasScrollCleared);
 
-        assertSearchRequest();
-        assertSearchScrollRequests(0);
+        scrollResponsesMocker.assertSearchRequest(INDEX_NAME);
+        scrollResponsesMocker.assertSearchScrollRequests(0, SCROLL_ID);
     }
 
     public void testQueryReturnsThreeBatches() {
-        new ScrollResponsesMocker()
-        .addBatch(createJsonDoc("a"), createJsonDoc("b"), createJsonDoc("c"))
-        .addBatch(createJsonDoc("d"), createJsonDoc("e"))
-        .addBatch(createJsonDoc("f"))
-        .finishMock();
+        ResponsesMocker responsesMocker = new ScrollResponsesMocker(client)
+            .addBatch(createJsonDoc("a"), createJsonDoc("b"), createJsonDoc("c"))
+            .addBatch(createJsonDoc("d"), createJsonDoc("e"))
+            .addBatch(createJsonDoc("f"))
+            .finishMock();
 
         assertTrue(testIterator.hasNext());
 
@@ -126,8 +127,8 @@ public class BatchedDocumentsIteratorTests extends ESTestCase {
         assertFalse(testIterator.hasNext());
         assertTrue(wasScrollCleared);
 
-        assertSearchRequest();
-        assertSearchScrollRequests(2);
+        responsesMocker.assertSearchRequest(INDEX_NAME);
+        responsesMocker.assertSearchScrollRequests(2, SCROLL_ID);
     }
 
     private String createJsonDoc(String value) {
@@ -144,55 +145,94 @@ public class BatchedDocumentsIteratorTests extends ESTestCase {
         }).when(client).execute(eq(ClearScrollAction.INSTANCE), any(), any());
     }
 
-    private void assertSearchRequest() {
-        List<SearchRequest> searchRequests = searchRequestCaptor.getAllValues();
-        assertThat(searchRequests.size(), equalTo(1));
-        SearchRequest searchRequest = searchRequests.get(0);
-        assertThat(searchRequest.indices(), equalTo(new String[] {INDEX_NAME}));
-        assertThat(searchRequest.scroll().keepAlive(), equalTo(TimeValue.timeValueMinutes(5)));
-        assertThat(searchRequest.source().query(), equalTo(QueryBuilders.matchAllQuery()));
-        assertThat(searchRequest.source().trackTotalHitsUpTo(), is(SearchContext.TRACK_TOTAL_HITS_ACCURATE));
-    }
 
-    private void assertSearchScrollRequests(int expectedCount) {
-        List<SearchScrollRequest> searchScrollRequests = searchScrollRequestCaptor.getAllValues();
-        assertThat(searchScrollRequests.size(), equalTo(expectedCount));
-        for (SearchScrollRequest request : searchScrollRequests) {
-            assertThat(request.scrollId(), equalTo(SCROLL_ID));
-            assertThat(request.scroll().keepAlive(), equalTo(TimeValue.timeValueMinutes(5)));
-        }
-    }
+    abstract static class ResponsesMocker {
+        protected Client client;
+        protected List<String[]> batches = new ArrayList<>();
+        protected long totalHits = 0;
+        protected List<SearchResponse> responses = new ArrayList<>();
 
-    private class ScrollResponsesMocker {
-        private List<String[]> batches = new ArrayList<>();
-        private long totalHits = 0;
-        private List<SearchResponse> responses = new ArrayList<>();
+        protected AtomicInteger responseIndex = new AtomicInteger(0);
 
-        private AtomicInteger responseIndex = new AtomicInteger(0);
+        protected ArgumentCaptor<SearchRequest> searchRequestCaptor = ArgumentCaptor.forClass(SearchRequest.class);
+        protected ArgumentCaptor<SearchScrollRequest> searchScrollRequestCaptor = ArgumentCaptor.forClass(SearchScrollRequest.class);
 
-        ScrollResponsesMocker addBatch(String... hits) {
+        ResponsesMocker(Client client) {
+            this.client = client;
+        }
+
+        ResponsesMocker addBatch(String... hits) {
             totalHits += hits.length;
             batches.add(hits);
             return this;
         }
 
-        @SuppressWarnings({"unchecked", "rawtypes"})
-        void finishMock() {
+        abstract ResponsesMocker finishMock();
+
+
+        protected SearchResponse createSearchResponseWithHits(String... hits) {
+            SearchHits searchHits = createHits(hits);
+            SearchResponse searchResponse = mock(SearchResponse.class);
+            when(searchResponse.getScrollId()).thenReturn(SCROLL_ID);
+            when(searchResponse.getHits()).thenReturn(searchHits);
+            return searchResponse;
+        }
+
+        protected SearchHits createHits(String... values) {
+            List<SearchHit> hits = new ArrayList<>();
+            for (String value : values) {
+                hits.add(new SearchHitBuilder(randomInt()).setSource(value).build());
+            }
+            return new SearchHits(hits.toArray(new SearchHit[hits.size()]), new TotalHits(totalHits, TotalHits.Relation.EQUAL_TO), 1.0f);
+        }
+
+        void assertSearchRequest(String indexName) {
+            List<SearchRequest> searchRequests = searchRequestCaptor.getAllValues();
+            assertThat(searchRequests.size(), equalTo(1));
+            SearchRequest searchRequest = searchRequests.get(0);
+            assertThat(searchRequest.indices(), equalTo(new String[] {indexName}));
+            assertThat(searchRequest.scroll().keepAlive(), equalTo(TimeValue.timeValueMinutes(5)));
+            assertThat(searchRequest.source().query(), equalTo(QueryBuilders.matchAllQuery()));
+            assertThat(searchRequest.source().trackTotalHitsUpTo(), is(SearchContext.TRACK_TOTAL_HITS_ACCURATE));
+        }
+
+        void assertSearchScrollRequests(int expectedCount, String scrollId) {
+            List<SearchScrollRequest> searchScrollRequests = searchScrollRequestCaptor.getAllValues();
+            assertThat(searchScrollRequests.size(), equalTo(expectedCount));
+            for (SearchScrollRequest request : searchScrollRequests) {
+                assertThat(request.scrollId(), equalTo(scrollId));
+                assertThat(request.scroll().keepAlive(), equalTo(TimeValue.timeValueMinutes(5)));
+            }
+        }
+    }
+
+    static class ScrollResponsesMocker extends ResponsesMocker {
+
+        ScrollResponsesMocker(Client client) {
+            super(client);
+        }
+
+        @Override
+        @SuppressWarnings("unchecked")
+        ResponsesMocker finishMock()
+        {
             if (batches.isEmpty()) {
                 givenInitialResponse();
-                return;
+                return this;
             }
+
             givenInitialResponse(batches.get(0));
             for (int i = 1; i < batches.size(); ++i) {
-                givenNextResponse(batches.get(i));
-            }
-            if (responses.size() > 0) {
-                doAnswer(invocationOnMock -> {
-                    ActionListener<SearchResponse> listener = (ActionListener<SearchResponse>) invocationOnMock.getArguments()[2];
-                    listener.onResponse(responses.get(responseIndex.getAndIncrement()));
-                    return null;
-                }).when(client).execute(eq(SearchScrollAction.INSTANCE), searchScrollRequestCaptor.capture(), any());
+                responses.add(createSearchResponseWithHits(batches.get(i)));
             }
+
+            doAnswer(invocationOnMock -> {
+                ActionListener<SearchResponse> listener = (ActionListener<SearchResponse>) invocationOnMock.getArguments()[2];
+                listener.onResponse(responses.get(responseIndex.getAndIncrement()));
+                return null;
+            }).when(client).execute(eq(SearchScrollAction.INSTANCE), searchScrollRequestCaptor.capture(), any());
+
+            return this;
         }
 
         @SuppressWarnings("unchecked")
@@ -205,28 +245,43 @@ public class BatchedDocumentsIteratorTests extends ESTestCase {
                 return null;
             }).when(client).execute(eq(SearchAction.INSTANCE), searchRequestCaptor.capture(), any());
         }
+    }
 
-        private void givenNextResponse(String... hits) {
-            responses.add(createSearchResponseWithHits(hits));
-        }
+    static class SearchResponsesMocker extends ResponsesMocker {
 
-        private SearchResponse createSearchResponseWithHits(String... hits) {
-            SearchHits searchHits = createHits(hits);
-            SearchResponse searchResponse = mock(SearchResponse.class);
-            when(searchResponse.getScrollId()).thenReturn(SCROLL_ID);
-            when(searchResponse.getHits()).thenReturn(searchHits);
-            return searchResponse;
+        SearchResponsesMocker(Client client) {
+            super(client);
         }
 
-        private SearchHits createHits(String... values) {
-            List<SearchHit> hits = new ArrayList<>();
-            for (String value : values) {
-                hits.add(new SearchHitBuilder(randomInt()).setSource(value).build());
+        @Override
+        @SuppressWarnings("unchecked")
+        ResponsesMocker finishMock()
+        {
+            if (batches.isEmpty()) {
+                doAnswer(invocationOnMock -> {
+                    ActionListener<SearchResponse> listener = (ActionListener<SearchResponse>) invocationOnMock.getArguments()[2];
+                    listener.onResponse(createSearchResponseWithHits());
+                    return null;
+                }).when(client).execute(eq(SearchAction.INSTANCE), searchRequestCaptor.capture(), any());
+
+                return this;
             }
-            return new SearchHits(hits.toArray(new SearchHit[hits.size()]), new TotalHits(totalHits, TotalHits.Relation.EQUAL_TO), 1.0f);
+
+            for (String[] batch : batches) {
+                responses.add(createSearchResponseWithHits(batch));
+            }
+
+            doAnswer(invocationOnMock -> {
+                ActionListener<SearchResponse> listener = (ActionListener<SearchResponse>) invocationOnMock.getArguments()[2];
+                listener.onResponse(responses.get(responseIndex.getAndIncrement()));
+                return null;
+            }).when(client).execute(eq(SearchAction.INSTANCE), searchRequestCaptor.capture(), any());
+
+            return this;
         }
     }
 
+
     private static class TestIterator extends BatchedDocumentsIterator<String> {
         TestIterator(OriginSettingClient client, String jobId) {
             super(client, jobId);

+ 131 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/persistence/SearchAfterDocumentsIteratorTests.java

@@ -0,0 +1,131 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+
+package org.elasticsearch.xpack.ml.utils.persistence;
+
+import org.elasticsearch.client.Client;
+import org.elasticsearch.client.OriginSettingClient;
+import org.elasticsearch.index.query.QueryBuilder;
+import org.elasticsearch.index.query.QueryBuilders;
+import org.elasticsearch.search.SearchHit;
+import org.elasticsearch.search.sort.FieldSortBuilder;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.core.ClientHelper;
+import org.elasticsearch.xpack.ml.test.MockOriginSettingClient;
+import org.junit.Before;
+import org.mockito.Mockito;
+
+import java.util.Deque;
+import java.util.NoSuchElementException;
+
+import static org.hamcrest.Matchers.empty;
+import static org.hamcrest.Matchers.hasSize;
+import static org.hamcrest.Matchers.not;
+
+public class SearchAfterDocumentsIteratorTests extends ESTestCase {
+
+    private static final String INDEX_NAME = "test-index";
+    private Client client;
+    private OriginSettingClient originSettingClient;
+
+    @Before
+    public void setUpMocks() {
+        client = Mockito.mock(Client.class);
+        originSettingClient = MockOriginSettingClient.mockOriginSettingClient(client, ClientHelper.ML_ORIGIN);
+    }
+
+    public void testHasNext()
+    {
+        new BatchedDocumentsIteratorTests.SearchResponsesMocker(client)
+            .addBatch(createJsonDoc("a"), createJsonDoc("b"), createJsonDoc("c"))
+            .addBatch(createJsonDoc("d"), createJsonDoc("e"))
+            .finishMock();
+
+        TestIterator testIterator = new TestIterator(originSettingClient, INDEX_NAME);
+        testIterator.setBatchSize(3);
+        assertTrue(testIterator.hasNext());
+        Deque<String> batch = testIterator.next();
+        assertThat(batch, hasSize(3));
+
+        assertTrue(testIterator.hasNext());
+        batch = testIterator.next();
+        assertThat(batch, hasSize(2));
+
+        assertFalse(testIterator.hasNext());
+        ESTestCase.expectThrows(NoSuchElementException.class, testIterator::next);
+    }
+
+    public void testFirstBatchIsEmpty()
+    {
+        new BatchedDocumentsIteratorTests.SearchResponsesMocker(client)
+            .addBatch()
+            .finishMock();
+
+        TestIterator testIterator = new TestIterator(originSettingClient, INDEX_NAME);
+        assertTrue(testIterator.hasNext());
+        Deque<String> next = testIterator.next();
+        assertThat(next, empty());
+        assertFalse(testIterator.hasNext());
+    }
+
+    public void testExtractSearchAfterValuesSet()
+    {
+        new BatchedDocumentsIteratorTests.SearchResponsesMocker(client)
+            .addBatch(createJsonDoc("a"), createJsonDoc("b"), createJsonDoc("c"))
+            .addBatch(createJsonDoc("d"), createJsonDoc("e"))
+            .finishMock();
+
+        TestIterator testIterator = new TestIterator(originSettingClient, INDEX_NAME);
+        testIterator.setBatchSize(3);
+        Deque<String> next = testIterator.next();
+        assertThat(next, not(empty()));
+        Object[] values = testIterator.searchAfterFields();
+        assertArrayEquals(new Object[] {"c"}, values);
+
+        next = testIterator.next();
+        assertThat(next, not(empty()));
+        values = testIterator.searchAfterFields();
+        assertArrayEquals(new Object[] {"e"}, values);
+    }
+
+    private static class TestIterator extends SearchAfterDocumentsIterator<String> {
+
+        private String searchAfterValue;
+
+        TestIterator(OriginSettingClient client, String index) {
+            super(client, index);
+        }
+
+        @Override
+        protected QueryBuilder getQuery() {
+            return QueryBuilders.matchAllQuery();
+        }
+
+        @Override
+        protected FieldSortBuilder sortField() {
+            return new FieldSortBuilder("name");
+        }
+
+        @Override
+        protected String map(SearchHit hit) {
+            return hit.getSourceAsString();
+        }
+
+        @Override
+        protected Object[] searchAfterFields() {
+            return new Object[] {searchAfterValue};
+        }
+
+        @Override
+        protected void extractSearchAfterFields(SearchHit lastSearchHit) {
+            searchAfterValue = (String)lastSearchHit.getSourceAsMap().get("name");
+        }
+    }
+
+    private String createJsonDoc(String value) {
+        return "{\"name\":\"" + value + "\"}";
+    }
+}

+ 77 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/persistence/WrappedBatchedJobsIteratorTests.java

@@ -0,0 +1,77 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+
+package org.elasticsearch.xpack.ml.utils.persistence;
+
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.core.ml.job.config.Job;
+import org.elasticsearch.xpack.core.ml.job.config.JobTests;
+
+import java.util.ArrayDeque;
+import java.util.Arrays;
+import java.util.Deque;
+import java.util.Iterator;
+import java.util.List;
+
+public class WrappedBatchedJobsIteratorTests extends ESTestCase {
+
+    static class TestBatchedIterator implements BatchedIterator<Job.Builder> {
+
+        private Iterator<Deque<Job.Builder>> batches;
+
+        TestBatchedIterator(Iterator<Deque<Job.Builder>> batches) {
+            this.batches = batches;
+        }
+
+        @Override
+        public boolean hasNext() {
+            return batches.hasNext();
+        }
+
+        @Override
+        public Deque<Job.Builder> next() {
+            return batches.next();
+        }
+    }
+
+    @SuppressWarnings("unchecked")
+    public void testBatchedIteration() {
+
+        Deque<Job.Builder> batch1 = new ArrayDeque<>();
+        batch1.add(JobTests.buildJobBuilder("job1"));
+        batch1.add(JobTests.buildJobBuilder("job2"));
+        batch1.add(JobTests.buildJobBuilder("job3"));
+
+        Deque<Job.Builder> batch2 = new ArrayDeque<>();
+        batch2.add(JobTests.buildJobBuilder("job4"));
+        batch2.add(JobTests.buildJobBuilder("job5"));
+        batch2.add(JobTests.buildJobBuilder("job6"));
+
+        Deque<Job.Builder> batch3 = new ArrayDeque<>();
+        batch3.add(JobTests.buildJobBuilder("job7"));
+
+        List<Deque<Job.Builder>> allBatches = Arrays.asList(batch1, batch2, batch3);
+
+        TestBatchedIterator batchedIterator = new TestBatchedIterator(allBatches.iterator());
+        WrappedBatchedJobsIterator wrappedIterator = new WrappedBatchedJobsIterator(batchedIterator);
+
+        assertTrue(wrappedIterator.hasNext());
+        assertEquals("job1", wrappedIterator.next().getId());
+        assertTrue(wrappedIterator.hasNext());
+        assertEquals("job2", wrappedIterator.next().getId());
+        assertTrue(wrappedIterator.hasNext());
+        assertEquals("job3", wrappedIterator.next().getId());
+        assertTrue(wrappedIterator.hasNext());
+        assertEquals("job4", wrappedIterator.next().getId());
+        assertTrue(wrappedIterator.hasNext());
+        assertEquals("job5", wrappedIterator.next().getId());
+        assertTrue(wrappedIterator.hasNext());
+        assertEquals("job6", wrappedIterator.next().getId());
+        assertTrue(wrappedIterator.hasNext());
+        assertEquals("job7", wrappedIterator.next().getId());
+        assertFalse(wrappedIterator.hasNext());
+    }
+}

+ 6 - 1
x-pack/plugin/src/test/resources/rest-api-spec/test/ml/delete_expired_data.yml

@@ -63,7 +63,12 @@ setup:
         timeout: "10h"
         requests_per_second: 100000.0
   - match: { deleted: true}
-
+---
+"Test delete expired data with unknown job id":
+  - do:
+      catch: missing
+      ml.delete_expired_data:
+        job_id: not-a-job
 ---
 "Test delete expired data with job id":
   - do: