Browse Source

Fix SearchResponse leak in MutableSearchResponse (#104524)

Fixing a leak in how a `SearchResponse` is created in
`MutableSearchResponse` as well as a couple obvious leaks in the tests
for `AsyncSearchTask`.

closes #104491 closes #104522

non-issue since this hasn't made it into any release yet

tip :) -> Review with whitespace hidden
Armin Braun 1 year ago
parent
commit
5c284ec987

+ 14 - 9
x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/MutableSearchResponse.java

@@ -398,15 +398,20 @@ class MutableSearchResponse implements Releasable {
         if (this.failure != null) {
             reduceException.addSuppressed(this.failure);
         }
-        return new AsyncSearchResponse(
-            task.getExecutionId().getEncoded(),
-            buildResponse(task.getStartTimeNanos(), null),
-            reduceException,
-            isPartial,
-            frozen == false,
-            task.getStartTime(),
-            expirationTime
-        );
+        var response = buildResponse(task.getStartTimeNanos(), null);
+        try {
+            return new AsyncSearchResponse(
+                task.getExecutionId().getEncoded(),
+                response,
+                reduceException,
+                isPartial,
+                frozen == false,
+                task.getStartTime(),
+                expirationTime
+            );
+        } finally {
+            response.decRef();
+        }
     }
 
     private void failIfFrozen() {

+ 278 - 252
x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/AsyncSearchTaskTests.java

@@ -97,304 +97,330 @@ public class AsyncSearchTaskTests extends ESTestCase {
         SearchRequest searchRequest = new SearchRequest("index1", "index2").source(
             new SearchSourceBuilder().query(QueryBuilders.termQuery("field", "value"))
         );
-        AsyncSearchTask asyncSearchTask = new AsyncSearchTask(
-            0L,
-            "",
-            "",
-            new TaskId("node1", 0),
-            searchRequest::buildDescription,
-            TimeValue.timeValueHours(1),
-            Collections.emptyMap(),
-            Collections.emptyMap(),
-            new AsyncExecutionId("0", new TaskId("node1", 1)),
-            new NoOpClient(threadPool),
-            threadPool,
-            (t) -> () -> null
-        );
-        assertEquals("""
-            async_search{indices[index1,index2], search_type[QUERY_THEN_FETCH], source\
-            [{"query":{"term":{"field":{"value":"value"}}}}]}""", asyncSearchTask.getDescription());
+        try (
+            AsyncSearchTask asyncSearchTask = new AsyncSearchTask(
+                0L,
+                "",
+                "",
+                new TaskId("node1", 0),
+                searchRequest::buildDescription,
+                TimeValue.timeValueHours(1),
+                Collections.emptyMap(),
+                Collections.emptyMap(),
+                new AsyncExecutionId("0", new TaskId("node1", 1)),
+                new NoOpClient(threadPool),
+                threadPool,
+                (t) -> () -> null
+            )
+        ) {
+            assertEquals("""
+                async_search{indices[index1,index2], search_type[QUERY_THEN_FETCH], source\
+                [{"query":{"term":{"field":{"value":"value"}}}}]}""", asyncSearchTask.getDescription());
+        }
     }
 
     public void testWaitForInit() throws InterruptedException {
-        AsyncSearchTask task = new AsyncSearchTask(
-            0L,
-            "",
-            "",
-            new TaskId("node1", 0),
-            () -> null,
-            TimeValue.timeValueHours(1),
-            Collections.emptyMap(),
-            Collections.emptyMap(),
-            new AsyncExecutionId("0", new TaskId("node1", 1)),
-            new NoOpClient(threadPool),
-            threadPool,
-            (t) -> () -> null
-        );
-        int numShards = randomIntBetween(0, 10);
-        List<SearchShard> shards = new ArrayList<>();
-        for (int i = 0; i < numShards; i++) {
-            shards.add(new SearchShard(null, new ShardId("0", "0", 1)));
-        }
-        List<SearchShard> skippedShards = new ArrayList<>();
-        int numSkippedShards = randomIntBetween(0, 10);
-        for (int i = 0; i < numSkippedShards; i++) {
-            skippedShards.add(new SearchShard(null, new ShardId("0", "0", 1)));
-        }
+        try (
+            AsyncSearchTask task = new AsyncSearchTask(
+                0L,
+                "",
+                "",
+                new TaskId("node1", 0),
+                () -> null,
+                TimeValue.timeValueHours(1),
+                Collections.emptyMap(),
+                Collections.emptyMap(),
+                new AsyncExecutionId("0", new TaskId("node1", 1)),
+                new NoOpClient(threadPool),
+                threadPool,
+                (t) -> () -> null
+            )
+        ) {
+            int numShards = randomIntBetween(0, 10);
+            List<SearchShard> shards = new ArrayList<>();
+            for (int i = 0; i < numShards; i++) {
+                shards.add(new SearchShard(null, new ShardId("0", "0", 1)));
+            }
+            List<SearchShard> skippedShards = new ArrayList<>();
+            int numSkippedShards = randomIntBetween(0, 10);
+            for (int i = 0; i < numSkippedShards; i++) {
+                skippedShards.add(new SearchShard(null, new ShardId("0", "0", 1)));
+            }
 
-        int numThreads = randomIntBetween(1, 10);
-        CountDownLatch latch = new CountDownLatch(numThreads);
-        for (int i = 0; i < numThreads; i++) {
-            Thread thread = new Thread(() -> task.addCompletionListener(ActionTestUtils.assertNoFailureListener(resp -> {
-                assertThat(numShards + numSkippedShards, equalTo(resp.getSearchResponse().getTotalShards()));
-                assertThat(numSkippedShards, equalTo(resp.getSearchResponse().getSkippedShards()));
-                assertThat(0, equalTo(resp.getSearchResponse().getFailedShards()));
-                latch.countDown();
-            }), TimeValue.timeValueMillis(1)));
-            thread.start();
+            int numThreads = randomIntBetween(1, 10);
+            CountDownLatch latch = new CountDownLatch(numThreads);
+            for (int i = 0; i < numThreads; i++) {
+                Thread thread = new Thread(() -> task.addCompletionListener(ActionTestUtils.assertNoFailureListener(resp -> {
+                    assertThat(numShards + numSkippedShards, equalTo(resp.getSearchResponse().getTotalShards()));
+                    assertThat(numSkippedShards, equalTo(resp.getSearchResponse().getSkippedShards()));
+                    assertThat(0, equalTo(resp.getSearchResponse().getFailedShards()));
+                    latch.countDown();
+                }), TimeValue.timeValueMillis(1)));
+                thread.start();
+            }
+            assertFalse(latch.await(numThreads * 2, TimeUnit.MILLISECONDS));
+            task.getSearchProgressActionListener()
+                .onListShards(shards, skippedShards, SearchResponse.Clusters.EMPTY, false, createTimeProvider());
+            latch.await();
         }
-        assertFalse(latch.await(numThreads * 2, TimeUnit.MILLISECONDS));
-        task.getSearchProgressActionListener()
-            .onListShards(shards, skippedShards, SearchResponse.Clusters.EMPTY, false, createTimeProvider());
-        latch.await();
     }
 
     public void testWithFailure() throws InterruptedException {
-        AsyncSearchTask task = createAsyncSearchTask();
-        int numThreads = randomIntBetween(1, 10);
-        CountDownLatch latch = new CountDownLatch(numThreads);
-        for (int i = 0; i < numThreads; i++) {
-            Thread thread = new Thread(() -> task.addCompletionListener(ActionTestUtils.assertNoFailureListener(resp -> {
-                assertNull(resp.getSearchResponse());
-                assertNotNull(resp.getFailure());
-                assertTrue(resp.isPartial());
-                latch.countDown();
-            }), TimeValue.timeValueMillis(1)));
-            thread.start();
+        try (AsyncSearchTask task = createAsyncSearchTask()) {
+            int numThreads = randomIntBetween(1, 10);
+            CountDownLatch latch = new CountDownLatch(numThreads);
+            for (int i = 0; i < numThreads; i++) {
+                Thread thread = new Thread(() -> task.addCompletionListener(ActionTestUtils.assertNoFailureListener(resp -> {
+                    assertNull(resp.getSearchResponse());
+                    assertNotNull(resp.getFailure());
+                    assertTrue(resp.isPartial());
+                    latch.countDown();
+                }), TimeValue.timeValueMillis(1)));
+                thread.start();
+            }
+            assertFalse(latch.await(numThreads * 2, TimeUnit.MILLISECONDS));
+            task.getSearchProgressActionListener().onFailure(new Exception("boom"));
+            latch.await();
         }
-        assertFalse(latch.await(numThreads * 2, TimeUnit.MILLISECONDS));
-        task.getSearchProgressActionListener().onFailure(new Exception("boom"));
-        latch.await();
     }
 
     public void testWithFailureAndGetResponseFailureDuringReduction() throws InterruptedException {
-        AsyncSearchTask task = createAsyncSearchTask();
-        task.getSearchProgressActionListener()
-            .onListShards(Collections.emptyList(), Collections.emptyList(), SearchResponse.Clusters.EMPTY, false, createTimeProvider());
-        InternalAggregations aggs = InternalAggregations.from(
-            Collections.singletonList(
-                new StringTerms(
-                    "name",
-                    BucketOrder.key(true),
-                    BucketOrder.key(true),
-                    1,
-                    1,
-                    Collections.emptyMap(),
-                    DocValueFormat.RAW,
-                    1,
-                    false,
-                    1,
-                    Collections.emptyList(),
-                    0L
-                )
-            )
-        );
-        task.getSearchProgressActionListener()
-            .onPartialReduce(Collections.emptyList(), new TotalHits(0, TotalHits.Relation.EQUAL_TO), aggs, 1);
-        task.getSearchProgressActionListener().onFailure(new CircuitBreakingException("boom", CircuitBreaker.Durability.TRANSIENT));
         AtomicReference<AsyncSearchResponse> response = new AtomicReference<>();
-        CountDownLatch latch = new CountDownLatch(1);
-        task.addCompletionListener(new ActionListener<>() {
-            @Override
-            public void onResponse(AsyncSearchResponse asyncSearchResponse) {
-                assertTrue(response.compareAndSet(null, asyncSearchResponse));
-                latch.countDown();
-            }
+        try (AsyncSearchTask task = createAsyncSearchTask()) {
+            task.getSearchProgressActionListener()
+                .onListShards(Collections.emptyList(), Collections.emptyList(), SearchResponse.Clusters.EMPTY, false, createTimeProvider());
+            InternalAggregations aggs = InternalAggregations.from(
+                Collections.singletonList(
+                    new StringTerms(
+                        "name",
+                        BucketOrder.key(true),
+                        BucketOrder.key(true),
+                        1,
+                        1,
+                        Collections.emptyMap(),
+                        DocValueFormat.RAW,
+                        1,
+                        false,
+                        1,
+                        Collections.emptyList(),
+                        0L
+                    )
+                )
+            );
+            task.getSearchProgressActionListener()
+                .onPartialReduce(Collections.emptyList(), new TotalHits(0, TotalHits.Relation.EQUAL_TO), aggs, 1);
+            task.getSearchProgressActionListener().onFailure(new CircuitBreakingException("boom", CircuitBreaker.Durability.TRANSIENT));
+            CountDownLatch latch = new CountDownLatch(1);
+            task.addCompletionListener(new ActionListener<>() {
+                @Override
+                public void onResponse(AsyncSearchResponse asyncSearchResponse) {
+                    assertTrue(response.compareAndSet(null, asyncSearchResponse));
+                    asyncSearchResponse.mustIncRef();
+                    latch.countDown();
+                }
 
-            @Override
-            public void onFailure(Exception e) {
-                throw new AssertionError("onFailure should not be called");
-            }
-        }, TimeValue.timeValueMillis(10L));
-        assertTrue(latch.await(1, TimeUnit.SECONDS));
+                @Override
+                public void onFailure(Exception e) {
+                    throw new AssertionError("onFailure should not be called");
+                }
+            }, TimeValue.timeValueMillis(10L));
+            assertTrue(latch.await(1, TimeUnit.SECONDS));
+        }
         AsyncSearchResponse asyncSearchResponse = response.get();
-        assertNotNull(response.get().getSearchResponse());
-        assertEquals(0, response.get().getSearchResponse().getTotalShards());
-        assertEquals(0, response.get().getSearchResponse().getSuccessfulShards());
-        assertEquals(0, response.get().getSearchResponse().getFailedShards());
-        Exception failure = asyncSearchResponse.getFailure();
-        assertThat(failure, instanceOf(ElasticsearchException.class));
-        assertEquals("Async search: error while reducing partial results", failure.getMessage());
-        assertEquals(1, failure.getSuppressed().length);
-        assertThat(failure.getSuppressed()[0], instanceOf(ElasticsearchException.class));
-        assertEquals("error while executing search", failure.getSuppressed()[0].getMessage());
-        assertThat(failure.getSuppressed()[0].getCause(), instanceOf(CircuitBreakingException.class));
-        assertEquals("boom", failure.getSuppressed()[0].getCause().getMessage());
+        try {
+            assertNotNull(response.get().getSearchResponse());
+            assertEquals(0, response.get().getSearchResponse().getTotalShards());
+            assertEquals(0, response.get().getSearchResponse().getSuccessfulShards());
+            assertEquals(0, response.get().getSearchResponse().getFailedShards());
+            Exception failure = asyncSearchResponse.getFailure();
+            assertThat(failure, instanceOf(ElasticsearchException.class));
+            assertEquals("Async search: error while reducing partial results", failure.getMessage());
+            assertEquals(1, failure.getSuppressed().length);
+            assertThat(failure.getSuppressed()[0], instanceOf(ElasticsearchException.class));
+            assertEquals("error while executing search", failure.getSuppressed()[0].getMessage());
+            assertThat(failure.getSuppressed()[0].getCause(), instanceOf(CircuitBreakingException.class));
+            assertEquals("boom", failure.getSuppressed()[0].getCause().getMessage());
+        } finally {
+            asyncSearchResponse.decRef();
+        }
     }
 
     public void testWaitForCompletion() throws InterruptedException {
-        AsyncSearchTask task = createAsyncSearchTask();
-        int numShards = randomIntBetween(0, 10);
-        List<SearchShard> shards = new ArrayList<>();
-        for (int i = 0; i < numShards; i++) {
-            shards.add(new SearchShard(null, new ShardId("0", "0", 1)));
-        }
-        List<SearchShard> skippedShards = new ArrayList<>();
-        int numSkippedShards = randomIntBetween(0, 10);
-        for (int i = 0; i < numSkippedShards; i++) {
-            skippedShards.add(new SearchShard(null, new ShardId("0", "0", 1)));
-        }
-        int totalShards = numShards + numSkippedShards;
-        task.getSearchProgressActionListener()
-            .onListShards(shards, skippedShards, SearchResponse.Clusters.EMPTY, false, createTimeProvider());
-        for (int i = 0; i < numShards; i++) {
+        try (AsyncSearchTask task = createAsyncSearchTask()) {
+            int numShards = randomIntBetween(0, 10);
+            List<SearchShard> shards = new ArrayList<>();
+            for (int i = 0; i < numShards; i++) {
+                shards.add(new SearchShard(null, new ShardId("0", "0", 1)));
+            }
+            List<SearchShard> skippedShards = new ArrayList<>();
+            int numSkippedShards = randomIntBetween(0, 10);
+            for (int i = 0; i < numSkippedShards; i++) {
+                skippedShards.add(new SearchShard(null, new ShardId("0", "0", 1)));
+            }
+            int totalShards = numShards + numSkippedShards;
+            task.getSearchProgressActionListener()
+                .onListShards(shards, skippedShards, SearchResponse.Clusters.EMPTY, false, createTimeProvider());
+            for (int i = 0; i < numShards; i++) {
+                task.getSearchProgressActionListener()
+                    .onPartialReduce(shards.subList(i, i + 1), new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), null, 0);
+                assertCompletionListeners(task, totalShards, 1 + numSkippedShards, numSkippedShards, 0, true, false);
+            }
             task.getSearchProgressActionListener()
-                .onPartialReduce(shards.subList(i, i + 1), new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), null, 0);
-            assertCompletionListeners(task, totalShards, 1 + numSkippedShards, numSkippedShards, 0, true, false);
+                .onFinalReduce(shards, new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), null, 0);
+            assertCompletionListeners(task, totalShards, totalShards, numSkippedShards, 0, true, false);
+            ActionListener.respondAndRelease(
+                (AsyncSearchTask.Listener) task.getProgressListener(),
+                newSearchResponse(totalShards, totalShards, numSkippedShards)
+            );
+            assertCompletionListeners(task, totalShards, totalShards, numSkippedShards, 0, false, false);
         }
-        task.getSearchProgressActionListener()
-            .onFinalReduce(shards, new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), null, 0);
-        assertCompletionListeners(task, totalShards, totalShards, numSkippedShards, 0, true, false);
-        ((AsyncSearchTask.Listener) task.getProgressListener()).onResponse(newSearchResponse(totalShards, totalShards, numSkippedShards));
-        assertCompletionListeners(task, totalShards, totalShards, numSkippedShards, 0, false, false);
     }
 
     public void testWithFetchFailures() throws InterruptedException {
-        AsyncSearchTask task = createAsyncSearchTask();
-        int numShards = randomIntBetween(2, 10);
-        List<SearchShard> shards = new ArrayList<>();
-        for (int i = 0; i < numShards; i++) {
-            shards.add(new SearchShard(null, new ShardId("0", "0", 1)));
-        }
-        List<SearchShard> skippedShards = new ArrayList<>();
-        int numSkippedShards = randomIntBetween(0, 10);
-        for (int i = 0; i < numSkippedShards; i++) {
-            skippedShards.add(new SearchShard(null, new ShardId("0", "0", 1)));
-        }
-        int totalShards = numShards + numSkippedShards;
-        task.getSearchProgressActionListener()
-            .onListShards(shards, skippedShards, SearchResponse.Clusters.EMPTY, false, createTimeProvider());
-        for (int i = 0; i < numShards; i++) {
+        try (AsyncSearchTask task = createAsyncSearchTask()) {
+            int numShards = randomIntBetween(2, 10);
+            List<SearchShard> shards = new ArrayList<>();
+            for (int i = 0; i < numShards; i++) {
+                shards.add(new SearchShard(null, new ShardId("0", "0", 1)));
+            }
+            List<SearchShard> skippedShards = new ArrayList<>();
+            int numSkippedShards = randomIntBetween(0, 10);
+            for (int i = 0; i < numSkippedShards; i++) {
+                skippedShards.add(new SearchShard(null, new ShardId("0", "0", 1)));
+            }
+            int totalShards = numShards + numSkippedShards;
             task.getSearchProgressActionListener()
-                .onPartialReduce(shards.subList(i, i + 1), new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), null, 0);
-            assertCompletionListeners(task, totalShards, 1 + numSkippedShards, numSkippedShards, 0, true, false);
-        }
-        task.getSearchProgressActionListener()
-            .onFinalReduce(shards, new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), null, 0);
-        int numFetchFailures = randomIntBetween(1, numShards - 1);
-        ShardSearchFailure[] shardSearchFailures = new ShardSearchFailure[numFetchFailures];
-        for (int i = 0; i < numFetchFailures; i++) {
-            IOException failure = new IOException("boum");
-            // fetch failures are currently ignored, they come back with onFailure or onResponse anyways
-            task.getSearchProgressActionListener().onFetchFailure(i, new SearchShardTarget("0", new ShardId("0", "0", 1), null), failure);
-            shardSearchFailures[i] = new ShardSearchFailure(failure);
+                .onListShards(shards, skippedShards, SearchResponse.Clusters.EMPTY, false, createTimeProvider());
+            for (int i = 0; i < numShards; i++) {
+                task.getSearchProgressActionListener()
+                    .onPartialReduce(shards.subList(i, i + 1), new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), null, 0);
+                assertCompletionListeners(task, totalShards, 1 + numSkippedShards, numSkippedShards, 0, true, false);
+            }
+            task.getSearchProgressActionListener()
+                .onFinalReduce(shards, new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), null, 0);
+            int numFetchFailures = randomIntBetween(1, numShards - 1);
+            ShardSearchFailure[] shardSearchFailures = new ShardSearchFailure[numFetchFailures];
+            for (int i = 0; i < numFetchFailures; i++) {
+                IOException failure = new IOException("boum");
+                // fetch failures are currently ignored, they come back with onFailure or onResponse anyways
+                task.getSearchProgressActionListener()
+                    .onFetchFailure(i, new SearchShardTarget("0", new ShardId("0", "0", 1), null), failure);
+                shardSearchFailures[i] = new ShardSearchFailure(failure);
+            }
+            assertCompletionListeners(task, totalShards, totalShards, numSkippedShards, 0, true, false);
+            ActionListener.respondAndRelease(
+                (AsyncSearchTask.Listener) task.getProgressListener(),
+                newSearchResponse(totalShards, totalShards - numFetchFailures, numSkippedShards, shardSearchFailures)
+            );
+            assertCompletionListeners(task, totalShards, totalShards - numFetchFailures, numSkippedShards, numFetchFailures, false, false);
         }
-        assertCompletionListeners(task, totalShards, totalShards, numSkippedShards, 0, true, false);
-        ((AsyncSearchTask.Listener) task.getProgressListener()).onResponse(
-            newSearchResponse(totalShards, totalShards - numFetchFailures, numSkippedShards, shardSearchFailures)
-        );
-        assertCompletionListeners(task, totalShards, totalShards - numFetchFailures, numSkippedShards, numFetchFailures, false, false);
     }
 
     public void testFatalFailureDuringFetch() throws InterruptedException {
-        AsyncSearchTask task = createAsyncSearchTask();
-        int numShards = randomIntBetween(0, 10);
-        List<SearchShard> shards = new ArrayList<>();
-        for (int i = 0; i < numShards; i++) {
-            shards.add(new SearchShard(null, new ShardId("0", "0", 1)));
-        }
-        List<SearchShard> skippedShards = new ArrayList<>();
-        int numSkippedShards = randomIntBetween(0, 10);
-        for (int i = 0; i < numSkippedShards; i++) {
-            skippedShards.add(new SearchShard(null, new ShardId("0", "0", 1)));
-        }
-        int totalShards = numShards + numSkippedShards;
-        task.getSearchProgressActionListener()
-            .onListShards(shards, skippedShards, SearchResponse.Clusters.EMPTY, false, createTimeProvider());
-        for (int i = 0; i < numShards; i++) {
+        try (AsyncSearchTask task = createAsyncSearchTask()) {
+            int numShards = randomIntBetween(0, 10);
+            List<SearchShard> shards = new ArrayList<>();
+            for (int i = 0; i < numShards; i++) {
+                shards.add(new SearchShard(null, new ShardId("0", "0", 1)));
+            }
+            List<SearchShard> skippedShards = new ArrayList<>();
+            int numSkippedShards = randomIntBetween(0, 10);
+            for (int i = 0; i < numSkippedShards; i++) {
+                skippedShards.add(new SearchShard(null, new ShardId("0", "0", 1)));
+            }
+            int totalShards = numShards + numSkippedShards;
             task.getSearchProgressActionListener()
-                .onPartialReduce(shards.subList(0, i + 1), new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), null, 0);
-            assertCompletionListeners(task, totalShards, i + 1 + numSkippedShards, numSkippedShards, 0, true, false);
-        }
-        task.getSearchProgressActionListener()
-            .onFinalReduce(shards, new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), null, 0);
-        for (int i = 0; i < numShards; i++) {
-            // fetch failures are currently ignored, they come back with onFailure or onResponse anyways
+                .onListShards(shards, skippedShards, SearchResponse.Clusters.EMPTY, false, createTimeProvider());
+            for (int i = 0; i < numShards; i++) {
+                task.getSearchProgressActionListener()
+                    .onPartialReduce(shards.subList(0, i + 1), new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), null, 0);
+                assertCompletionListeners(task, totalShards, i + 1 + numSkippedShards, numSkippedShards, 0, true, false);
+            }
             task.getSearchProgressActionListener()
-                .onFetchFailure(i, new SearchShardTarget("0", new ShardId("0", "0", 1), null), new IOException("boum"));
+                .onFinalReduce(shards, new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), null, 0);
+            for (int i = 0; i < numShards; i++) {
+                // fetch failures are currently ignored, they come back with onFailure or onResponse anyways
+                task.getSearchProgressActionListener()
+                    .onFetchFailure(i, new SearchShardTarget("0", new ShardId("0", "0", 1), null), new IOException("boum"));
+            }
+            assertCompletionListeners(task, totalShards, totalShards, numSkippedShards, 0, true, false);
+            ((AsyncSearchTask.Listener) task.getProgressListener()).onFailure(new IOException("boum"));
+            assertCompletionListeners(task, totalShards, totalShards, numSkippedShards, 0, true, true);
         }
-        assertCompletionListeners(task, totalShards, totalShards, numSkippedShards, 0, true, false);
-        ((AsyncSearchTask.Listener) task.getProgressListener()).onFailure(new IOException("boum"));
-        assertCompletionListeners(task, totalShards, totalShards, numSkippedShards, 0, true, true);
     }
 
     public void testFatalFailureWithNoCause() throws InterruptedException {
-        AsyncSearchTask task = createAsyncSearchTask();
-        AsyncSearchTask.Listener listener = task.getSearchProgressActionListener();
-        int numShards = randomIntBetween(0, 10);
-        List<SearchShard> shards = new ArrayList<>();
-        for (int i = 0; i < numShards; i++) {
-            shards.add(new SearchShard(null, new ShardId("0", "0", 1)));
-        }
-        List<SearchShard> skippedShards = new ArrayList<>();
-        int numSkippedShards = randomIntBetween(0, 10);
-        for (int i = 0; i < numSkippedShards; i++) {
-            skippedShards.add(new SearchShard(null, new ShardId("0", "0", 1)));
-        }
-        int totalShards = numShards + numSkippedShards;
-        task.getSearchProgressActionListener()
-            .onListShards(shards, skippedShards, SearchResponse.Clusters.EMPTY, false, createTimeProvider());
+        try (AsyncSearchTask task = createAsyncSearchTask()) {
+            AsyncSearchTask.Listener listener = task.getSearchProgressActionListener();
+            int numShards = randomIntBetween(0, 10);
+            List<SearchShard> shards = new ArrayList<>();
+            for (int i = 0; i < numShards; i++) {
+                shards.add(new SearchShard(null, new ShardId("0", "0", 1)));
+            }
+            List<SearchShard> skippedShards = new ArrayList<>();
+            int numSkippedShards = randomIntBetween(0, 10);
+            for (int i = 0; i < numSkippedShards; i++) {
+                skippedShards.add(new SearchShard(null, new ShardId("0", "0", 1)));
+            }
+            int totalShards = numShards + numSkippedShards;
+            task.getSearchProgressActionListener()
+                .onListShards(shards, skippedShards, SearchResponse.Clusters.EMPTY, false, createTimeProvider());
 
-        listener.onFailure(new SearchPhaseExecutionException("fetch", "boum", ShardSearchFailure.EMPTY_ARRAY));
-        assertCompletionListeners(task, totalShards, 0, numSkippedShards, 0, true, true);
+            listener.onFailure(new SearchPhaseExecutionException("fetch", "boum", ShardSearchFailure.EMPTY_ARRAY));
+            assertCompletionListeners(task, totalShards, 0, numSkippedShards, 0, true, true);
+        }
     }
 
     public void testAddCompletionListenerScheduleErrorWaitForInitListener() throws InterruptedException {
         throwOnSchedule = true;
-        AsyncSearchTask asyncSearchTask = createAsyncSearchTask();
-        AtomicReference<Exception> failure = new AtomicReference<>();
-        CountDownLatch latch = new CountDownLatch(1);
-        // onListShards has not been executed, then addCompletionListener has to wait for the
-        // onListShards call and is executed as init listener
-        asyncSearchTask.addCompletionListener(new ActionListener<>() {
-            @Override
-            public void onResponse(AsyncSearchResponse asyncSearchResponse) {
-                throw new AssertionError("onResponse should not be called");
-            }
+        AtomicReference<Exception> failure;
+        try (AsyncSearchTask asyncSearchTask = createAsyncSearchTask()) {
+            failure = new AtomicReference<>();
+            CountDownLatch latch = new CountDownLatch(1);
+            // onListShards has not been executed, then addCompletionListener has to wait for the
+            // onListShards call and is executed as init listener
+            asyncSearchTask.addCompletionListener(new ActionListener<>() {
+                @Override
+                public void onResponse(AsyncSearchResponse asyncSearchResponse) {
+                    throw new AssertionError("onResponse should not be called");
+                }
 
-            @Override
-            public void onFailure(Exception e) {
-                assertTrue(failure.compareAndSet(null, e));
-                latch.countDown();
-            }
-        }, TimeValue.timeValueMillis(500L));
-        asyncSearchTask.getSearchProgressActionListener()
-            .onListShards(Collections.emptyList(), Collections.emptyList(), SearchResponse.Clusters.EMPTY, false, createTimeProvider());
-        assertTrue(latch.await(1000, TimeUnit.SECONDS));
+                @Override
+                public void onFailure(Exception e) {
+                    assertTrue(failure.compareAndSet(null, e));
+                    latch.countDown();
+                }
+            }, TimeValue.timeValueMillis(500L));
+            asyncSearchTask.getSearchProgressActionListener()
+                .onListShards(Collections.emptyList(), Collections.emptyList(), SearchResponse.Clusters.EMPTY, false, createTimeProvider());
+            assertTrue(latch.await(1000, TimeUnit.SECONDS));
+        }
         assertThat(failure.get(), instanceOf(RuntimeException.class));
     }
 
     public void testAddCompletionListenerScheduleErrorInitListenerExecutedImmediately() throws InterruptedException {
         throwOnSchedule = true;
-        AsyncSearchTask asyncSearchTask = createAsyncSearchTask();
-        asyncSearchTask.getSearchProgressActionListener()
-            .onListShards(Collections.emptyList(), Collections.emptyList(), SearchResponse.Clusters.EMPTY, false, createTimeProvider());
-        CountDownLatch latch = new CountDownLatch(1);
-        AtomicReference<Exception> failure = new AtomicReference<>();
-        // onListShards has already been executed, then addCompletionListener is executed immediately
-        asyncSearchTask.addCompletionListener(new ActionListener<>() {
-            @Override
-            public void onResponse(AsyncSearchResponse asyncSearchResponse) {
-                throw new AssertionError("onResponse should not be called");
-            }
+        AtomicReference<Exception> failure;
+        try (AsyncSearchTask asyncSearchTask = createAsyncSearchTask()) {
+            asyncSearchTask.getSearchProgressActionListener()
+                .onListShards(Collections.emptyList(), Collections.emptyList(), SearchResponse.Clusters.EMPTY, false, createTimeProvider());
+            CountDownLatch latch = new CountDownLatch(1);
+            failure = new AtomicReference<>();
+            // onListShards has already been executed, then addCompletionListener is executed immediately
+            asyncSearchTask.addCompletionListener(new ActionListener<>() {
+                @Override
+                public void onResponse(AsyncSearchResponse asyncSearchResponse) {
+                    throw new AssertionError("onResponse should not be called");
+                }
 
-            @Override
-            public void onFailure(Exception e) {
-                assertTrue(failure.compareAndSet(null, e));
-                latch.countDown();
-            }
-        }, TimeValue.timeValueMillis(500L));
-        assertTrue(latch.await(1000, TimeUnit.SECONDS));
+                @Override
+                public void onFailure(Exception e) {
+                    assertTrue(failure.compareAndSet(null, e));
+                    latch.countDown();
+                }
+            }, TimeValue.timeValueMillis(500L));
+            assertTrue(latch.await(1000, TimeUnit.SECONDS));
+        }
         assertThat(failure.get(), instanceOf(RuntimeException.class));
     }