Selaa lähdekoodia

Use the remaining scroll response documents on update by query bulk requests (#71430)

In update by query requests where max_docs < size and conflicts=proceed
we weren't using the remaining documents from the scroll response in
cases where there were conflicts and in the first bulk request the
successful updates < max_docs. This commit address that problem and
use the remaining documents from the scroll response instead of
requesting a new page.

Closes #63671
Francisco Fernández Castaño 4 vuotta sitten
vanhempi
commit
9d8fb9fba2

+ 4 - 1
docs/reference/docs/delete-by-query.asciidoc

@@ -84,7 +84,10 @@ and all failed requests are returned in the response. Any delete requests that
 completed successfully still stick, they are not rolled back.
 
 You can opt to count version conflicts instead of halting and returning by
-setting `conflicts` to `proceed`.
+setting `conflicts` to `proceed`. Note that if you opt to count version conflicts
+the operation could attempt to delete more documents from the source
+than `max_docs` until it has successfully deleted `max_docs` documents, or it has gone through
+every document in the source query.
 
 ===== Refreshing shards
 

+ 7 - 1
docs/reference/docs/reindex.asciidoc

@@ -130,6 +130,9 @@ By default, version conflicts abort the `_reindex` process.
 To continue reindexing if there are conflicts, set the `"conflicts"` request body parameter to `proceed`.
 In this case, the response includes a count of the version conflicts that were encountered.
 Note that the handling of other error types is unaffected by the `"conflicts"` parameter.
+Additionally, if you opt to count version conflicts the operation could attempt to reindex more documents
+from the source than `max_docs` until it has successfully indexed `max_docs` documents into the target, or it has gone
+through every document in the source query.
 
 [[docs-reindex-task-api]]
 ===== Running reindex asynchronously
@@ -497,6 +500,7 @@ include::{es-repo-dir}/rest-api/common-parms.asciidoc[tag=max_docs]
 [[docs-reindex-api-request-body]]
 ==== {api-request-body-title}
 
+[[conflicts]]
 `conflicts`::
 (Optional, enum) Set to `proceed` to continue reindexing even if there are conflicts.
 Defaults to `abort`.
@@ -507,7 +511,9 @@ Defaults to `abort`.
 Also accepts a comma-separated list to reindex from multiple sources.
 
 `max_docs`:::
-(Optional, integer) The maximum number of documents to reindex.
+(Optional, integer) The maximum number of documents to reindex. If <<conflicts, conflicts>> is equal to
+`proceed`, reindex could attempt to reindex more documents from the source than `max_docs` until it has successfully
+indexed `max_docs` documents into the target, or it has gone through every document in the source query.
 
 `query`:::
 (Optional, <<query-dsl, query object>>) Specifies the documents to reindex using the Query DSL.

+ 4 - 1
docs/reference/docs/update-by-query.asciidoc

@@ -69,7 +69,10 @@ When the versions match, the document is updated and the version number is incre
 If a document changes between the time that the snapshot is taken and
 the update operation is processed, it results in a version conflict and the operation fails.
 You can opt to count version conflicts instead of halting and returning by
-setting `conflicts` to `proceed`.
+setting `conflicts` to `proceed`. Note that if you opt to count
+version conflicts the operation could attempt to update more documents from the source than
+`max_docs` until it has successfully updated `max_docs` documents, or it has gone through every document
+in the source query.
 
 NOTE: Documents with a version equal to 0 cannot be updated using update by
 query because `internal` versioning does not support 0 as a valid

+ 309 - 0
modules/reindex/src/internalClusterTest/java/org/elasticsearch/index/reindex/BulkByScrollUsesAllScrollDocumentsAfterConflictsIntegTests.java

@@ -0,0 +1,309 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.index.reindex;
+
+import org.elasticsearch.action.ActionFuture;
+import org.elasticsearch.action.bulk.BulkItemResponse;
+import org.elasticsearch.action.bulk.BulkRequest;
+import org.elasticsearch.action.bulk.BulkResponse;
+import org.elasticsearch.action.index.IndexRequest;
+import org.elasticsearch.action.index.IndexRequestBuilder;
+import org.elasticsearch.action.search.SearchRequestBuilder;
+import org.elasticsearch.action.search.SearchResponse;
+import org.elasticsearch.cluster.metadata.IndexMetadata;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.bytes.BytesReference;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.util.CollectionUtils;
+import org.elasticsearch.common.util.concurrent.EsThreadPoolExecutor;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentHelper;
+import org.elasticsearch.common.xcontent.XContentType;
+import org.elasticsearch.index.VersionType;
+import org.elasticsearch.index.query.QueryBuilders;
+import org.elasticsearch.plugins.Plugin;
+import org.elasticsearch.script.MockScriptPlugin;
+import org.elasticsearch.script.Script;
+import org.elasticsearch.script.ScriptType;
+import org.elasticsearch.search.SearchHit;
+import org.elasticsearch.search.sort.SortOrder;
+import org.elasticsearch.test.ESIntegTestCase;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.junit.Before;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Locale;
+import java.util.Map;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.CyclicBarrier;
+import java.util.function.BiConsumer;
+import java.util.function.Function;
+
+import static org.elasticsearch.common.lucene.uid.Versions.MATCH_DELETED;
+import static org.elasticsearch.common.xcontent.XContentFactory.jsonBuilder;
+import static org.elasticsearch.index.mapper.MapperService.SINGLE_MAPPING_NAME;
+import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.lessThanOrEqualTo;
+
+@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.TEST, numDataNodes = 0)
+public class BulkByScrollUsesAllScrollDocumentsAfterConflictsIntegTests extends ReindexTestCase {
+    private static final String SCRIPT_LANG = "fake_lang";
+    private static final String NOOP_GENERATOR = "modificationScript";
+    private static final String RETURN_NOOP_FIELD = "return_noop";
+    private static final String SORTING_FIELD = "num";
+
+    @Override
+    protected Collection<Class<? extends Plugin>> nodePlugins() {
+        return CollectionUtils.appendToCopy(super.nodePlugins(), CustomScriptPlugin.class);
+    }
+
+    public static class CustomScriptPlugin extends MockScriptPlugin {
+        @Override
+        @SuppressWarnings("unchecked")
+        protected Map<String, Function<Map<String, Object>, Object>> pluginScripts() {
+            return Map.of(NOOP_GENERATOR, (vars) -> {
+                final Map<String, Object> ctx = (Map<String, Object>) vars.get("ctx");
+                final Map<String, Object> source = (Map<String, Object>) ctx.get("_source");
+                if (source.containsKey(RETURN_NOOP_FIELD)) {
+                    ctx.put("op", "noop");
+                }
+                return vars;
+            });
+        }
+
+        @Override
+        public String pluginScriptLang() {
+            return SCRIPT_LANG;
+        }
+    }
+
+    @Before
+    public void setUpCluster() {
+        internalCluster().startMasterOnlyNode();
+        // Use a single thread pool for writes so we can enforce a consistent ordering
+        internalCluster().startDataOnlyNode(Settings.builder().put("thread_pool.write.size", 1).build());
+    }
+
+    public void testUpdateByQuery() throws Exception {
+        final String indexName = randomAlphaOfLength(10).toLowerCase(Locale.ROOT);
+        final boolean scriptEnabled = randomBoolean();
+        executeConcurrentUpdatesOnSubsetOfDocs(indexName,
+            indexName,
+            scriptEnabled,
+            updateByQuery(),
+            true,
+            (bulkByScrollResponse, updatedDocCount) -> {
+                assertThat(bulkByScrollResponse.getUpdated(), is((long) updatedDocCount));
+        });
+    }
+
+    public void testReindex() throws Exception {
+        final String sourceIndex = randomAlphaOfLength(10).toLowerCase(Locale.ROOT);
+        final String targetIndex = randomAlphaOfLength(10).toLowerCase(Locale.ROOT);
+        createIndexWithSingleShard(targetIndex);
+
+        final ReindexRequestBuilder reindexRequestBuilder = reindex();
+        reindexRequestBuilder.destination(targetIndex);
+        reindexRequestBuilder.destination().setVersionType(VersionType.INTERNAL);
+        // Force MATCH_DELETE version so we get reindex conflicts
+        reindexRequestBuilder.destination().setVersion(MATCH_DELETED);
+
+        final boolean scriptEnabled = randomBoolean();
+        executeConcurrentUpdatesOnSubsetOfDocs(sourceIndex,
+            targetIndex,
+            scriptEnabled,
+            reindexRequestBuilder,
+            false,
+            (bulkByScrollResponse, reindexDocCount) -> {
+            assertThat(bulkByScrollResponse.getCreated(), is((long) reindexDocCount));
+        });
+    }
+
+    public void testDeleteByQuery() throws Exception {
+        final String indexName = randomAlphaOfLength(10).toLowerCase(Locale.ROOT);
+        executeConcurrentUpdatesOnSubsetOfDocs(indexName,
+            indexName,
+            false,
+            deleteByQuery(),
+            true,
+            (bulkByScrollResponse, deletedDocCount) -> {
+                assertThat(bulkByScrollResponse.getDeleted(), is((long) deletedDocCount));
+        });
+    }
+
+    <R extends AbstractBulkByScrollRequest<R>,
+     Self extends AbstractBulkByScrollRequestBuilder<R, Self>> void executeConcurrentUpdatesOnSubsetOfDocs(String sourceIndex,
+        String targetIndex,
+        boolean scriptEnabled,
+        AbstractBulkByScrollRequestBuilder<R, Self> requestBuilder,
+        boolean useOptimisticConcurrency,
+        BiConsumer<BulkByScrollResponse, Integer> resultConsumer) throws Exception {
+        createIndexWithSingleShard(sourceIndex);
+
+        final int numDocs = 100;
+        final int maxDocs = 10;
+        final int scrollSize = randomIntBetween(maxDocs, numDocs);
+
+        List<IndexRequestBuilder> indexRequests = new ArrayList<>(numDocs);
+        int noopDocs = 0;
+        for (int i = numDocs; i > 0; i--) {
+            Map<String, Object> source = new HashMap<>();
+            source.put(SORTING_FIELD, i);
+            // Force that the first maxDocs are transformed into a noop
+            if (scriptEnabled && noopDocs < maxDocs) {
+                // Add a marker on the document to signal that this
+                // document should return a noop in the script
+                source.put(RETURN_NOOP_FIELD, true);
+                noopDocs++;
+            }
+            indexRequests.add(client().prepareIndex(sourceIndex).setId(Integer.toString(i)).setSource(source));
+        }
+        indexRandom(true, indexRequests);
+
+        final ThreadPool threadPool = internalCluster().getDataNodeInstance(ThreadPool.class);
+
+        final int writeThreads = threadPool.info(ThreadPool.Names.WRITE).getMax();
+        assertThat(writeThreads, equalTo(1));
+        final EsThreadPoolExecutor writeThreadPool = (EsThreadPoolExecutor) threadPool.executor(ThreadPool.Names.WRITE);
+        final CyclicBarrier barrier = new CyclicBarrier(writeThreads + 1);
+        final CountDownLatch latch = new CountDownLatch(1);
+
+        // Block the write thread pool
+        writeThreadPool.submit(() -> {
+            try {
+                barrier.await();
+                latch.await();
+            } catch (Exception e) {
+                throw new AssertionError(e);
+            }
+        });
+        // Ensure that the write thread blocking task is currently executing
+        barrier.await();
+
+        final SearchResponse searchResponse = client().prepareSearch(sourceIndex)
+            .setSize(numDocs) // Get all indexed docs
+            .addSort(SORTING_FIELD, SortOrder.DESC)
+            .execute()
+            .actionGet();
+
+        // Modify a subset of the target documents concurrently
+        final List<SearchHit> originalDocs = Arrays.asList(searchResponse.getHits().getHits());
+        int conflictingOps = randomIntBetween(maxDocs, numDocs);
+        final List<SearchHit> docsModifiedConcurrently = randomSubsetOf(conflictingOps, originalDocs);
+
+        BulkRequest conflictingUpdatesBulkRequest = new BulkRequest();
+        for (SearchHit searchHit : docsModifiedConcurrently) {
+            if (scriptEnabled && searchHit.getSourceAsMap().containsKey(RETURN_NOOP_FIELD)) {
+                conflictingOps--;
+            }
+            conflictingUpdatesBulkRequest.add(createUpdatedIndexRequest(searchHit, targetIndex, useOptimisticConcurrency));
+        }
+
+        // The bulk request is enqueued before the update by query
+        final ActionFuture<BulkResponse> bulkFuture = client().bulk(conflictingUpdatesBulkRequest);
+
+        // Ensure that the concurrent writes are enqueued before the update by query request is sent
+        assertBusy(() -> assertThat(writeThreadPool.getQueue().size(), equalTo(1)));
+
+        requestBuilder.source(sourceIndex)
+            .maxDocs(maxDocs)
+            .abortOnVersionConflict(false);
+
+        if (scriptEnabled) {
+            final Script script = new Script(ScriptType.INLINE, SCRIPT_LANG, NOOP_GENERATOR, Collections.emptyMap());
+            ((AbstractBulkIndexByScrollRequestBuilder) requestBuilder).script(script);
+        }
+
+        final SearchRequestBuilder source = requestBuilder.source();
+        source.setSize(scrollSize);
+        source.addSort(SORTING_FIELD, SortOrder.DESC);
+        source.setQuery(QueryBuilders.matchAllQuery());
+        final ActionFuture<BulkByScrollResponse> updateByQueryResponse = requestBuilder.execute();
+
+        assertBusy(() -> assertThat(writeThreadPool.getQueue().size(), equalTo(2)));
+
+        // Allow tasks from the write thread to make progress
+        latch.countDown();
+
+        final BulkResponse bulkItemResponses = bulkFuture.actionGet();
+        for (BulkItemResponse bulkItemResponse : bulkItemResponses) {
+            assertThat(Strings.toString(bulkItemResponses), bulkItemResponse.isFailed(), is(false));
+        }
+
+        final BulkByScrollResponse bulkByScrollResponse = updateByQueryResponse.actionGet();
+        assertThat(bulkByScrollResponse.getVersionConflicts(), lessThanOrEqualTo((long) conflictingOps));
+        // When scripts are enabled, the first maxDocs are a NoOp
+        final int candidateOps = scriptEnabled ? numDocs - maxDocs : numDocs;
+        int successfulOps = Math.min(candidateOps - conflictingOps, maxDocs);
+        assertThat(bulkByScrollResponse.getNoops(), is((long) (scriptEnabled ? maxDocs : 0)));
+        resultConsumer.accept(bulkByScrollResponse, successfulOps);
+    }
+
+    private void createIndexWithSingleShard(String index) throws Exception {
+        final Settings indexSettings = Settings.builder()
+            .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1)
+            .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)
+            .build();
+        final XContentBuilder mappings = jsonBuilder();
+        {
+            mappings.startObject();
+            {
+                mappings.startObject(SINGLE_MAPPING_NAME);
+                mappings.field("dynamic", "strict");
+                {
+                    mappings.startObject("properties");
+                    {
+                        mappings.startObject(SORTING_FIELD);
+                        mappings.field("type", "integer");
+                        mappings.endObject();
+                    }
+                    {
+                        mappings.startObject(RETURN_NOOP_FIELD);
+                        mappings.field("type", "boolean");
+                        mappings.endObject();
+                    }
+                    mappings.endObject();
+                }
+                mappings.endObject();
+            }
+            mappings.endObject();
+        }
+
+        // Use explicit mappings so we don't have to create those on demands and the task ordering
+        // can change to wait for mapping updates
+        assertAcked(
+            prepareCreate(index)
+                .setSettings(indexSettings)
+                .setMapping(mappings)
+        );
+    }
+
+    private IndexRequest createUpdatedIndexRequest(SearchHit searchHit, String targetIndex, boolean useOptimisticUpdate) {
+        final BytesReference sourceRef = searchHit.getSourceRef();
+        final XContentType xContentType = sourceRef != null ? XContentHelper.xContentType(sourceRef) : null;
+        IndexRequest indexRequest = new IndexRequest();
+        indexRequest.index(targetIndex);
+        indexRequest.id(searchHit.getId());
+        indexRequest.source(sourceRef, xContentType);
+        if (useOptimisticUpdate) {
+            indexRequest.setIfSeqNo(searchHit.getSeqNo());
+            indexRequest.setIfPrimaryTerm(searchHit.getPrimaryTerm());
+        } else {
+            indexRequest.version(MATCH_DELETED);
+        }
+        return indexRequest;
+    }
+}

+ 84 - 12
modules/reindex/src/main/java/org/elasticsearch/index/reindex/AbstractAsyncBulkByScrollAction.java

@@ -55,6 +55,7 @@ import java.util.Map;
 import java.util.Objects;
 import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicLong;
 import java.util.function.BiFunction;
 
@@ -104,6 +105,14 @@ public abstract class AbstractAsyncBulkByScrollAction<Request extends AbstractBu
      */
     private final BiFunction<RequestWrapper<?>, ScrollableHitSource.Hit, RequestWrapper<?>> scriptApplier;
     private int lastBatchSize;
+    /**
+     * Keeps track of the total number of bulk operations performed
+     * from a single scroll response. It is possible that
+     * multiple bulk requests are performed from a single scroll
+     * response, meaning that we have to take into account the total
+     * in order to compute a correct scroll keep alive time.
+     */
+    private final AtomicInteger totalBatchSizeInSingleScrollResponse = new AtomicInteger();
 
     AbstractAsyncBulkByScrollAction(BulkByScrollTask task, boolean needsSourceDocumentVersions,
                                     boolean needsSourceDocumentSeqNoAndPrimaryTerm, Logger logger, ParentTaskAssigningClient client,
@@ -244,6 +253,10 @@ public abstract class AbstractAsyncBulkByScrollAction<Request extends AbstractBu
     }
 
     void onScrollResponse(ScrollableHitSource.AsyncResponse asyncResponse) {
+        onScrollResponse(new ScrollConsumableHitsResponse(asyncResponse));
+    }
+
+    void onScrollResponse(ScrollConsumableHitsResponse asyncResponse) {
         // lastBatchStartTime is essentially unused (see WorkerBulkByScrollTaskState.throttleWaitTime. Leaving it for now, since it seems
         // like a bug?
         onScrollResponse(System.nanoTime(), this.lastBatchSize, asyncResponse);
@@ -255,9 +268,9 @@ public abstract class AbstractAsyncBulkByScrollAction<Request extends AbstractBu
      * @param lastBatchSize the size of the last batch. Used to calculate the throttling delay.
      * @param asyncResponse the response to process from ScrollableHitSource
      */
-    void onScrollResponse(long lastBatchStartTimeNS, int lastBatchSize, ScrollableHitSource.AsyncResponse asyncResponse) {
+    void onScrollResponse(long lastBatchStartTimeNS, int lastBatchSize, ScrollConsumableHitsResponse asyncResponse) {
         ScrollableHitSource.Response response = asyncResponse.response();
-        logger.debug("[{}]: got scroll response with [{}] hits", task.getId(), response.getHits().size());
+        logger.debug("[{}]: got scroll response with [{}] hits", task.getId(), asyncResponse.remainingHits());
         if (task.isCancelled()) {
             logger.debug("[{}]: finishing early because the task was cancelled", task.getId());
             finishHim(null);
@@ -300,27 +313,29 @@ public abstract class AbstractAsyncBulkByScrollAction<Request extends AbstractBu
      * delay has been slept. Uses the generic thread pool because reindex is rare enough not to need its own thread pool and because the
      * thread may be blocked by the user script.
      */
-    void prepareBulkRequest(long thisBatchStartTimeNS, ScrollableHitSource.AsyncResponse asyncResponse) {
-        ScrollableHitSource.Response response = asyncResponse.response();
+    void prepareBulkRequest(long thisBatchStartTimeNS, ScrollConsumableHitsResponse asyncResponse) {
         logger.debug("[{}]: preparing bulk request", task.getId());
         if (task.isCancelled()) {
             logger.debug("[{}]: finishing early because the task was cancelled", task.getId());
             finishHim(null);
             return;
         }
-        if (response.getHits().isEmpty()) {
+        if (asyncResponse.hasRemainingHits() == false) {
             refreshAndFinish(emptyList(), emptyList(), false);
             return;
         }
         worker.countBatch();
-        List<? extends ScrollableHitSource.Hit> hits = response.getHits();
+        final List<? extends ScrollableHitSource.Hit> hits;
+
         if (mainRequest.getMaxDocs() != MAX_DOCS_ALL_MATCHES) {
             // Truncate the hits if we have more than the request max docs
-            long remaining = max(0, mainRequest.getMaxDocs() - worker.getSuccessfullyProcessed());
-            if (remaining < hits.size()) {
-                hits = hits.subList(0, (int) remaining);
-            }
+            long remainingDocsToProcess = max(0, mainRequest.getMaxDocs() - worker.getSuccessfullyProcessed());
+            hits = remainingDocsToProcess < asyncResponse.remainingHits() ? asyncResponse.consumeHits((int) remainingDocsToProcess)
+                                                                          : asyncResponse.consumeRemainingHits();
+        } else {
+            hits = asyncResponse.consumeRemainingHits();
         }
+
         BulkRequest request = buildBulk(hits);
         if (request.requests().isEmpty()) {
             /*
@@ -417,14 +432,24 @@ public abstract class AbstractAsyncBulkByScrollAction<Request extends AbstractBu
         }
     }
 
-    void notifyDone(long thisBatchStartTimeNS, ScrollableHitSource.AsyncResponse asyncResponse, int batchSize) {
+    void notifyDone(long thisBatchStartTimeNS,
+                    ScrollConsumableHitsResponse asyncResponse,
+                    int batchSize) {
         if (task.isCancelled()) {
             logger.debug("[{}]: finishing early because the task was cancelled", task.getId());
             finishHim(null);
             return;
         }
         this.lastBatchSize = batchSize;
-        asyncResponse.done(worker.throttleWaitTime(thisBatchStartTimeNS, System.nanoTime(), batchSize));
+        this.totalBatchSizeInSingleScrollResponse.addAndGet(batchSize);
+
+        if (asyncResponse.hasRemainingHits() == false) {
+            int totalBatchSize = totalBatchSizeInSingleScrollResponse.getAndSet(0);
+            asyncResponse.done(worker.throttleWaitTime(thisBatchStartTimeNS, System.nanoTime(), totalBatchSize));
+        } else {
+            onScrollResponse(asyncResponse);
+        }
+
     }
 
     private void recordFailure(Failure failure, List<Failure> failures) {
@@ -848,4 +873,51 @@ public abstract class AbstractAsyncBulkByScrollAction<Request extends AbstractBu
             return id.toLowerCase(Locale.ROOT);
         }
     }
+
+    static class ScrollConsumableHitsResponse {
+        private final ScrollableHitSource.AsyncResponse asyncResponse;
+        private final List<? extends ScrollableHitSource.Hit> hits;
+        private int consumedOffset = 0;
+
+        ScrollConsumableHitsResponse(ScrollableHitSource.AsyncResponse asyncResponse) {
+            this.asyncResponse = asyncResponse;
+            this.hits = asyncResponse.response().getHits();
+        }
+
+        ScrollableHitSource.Response response() {
+            return asyncResponse.response();
+        }
+
+        List<? extends ScrollableHitSource.Hit> consumeRemainingHits() {
+            return consumeHits(remainingHits());
+        }
+
+        List<? extends ScrollableHitSource.Hit> consumeHits(int numberOfHits) {
+            if (numberOfHits < 0) {
+                throw new IllegalArgumentException("Invalid number of hits to consume [" + numberOfHits + "]");
+            }
+
+            if (numberOfHits > remainingHits()) {
+                throw new IllegalArgumentException(
+                    "Unable to provide [" + numberOfHits + "] hits as there are only [" + remainingHits() + "] hits available"
+                );
+            }
+
+            int start = consumedOffset;
+            consumedOffset += numberOfHits;
+            return hits.subList(start, consumedOffset);
+        }
+
+        boolean hasRemainingHits() {
+            return remainingHits() > 0;
+        }
+
+        int remainingHits() {
+            return hits.size() - consumedOffset;
+        }
+
+        void done(TimeValue extraKeepAlive) {
+            asyncResponse.done(extraKeepAlive);
+        }
+    }
 }

+ 97 - 10
modules/reindex/src/test/java/org/elasticsearch/index/reindex/AsyncBulkByScrollActionTests.java

@@ -697,23 +697,110 @@ public class AsyncBulkByScrollActionTests extends ESTestCase {
         }
     }
 
+    public void testScrollConsumableHitsResponseCanBeConsumedInChunks() {
+        List<ScrollableHitSource.BasicHit> hits = new ArrayList<>();
+        int numberOfHits = randomIntBetween(0, 300);
+        for (int i = 0; i < numberOfHits; i++) {
+            hits.add(new ScrollableHitSource.BasicHit("idx", "id-" + i, -1));
+        }
+        final ScrollableHitSource.Response scrollResponse =
+            new ScrollableHitSource.Response(false, emptyList(), hits.size(), hits, "scrollid");
+        final AbstractAsyncBulkByScrollAction.ScrollConsumableHitsResponse response =
+            new AbstractAsyncBulkByScrollAction.ScrollConsumableHitsResponse(new ScrollableHitSource.AsyncResponse() {
+                @Override
+                public ScrollableHitSource.Response response() {
+                    return scrollResponse;
+                }
+
+                @Override
+                public void done(TimeValue extraKeepAlive) {
+                }
+            });
+
+        assertThat(response.remainingHits(), equalTo(numberOfHits));
+        assertThat(response.hasRemainingHits(), equalTo(numberOfHits > 0));
+
+        int totalConsumedHits = 0;
+        while (response.hasRemainingHits()) {
+            final int numberOfHitsToConsume;
+            final List<? extends ScrollableHitSource.Hit> consumedHits;
+            if (randomBoolean()) {
+                numberOfHitsToConsume = numberOfHits - totalConsumedHits;
+                consumedHits = response.consumeRemainingHits();
+            } else {
+                numberOfHitsToConsume = randomIntBetween(1, numberOfHits - totalConsumedHits);
+                consumedHits = response.consumeHits(numberOfHitsToConsume);
+            }
+
+            assertThat(consumedHits.size(), equalTo(numberOfHitsToConsume));
+            assertThat(consumedHits, equalTo(hits.subList(totalConsumedHits, totalConsumedHits + numberOfHitsToConsume)));
+            totalConsumedHits += numberOfHitsToConsume;
+
+            assertThat(response.remainingHits(), equalTo(numberOfHits - totalConsumedHits));
+        }
+
+        assertThat(response.consumeRemainingHits().isEmpty(), equalTo(true));
+    }
+
+    public void testScrollConsumableHitsResponseErrorHandling() {
+        List<ScrollableHitSource.BasicHit> hits = new ArrayList<>();
+        int numberOfHits = randomIntBetween(2, 300);
+        for (int i = 0; i < numberOfHits; i++) {
+            hits.add(new ScrollableHitSource.BasicHit("idx", "id-" + i, -1));
+        }
+
+        final ScrollableHitSource.Response scrollResponse =
+            new ScrollableHitSource.Response(false, emptyList(), hits.size(), hits, "scrollid");
+        final AbstractAsyncBulkByScrollAction.ScrollConsumableHitsResponse response =
+            new AbstractAsyncBulkByScrollAction.ScrollConsumableHitsResponse(new ScrollableHitSource.AsyncResponse() {
+                @Override
+                public ScrollableHitSource.Response response() {
+                    return scrollResponse;
+                }
+
+                @Override
+                public void done(TimeValue extraKeepAlive) {
+                }
+            });
+
+        assertThat(response.remainingHits(), equalTo(numberOfHits));
+        assertThat(response.hasRemainingHits(), equalTo(true));
+
+        expectThrows(IllegalArgumentException.class, () -> response.consumeHits(-1));
+        expectThrows(IllegalArgumentException.class, () -> response.consumeHits(numberOfHits + 1));
+
+        if (randomBoolean()) {
+            response.consumeHits(numberOfHits - 1);
+            // Unable to consume more than remaining hits
+            expectThrows(IllegalArgumentException.class, () -> response.consumeHits(response.remainingHits() + 1));
+            response.consumeHits(1);
+        } else {
+            response.consumeRemainingHits();
+        }
+
+        expectThrows(IllegalArgumentException.class, () -> response.consumeHits(1));
+    }
+
     /**
      * Simulate a scroll response by setting the scroll id and firing the onScrollResponse method.
      */
     private void simulateScrollResponse(DummyAsyncBulkByScrollAction action, long lastBatchTime, int lastBatchSize,
             ScrollableHitSource.Response response) {
         action.setScroll(scrollId());
-        action.onScrollResponse(lastBatchTime, lastBatchSize, new ScrollableHitSource.AsyncResponse() {
-            @Override
-            public ScrollableHitSource.Response response() {
-                return response;
-            }
+        action.onScrollResponse(lastBatchTime, lastBatchSize,
+            new AbstractAsyncBulkByScrollAction.ScrollConsumableHitsResponse(
+                new ScrollableHitSource.AsyncResponse() {
+                    @Override
+                    public ScrollableHitSource.Response response() {
+                        return response;
+                    }
 
-            @Override
-            public void done(TimeValue extraKeepAlive) {
-                fail();
-            }
-        });
+                    @Override
+                    public void done(TimeValue extraKeepAlive) {
+                        fail();
+                    }
+                })
+        );
     }
 
     private class DummyAsyncBulkByScrollAction