1
0
Эх сурвалжийг харах

Reindex share retry between hit sources (#44203)

The client and remote hit sources had each their own retry mechanism,
which would do the same. Supporting resiliency we would have to expand
on the retry mechanisms and as a preparation for that, the retry
mechanism is now shared such that each sub class is only responsible for
sending requests and converting responses/failures to common format.

Part of #42612
Henning Andersen 6 жил өмнө
parent
commit
c8977a44d1

+ 70 - 85
modules/reindex/src/main/java/org/elasticsearch/index/reindex/remote/RemoteScrollableHitSource.java

@@ -39,20 +39,19 @@ import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.common.unit.TimeValue;
-import org.elasticsearch.common.util.concurrent.AbstractRunnable;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
 import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
 import org.elasticsearch.common.xcontent.NamedXContentRegistry;
 import org.elasticsearch.common.xcontent.XContentParseException;
 import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.common.xcontent.XContentType;
+import org.elasticsearch.index.reindex.RejectAwareActionListener;
 import org.elasticsearch.index.reindex.ScrollableHitSource;
 import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.threadpool.ThreadPool;
 
 import java.io.IOException;
 import java.io.InputStream;
-import java.util.Iterator;
 import java.util.function.BiFunction;
 import java.util.function.Consumer;
 
@@ -77,31 +76,31 @@ public class RemoteScrollableHitSource extends ScrollableHitSource {
     }
 
     @Override
-    protected void doStart(Consumer<? super Response> onResponse) {
-        lookupRemoteVersion(version -> {
+    protected void doStart(RejectAwareActionListener<Response> searchListener) {
+        lookupRemoteVersion(RejectAwareActionListener.withResponseHandler(searchListener, version -> {
             remoteVersion = version;
             execute(RemoteRequestBuilders.initialSearch(searchRequest, query, remoteVersion),
-                    RESPONSE_PARSER, r -> onStartResponse(onResponse, r));
-        });
+                RESPONSE_PARSER, RejectAwareActionListener.withResponseHandler(searchListener, r -> onStartResponse(searchListener, r)));
+        }));
     }
 
-    void lookupRemoteVersion(Consumer<Version> onVersion) {
-        execute(new Request("GET", ""), MAIN_ACTION_PARSER, onVersion);
+    void lookupRemoteVersion(RejectAwareActionListener<Version> listener) {
+        execute(new Request("GET", ""), MAIN_ACTION_PARSER, listener);
     }
 
-    private void onStartResponse(Consumer<? super Response> onResponse, Response response) {
+    private void onStartResponse(RejectAwareActionListener<Response> searchListener, Response response) {
         if (Strings.hasLength(response.getScrollId()) && response.getHits().isEmpty()) {
             logger.debug("First response looks like a scan response. Jumping right to the second. scroll=[{}]", response.getScrollId());
-            doStartNextScroll(response.getScrollId(), timeValueMillis(0), onResponse);
+            doStartNextScroll(response.getScrollId(), timeValueMillis(0), searchListener);
         } else {
-            onResponse.accept(response);
+            searchListener.onResponse(response);
         }
     }
 
     @Override
-    protected void doStartNextScroll(String scrollId, TimeValue extraKeepAlive, Consumer<? super Response> onResponse) {
+    protected void doStartNextScroll(String scrollId, TimeValue extraKeepAlive, RejectAwareActionListener<Response> searchListener) {
         TimeValue keepAlive = timeValueNanos(searchRequest.scroll().keepAlive().nanos() + extraKeepAlive.nanos());
-        execute(RemoteRequestBuilders.scroll(scrollId, keepAlive, remoteVersion), RESPONSE_PARSER, onResponse);
+        execute(RemoteRequestBuilders.scroll(scrollId, keepAlive, remoteVersion), RESPONSE_PARSER, searchListener);
     }
 
     @Override
@@ -153,91 +152,77 @@ public class RemoteScrollableHitSource extends ScrollableHitSource {
     }
 
     private <T> void execute(Request request,
-                             BiFunction<XContentParser, XContentType, T> parser, Consumer<? super T> listener) {
+                             BiFunction<XContentParser, XContentType, T> parser, RejectAwareActionListener<? super T> listener) {
         // Preserve the thread context so headers survive after the call
         java.util.function.Supplier<ThreadContext.StoredContext> contextSupplier = threadPool.getThreadContext().newRestorableContext(true);
-        class RetryHelper extends AbstractRunnable {
-            private final Iterator<TimeValue> retries = backoffPolicy.iterator();
-
-            @Override
-            protected void doRun() throws Exception {
-                client.performRequestAsync(request, new ResponseListener() {
-                    @Override
-                    public void onSuccess(org.elasticsearch.client.Response response) {
-                        // Restore the thread context to get the precious headers
-                        try (ThreadContext.StoredContext ctx = contextSupplier.get()) {
-                            assert ctx != null; // eliminates compiler warning
-                            T parsedResponse;
-                            try {
-                                HttpEntity responseEntity = response.getEntity();
-                                InputStream content = responseEntity.getContent();
-                                XContentType xContentType = null;
-                                if (responseEntity.getContentType() != null) {
-                                    final String mimeType = ContentType.parse(responseEntity.getContentType().getValue()).getMimeType();
-                                    xContentType = XContentType.fromMediaType(mimeType);
-                                }
-                                if (xContentType == null) {
-                                    try {
-                                        throw new ElasticsearchException(
-                                            "Response didn't include Content-Type: " + bodyMessage(response.getEntity()));
-                                    } catch (IOException e) {
-                                        ElasticsearchException ee = new ElasticsearchException("Error extracting body from response");
-                                        ee.addSuppressed(e);
-                                        throw ee;
-                                    }
+        try {
+            client.performRequestAsync(request, new ResponseListener() {
+                @Override
+                public void onSuccess(org.elasticsearch.client.Response response) {
+                    // Restore the thread context to get the precious headers
+                    try (ThreadContext.StoredContext ctx = contextSupplier.get()) {
+                        assert ctx != null; // eliminates compiler warning
+                        T parsedResponse;
+                        try {
+                            HttpEntity responseEntity = response.getEntity();
+                            InputStream content = responseEntity.getContent();
+                            XContentType xContentType = null;
+                            if (responseEntity.getContentType() != null) {
+                                final String mimeType = ContentType.parse(responseEntity.getContentType().getValue()).getMimeType();
+                                xContentType = XContentType.fromMediaType(mimeType);
+                            }
+                            if (xContentType == null) {
+                                try {
+                                    throw new ElasticsearchException(
+                                        "Response didn't include Content-Type: " + bodyMessage(response.getEntity()));
+                                } catch (IOException e) {
+                                    ElasticsearchException ee = new ElasticsearchException("Error extracting body from response");
+                                    ee.addSuppressed(e);
+                                    throw ee;
                                 }
-                                // EMPTY is safe here because we don't call namedObject
-                                try (XContentParser xContentParser = xContentType.xContent().createParser(NamedXContentRegistry.EMPTY,
-                                    LoggingDeprecationHandler.INSTANCE, content)) {
-                                    parsedResponse = parser.apply(xContentParser, xContentType);
-                                } catch (XContentParseException e) {
+                            }
+                            // EMPTY is safe here because we don't call namedObject
+                            try (XContentParser xContentParser = xContentType.xContent().createParser(NamedXContentRegistry.EMPTY,
+                                LoggingDeprecationHandler.INSTANCE, content)) {
+                                parsedResponse = parser.apply(xContentParser, xContentType);
+                            } catch (XContentParseException e) {
                                 /* Because we're streaming the response we can't get a copy of it here. The best we can do is hint that it
                                  * is totally wrong and we're probably not talking to Elasticsearch. */
-                                    throw new ElasticsearchException(
-                                        "Error parsing the response, remote is likely not an Elasticsearch instance", e);
-                                }
-                            } catch (IOException e) {
                                 throw new ElasticsearchException(
-                                    "Error deserializing response, remote is likely not an Elasticsearch instance", e);
+                                    "Error parsing the response, remote is likely not an Elasticsearch instance", e);
                             }
-                            listener.accept(parsedResponse);
+                        } catch (IOException e) {
+                            throw new ElasticsearchException(
+                                "Error deserializing response, remote is likely not an Elasticsearch instance", e);
                         }
+                        listener.onResponse(parsedResponse);
                     }
+                }
 
-                    @Override
-                    public void onFailure(Exception e) {
-                        try (ThreadContext.StoredContext ctx = contextSupplier.get()) {
-                            assert ctx != null; // eliminates compiler warning
-                            if (e instanceof ResponseException) {
-                                ResponseException re = (ResponseException) e;
-                                if (RestStatus.TOO_MANY_REQUESTS.getStatus() == re.getResponse().getStatusLine().getStatusCode()) {
-                                    if (retries.hasNext()) {
-                                        TimeValue delay = retries.next();
-                                        logger.trace(
-                                            (Supplier<?>) () -> new ParameterizedMessage("retrying rejected search after [{}]", delay), e);
-                                        countSearchRetry.run();
-                                        threadPool.schedule(RetryHelper.this, delay, ThreadPool.Names.SAME);
-                                        return;
-                                    }
-                                }
-                                e = wrapExceptionToPreserveStatus(re.getResponse().getStatusLine().getStatusCode(),
-                                    re.getResponse().getEntity(), re);
-                            } else if (e instanceof ContentTooLongException) {
-                                e = new IllegalArgumentException(
-                                    "Remote responded with a chunk that was too large. Use a smaller batch size.", e);
+                @Override
+                public void onFailure(Exception e) {
+                    try (ThreadContext.StoredContext ctx = contextSupplier.get()) {
+                        assert ctx != null; // eliminates compiler warning
+                        if (e instanceof ResponseException) {
+                            ResponseException re = (ResponseException) e;
+                            int statusCode = re.getResponse().getStatusLine().getStatusCode();
+                            e = wrapExceptionToPreserveStatus(statusCode,
+                                re.getResponse().getEntity(), re);
+                            if (RestStatus.TOO_MANY_REQUESTS.getStatus() == statusCode) {
+                                listener.onRejection(e);
+                                return;
                             }
-                            fail.accept(e);
+                        } else if (e instanceof ContentTooLongException) {
+                            e = new IllegalArgumentException(
+                                "Remote responded with a chunk that was too large. Use a smaller batch size.", e);
                         }
+                        listener.onFailure(e);
                     }
-                });
-            }
-
-            @Override
-            public void onFailure(Exception t) {
-                fail.accept(t);
-            }
+                }
+            });
+        } catch (Exception e) {
+            listener.onFailure(e);
         }
-        new RetryHelper().run();
     }
 
     /**
@@ -261,7 +246,7 @@ public class RemoteScrollableHitSource extends ScrollableHitSource {
         }
     }
 
-    static String bodyMessage(@Nullable HttpEntity entity) throws IOException {
+    private static String bodyMessage(@Nullable HttpEntity entity) throws IOException {
         if (entity == null) {
             return "No error body.";
         } else {

+ 97 - 62
modules/reindex/src/test/java/org/elasticsearch/index/reindex/remote/RemoteScrollableHitSourceTests.java

@@ -52,6 +52,8 @@ import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.common.unit.TimeValue;
 import org.elasticsearch.common.util.concurrent.EsExecutors;
 import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
+import org.elasticsearch.index.reindex.RejectAwareActionListener;
+import org.elasticsearch.index.reindex.ScrollableHitSource;
 import org.elasticsearch.index.reindex.ScrollableHitSource.Response;
 import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.search.builder.SearchSourceBuilder;
@@ -67,10 +69,14 @@ import java.io.IOException;
 import java.io.InputStreamReader;
 import java.net.URL;
 import java.nio.charset.StandardCharsets;
+import java.util.Queue;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Future;
+import java.util.concurrent.LinkedBlockingQueue;
 import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.Consumer;
+import java.util.stream.Stream;
 
 import static org.elasticsearch.common.unit.TimeValue.timeValueMillis;
 import static org.elasticsearch.common.unit.TimeValue.timeValueMinutes;
@@ -91,6 +97,8 @@ public class RemoteScrollableHitSourceTests extends ESTestCase {
     private SearchRequest searchRequest;
     private int retriesAllowed;
 
+    private final Queue<ScrollableHitSource.AsyncResponse> responseQueue = new LinkedBlockingQueue<>();
+
     @Before
     @Override
     public void setUp() throws Exception {
@@ -122,6 +130,11 @@ public class RemoteScrollableHitSourceTests extends ESTestCase {
         terminate(threadPool);
     }
 
+    @After
+    public void validateAllConsumed() {
+        assertTrue(responseQueue.isEmpty());
+    }
+
     public void testLookupRemoteVersion() throws Exception {
         assertLookupRemoteVersion(Version.fromString("0.20.5"), "main/0_20_5.json");
         assertLookupRemoteVersion(Version.fromString("0.90.13"), "main/0_90_13.json");
@@ -135,16 +148,17 @@ public class RemoteScrollableHitSourceTests extends ESTestCase {
 
     private void assertLookupRemoteVersion(Version expected, String s) throws Exception {
         AtomicBoolean called = new AtomicBoolean();
-        sourceWithMockedRemoteCall(false, ContentType.APPLICATION_JSON, s).lookupRemoteVersion(v -> {
-            assertEquals(expected, v);
-            called.set(true);
-        });
+        sourceWithMockedRemoteCall(false, ContentType.APPLICATION_JSON, s)
+            .lookupRemoteVersion(wrapAsListener(v -> {
+                assertEquals(expected, v);
+                called.set(true);
+        }));
         assertTrue(called.get());
     }
 
     public void testParseStartOk() throws Exception {
         AtomicBoolean called = new AtomicBoolean();
-        sourceWithMockedRemoteCall("start_ok.json").doStart(r -> {
+        sourceWithMockedRemoteCall("start_ok.json").doStart(wrapAsListener(r -> {
             assertFalse(r.isTimedOut());
             assertEquals(FAKE_SCROLL_ID, r.getScrollId());
             assertEquals(4, r.getTotalHits());
@@ -156,13 +170,13 @@ public class RemoteScrollableHitSourceTests extends ESTestCase {
             assertEquals("{\"test\":\"test2\"}", r.getHits().get(0).getSource().utf8ToString());
             assertNull(r.getHits().get(0).getRouting());
             called.set(true);
-        });
+        }));
         assertTrue(called.get());
     }
 
     public void testParseScrollOk() throws Exception {
         AtomicBoolean called = new AtomicBoolean();
-        sourceWithMockedRemoteCall("scroll_ok.json").doStartNextScroll("", timeValueMillis(0), r -> {
+        sourceWithMockedRemoteCall("scroll_ok.json").doStartNextScroll("", timeValueMillis(0), wrapAsListener(r -> {
             assertFalse(r.isTimedOut());
             assertEquals(FAKE_SCROLL_ID, r.getScrollId());
             assertEquals(4, r.getTotalHits());
@@ -174,7 +188,7 @@ public class RemoteScrollableHitSourceTests extends ESTestCase {
             assertEquals("{\"test\":\"test3\"}", r.getHits().get(0).getSource().utf8ToString());
             assertNull(r.getHits().get(0).getRouting());
             called.set(true);
-        });
+        }));
         assertTrue(called.get());
     }
 
@@ -183,12 +197,12 @@ public class RemoteScrollableHitSourceTests extends ESTestCase {
      */
     public void testParseScrollFullyLoaded() throws Exception {
         AtomicBoolean called = new AtomicBoolean();
-        sourceWithMockedRemoteCall("scroll_fully_loaded.json").doStartNextScroll("", timeValueMillis(0), r -> {
+        sourceWithMockedRemoteCall("scroll_fully_loaded.json").doStartNextScroll("", timeValueMillis(0), wrapAsListener(r -> {
             assertEquals("AVToMiDL50DjIiBO3yKA", r.getHits().get(0).getId());
             assertEquals("{\"test\":\"test3\"}", r.getHits().get(0).getSource().utf8ToString());
             assertEquals("testrouting", r.getHits().get(0).getRouting());
             called.set(true);
-        });
+        }));
         assertTrue(called.get());
     }
 
@@ -197,12 +211,12 @@ public class RemoteScrollableHitSourceTests extends ESTestCase {
      */
     public void testParseScrollFullyLoadedFrom1_7() throws Exception {
         AtomicBoolean called = new AtomicBoolean();
-        sourceWithMockedRemoteCall("scroll_fully_loaded_1_7.json").doStartNextScroll("", timeValueMillis(0), r -> {
+        sourceWithMockedRemoteCall("scroll_fully_loaded_1_7.json").doStartNextScroll("", timeValueMillis(0), wrapAsListener(r -> {
             assertEquals("AVToMiDL50DjIiBO3yKA", r.getHits().get(0).getId());
             assertEquals("{\"test\":\"test3\"}", r.getHits().get(0).getSource().utf8ToString());
             assertEquals("testrouting", r.getHits().get(0).getRouting());
             called.set(true);
-        });
+        }));
         assertTrue(called.get());
     }
 
@@ -212,7 +226,7 @@ public class RemoteScrollableHitSourceTests extends ESTestCase {
      */
     public void testScanJumpStart() throws Exception {
         AtomicBoolean called = new AtomicBoolean();
-        sourceWithMockedRemoteCall("start_scan.json", "scroll_ok.json").doStart(r -> {
+        sourceWithMockedRemoteCall("start_scan.json", "scroll_ok.json").doStart(wrapAsListener(r -> {
             assertFalse(r.isTimedOut());
             assertEquals(FAKE_SCROLL_ID, r.getScrollId());
             assertEquals(4, r.getTotalHits());
@@ -224,7 +238,7 @@ public class RemoteScrollableHitSourceTests extends ESTestCase {
             assertEquals("{\"test\":\"test3\"}", r.getHits().get(0).getSource().utf8ToString());
             assertNull(r.getHits().get(0).getRouting());
             called.set(true);
-        });
+        }));
         assertTrue(called.get());
     }
 
@@ -252,10 +266,10 @@ public class RemoteScrollableHitSourceTests extends ESTestCase {
             assertEquals("{\"test\":\"test1\"}", r.getHits().get(0).getSource().utf8ToString());
             called.set(true);
         };
-        sourceWithMockedRemoteCall("rejection.json").doStart(checkResponse);
+        sourceWithMockedRemoteCall("rejection.json").doStart(wrapAsListener(checkResponse));
         assertTrue(called.get());
         called.set(false);
-        sourceWithMockedRemoteCall("rejection.json").doStartNextScroll("scroll", timeValueMillis(0), checkResponse);
+        sourceWithMockedRemoteCall("rejection.json").doStartNextScroll("scroll", timeValueMillis(0), wrapAsListener(checkResponse));
         assertTrue(called.get());
     }
 
@@ -281,10 +295,11 @@ public class RemoteScrollableHitSourceTests extends ESTestCase {
             assertEquals("{\"test\":\"test10000\"}", r.getHits().get(0).getSource().utf8ToString());
             called.set(true);
         };
-        sourceWithMockedRemoteCall("failure_with_status.json").doStart(checkResponse);
+        sourceWithMockedRemoteCall("failure_with_status.json").doStart(wrapAsListener(checkResponse));
         assertTrue(called.get());
         called.set(false);
-        sourceWithMockedRemoteCall("failure_with_status.json").doStartNextScroll("scroll", timeValueMillis(0), checkResponse);
+        sourceWithMockedRemoteCall("failure_with_status.json").doStartNextScroll("scroll", timeValueMillis(0),
+            wrapAsListener(checkResponse));
         assertTrue(called.get());
     }
 
@@ -302,48 +317,51 @@ public class RemoteScrollableHitSourceTests extends ESTestCase {
             assertEquals(14, failure.getColumnNumber());
             called.set(true);
         };
-        sourceWithMockedRemoteCall("request_failure.json").doStart(checkResponse);
+        sourceWithMockedRemoteCall("request_failure.json").doStart(wrapAsListener(checkResponse));
         assertTrue(called.get());
         called.set(false);
-        sourceWithMockedRemoteCall("request_failure.json").doStartNextScroll("scroll", timeValueMillis(0), checkResponse);
+        sourceWithMockedRemoteCall("request_failure.json").doStartNextScroll("scroll", timeValueMillis(0), wrapAsListener(checkResponse));
         assertTrue(called.get());
     }
 
     public void testRetryAndSucceed() throws Exception {
-        AtomicBoolean called = new AtomicBoolean();
-        Consumer<Response> checkResponse = r -> {
-            assertThat(r.getFailures(), hasSize(0));
-            called.set(true);
-        };
         retriesAllowed = between(1, Integer.MAX_VALUE);
-        sourceWithMockedRemoteCall("fail:rejection.json", "start_ok.json").doStart(checkResponse);
-        assertTrue(called.get());
+        sourceWithMockedRemoteCall("fail:rejection.json", "start_ok.json", "fail:rejection.json", "scroll_ok.json").start();
+        ScrollableHitSource.AsyncResponse response = responseQueue.poll();
+        assertNotNull(response);
+        assertThat(response.response().getFailures(), empty());
+        assertTrue(responseQueue.isEmpty());
         assertEquals(1, retries);
         retries = 0;
-        called.set(false);
-        sourceWithMockedRemoteCall("fail:rejection.json", "scroll_ok.json").doStartNextScroll("scroll", timeValueMillis(0),
-                checkResponse);
-        assertTrue(called.get());
+        response.done(timeValueMillis(0));
+        response = responseQueue.poll();
+        assertNotNull(response);
+        assertThat(response.response().getFailures(), empty());
+        assertTrue(responseQueue.isEmpty());
         assertEquals(1, retries);
     }
 
     public void testRetryUntilYouRunOutOfTries() throws Exception {
-        AtomicBoolean called = new AtomicBoolean();
-        Consumer<Response> checkResponse = r -> called.set(true);
         retriesAllowed = between(0, 10);
         String[] paths = new String[retriesAllowed + 2];
         for (int i = 0; i < retriesAllowed + 2; i++) {
             paths[i] = "fail:rejection.json";
         }
-        RuntimeException e = expectThrows(RuntimeException.class, () -> sourceWithMockedRemoteCall(paths).doStart(checkResponse));
+        RuntimeException e = expectThrows(RuntimeException.class, () -> sourceWithMockedRemoteCall(paths).start());
         assertEquals("failed", e.getMessage());
-        assertFalse(called.get());
+        assertTrue(responseQueue.isEmpty());
         assertEquals(retriesAllowed, retries);
         retries = 0;
-        e = expectThrows(RuntimeException.class,
-                () -> sourceWithMockedRemoteCall(paths).doStartNextScroll("scroll", timeValueMillis(0), checkResponse));
+        String[] searchOKPaths = Stream.concat(Stream.of("start_ok.json"), Stream.of(paths)).toArray(String[]::new);
+        sourceWithMockedRemoteCall(searchOKPaths).start();
+        ScrollableHitSource.AsyncResponse response = responseQueue.poll();
+        assertNotNull(response);
+        assertThat(response.response().getFailures(), empty());
+        assertTrue(responseQueue.isEmpty());
+
+        e = expectThrows(RuntimeException.class, () -> response.done(timeValueMillis(0)));
         assertEquals("failed", e.getMessage());
-        assertFalse(called.get());
+        assertTrue(responseQueue.isEmpty());
         assertEquals(retriesAllowed, retries);
     }
 
@@ -351,10 +369,10 @@ public class RemoteScrollableHitSourceTests extends ESTestCase {
         String header = randomAlphaOfLength(5);
         threadPool.getThreadContext().putHeader("test", header);
         AtomicBoolean called = new AtomicBoolean();
-        sourceWithMockedRemoteCall("start_ok.json").doStart(r -> {
+        sourceWithMockedRemoteCall("start_ok.json").doStart(wrapAsListener(r -> {
             assertEquals(header, threadPool.getThreadContext().getHeader("test"));
             called.set(true);
-        });
+        }));
         assertTrue(called.get());
     }
 
@@ -424,10 +442,7 @@ public class RemoteScrollableHitSourceTests extends ESTestCase {
         });
         RemoteScrollableHitSource source = sourceWithMockedClient(true, httpClient);
 
-        AtomicBoolean called = new AtomicBoolean();
-        Consumer<Response> checkResponse = r -> called.set(true);
-        Throwable e = expectThrows(RuntimeException.class,
-                () -> source.doStartNextScroll(FAKE_SCROLL_ID, timeValueMillis(0), checkResponse));
+        Throwable e = expectThrows(RuntimeException.class, source::start);
         // Unwrap the some artifacts from the test
         while (e.getMessage().equals("failed")) {
             e = e.getCause();
@@ -436,24 +451,24 @@ public class RemoteScrollableHitSourceTests extends ESTestCase {
         assertEquals("Remote responded with a chunk that was too large. Use a smaller batch size.", e.getMessage());
         // And that exception is reported as being caused by the underlying exception returned by the client
         assertSame(tooLong, e.getCause());
-        assertFalse(called.get());
+        assertTrue(responseQueue.isEmpty());
     }
 
-    public void testNoContentTypeIsError() throws Exception {
-        Exception e = expectThrows(RuntimeException.class, () ->
-                sourceWithMockedRemoteCall(false, null, "main/0_20_5.json").lookupRemoteVersion(null));
-        assertThat(e.getCause().getCause().getCause().getMessage(), containsString("Response didn't include Content-Type: body={"));
+    public void testNoContentTypeIsError() {
+        RuntimeException e = expectListenerFailure(RuntimeException.class, (RejectAwareActionListener<Version> listener) ->
+                sourceWithMockedRemoteCall(false, null, "main/0_20_5.json").lookupRemoteVersion(listener));
+        assertThat(e.getMessage(), containsString("Response didn't include Content-Type: body={"));
     }
 
-    public void testInvalidJsonThinksRemoveIsNotES() throws IOException {
-        Exception e = expectThrows(RuntimeException.class, () -> sourceWithMockedRemoteCall("some_text.txt").doStart(null));
+    public void testInvalidJsonThinksRemoteIsNotES() throws IOException {
+        Exception e = expectThrows(RuntimeException.class, () -> sourceWithMockedRemoteCall("some_text.txt").start());
         assertEquals("Error parsing the response, remote is likely not an Elasticsearch instance",
                 e.getCause().getCause().getCause().getMessage());
     }
 
-    public void testUnexpectedJsonThinksRemoveIsNotES() throws IOException {
+    public void testUnexpectedJsonThinksRemoteIsNotES() throws IOException {
         // Use the response from a main action instead of a proper start response to generate a parse error
-        Exception e = expectThrows(RuntimeException.class, () -> sourceWithMockedRemoteCall("main/2_3_3.json").doStart(null));
+        Exception e = expectThrows(RuntimeException.class, () -> sourceWithMockedRemoteCall("main/2_3_3.json").start());
         assertEquals("Error parsing the response, remote is likely not an Elasticsearch instance",
                 e.getCause().getCause().getCause().getMessage());
     }
@@ -486,8 +501,7 @@ public class RemoteScrollableHitSourceTests extends ESTestCase {
      * synchronously rather than asynchronously.
      */
     @SuppressWarnings("unchecked")
-    private RemoteScrollableHitSource sourceWithMockedRemoteCall(boolean mockRemoteVersion, ContentType contentType, String... paths)
-            throws Exception {
+    private RemoteScrollableHitSource sourceWithMockedRemoteCall(boolean mockRemoteVersion, ContentType contentType, String... paths) {
         URL[] resources = new URL[paths.length];
         for (int i = 0; i < paths.length; i++) {
             resources[i] = Thread.currentThread().getContextClassLoader().getResource("responses/" + paths[i].replace("fail:", ""));
@@ -533,8 +547,7 @@ public class RemoteScrollableHitSourceTests extends ESTestCase {
         return sourceWithMockedClient(mockRemoteVersion, httpClient);
     }
 
-    private RemoteScrollableHitSource sourceWithMockedClient(boolean mockRemoteVersion, CloseableHttpAsyncClient httpClient)
-            throws Exception {
+    private RemoteScrollableHitSource sourceWithMockedClient(boolean mockRemoteVersion, CloseableHttpAsyncClient httpClient) {
         HttpAsyncClientBuilder clientBuilder = mock(HttpAsyncClientBuilder.class);
         when(clientBuilder.build()).thenReturn(httpClient);
 
@@ -543,11 +556,11 @@ public class RemoteScrollableHitSourceTests extends ESTestCase {
 
         TestRemoteScrollableHitSource hitSource = new TestRemoteScrollableHitSource(restClient) {
             @Override
-            void lookupRemoteVersion(Consumer<Version> onVersion) {
+            void lookupRemoteVersion(RejectAwareActionListener<Version> listener) {
                 if (mockRemoteVersion) {
-                    onVersion.accept(Version.CURRENT);
+                    listener.onResponse(Version.CURRENT);
                 } else {
-                    super.lookupRemoteVersion(onVersion);
+                    super.lookupRemoteVersion(listener);
                 }
             }
         };
@@ -572,8 +585,30 @@ public class RemoteScrollableHitSourceTests extends ESTestCase {
     private class TestRemoteScrollableHitSource extends RemoteScrollableHitSource {
         TestRemoteScrollableHitSource(RestClient client) {
             super(RemoteScrollableHitSourceTests.this.logger, backoff(), RemoteScrollableHitSourceTests.this.threadPool,
-                    RemoteScrollableHitSourceTests.this::countRetry, r -> fail(), RemoteScrollableHitSourceTests.this::failRequest, client,
-                    new BytesArray("{}"), RemoteScrollableHitSourceTests.this.searchRequest);
+                RemoteScrollableHitSourceTests.this::countRetry,
+                responseQueue::add, RemoteScrollableHitSourceTests.this::failRequest,
+                client, new BytesArray("{}"), RemoteScrollableHitSourceTests.this.searchRequest);
         }
     }
+
+    private <T> RejectAwareActionListener<T> wrapAsListener(Consumer<T> consumer) {
+        Consumer<Exception> throwing = e -> {
+            throw new AssertionError(e);
+        };
+        return RejectAwareActionListener.wrap(consumer::accept, throwing, throwing);
+    }
+
+    @SuppressWarnings("unchecked")
+    private <T extends Exception, V> T expectListenerFailure(Class<T> expectedException, Consumer<RejectAwareActionListener<V>> subject) {
+        AtomicReference<T> exception = new AtomicReference<>();
+        subject.accept(RejectAwareActionListener.wrap(
+            r -> fail(),
+            e -> {
+                assertThat(e, instanceOf(expectedException));
+                assertTrue(exception.compareAndSet(null, (T) e));
+            },
+            e -> fail()));
+        assertNotNull(exception.get());
+        return exception.get();
+    }
 }

+ 27 - 75
server/src/main/java/org/elasticsearch/index/reindex/ClientScrollableHitSource.java

@@ -35,7 +35,6 @@ import org.elasticsearch.client.ParentTaskAssigningClient;
 import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.common.document.DocumentField;
 import org.elasticsearch.common.unit.TimeValue;
-import org.elasticsearch.common.util.concurrent.AbstractRunnable;
 import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
 import org.elasticsearch.common.xcontent.XContentHelper;
 import org.elasticsearch.common.xcontent.XContentType;
@@ -44,7 +43,6 @@ import org.elasticsearch.search.SearchHit;
 import org.elasticsearch.threadpool.ThreadPool;
 
 import java.util.ArrayList;
-import java.util.Iterator;
 import java.util.List;
 import java.util.function.Consumer;
 
@@ -69,22 +67,38 @@ public class ClientScrollableHitSource extends ScrollableHitSource {
     }
 
     @Override
-    public void doStart(Consumer<? super Response> onResponse) {
+    public void doStart(RejectAwareActionListener<Response> searchListener) {
         if (logger.isDebugEnabled()) {
             logger.debug("executing initial scroll against {}",
-                    isEmpty(firstSearchRequest.indices()) ? "all indices" : firstSearchRequest.indices());
+                isEmpty(firstSearchRequest.indices()) ? "all indices" : firstSearchRequest.indices());
         }
-        searchWithRetry(listener -> client.search(firstSearchRequest, listener), r -> consume(r, onResponse));
+        client.search(firstSearchRequest, wrapListener(searchListener));
     }
 
     @Override
-    protected void doStartNextScroll(String scrollId, TimeValue extraKeepAlive, Consumer<? super Response> onResponse) {
-        searchWithRetry(listener -> {
-            SearchScrollRequest request = new SearchScrollRequest();
-            // Add the wait time into the scroll timeout so it won't timeout while we wait for throttling
-            request.scrollId(scrollId).scroll(timeValueNanos(firstSearchRequest.scroll().keepAlive().nanos() + extraKeepAlive.nanos()));
-            client.searchScroll(request, listener);
-        }, r -> consume(r, onResponse));
+    protected void doStartNextScroll(String scrollId, TimeValue extraKeepAlive, RejectAwareActionListener<Response> searchListener) {
+        SearchScrollRequest request = new SearchScrollRequest();
+        // Add the wait time into the scroll timeout so it won't timeout while we wait for throttling
+        request.scrollId(scrollId).scroll(timeValueNanos(firstSearchRequest.scroll().keepAlive().nanos() + extraKeepAlive.nanos()));
+        client.searchScroll(request, wrapListener(searchListener));
+    }
+
+    private ActionListener<SearchResponse> wrapListener(RejectAwareActionListener<Response> searchListener) {
+        return new ActionListener<>() {
+            @Override
+            public void onResponse(SearchResponse searchResponse) {
+                searchListener.onResponse(wrapSearchResponse(searchResponse));
+            }
+
+            @Override
+            public void onFailure(Exception e) {
+                if (ExceptionsHelper.unwrap(e, EsRejectedExecutionException.class) != null) {
+                    searchListener.onRejection(e);
+                } else {
+                    searchListener.onFailure(e);
+                }
+            }
+        };
     }
 
     @Override
@@ -115,69 +129,7 @@ public class ClientScrollableHitSource extends ScrollableHitSource {
         onCompletion.run();
     }
 
-    /**
-     * Run a search action and call onResponse when a the response comes in, retrying if the action fails with an exception caused by
-     * rejected execution.
-     *
-     * @param action consumes a listener and starts the action. The listener it consumes is rigged to retry on failure.
-     * @param onResponse consumes the response from the action
-     */
-    private void searchWithRetry(Consumer<ActionListener<SearchResponse>> action, Consumer<SearchResponse> onResponse) {
-        /*
-         * RetryHelper is both an AbstractRunnable and an ActionListener<SearchResponse> - meaning that it both starts the search and
-         * handles reacts to the results. The complexity is all in onFailure which either adapts the failure to the "fail" listener or
-         * retries the search. Since both AbstractRunnable and ActionListener define the onFailure method it is called for either failure
-         * to run the action (either while running or before starting) and for failure on the response from the action.
-         */
-        class RetryHelper extends AbstractRunnable implements ActionListener<SearchResponse> {
-            private final Iterator<TimeValue> retries = backoffPolicy.iterator();
-            /**
-             * The runnable to run that retries in the same context as the original call.
-             */
-            private Runnable retryWithContext;
-            private volatile int retryCount = 0;
-
-            @Override
-            protected void doRun() throws Exception {
-                action.accept(this);
-            }
-
-            @Override
-            public void onResponse(SearchResponse response) {
-                onResponse.accept(response);
-            }
-
-            @Override
-            public void onFailure(Exception e) {
-                if (ExceptionsHelper.unwrap(e, EsRejectedExecutionException.class) != null) {
-                    if (retries.hasNext()) {
-                        retryCount += 1;
-                        TimeValue delay = retries.next();
-                        logger.trace(() -> new ParameterizedMessage("retrying rejected search after [{}]", delay), e);
-                        countSearchRetry.run();
-                        threadPool.schedule(retryWithContext, delay, ThreadPool.Names.SAME);
-                    } else {
-                        logger.warn(() -> new ParameterizedMessage(
-                                "giving up on search because we retried [{}] times without success", retryCount), e);
-                        fail.accept(e);
-                    }
-                } else {
-                    logger.warn("giving up on search because it failed with a non-retryable exception", e);
-                    fail.accept(e);
-                }
-            }
-        }
-        RetryHelper helper = new RetryHelper();
-        // Wrap the helper in a runnable that preserves the current context so we keep it on retry.
-        helper.retryWithContext = threadPool.getThreadContext().preserveContext(helper);
-        helper.run();
-    }
-
-    private void consume(SearchResponse response, Consumer<? super Response> onResponse) {
-        onResponse.accept(wrap(response));
-    }
-
-    private Response wrap(SearchResponse response) {
+    private Response wrapSearchResponse(SearchResponse response) {
         List<SearchFailure> failures;
         if (response.getShardFailures() == null) {
             failures = emptyList();

+ 81 - 0
server/src/main/java/org/elasticsearch/index/reindex/RejectAwareActionListener.java

@@ -0,0 +1,81 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.index.reindex;
+
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.common.CheckedConsumer;
+
+import java.util.function.Consumer;
+
+// public for testing
+public interface RejectAwareActionListener<T> extends ActionListener<T> {
+    void onRejection(Exception e);
+
+    /**
+     * Return a new listener that delegates failure/reject to errorDelegate but forwards response to responseHandler
+     */
+    static <X> RejectAwareActionListener<X> withResponseHandler(RejectAwareActionListener<?> errorDelegate, Consumer<X> responseHandler) {
+        return new RejectAwareActionListener<>() {
+            @Override
+            public void onRejection(Exception e) {
+                errorDelegate.onRejection(e);
+            }
+
+            @Override
+            public void onResponse(X t) {
+                responseHandler.accept(t);
+            }
+
+            @Override
+            public void onFailure(Exception e) {
+                errorDelegate.onFailure(e);
+            }
+        };
+    }
+
+    /**
+     * Similar to {@link ActionListener#wrap(CheckedConsumer, Consumer)}, extended to have handler for onRejection.
+     */
+    static <Response> RejectAwareActionListener<Response> wrap(CheckedConsumer<Response, ? extends Exception> onResponse,
+                                                    Consumer<Exception> onFailure, Consumer<Exception> onRejection) {
+        return new RejectAwareActionListener<Response>() {
+            @Override
+            public void onResponse(Response response) {
+                try {
+                    onResponse.accept(response);
+                } catch (Exception e) {
+                    onFailure(e);
+                }
+            }
+
+            @Override
+            public void onFailure(Exception e) {
+                onFailure.accept(e);
+            }
+
+            @Override
+            public void onRejection(Exception e) {
+                onRejection.accept(e);
+            }
+        };
+    }
+
+}
+

+ 78 - 0
server/src/main/java/org/elasticsearch/index/reindex/RetryListener.java

@@ -0,0 +1,78 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.index.reindex;
+
+import org.apache.logging.log4j.Logger;
+import org.apache.logging.log4j.message.ParameterizedMessage;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.bulk.BackoffPolicy;
+import org.elasticsearch.common.unit.TimeValue;
+import org.elasticsearch.threadpool.ThreadPool;
+
+import java.util.Iterator;
+import java.util.function.Consumer;
+
+class RetryListener implements RejectAwareActionListener<ScrollableHitSource.Response> {
+    private final Logger logger;
+    private final Iterator<TimeValue> retries;
+    private final ThreadPool threadPool;
+    private final Consumer<RejectAwareActionListener<ScrollableHitSource.Response>> retryScrollHandler;
+    private final ActionListener<ScrollableHitSource.Response> delegate;
+    private int retryCount = 0;
+
+    RetryListener(Logger logger, ThreadPool threadPool, BackoffPolicy backoffPolicy,
+                          Consumer<RejectAwareActionListener<ScrollableHitSource.Response>> retryScrollHandler,
+                          ActionListener<ScrollableHitSource.Response> delegate) {
+        this.logger = logger;
+        this.threadPool = threadPool;
+        this.retries = backoffPolicy.iterator();
+        this.retryScrollHandler = retryScrollHandler;
+        this.delegate = delegate;
+    }
+
+    @Override
+    public void onResponse(ScrollableHitSource.Response response) {
+        delegate.onResponse(response);
+    }
+
+    @Override
+    public void onFailure(Exception e) {
+        delegate.onFailure(e);
+    }
+
+    @Override
+    public void onRejection(Exception e) {
+        if (retries.hasNext()) {
+            retryCount += 1;
+            TimeValue delay = retries.next();
+            logger.trace(() -> new ParameterizedMessage("retrying rejected search after [{}]", delay), e);
+            schedule(() -> retryScrollHandler.accept(this), delay);
+        } else {
+            logger.warn(() -> new ParameterizedMessage(
+                "giving up on search because we retried [{}] times without success", retryCount), e);
+            delegate.onFailure(e);
+        }
+    }
+
+    private void schedule(Runnable runnable, TimeValue delay) {
+        // schedule does not preserve context so have to do this manually
+        threadPool.schedule(threadPool.preserveContext(runnable), delay, ThreadPool.Names.SAME);
+    }
+}

+ 23 - 11
server/src/main/java/org/elasticsearch/index/reindex/ScrollableHitSource.java

@@ -21,6 +21,7 @@ package org.elasticsearch.index.reindex;
 
 import org.apache.logging.log4j.Logger;
 import org.elasticsearch.ElasticsearchException;
+import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.bulk.BackoffPolicy;
 import org.elasticsearch.action.search.ShardSearchFailure;
 import org.elasticsearch.common.Nullable;
@@ -71,23 +72,28 @@ public abstract class ScrollableHitSource {
     }
 
     public final void start() {
-        doStart(response -> {
-           setScroll(response.getScrollId());
-           logger.debug("scroll returned [{}] documents with a scroll id of [{}]", response.getHits().size(), response.getScrollId());
-           onResponse(response);
-        });
+        doStart(createRetryListener(this::doStart));
+    }
+
+    private RetryListener createRetryListener(Consumer<RejectAwareActionListener<Response>> retryHandler) {
+        Consumer<RejectAwareActionListener<Response>> countingRetryHandler = listener -> {
+            countSearchRetry.run();
+            retryHandler.accept(listener);
+        };
+        return new RetryListener(logger, threadPool, backoffPolicy, countingRetryHandler,
+            ActionListener.wrap(this::onResponse, fail));
     }
-    protected abstract void doStart(Consumer<? super Response> onResponse);
 
+    // package private for tests.
     final void startNextScroll(TimeValue extraKeepAlive) {
-        doStartNextScroll(scrollId.get(), extraKeepAlive, response -> {
-            setScroll(response.getScrollId());
-            onResponse(response);
-        });
+        startNextScroll(extraKeepAlive, createRetryListener(listener -> startNextScroll(extraKeepAlive, listener)));
+    }
+    private void startNextScroll(TimeValue extraKeepAlive, RejectAwareActionListener<Response> searchListener) {
+        doStartNextScroll(scrollId.get(), extraKeepAlive, searchListener);
     }
-    protected abstract void doStartNextScroll(String scrollId, TimeValue extraKeepAlive, Consumer<? super Response> onResponse);
 
     private void onResponse(Response response) {
+        logger.debug("scroll returned [{}] documents with a scroll id of [{}]", response.getHits().size(), response.getScrollId());
         setScroll(response.getScrollId());
         onResponse.accept(new AsyncResponse() {
             private AtomicBoolean alreadyDone = new AtomicBoolean();
@@ -113,6 +119,12 @@ public abstract class ScrollableHitSource {
         }
     }
 
+    // following is the SPI to be implemented.
+    protected abstract void doStart(RejectAwareActionListener<Response> searchListener);
+
+    protected abstract void doStartNextScroll(String scrollId, TimeValue extraKeepAlive,
+                                              RejectAwareActionListener<Response> searchListener);
+
     /**
      * Called to clear a scroll id.
      *