Parcourir la source

Fix spurious failures in AsyncSearchIntegTestCase (#56026)

Async search integration tests are subject to random failures when:
  * The test index has more than one replica.
  * The request cache is used.
  * Some shards are empty.
  * The maintenance service starts a garbage collection when node is closing.

They are also slow because the test index is created/populated on each
test method.

This change refactors these integration tests in order to:
  * Create the index once for the entire test suite.
  * Fix the usage of the request cache and replicas.
  * Ensures that all shards have at least one document.
  * Increase the delay of the maintenance service garbage collection.

Closes #55895
Closes #55988
Jim Ferenczi il y a 5 ans
Parent
commit
cb70ce7468

+ 28 - 26
x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/AsyncSearchActionIT.java

@@ -16,9 +16,9 @@ import org.elasticsearch.search.aggregations.bucket.terms.StringTerms;
 import org.elasticsearch.search.aggregations.metrics.InternalMax;
 import org.elasticsearch.search.aggregations.metrics.InternalMin;
 import org.elasticsearch.search.builder.SearchSourceBuilder;
+import org.elasticsearch.test.ESIntegTestCase;
 import org.elasticsearch.xpack.core.search.action.AsyncSearchResponse;
 import org.elasticsearch.xpack.core.search.action.SubmitAsyncSearchRequest;
-import org.junit.Before;
 
 import java.util.ArrayList;
 import java.util.HashMap;
@@ -37,21 +37,24 @@ import static org.hamcrest.Matchers.instanceOf;
 import static org.hamcrest.Matchers.lessThan;
 import static org.hamcrest.Matchers.lessThanOrEqualTo;
 
+@ESIntegTestCase.SuiteScopeTestCase
 public class AsyncSearchActionIT extends AsyncSearchIntegTestCase {
-    private String indexName;
-    private int numShards;
+    private static String indexName;
+    private static int numShards;
 
-    private int numKeywords;
-    private Map<String, AtomicInteger> keywordFreqs;
-    private float maxMetric = Float.NEGATIVE_INFINITY;
-    private float minMetric = Float.POSITIVE_INFINITY;
+    private static int numKeywords;
+    private static Map<String, AtomicInteger> keywordFreqs;
+    private static float maxMetric = Float.NEGATIVE_INFINITY;
+    private static float minMetric = Float.POSITIVE_INFINITY;
 
-    @Before
-    public void indexDocuments() throws InterruptedException {
+    @Override
+    public void setupSuiteScopeCluster() throws InterruptedException {
         indexName = "test-async";
-        numShards = randomIntBetween(internalCluster().numDataNodes(), internalCluster().numDataNodes()*10);
-        int numDocs = randomIntBetween(numShards, numShards*3);
-        createIndex(indexName, Settings.builder().put("index.number_of_shards", numShards).build());
+        numShards = randomIntBetween(1, 20);
+        int numDocs = randomIntBetween(100, 1000);
+        createIndex(indexName, Settings.builder()
+            .put("index.number_of_shards", numShards)
+            .build());
         numKeywords = randomIntBetween(1, 100);
         keywordFreqs = new HashMap<>();
         Set<String> keywordSet = new HashSet<>();
@@ -77,7 +80,6 @@ public class AsyncSearchActionIT extends AsyncSearchIntegTestCase {
             reqs.add(client().prepareIndex(indexName).setSource("terms", keyword, "metric", metric));
         }
         indexRandom(true, true, reqs);
-        ensureGreen("test-async");
     }
 
     public void testMaxMinAggregation() throws Exception {
@@ -87,7 +89,7 @@ public class AsyncSearchActionIT extends AsyncSearchIntegTestCase {
             .aggregation(AggregationBuilders.min("min").field("metric"))
             .aggregation(AggregationBuilders.max("max").field("metric"));
         try (SearchResponseIterator it =
-                 assertBlockingIterator(indexName, source, numFailures, step)) {
+                 assertBlockingIterator(indexName, numShards, source, numFailures, step)) {
             AsyncSearchResponse response = it.next();
             while (it.hasNext()) {
                 response = it.next();
@@ -130,7 +132,7 @@ public class AsyncSearchActionIT extends AsyncSearchIntegTestCase {
         SearchSourceBuilder source = new SearchSourceBuilder()
             .aggregation(AggregationBuilders.terms("terms").field("terms.keyword").size(numKeywords));
         try (SearchResponseIterator it =
-                 assertBlockingIterator(indexName, source, numFailures, step)) {
+                 assertBlockingIterator(indexName, numShards, source, numFailures, step)) {
             AsyncSearchResponse response = it.next();
             while (it.hasNext()) {
                 response = it.next();
@@ -173,11 +175,11 @@ public class AsyncSearchActionIT extends AsyncSearchIntegTestCase {
     public void testRestartAfterCompletion() throws Exception {
         final AsyncSearchResponse initial;
         try (SearchResponseIterator it =
-                 assertBlockingIterator(indexName, new SearchSourceBuilder(), 0, 2)) {
+                 assertBlockingIterator(indexName, numShards, new SearchSourceBuilder(), 0, 2)) {
             initial = it.next();
         }
         ensureTaskCompletion(initial.getId());
-        restartTaskNode(initial.getId());
+        restartTaskNode(initial.getId(), indexName);
         AsyncSearchResponse response = getAsyncSearch(initial.getId());
         assertNotNull(response.getSearchResponse());
         assertFalse(response.isRunning());
@@ -189,7 +191,7 @@ public class AsyncSearchActionIT extends AsyncSearchIntegTestCase {
     public void testDeleteCancelRunningTask() throws Exception {
         final AsyncSearchResponse initial;
         SearchResponseIterator it =
-            assertBlockingIterator(indexName, new SearchSourceBuilder(), randomBoolean() ? 1 : 0, 2);
+            assertBlockingIterator(indexName, numShards, new SearchSourceBuilder(), randomBoolean() ? 1 : 0, 2);
         initial = it.next();
         deleteAsyncSearch(initial.getId());
         it.close();
@@ -199,7 +201,7 @@ public class AsyncSearchActionIT extends AsyncSearchIntegTestCase {
 
     public void testDeleteCleanupIndex() throws Exception {
         SearchResponseIterator it =
-            assertBlockingIterator(indexName, new SearchSourceBuilder(), randomBoolean() ? 1 : 0, 2);
+            assertBlockingIterator(indexName, numShards, new SearchSourceBuilder(), randomBoolean() ? 1 : 0, 2);
         AsyncSearchResponse response = it.next();
         deleteAsyncSearch(response.getId());
         it.close();
@@ -210,7 +212,7 @@ public class AsyncSearchActionIT extends AsyncSearchIntegTestCase {
     public void testCleanupOnFailure() throws Exception {
         final AsyncSearchResponse initial;
         try (SearchResponseIterator it =
-                 assertBlockingIterator(indexName, new SearchSourceBuilder(), numShards, 2)) {
+                 assertBlockingIterator(indexName, numShards, new SearchSourceBuilder(), numShards, 2)) {
             initial = it.next();
         }
         ensureTaskCompletion(initial.getId());
@@ -226,7 +228,7 @@ public class AsyncSearchActionIT extends AsyncSearchIntegTestCase {
 
     public void testInvalidId() throws Exception {
         SearchResponseIterator it =
-            assertBlockingIterator(indexName, new SearchSourceBuilder(), randomBoolean() ? 1 : 0, 2);
+            assertBlockingIterator(indexName, numShards, new SearchSourceBuilder(), randomBoolean() ? 1 : 0, 2);
         AsyncSearchResponse response = it.next();
         ExecutionException exc = expectThrows(ExecutionException.class, () -> getAsyncSearch("invalid"));
         assertThat(exc.getCause(), instanceOf(IllegalArgumentException.class));
@@ -258,7 +260,7 @@ public class AsyncSearchActionIT extends AsyncSearchIntegTestCase {
     public void testCancellation() throws Exception {
         SubmitAsyncSearchRequest request = new SubmitAsyncSearchRequest(indexName);
         request.getSearchRequest().source(
-            new SearchSourceBuilder().aggregation(new CancellingAggregationBuilder("test"))
+            new SearchSourceBuilder().aggregation(new CancellingAggregationBuilder("test", randomLong()))
         );
         request.setWaitForCompletionTimeout(TimeValue.timeValueMillis(1));
         AsyncSearchResponse response = submitAsyncSearch(request);
@@ -281,9 +283,8 @@ public class AsyncSearchActionIT extends AsyncSearchIntegTestCase {
 
     public void testUpdateRunningKeepAlive() throws Exception {
         SubmitAsyncSearchRequest request = new SubmitAsyncSearchRequest(indexName);
-        request.getSearchRequest().source(
-            new SearchSourceBuilder().aggregation(new CancellingAggregationBuilder("test"))
-        );
+        request.getSearchRequest()
+            .source(new SearchSourceBuilder().aggregation(new CancellingAggregationBuilder("test", randomLong())));
         long now = System.currentTimeMillis();
         request.setWaitForCompletionTimeout(TimeValue.timeValueMillis(1));
         AsyncSearchResponse response = submitAsyncSearch(request);
@@ -356,6 +357,7 @@ public class AsyncSearchActionIT extends AsyncSearchIntegTestCase {
         request.setWaitForCompletionTimeout(TimeValue.timeValueMinutes(10));
         request.setKeepOnCompletion(true);
         long now = System.currentTimeMillis();
+
         AsyncSearchResponse response = submitAsyncSearch(request);
         assertNotNull(response.getSearchResponse());
         assertFalse(response.isRunning());
@@ -374,7 +376,7 @@ public class AsyncSearchActionIT extends AsyncSearchIntegTestCase {
 
         SubmitAsyncSearchRequest newReq = new SubmitAsyncSearchRequest(indexName);
         newReq.getSearchRequest().source(
-            new SearchSourceBuilder().aggregation(new CancellingAggregationBuilder("test"))
+            new SearchSourceBuilder().aggregation(new CancellingAggregationBuilder("test", randomLong()))
         );
         newReq.setWaitForCompletionTimeout(TimeValue.timeValueMillis(1));
         AsyncSearchResponse newResp = submitAsyncSearch(newReq);

+ 53 - 60
x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/AsyncSearchIntegTestCase.java

@@ -9,19 +9,18 @@ import org.apache.lucene.search.TotalHits;
 import org.elasticsearch.ExceptionsHelper;
 import org.elasticsearch.ResourceNotFoundException;
 import org.elasticsearch.action.admin.cluster.node.tasks.get.GetTaskResponse;
-import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsGroup;
-import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsResponse;
 import org.elasticsearch.action.admin.cluster.state.ClusterStateResponse;
 import org.elasticsearch.action.get.GetResponse;
 import org.elasticsearch.action.support.master.AcknowledgedResponse;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.unit.TimeValue;
+import org.elasticsearch.common.xcontent.ContextParser;
 import org.elasticsearch.index.reindex.ReindexPlugin;
-import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.plugins.Plugin;
-import org.elasticsearch.plugins.PluginsService;
+import org.elasticsearch.plugins.SearchPlugin;
 import org.elasticsearch.rest.RestStatus;
+import org.elasticsearch.search.aggregations.bucket.filter.InternalFilter;
 import org.elasticsearch.search.builder.SearchSourceBuilder;
 import org.elasticsearch.tasks.TaskId;
 import org.elasticsearch.test.ESIntegTestCase;
@@ -34,18 +33,15 @@ import org.elasticsearch.xpack.core.search.action.GetAsyncSearchAction;
 import org.elasticsearch.xpack.core.search.action.SubmitAsyncSearchAction;
 import org.elasticsearch.xpack.core.search.action.SubmitAsyncSearchRequest;
 import org.elasticsearch.xpack.ilm.IndexLifecycle;
+import org.junit.After;
 
 import java.io.Closeable;
 import java.util.Arrays;
 import java.util.Collection;
-import java.util.Comparator;
+import java.util.Collections;
 import java.util.Iterator;
-import java.util.Map;
+import java.util.List;
 import java.util.concurrent.ExecutionException;
-import java.util.concurrent.atomic.AtomicInteger;
-import java.util.concurrent.atomic.AtomicReference;
-import java.util.function.Function;
-import java.util.stream.Collectors;
 
 import static org.elasticsearch.xpack.search.AsyncSearch.INDEX;
 import static org.elasticsearch.xpack.search.AsyncSearchMaintenanceService.ASYNC_SEARCH_CLEANUP_INTERVAL_SETTING;
@@ -55,6 +51,31 @@ import static org.hamcrest.Matchers.lessThanOrEqualTo;
 public abstract class AsyncSearchIntegTestCase extends ESIntegTestCase {
     interface SearchResponseIterator extends Iterator<AsyncSearchResponse>, Closeable {}
 
+    public static class SearchTestPlugin extends Plugin implements SearchPlugin {
+        public SearchTestPlugin() {}
+
+        @Override
+        public List<QuerySpec<?>> getQueries() {
+            return Collections.singletonList(new QuerySpec<>(BlockingQueryBuilder.NAME, in -> new BlockingQueryBuilder(in),
+                p -> {
+                    throw new IllegalStateException("not implemented");
+                }));
+        }
+
+        @Override
+        public List<AggregationSpec> getAggregations() {
+            return Collections.singletonList(new AggregationSpec(CancellingAggregationBuilder.NAME, CancellingAggregationBuilder::new,
+                (ContextParser<String, CancellingAggregationBuilder>) (p, c) -> {
+                    throw new IllegalStateException("not implemented");
+                }).addResultReader(InternalFilter::new));
+        }
+    }
+
+    @After
+    public void releaseQueryLatch() {
+        BlockingQueryBuilder.releaseQueryLatch();
+    }
+
     @Override
     protected Collection<Class<? extends Plugin>> nodePlugins() {
         return Arrays.asList(LocalStateCompositeXPackPlugin.class, AsyncSearch.class, IndexLifecycle.class,
@@ -65,14 +86,14 @@ public abstract class AsyncSearchIntegTestCase extends ESIntegTestCase {
     protected Settings nodeSettings(int nodeOrdinal) {
         return Settings.builder()
             .put(super.nodeSettings(0))
-            .put(ASYNC_SEARCH_CLEANUP_INTERVAL_SETTING.getKey(), TimeValue.timeValueMillis(1))
+            .put(ASYNC_SEARCH_CLEANUP_INTERVAL_SETTING.getKey(), TimeValue.timeValueMillis(100))
             .build();
     }
 
     /**
      * Restart the node that runs the {@link TaskId} decoded from the provided {@link AsyncExecutionId}.
      */
-    protected void restartTaskNode(String id) throws Exception {
+    protected void restartTaskNode(String id, String indexName) throws Exception {
         AsyncExecutionId searchId = AsyncExecutionId.decode(id);
         final ClusterStateResponse clusterState = client().admin().cluster()
             .prepareState().clear().setNodes(true).get();
@@ -83,7 +104,7 @@ public abstract class AsyncSearchIntegTestCase extends ESIntegTestCase {
                 return super.onNodeStopped(nodeName);
             }
         });
-        ensureYellow(INDEX);
+        ensureYellow(INDEX, indexName);
     }
 
     protected AsyncSearchResponse submitAsyncSearch(SubmitAsyncSearchRequest request) throws ExecutionException, InterruptedException {
@@ -147,41 +168,31 @@ public abstract class AsyncSearchIntegTestCase extends ESIntegTestCase {
         });
     }
 
+    /**
+     * Returns a {@link SearchResponseIterator} that blocks query shard executions
+     * until {@link SearchResponseIterator#next()} is called. That allows to randomly
+     * generate partial results that can be consumed in order.
+     */
     protected SearchResponseIterator assertBlockingIterator(String indexName,
+                                                            int numShards,
                                                             SearchSourceBuilder source,
                                                             int numFailures,
                                                             int progressStep) throws Exception {
         SubmitAsyncSearchRequest request = new SubmitAsyncSearchRequest(source, indexName);
         request.setBatchedReduceSize(progressStep);
         request.setWaitForCompletionTimeout(TimeValue.timeValueMillis(1));
-        ClusterSearchShardsResponse response = dataNodeClient().admin().cluster()
-            .prepareSearchShards(request.getSearchRequest().indices()).get();
-        AtomicInteger failures = new AtomicInteger(numFailures);
-        Map<ShardId, ShardIdLatch> shardLatchMap = Arrays.stream(response.getGroups())
-            .map(ClusterSearchShardsGroup::getShardId)
-            .collect(
-                Collectors.toMap(
-                    Function.identity(),
-                    id -> new ShardIdLatch(id, failures.decrementAndGet() >= 0)
-                )
-            );
-        ShardIdLatch[] shardLatchArray = shardLatchMap.values().stream()
-            .sorted(Comparator.comparing(ShardIdLatch::shardId))
-            .toArray(ShardIdLatch[]::new);
-        resetPluginsLatch(shardLatchMap);
-        request.getSearchRequest().source().query(new BlockingQueryBuilder(shardLatchMap));
+        BlockingQueryBuilder.QueryLatch queryLatch = BlockingQueryBuilder.acquireQueryLatch(numFailures);
+        request.getSearchRequest().source().query(new BlockingQueryBuilder(random().nextLong()));
 
         final AsyncSearchResponse initial = client().execute(SubmitAsyncSearchAction.INSTANCE, request).get();
-
         assertTrue(initial.isPartial());
         assertThat(initial.status(), equalTo(RestStatus.OK));
-        assertThat(initial.getSearchResponse().getTotalShards(), equalTo(shardLatchArray.length));
+        assertThat(initial.getSearchResponse().getTotalShards(), equalTo(numShards));
         assertThat(initial.getSearchResponse().getSuccessfulShards(), equalTo(0));
         assertThat(initial.getSearchResponse().getShardFailures().length, equalTo(0));
 
         return new SearchResponseIterator() {
             private AsyncSearchResponse response = initial;
-            private int shardIndex = 0;
             private boolean isFirst = true;
 
             @Override
@@ -203,32 +214,24 @@ public abstract class AsyncSearchIntegTestCase extends ESIntegTestCase {
                     isFirst = false;
                     return response;
                 }
-                AtomicReference<AsyncSearchResponse> atomic = new AtomicReference<>();
-                int step = shardIndex == 0 ? progressStep+1 : progressStep-1;
-                int index = 0;
-                while (index < step && shardIndex < shardLatchArray.length) {
-                    if (shardLatchArray[shardIndex].shouldFail() == false) {
-                        ++index;
-                    }
-                    shardLatchArray[shardIndex++].countDown();
-                }
+                queryLatch.countDownAndReset();
                 AsyncSearchResponse newResponse = client().execute(GetAsyncSearchAction.INSTANCE,
                     new GetAsyncSearchAction.Request(response.getId())
                         .setWaitForCompletion(TimeValue.timeValueMillis(10))).get();
 
                 if (newResponse.isRunning()) {
-                    assertThat(newResponse.status(),  equalTo(RestStatus.OK));
+                    assertThat(newResponse.status(), equalTo(RestStatus.OK));
                     assertTrue(newResponse.isPartial());
                     assertNull(newResponse.getFailure());
                     assertNotNull(newResponse.getSearchResponse());
-                    assertThat(newResponse.getSearchResponse().getTotalShards(), equalTo(shardLatchArray.length));
+                    assertThat(newResponse.getSearchResponse().getTotalShards(), equalTo(numShards));
                     assertThat(newResponse.getSearchResponse().getShardFailures().length, lessThanOrEqualTo(numFailures));
-                } else if (numFailures == shardLatchArray.length) {
-                    assertThat(newResponse.status(),  equalTo(RestStatus.INTERNAL_SERVER_ERROR));
+                } else if (numFailures == numShards) {
+                    assertThat(newResponse.status(), equalTo(RestStatus.INTERNAL_SERVER_ERROR));
                     assertNotNull(newResponse.getFailure());
                     assertTrue(newResponse.isPartial());
                     assertNotNull(newResponse.getSearchResponse());
-                    assertThat(newResponse.getSearchResponse().getTotalShards(), equalTo(shardLatchArray.length));
+                    assertThat(newResponse.getSearchResponse().getTotalShards(), equalTo(numShards));
                     assertThat(newResponse.getSearchResponse().getSuccessfulShards(), equalTo(0));
                     assertThat(newResponse.getSearchResponse().getShardFailures().length, equalTo(numFailures));
                     assertNull(newResponse.getSearchResponse().getAggregations());
@@ -237,32 +240,22 @@ public abstract class AsyncSearchIntegTestCase extends ESIntegTestCase {
                     assertThat(newResponse.getSearchResponse().getHits().getTotalHits().relation,
                         equalTo(TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO));
                 } else {
-                    assertThat(newResponse.status(),  equalTo(RestStatus.OK));
+                    assertThat(newResponse.status(), equalTo(RestStatus.OK));
                     assertNotNull(newResponse.getSearchResponse());
                     assertFalse(newResponse.isPartial());
                     assertThat(newResponse.status(), equalTo(RestStatus.OK));
-                    assertThat(newResponse.getSearchResponse().getTotalShards(), equalTo(shardLatchArray.length));
+                    assertThat(newResponse.getSearchResponse().getTotalShards(), equalTo(numShards));
                     assertThat(newResponse.getSearchResponse().getShardFailures().length, equalTo(numFailures));
                     assertThat(newResponse.getSearchResponse().getSuccessfulShards(),
-                        equalTo(shardLatchArray.length-newResponse.getSearchResponse().getShardFailures().length));
+                        equalTo(numShards - newResponse.getSearchResponse().getShardFailures().length));
                 }
                 return response = newResponse;
             }
 
             @Override
             public void close() {
-                Arrays.stream(shardLatchArray).forEach(shard -> {
-                    if (shard.getCount() == 1) {
-                        shard.countDown();
-                    }
-                });
+                queryLatch.close();
             }
         };
     }
-
-    private void resetPluginsLatch(Map<ShardId, ShardIdLatch> newLatch) {
-        for (PluginsService pluginsService : internalCluster().getDataNodeInstances(PluginsService.class)) {
-            pluginsService.filterPlugins(SearchTestPlugin.class).forEach(p -> p.resetQueryLatch(newLatch));
-        }
-    }
 }

+ 97 - 37
x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/BlockingQueryBuilder.java

@@ -9,43 +9,65 @@ import org.apache.lucene.search.IndexSearcher;
 import org.apache.lucene.search.Query;
 import org.apache.lucene.search.ScoreMode;
 import org.apache.lucene.search.Weight;
-import org.elasticsearch.common.ParsingException;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.lucene.search.Queries;
-import org.elasticsearch.common.xcontent.ObjectParser;
 import org.elasticsearch.common.xcontent.XContentBuilder;
-import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.index.query.AbstractQueryBuilder;
 import org.elasticsearch.index.query.QueryShardContext;
-import org.elasticsearch.index.shard.ShardId;
+import org.elasticsearch.xpack.search.AsyncSearchIntegTestCase.SearchResponseIterator;
 
+import java.io.Closeable;
 import java.io.IOException;
-import java.util.Map;
+import java.util.HashSet;
+import java.util.Set;
+import java.util.concurrent.CountDownLatch;
 
 /**
- * A query builder that blocks shard execution based on the provided {@link ShardIdLatch}.
+ * A query builder that blocks shard execution based on a {@link QueryLatch}
+ * that is shared inside a single jvm (static).
  */
 class BlockingQueryBuilder extends AbstractQueryBuilder<BlockingQueryBuilder> {
     public static final String NAME = "block";
-    private final Map<ShardId, ShardIdLatch> shardsLatch;
+    private static QueryLatch queryLatch;
 
-    BlockingQueryBuilder(Map<ShardId, ShardIdLatch> shardsLatch) {
-        super();
-        this.shardsLatch = shardsLatch;
+    private final long randomUID;
+
+    /**
+     * Creates a new query latch with an expected number of <code>numShardFailures</code>.
+     */
+    public static synchronized QueryLatch acquireQueryLatch(int numShardFailures) {
+        assert queryLatch == null;
+        return queryLatch = new QueryLatch(numShardFailures);
     }
 
-    BlockingQueryBuilder(StreamInput in, Map<ShardId, ShardIdLatch> shardsLatch) throws IOException {
-        super(in);
-        this.shardsLatch = shardsLatch;
+    /**
+     * Releases the current query latch.
+     */
+    public static synchronized void releaseQueryLatch() {
+        if (queryLatch != null) {
+            queryLatch.close();
+            queryLatch = null;
+        }
+    }
+
+    /**
+     * Creates a {@link BlockingQueryBuilder} with the provided <code>randomUID</code>.
+     */
+    BlockingQueryBuilder(long randomUID) {
+        super();
+        this.randomUID = randomUID;
     }
 
-    BlockingQueryBuilder() {
-        this.shardsLatch = null;
+    BlockingQueryBuilder(StreamInput in) throws IOException {
+        super(in);
+        this.randomUID = in.readLong();
     }
 
     @Override
-    protected void doWriteTo(StreamOutput out) {}
+    protected void doWriteTo(StreamOutput out) throws IOException {
+        out.writeLong(randomUID);
+    }
 
     @Override
     protected void doXContent(XContentBuilder builder, Params params) throws IOException {
@@ -53,33 +75,16 @@ class BlockingQueryBuilder extends AbstractQueryBuilder<BlockingQueryBuilder> {
         builder.endObject();
     }
 
-    private static final ObjectParser<BlockingQueryBuilder, Void> PARSER = new ObjectParser<>(NAME, BlockingQueryBuilder::new);
-
-    public static BlockingQueryBuilder fromXContent(XContentParser parser, Map<ShardId, ShardIdLatch> shardsLatch) {
-        try {
-            PARSER.apply(parser, null);
-            return new BlockingQueryBuilder(shardsLatch);
-        } catch (IllegalArgumentException e) {
-            throw new ParsingException(parser.getTokenLocation(), e.getMessage(), e);
-        }
-    }
-
     @Override
     protected Query doToQuery(QueryShardContext context) {
         final Query delegate = Queries.newMatchAllQuery();
         return new Query() {
             @Override
             public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
-                if (shardsLatch != null) {
-                    try {
-                        final ShardIdLatch latch = shardsLatch.get(new ShardId(context.index(), context.getShardId()));
-                        latch.await();
-                        if (latch.shouldFail()) {
-                            throw new IOException("boum");
-                        }
-                    } catch (InterruptedException e) {
-                        throw new RuntimeException(e);
-                    }
+                try {
+                    queryLatch.await(context.getShardId());
+                } catch (InterruptedException e) {
+                    throw new RuntimeException(e);
                 }
                 return delegate.createWeight(searcher, scoreMode, boost);
             }
@@ -115,4 +120,59 @@ class BlockingQueryBuilder extends AbstractQueryBuilder<BlockingQueryBuilder> {
     public String getWriteableName() {
         return NAME;
     }
+
+    /**
+     *  A synchronization aid that is used by {@link BlockingQueryBuilder} to block shards executions until
+     *  the consumer calls {@link QueryLatch#countDownAndReset()}.
+     *  The static {@link QueryLatch} is shared in {@link AsyncSearchIntegTestCase#assertBlockingIterator} to provide
+     *  a {@link SearchResponseIterator} that unblocks shards executions whenever {@link SearchResponseIterator#next()}
+     *  is called.
+     */
+    static class QueryLatch implements Closeable {
+        private volatile CountDownLatch countDownLatch;
+        private final Set<Integer> failedShards = new HashSet<>();
+        private int numShardFailures;
+
+        QueryLatch(int numShardFailures) {
+            this.countDownLatch = new CountDownLatch(1);
+            this.numShardFailures = numShardFailures;
+        }
+
+        private void await(int shardId) throws IOException, InterruptedException {
+            CountDownLatch last = countDownLatch;
+            if (last != null) {
+                last.await();
+            }
+            synchronized (this) {
+                // ensure that we fail on replicas too
+                if (failedShards.contains(shardId)) {
+                    throw new IOException("boom");
+                } else if (numShardFailures > 0) {
+                    numShardFailures--;
+                    failedShards.add(shardId);
+                    throw new IOException("boom");
+                }
+            }
+        }
+
+        public synchronized void countDownAndReset() {
+            if (countDownLatch != null) {
+                CountDownLatch last = countDownLatch;
+                countDownLatch = new CountDownLatch(1);
+                if (last != null) {
+                    assert last.getCount() == 1;
+                    last.countDown();
+                }
+            }
+        }
+
+        @Override
+        public synchronized void close() {
+            if (countDownLatch != null) {
+                assert countDownLatch.getCount() == 1;
+                countDownLatch.countDown();
+            }
+            countDownLatch = null;
+        }
+    }
 }

+ 11 - 3
x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/CancellingAggregationBuilder.java

@@ -31,17 +31,24 @@ public class CancellingAggregationBuilder extends AbstractAggregationBuilder<Can
     static final String NAME = "cancel";
     static final int SLEEP_TIME = 10;
 
-    public CancellingAggregationBuilder(String name) {
+    private final long randomUID;
+
+    /**
+     * Creates a {@link CancellingAggregationBuilder} with the provided <code>randomUID</code>.
+     */
+    public CancellingAggregationBuilder(String name, long randomUID) {
         super(name);
+        this.randomUID = randomUID;
     }
 
     public CancellingAggregationBuilder(StreamInput in) throws IOException {
         super(in);
+        this.randomUID = in.readLong();
     }
 
     @Override
     protected AggregationBuilder shallowCopy(AggregatorFactories.Builder factoriesBuilder, Map<String, Object> metadata) {
-        return new CancellingAggregationBuilder(name);
+        return new CancellingAggregationBuilder(name, randomUID);
     }
 
     @Override
@@ -51,6 +58,7 @@ public class CancellingAggregationBuilder extends AbstractAggregationBuilder<Can
 
     @Override
     protected void doWriteTo(StreamOutput out) throws IOException {
+        out.writeLong(randomUID);
     }
 
     @Override
@@ -61,7 +69,7 @@ public class CancellingAggregationBuilder extends AbstractAggregationBuilder<Can
     }
 
     static final ConstructingObjectParser<CancellingAggregationBuilder, String> PARSER =
-        new ConstructingObjectParser<>(NAME, false, (args, name) -> new CancellingAggregationBuilder(name));
+        new ConstructingObjectParser<>(NAME, false, (args, name) -> new CancellingAggregationBuilder(name, 0L));
 
 
     static CancellingAggregationBuilder fromXContent(String aggName, XContentParser parser) {

+ 0 - 42
x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/SearchTestPlugin.java

@@ -1,42 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License;
- * you may not use this file except in compliance with the Elastic License.
- */
-package org.elasticsearch.xpack.search;
-
-import org.elasticsearch.index.shard.ShardId;
-import org.elasticsearch.plugins.Plugin;
-import org.elasticsearch.plugins.SearchPlugin;
-import org.elasticsearch.search.aggregations.bucket.filter.InternalFilter;
-
-import java.util.Collections;
-import java.util.List;
-import java.util.Map;
-
-public class SearchTestPlugin extends Plugin implements SearchPlugin {
-    private Map<ShardId, ShardIdLatch> shardsLatch;
-
-    public SearchTestPlugin() {
-        this.shardsLatch = null;
-    }
-
-    public void resetQueryLatch(Map<ShardId, ShardIdLatch> newLatch) {
-        shardsLatch = newLatch;
-    }
-
-    @Override
-    public List<QuerySpec<?>> getQueries() {
-        return Collections.singletonList(
-            new QuerySpec<>(BlockingQueryBuilder.NAME,
-                in -> new BlockingQueryBuilder(in, shardsLatch),
-                p -> BlockingQueryBuilder.fromXContent(p, shardsLatch))
-        );
-    }
-
-    @Override
-    public List<AggregationSpec> getAggregations() {
-        return Collections.singletonList(new AggregationSpec(CancellingAggregationBuilder.NAME, CancellingAggregationBuilder::new,
-            CancellingAggregationBuilder.PARSER).addResultReader(InternalFilter::new));
-    }
-}

+ 0 - 29
x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/ShardIdLatch.java

@@ -1,29 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License;
- * you may not use this file except in compliance with the Elastic License.
- */
-package org.elasticsearch.xpack.search;
-
-import org.elasticsearch.index.shard.ShardId;
-
-import java.util.concurrent.CountDownLatch;
-
-class ShardIdLatch extends CountDownLatch {
-    private final ShardId shard;
-    private final boolean shouldFail;
-
-    ShardIdLatch(ShardId shard, boolean shouldFail) {
-        super(1);
-        this.shard = shard;
-        this.shouldFail = shouldFail;
-    }
-
-    ShardId shardId() {
-        return shard;
-    }
-
-    boolean shouldFail() {
-        return shouldFail;
-    }
-}

+ 3 - 2
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/async/AsyncTaskMaintenanceService.java

@@ -69,6 +69,7 @@ public abstract class AsyncTaskMaintenanceService extends AbstractLifecycleCompo
 
     @Override
     protected void doStop() {
+        clusterService.removeListener(this);
         stopCleanup();
     }
 
@@ -107,7 +108,7 @@ public abstract class AsyncTaskMaintenanceService extends AbstractLifecycleCompo
     }
 
     synchronized void executeNextCleanup() {
-        if (lifecycle.stoppedOrClosed() == false && isCleanupRunning) {
+        if (isCleanupRunning) {
             long nowInMillis = System.currentTimeMillis();
             DeleteByQueryRequest toDelete = new DeleteByQueryRequest(index)
                 .setQuery(QueryBuilders.rangeQuery(EXPIRATION_TIME_FIELD).lte(nowInMillis));
@@ -117,7 +118,7 @@ public abstract class AsyncTaskMaintenanceService extends AbstractLifecycleCompo
     }
 
     synchronized void scheduleNextCleanup() {
-        if (lifecycle.stoppedOrClosed() == false && isCleanupRunning) {
+        if (isCleanupRunning) {
             try {
                 cancellable = threadPool.schedule(this::executeNextCleanup, delay, ThreadPool.Names.GENERIC);
             } catch (EsRejectedExecutionException e) {