Selaa lähdekoodia

Make AsyncSearchResponse ref-counted (#104128)

It's in the title, this class references `SearchResponse`, so it itself
must become ref-counted.
Armin Braun 1 vuosi sitten
vanhempi
commit
76dc8f8fd7

+ 172 - 98
x-pack/plugin/async-search/src/internalClusterTest/java/org/elasticsearch/xpack/search/CrossClusterAsyncSearchIT.java

@@ -157,11 +157,12 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
             request.getSearchRequest().searchType(SearchType.DFS_QUERY_THEN_FETCH);
         }
 
-        AsyncSearchResponse response = submitAsyncSearch(request);
-        assertNotNull(response.getSearchResponse());
-        assertTrue(response.isRunning());
-
-        {
+        final String responseId;
+        final AsyncSearchResponse response = submitAsyncSearch(request);
+        try {
+            responseId = response.getId();
+            assertNotNull(response.getSearchResponse());
+            assertTrue(response.isRunning());
             SearchResponse.Clusters clusters = response.getSearchResponse().getClusters();
             assertThat(clusters.getTotal(), equalTo(2));
             assertTrue("search cluster results should be marked as partial", clusters.hasPartialResults());
@@ -173,14 +174,16 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
             SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER);
             assertNotNull(remoteClusterSearchInfo);
             assertThat(localClusterSearchInfo.getStatus(), equalTo(SearchResponse.Cluster.Status.RUNNING));
+        } finally {
+            response.decRef();
         }
 
         SearchListenerPlugin.waitSearchStarted();
         SearchListenerPlugin.allowQueryPhase();
 
         waitForSearchTasksToFinish();
-        {
-            AsyncSearchResponse finishedResponse = getAsyncSearch(response.getId());
+        final AsyncSearchResponse finishedResponse = getAsyncSearch(responseId);
+        try {
             assertFalse(finishedResponse.isPartial());
 
             SearchResponse.Clusters clusters = finishedResponse.getSearchResponse().getClusters();
@@ -211,11 +214,13 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
             assertThat(remoteClusterSearchInfo.getFailedShards(), equalTo(0));
             assertThat(remoteClusterSearchInfo.getFailures().size(), equalTo(0));
             assertThat(remoteClusterSearchInfo.getTook().millis(), greaterThan(0L));
+        } finally {
+            finishedResponse.decRef();
         }
 
         // check that the async_search/status response includes the same cluster details
         {
-            AsyncStatusResponse statusResponse = getAsyncStatus(response.getId());
+            AsyncStatusResponse statusResponse = getAsyncStatus(responseId);
             assertFalse(statusResponse.isPartial());
 
             SearchResponse.Clusters clusters = statusResponse.getClusters();
@@ -276,11 +281,12 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
 
         boolean minimizeRoundtrips = TransportSearchAction.shouldMinimizeRoundtrips(request.getSearchRequest());
 
-        AsyncSearchResponse response = submitAsyncSearch(request);
-        assertNotNull(response.getSearchResponse());
-        assertTrue(response.isRunning());
-
-        {
+        final String responseId;
+        final AsyncSearchResponse response = submitAsyncSearch(request);
+        try {
+            responseId = response.getId();
+            assertNotNull(response.getSearchResponse());
+            assertTrue(response.isRunning());
             SearchResponse.Clusters clusters = response.getSearchResponse().getClusters();
             assertThat(clusters.getTotal(), equalTo(2));
             assertTrue("search cluster results should be marked as partial", clusters.hasPartialResults());
@@ -292,14 +298,16 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
             SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER);
             assertNotNull(remoteClusterSearchInfo);
             assertThat(localClusterSearchInfo.getStatus(), equalTo(SearchResponse.Cluster.Status.RUNNING));
+        } finally {
+            response.decRef();
         }
 
         SearchListenerPlugin.waitSearchStarted();
         SearchListenerPlugin.allowQueryPhase();
 
         waitForSearchTasksToFinish();
-        {
-            AsyncSearchResponse finishedResponse = getAsyncSearch(response.getId());
+        final AsyncSearchResponse finishedResponse = getAsyncSearch(responseId);
+        try {
             assertNotNull(finishedResponse);
             assertFalse(finishedResponse.isPartial());
 
@@ -341,9 +349,11 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
             assertThat(remoteClusterSearchInfo.getFailedShards(), equalTo(0));
             assertThat(remoteClusterSearchInfo.getFailures().size(), equalTo(0));
             assertThat(remoteClusterSearchInfo.getTook().millis(), greaterThanOrEqualTo(0L));
+        } finally {
+            finishedResponse.decRef();
         }
         {
-            AsyncStatusResponse statusResponse = getAsyncStatus(response.getId());
+            AsyncStatusResponse statusResponse = getAsyncStatus(responseId);
             assertNotNull(statusResponse);
             assertFalse(statusResponse.isPartial());
 
@@ -413,13 +423,17 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
 
         boolean minimizeRoundtrips = TransportSearchAction.shouldMinimizeRoundtrips(request.getSearchRequest());
 
-        AsyncSearchResponse response = submitAsyncSearch(request);
-        assertNotNull(response.getSearchResponse());
-
-        waitForSearchTasksToFinish();
-
-        {
-            AsyncSearchResponse finishedResponse = getAsyncSearch(response.getId());
+        final AsyncSearchResponse response = submitAsyncSearch(request);
+        final String responseId;
+        try {
+            assertNotNull(response.getSearchResponse());
+            waitForSearchTasksToFinish();
+            responseId = response.getId();
+        } finally {
+            response.decRef();
+        }
+        final AsyncSearchResponse finishedResponse = getAsyncSearch(responseId);
+        try {
             assertTrue(finishedResponse.isPartial());
 
             SearchResponse.Clusters clusters = finishedResponse.getSearchResponse().getClusters();
@@ -447,6 +461,8 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
                 : SearchResponse.Cluster.Status.FAILED;
             assertThat(remoteClusterSearchInfo.getStatus(), equalTo(expectedStatus));
             assertAllShardsFailed(minimizeRoundtrips, remoteClusterSearchInfo, remoteNumShards);
+        } finally {
+            finishedResponse.decRef();
         }
         // check that the async_search/status response includes the same cluster details
         {
@@ -505,11 +521,10 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
         ThrowingQueryBuilder queryBuilder = new ThrowingQueryBuilder(randomLong(), new IllegalStateException("index corrupted"), 0);
         request.getSearchRequest().source(new SearchSourceBuilder().query(queryBuilder).size(10));
 
-        AsyncSearchResponse response = submitAsyncSearch(request);
-        assertNotNull(response.getSearchResponse());
-        assertTrue(response.isRunning());
-
-        {
+        final AsyncSearchResponse response = submitAsyncSearch(request);
+        try {
+            assertNotNull(response.getSearchResponse());
+            assertTrue(response.isRunning());
             SearchResponse.Clusters clusters = response.getSearchResponse().getClusters();
             assertThat(clusters.getTotal(), equalTo(2));
             assertTrue("search cluster results should be marked as partial", clusters.hasPartialResults());
@@ -521,6 +536,8 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
             SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER);
             assertNotNull(remoteClusterSearchInfo);
             assertThat(localClusterSearchInfo.getStatus(), equalTo(SearchResponse.Cluster.Status.RUNNING));
+        } finally {
+            response.decRef();
         }
 
         SearchListenerPlugin.waitSearchStarted();
@@ -528,8 +545,8 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
 
         waitForSearchTasksToFinish();
 
-        {
-            AsyncSearchResponse finishedResponse = getAsyncSearch(response.getId());
+        final AsyncSearchResponse finishedResponse = getAsyncSearch(response.getId());
+        try {
             assertTrue(finishedResponse.isPartial());
 
             SearchResponse.Clusters clusters = finishedResponse.getSearchResponse().getClusters();
@@ -563,6 +580,8 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
             assertThat(remoteClusterSearchInfo.getTook().millis(), greaterThan(0L));
             ShardSearchFailure remoteShardSearchFailure = remoteClusterSearchInfo.getFailures().get(0);
             assertTrue("should have 'index corrupted' in reason", remoteShardSearchFailure.reason().contains("index corrupted"));
+        } finally {
+            finishedResponse.decRef();
         }
         // check that the async_search/status response includes the same cluster details
         {
@@ -635,10 +654,10 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
 
         boolean minimizeRoundtrips = TransportSearchAction.shouldMinimizeRoundtrips(request.getSearchRequest());
 
-        AsyncSearchResponse response = submitAsyncSearch(request);
-        assertNotNull(response.getSearchResponse());
-        assertTrue(response.isRunning());
-        {
+        final AsyncSearchResponse response = submitAsyncSearch(request);
+        try {
+            assertNotNull(response.getSearchResponse());
+            assertTrue(response.isRunning());
             SearchResponse.Clusters clusters = response.getSearchResponse().getClusters();
             assertThat(clusters.getTotal(), equalTo(2));
             assertTrue("search cluster results should be marked as partial", clusters.hasPartialResults());
@@ -650,6 +669,8 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
             SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER);
             assertNotNull(remoteClusterSearchInfo);
             assertThat(localClusterSearchInfo.getStatus(), equalTo(SearchResponse.Cluster.Status.RUNNING));
+        } finally {
+            response.decRef();
         }
 
         SearchListenerPlugin.waitSearchStarted();
@@ -657,8 +678,8 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
 
         waitForSearchTasksToFinish();
 
-        {
-            AsyncSearchResponse finishedResponse = getAsyncSearch(response.getId());
+        final AsyncSearchResponse finishedResponse = getAsyncSearch(response.getId());
+        try {
             assertTrue(finishedResponse.isPartial());
 
             SearchResponse.Clusters clusters = finishedResponse.getSearchResponse().getClusters();
@@ -707,6 +728,8 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
             assertFalse(remoteClusterSearchInfo.isTimedOut());
             ShardSearchFailure remoteShardSearchFailure = remoteClusterSearchInfo.getFailures().get(0);
             assertTrue("should have 'index corrupted' in reason", remoteShardSearchFailure.reason().contains("index corrupted"));
+        } finally {
+            finishedResponse.decRef();
         }
         // check that the async_search/status response includes the same cluster details
         {
@@ -780,15 +803,16 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
             request.getSearchRequest().searchType(SearchType.DFS_QUERY_THEN_FETCH);
         }
 
-        AsyncSearchResponse response = submitAsyncSearch(request);
-        assertNotNull(response.getSearchResponse());
-        assertTrue(response.isRunning());
+        final AsyncSearchResponse response = submitAsyncSearch(request);
+        try {
+            assertNotNull(response.getSearchResponse());
+            assertTrue(response.isRunning());
+
+            boolean minimizeRoundtrips = TransportSearchAction.shouldMinimizeRoundtrips(request.getSearchRequest());
 
-        boolean minimizeRoundtrips = TransportSearchAction.shouldMinimizeRoundtrips(request.getSearchRequest());
+            assertNotNull(response.getSearchResponse());
+            assertTrue(response.isRunning());
 
-        assertNotNull(response.getSearchResponse());
-        assertTrue(response.isRunning());
-        {
             SearchResponse.Clusters clusters = response.getSearchResponse().getClusters();
             assertThat(clusters.getTotal(), equalTo(2));
             assertTrue("search cluster results should be marked as partial", clusters.hasPartialResults());
@@ -800,6 +824,8 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
             SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER);
             assertNotNull(remoteClusterSearchInfo);
             assertThat(localClusterSearchInfo.getStatus(), equalTo(SearchResponse.Cluster.Status.RUNNING));
+        } finally {
+            response.decRef();
         }
 
         SearchListenerPlugin.waitSearchStarted();
@@ -807,8 +833,8 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
 
         waitForSearchTasksToFinish();
 
-        {
-            AsyncSearchResponse finishedResponse = getAsyncSearch(response.getId());
+        final AsyncSearchResponse finishedResponse = getAsyncSearch(response.getId());
+        try {
             assertFalse(finishedResponse.isPartial());
 
             SearchResponse.Clusters clusters = finishedResponse.getSearchResponse().getClusters();
@@ -839,6 +865,8 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
 
             assertNotNull(remoteClusterSearchInfo.getTook());
             assertFalse(remoteClusterSearchInfo.isTimedOut());
+        } finally {
+            finishedResponse.decRef();
         }
         // check that the async_search/status response includes the same cluster details
         {
@@ -902,13 +930,18 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
             request.getSearchRequest().searchType(SearchType.DFS_QUERY_THEN_FETCH);
         }
 
-        AsyncSearchResponse response = submitAsyncSearch(request);
-        assertNotNull(response.getSearchResponse());
-
-        waitForSearchTasksToFinish();
+        final String responseId;
+        final AsyncSearchResponse response = submitAsyncSearch(request);
+        try {
+            assertNotNull(response.getSearchResponse());
+            waitForSearchTasksToFinish();
+            responseId = response.getId();
+        } finally {
+            response.decRef();
+        }
 
-        {
-            AsyncSearchResponse finishedResponse = getAsyncSearch(response.getId());
+        final AsyncSearchResponse finishedResponse = getAsyncSearch(responseId);
+        try {
             assertTrue(finishedResponse.getSearchResponse().isTimedOut());
             assertTrue(finishedResponse.isPartial());
 
@@ -943,10 +976,12 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
             assertThat(remoteClusterSearchInfo.getFailures().size(), equalTo(0));
             assertThat(remoteClusterSearchInfo.getTook().millis(), greaterThanOrEqualTo(0L));
             assertTrue(remoteClusterSearchInfo.isTimedOut());
+        } finally {
+            finishedResponse.decRef();
         }
         // check that the async_search/status response includes the same cluster details
         {
-            AsyncStatusResponse statusResponse = getAsyncStatus(response.getId());
+            AsyncStatusResponse statusResponse = getAsyncStatus(responseId);
             assertTrue(statusResponse.isPartial());
 
             SearchResponse.Clusters clusters = statusResponse.getClusters();
@@ -1006,13 +1041,19 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
         }
         request.getSearchRequest().source(new SearchSourceBuilder().query(new MatchAllQueryBuilder()).size(10));
 
-        AsyncSearchResponse response = submitAsyncSearch(request);
-        assertNotNull(response.getSearchResponse());
+        final String responseId;
+        final AsyncSearchResponse response = submitAsyncSearch(request);
+        try {
+            assertNotNull(response.getSearchResponse());
+            responseId = response.getId();
+        } finally {
+            response.decRef();
+        }
 
         waitForSearchTasksToFinish();
 
-        {
-            AsyncSearchResponse finishedResponse = getAsyncSearch(response.getId());
+        final AsyncSearchResponse finishedResponse = getAsyncSearch(responseId);
+        try {
             assertFalse(finishedResponse.isPartial());
 
             SearchResponse.Clusters clusters = finishedResponse.getSearchResponse().getClusters();
@@ -1035,11 +1076,13 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
             assertThat(remoteClusterSearchInfo.getFailedShards(), equalTo(0));
             assertThat(remoteClusterSearchInfo.getFailures().size(), equalTo(0));
             assertThat(remoteClusterSearchInfo.getTook().millis(), greaterThan(0L));
+        } finally {
+            finishedResponse.decRef();
         }
 
         // check that the async_search/status response includes the same cluster details
         {
-            AsyncStatusResponse statusResponse = getAsyncStatus(response.getId());
+            AsyncStatusResponse statusResponse = getAsyncStatus(responseId);
             assertFalse(statusResponse.isPartial());
 
             SearchResponse.Clusters clusters = statusResponse.getClusters();
@@ -1089,13 +1132,16 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
         ThrowingQueryBuilder queryBuilder = new ThrowingQueryBuilder(randomLong(), new IllegalStateException("index corrupted"), 0);
         request.getSearchRequest().source(new SearchSourceBuilder().query(queryBuilder).size(10));
 
-        AsyncSearchResponse response = submitAsyncSearch(request);
-        assertNotNull(response.getSearchResponse());
-
+        final AsyncSearchResponse response = submitAsyncSearch(request);
+        try {
+            assertNotNull(response.getSearchResponse());
+        } finally {
+            response.decRef();
+        }
         waitForSearchTasksToFinish();
 
-        {
-            AsyncSearchResponse finishedResponse = getAsyncSearch(response.getId());
+        final AsyncSearchResponse finishedResponse = getAsyncSearch(response.getId());
+        try {
             assertTrue(finishedResponse.isPartial());
 
             SearchResponse.Clusters clusters = finishedResponse.getSearchResponse().getClusters();
@@ -1119,6 +1165,8 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
             assertThat(remoteClusterSearchInfo.getTook().millis(), greaterThan(0L));
             ShardSearchFailure remoteShardSearchFailure = remoteClusterSearchInfo.getFailures().get(0);
             assertTrue("should have 'index corrupted' in reason", remoteShardSearchFailure.reason().contains("index corrupted"));
+        } finally {
+            finishedResponse.decRef();
         }
         // check that the async_search/status response includes the same cluster details
         {
@@ -1177,12 +1225,16 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
 
         boolean minimizeRoundtrips = TransportSearchAction.shouldMinimizeRoundtrips(request.getSearchRequest());
 
-        AsyncSearchResponse response = submitAsyncSearch(request);
-        assertNotNull(response.getSearchResponse());
+        final AsyncSearchResponse response = submitAsyncSearch(request);
+        try {
+            assertNotNull(response.getSearchResponse());
+        } finally {
+            response.decRef();
+        }
 
         waitForSearchTasksToFinish();
-        {
-            AsyncSearchResponse finishedResponse = getAsyncSearch(response.getId());
+        final AsyncSearchResponse finishedResponse = getAsyncSearch(response.getId());
+        try {
             assertTrue(finishedResponse.isPartial());
 
             SearchResponse.Clusters clusters = finishedResponse.getSearchResponse().getClusters();
@@ -1207,6 +1259,8 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
                 : SearchResponse.Cluster.Status.FAILED;
             assertThat(remoteClusterSearchInfo.getStatus(), equalTo(expectedStatus));
             assertAllShardsFailed(minimizeRoundtrips, remoteClusterSearchInfo, remoteNumShards);
+        } finally {
+            finishedResponse.decRef();
         }
         // check that the async_search/status response includes the same cluster details
         {
@@ -1254,9 +1308,13 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
         request.getSearchRequest().allowPartialSearchResults(false);
         request.getSearchRequest().source(new SearchSourceBuilder().query(new MatchAllQueryBuilder()).size(10));
 
-        AsyncSearchResponse response = submitAsyncSearch(request);
-        assertNotNull(response.getSearchResponse());
-        assertTrue(response.isRunning());
+        final AsyncSearchResponse response = submitAsyncSearch(request);
+        try {
+            assertNotNull(response.getSearchResponse());
+        } finally {
+            response.decRef();
+            assertTrue(response.isRunning());
+        }
 
         SearchListenerPlugin.waitSearchStarted();
 
@@ -1317,17 +1375,21 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
             }
 
             // check async search status before allowing query to continue but after cancellation
-            AsyncSearchResponse searchResponseAfterCancellation = getAsyncSearch(response.getId());
-            assertTrue(searchResponseAfterCancellation.isPartial());
-            assertTrue(searchResponseAfterCancellation.isRunning());
-            assertFalse(searchResponseAfterCancellation.getSearchResponse().isTimedOut());
-            assertThat(searchResponseAfterCancellation.getSearchResponse().getClusters().getTotal(), equalTo(2));
-
-            AsyncStatusResponse statusResponse = getAsyncStatus(response.getId());
-            assertTrue(statusResponse.isPartial());
-            assertTrue(statusResponse.isRunning());
-            assertThat(statusResponse.getClusters().getTotal(), equalTo(2));
-            assertNull(statusResponse.getCompletionStatus());
+            final AsyncSearchResponse searchResponseAfterCancellation = getAsyncSearch(response.getId());
+            try {
+                assertTrue(searchResponseAfterCancellation.isPartial());
+                assertTrue(searchResponseAfterCancellation.isRunning());
+                assertFalse(searchResponseAfterCancellation.getSearchResponse().isTimedOut());
+                assertThat(searchResponseAfterCancellation.getSearchResponse().getClusters().getTotal(), equalTo(2));
+
+                AsyncStatusResponse statusResponse = getAsyncStatus(response.getId());
+                assertTrue(statusResponse.isPartial());
+                assertTrue(statusResponse.isRunning());
+                assertThat(statusResponse.getClusters().getTotal(), equalTo(2));
+                assertNull(statusResponse.getCompletionStatus());
+            } finally {
+                searchResponseAfterCancellation.decRef();
+            }
 
         } finally {
             SearchListenerPlugin.allowQueryPhase();
@@ -1343,18 +1405,22 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
         assertThat(statusResponseAfterCompletion.getClusters().getTotal(), equalTo(2));
         assertThat(statusResponseAfterCompletion.getCompletionStatus(), equalTo(RestStatus.BAD_REQUEST));
 
-        AsyncSearchResponse searchResponseAfterCompletion = getAsyncSearch(response.getId());
-        assertTrue(searchResponseAfterCompletion.isPartial());
-        assertFalse(searchResponseAfterCompletion.isRunning());
-        assertFalse(searchResponseAfterCompletion.getSearchResponse().isTimedOut());
-        assertThat(searchResponseAfterCompletion.getSearchResponse().getClusters().getTotal(), equalTo(2));
-        Throwable cause = ExceptionsHelper.unwrap(searchResponseAfterCompletion.getFailure(), TaskCancelledException.class);
-        assertNotNull("TaskCancelledException should be in the causal chain", cause);
-        String json = Strings.toString(
-            ChunkedToXContent.wrapAsToXContent(searchResponseAfterCompletion)
-                .toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)
-        );
-        assertThat(json, matchesRegex(".*task (was)?\s*cancelled.*"));
+        final AsyncSearchResponse searchResponseAfterCompletion = getAsyncSearch(response.getId());
+        try {
+            assertTrue(searchResponseAfterCompletion.isPartial());
+            assertFalse(searchResponseAfterCompletion.isRunning());
+            assertFalse(searchResponseAfterCompletion.getSearchResponse().isTimedOut());
+            assertThat(searchResponseAfterCompletion.getSearchResponse().getClusters().getTotal(), equalTo(2));
+            Throwable cause = ExceptionsHelper.unwrap(searchResponseAfterCompletion.getFailure(), TaskCancelledException.class);
+            assertNotNull("TaskCancelledException should be in the causal chain", cause);
+            String json = Strings.toString(
+                ChunkedToXContent.wrapAsToXContent(searchResponseAfterCompletion)
+                    .toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)
+            );
+            assertThat(json, matchesRegex(".*task (was)?\s*cancelled.*"));
+        } finally {
+            searchResponseAfterCompletion.decRef();
+        }
     }
 
     public void testCancelViaAsyncSearchDelete() throws Exception {
@@ -1374,9 +1440,13 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
         request.getSearchRequest().allowPartialSearchResults(false);
         request.getSearchRequest().source(new SearchSourceBuilder().query(new MatchAllQueryBuilder()).size(10));
 
-        AsyncSearchResponse response = submitAsyncSearch(request);
-        assertNotNull(response.getSearchResponse());
-        assertTrue(response.isRunning());
+        final AsyncSearchResponse response = submitAsyncSearch(request);
+        try {
+            assertNotNull(response.getSearchResponse());
+            assertTrue(response.isRunning());
+        } finally {
+            response.decRef();
+        }
 
         SearchListenerPlugin.waitSearchStarted();
 
@@ -1477,9 +1547,13 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
         request.setWaitForCompletionTimeout(TimeValue.timeValueMillis(1));
         request.setKeepOnCompletion(true);
 
-        AsyncSearchResponse response = submitAsyncSearch(request);
-        assertNotNull(response.getSearchResponse());
-        assertTrue(response.isRunning());
+        final AsyncSearchResponse response = submitAsyncSearch(request);
+        try {
+            assertNotNull(response.getSearchResponse());
+            assertTrue(response.isRunning());
+        } finally {
+            response.decRef();
+        }
 
         SearchListenerPlugin.waitSearchStarted();
 

+ 27 - 10
x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/AsyncSearchTask.java

@@ -7,6 +7,7 @@
 package org.elasticsearch.xpack.search;
 
 import org.apache.lucene.search.TotalHits;
+import org.apache.lucene.util.SetOnce;
 import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.ExceptionsHelper;
@@ -23,6 +24,8 @@ import org.elasticsearch.action.search.SearchTask;
 import org.elasticsearch.action.search.ShardSearchFailure;
 import org.elasticsearch.action.search.TransportSearchAction;
 import org.elasticsearch.client.internal.Client;
+import org.elasticsearch.core.Releasable;
+import org.elasticsearch.core.Releasables;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.search.SearchShardTarget;
 import org.elasticsearch.search.aggregations.AggregationReduceContext;
@@ -42,7 +45,6 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.concurrent.atomic.AtomicBoolean;
-import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.Consumer;
 import java.util.function.Function;
 import java.util.function.Supplier;
@@ -52,7 +54,7 @@ import static java.util.Collections.singletonList;
 /**
  * Task that tracks the progress of a currently running {@link SearchRequest}.
  */
-final class AsyncSearchTask extends SearchTask implements AsyncTask {
+final class AsyncSearchTask extends SearchTask implements AsyncTask, Releasable {
     private final AsyncExecutionId searchId;
     private final Client client;
     private final ThreadPool threadPool;
@@ -71,7 +73,7 @@ final class AsyncSearchTask extends SearchTask implements AsyncTask {
     private volatile long expirationTimeMillis;
     private final AtomicBoolean isCancelling = new AtomicBoolean(false);
 
-    private final AtomicReference<MutableSearchResponse> searchResponse = new AtomicReference<>();
+    private final SetOnce<MutableSearchResponse> searchResponse = new SetOnce<>();
 
     /**
      * Creates an instance of {@link AsyncSearchTask}.
@@ -220,7 +222,12 @@ final class AsyncSearchTask extends SearchTask implements AsyncTask {
             }
         }
         if (executeImmediately) {
-            listener.accept(getResponseWithHeaders());
+            var response = getResponseWithHeaders();
+            try {
+                listener.accept(response);
+            } finally {
+                response.decRef();
+            }
         }
     }
 
@@ -308,10 +315,13 @@ final class AsyncSearchTask extends SearchTask implements AsyncTask {
         // we don't need to restore the response headers, they should be included in the current
         // context since we are called by the search action listener.
         AsyncSearchResponse finalResponse = getResponse();
-        for (Consumer<AsyncSearchResponse> consumer : completionsListenersCopy.values()) {
-            consumer.accept(finalResponse);
+        try {
+            for (Consumer<AsyncSearchResponse> consumer : completionsListenersCopy.values()) {
+                consumer.accept(finalResponse);
+            }
+        } finally {
+            finalResponse.decRef();
         }
-
     }
 
     /**
@@ -369,6 +379,11 @@ final class AsyncSearchTask extends SearchTask implements AsyncTask {
         );
     }
 
+    @Override
+    public void close() {
+        Releasables.close(searchResponse.get());
+    }
+
     class Listener extends SearchProgressActionListener {
 
         // needed when there's a single coordinator for all CCS search phases (minimize_roundtrips=false)
@@ -438,8 +453,7 @@ final class AsyncSearchTask extends SearchTask implements AsyncTask {
                 delegate = new CCSSingleCoordinatorSearchProgressListener();
                 delegate.onListShards(shards, skipped, clusters, fetchPhase, timeProvider);
             }
-            searchResponse.compareAndSet(
-                null,
+            searchResponse.set(
                 new MutableSearchResponse(shards.size() + skipped.size(), skipped.size(), clusters, threadPool.getThreadContext())
             );
             executeInitListeners();
@@ -494,7 +508,10 @@ final class AsyncSearchTask extends SearchTask implements AsyncTask {
         @Override
         public void onFailure(Exception exc) {
             // if the failure occurred before calling onListShards
-            searchResponse.compareAndSet(null, new MutableSearchResponse(-1, -1, null, threadPool.getThreadContext()));
+            var r = new MutableSearchResponse(-1, -1, null, threadPool.getThreadContext());
+            if (searchResponse.trySet(r) == false) {
+                r.close();
+            }
             searchResponse.get()
                 .updateWithFailure(new ElasticsearchStatusException("error while executing search", ExceptionsHelper.status(exc), exc));
             executeInitListeners();

+ 30 - 10
x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/MutableSearchResponse.java

@@ -15,6 +15,7 @@ import org.elasticsearch.action.search.ShardSearchFailure;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.util.concurrent.AtomicArray;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
+import org.elasticsearch.core.Releasable;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.search.SearchHits;
 import org.elasticsearch.search.aggregations.InternalAggregations;
@@ -34,7 +35,7 @@ import static org.elasticsearch.xpack.core.async.AsyncTaskIndexService.restoreRe
  * creating an async response concurrently. This limits the number of final reduction that can
  * run concurrently to 1 and ensures that we pause the search progress when an {@link AsyncSearchResponse} is built.
  */
-class MutableSearchResponse {
+class MutableSearchResponse implements Releasable {
     private static final TotalHits EMPTY_TOTAL_HITS = new TotalHits(0L, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO);
     private final int totalShards;
     private final int skippedShards;
@@ -117,7 +118,12 @@ class MutableSearchResponse {
             : getShardsInResponseMismatchInfo(response, ccsMinimizeRoundtrips);
 
         this.responseHeaders = threadContext.getResponseHeaders();
+        response.mustIncRef();
+        var existing = this.finalResponse;
         this.finalResponse = response;
+        if (existing != null) {
+            existing.decRef();
+        }
         this.isPartial = isPartialResponse(response);
         this.frozen = true;
     }
@@ -188,6 +194,7 @@ class MutableSearchResponse {
         if (finalResponse != null) {
             // We have a final response, use it.
             searchResponse = finalResponse;
+            searchResponse.mustIncRef();
         } else if (clusters == null) {
             // An error occurred before we got the shard list
             searchResponse = null;
@@ -203,15 +210,21 @@ class MutableSearchResponse {
             reducedAggsSource = () -> reducedAggs;
             searchResponse = buildResponse(task.getStartTimeNanos(), reducedAggs);
         }
-        return new AsyncSearchResponse(
-            task.getExecutionId().getEncoded(),
-            searchResponse,
-            failure,
-            isPartial,
-            frozen == false,
-            task.getStartTime(),
-            expirationTime
-        );
+        try {
+            return new AsyncSearchResponse(
+                task.getExecutionId().getEncoded(),
+                searchResponse,
+                failure,
+                isPartial,
+                frozen == false,
+                task.getStartTime(),
+                expirationTime
+            );
+        } finally {
+            if (searchResponse != null) {
+                searchResponse.decRef();
+            }
+        }
     }
 
     /**
@@ -358,4 +371,11 @@ class MutableSearchResponse {
             throw new IllegalStateException("assert method hit unexpected case for ccsMinimizeRoundtrips=false");
         }
     }
+
+    @Override
+    public void close() {
+        if (finalResponse != null) {
+            finalResponse.decRef();
+        }
+    }
 }

+ 64 - 43
x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/TransportSubmitAsyncSearchAction.java

@@ -103,42 +103,59 @@ public class TransportSubmitAsyncSearchAction extends HandledTransportAction<Sub
                             // creates the fallback response if the node crashes/restarts in the middle of the request
                             // TODO: store intermediate results ?
                             AsyncSearchResponse initialResp = searchResponse.clone(searchResponse.getId());
-                            store.createResponse(docId, searchTask.getOriginHeaders(), initialResp, new ActionListener<>() {
-                                @Override
-                                public void onResponse(DocWriteResponse r) {
-                                    if (searchResponse.isRunning()) {
-                                        try {
-                                            // store the final response on completion unless the submit is cancelled
-                                            searchTask.addCompletionListener(
-                                                finalResponse -> onFinalResponse(searchTask, finalResponse, () -> {})
-                                            );
-                                        } finally {
-                                            submitListener.onResponse(searchResponse);
+                            searchResponse.mustIncRef();
+                            try {
+                                store.createResponse(
+                                    docId,
+                                    searchTask.getOriginHeaders(),
+                                    initialResp,
+                                    ActionListener.runAfter(new ActionListener<>() {
+                                        @Override
+                                        public void onResponse(DocWriteResponse r) {
+                                            if (searchResponse.isRunning()) {
+                                                try {
+                                                    // store the final response on completion unless the submit is cancelled
+                                                    searchTask.addCompletionListener(
+                                                        finalResponse -> onFinalResponse(searchTask, finalResponse, () -> {})
+                                                    );
+                                                } finally {
+                                                    submitListener.onResponse(searchResponse);
+                                                }
+                                            } else {
+                                                searchResponse.mustIncRef();
+                                                onFinalResponse(
+                                                    searchTask,
+                                                    searchResponse,
+                                                    () -> ActionListener.respondAndRelease(submitListener, searchResponse)
+                                                );
+                                            }
                                         }
-                                    } else {
-                                        onFinalResponse(searchTask, searchResponse, () -> submitListener.onResponse(searchResponse));
-                                    }
-                                }
 
-                                @Override
-                                public void onFailure(Exception exc) {
-                                    onFatalFailure(
-                                        searchTask,
-                                        exc,
-                                        searchResponse.isRunning(),
-                                        "fatal failure: unable to store initial response",
-                                        submitListener
-                                    );
-                                }
-                            });
+                                        @Override
+                                        public void onFailure(Exception exc) {
+                                            onFatalFailure(
+                                                searchTask,
+                                                exc,
+                                                searchResponse.isRunning(),
+                                                "fatal failure: unable to store initial response",
+                                                submitListener
+                                            );
+                                        }
+                                    }, searchResponse::decRef)
+                                );
+                            } finally {
+                                initialResp.decRef();
+                            }
                         } catch (Exception exc) {
                             onFatalFailure(searchTask, exc, searchResponse.isRunning(), "fatal failure: generic error", submitListener);
                         }
                     } else {
-                        // the task completed within the timeout so the response is sent back to the user
-                        // with a null id since nothing was stored on the cluster.
-                        taskManager.unregister(searchTask);
-                        submitListener.onResponse(searchResponse.clone(null));
+                        try (searchTask) {
+                            // the task completed within the timeout so the response is sent back to the user
+                            // with a null id since nothing was stored on the cluster.
+                            taskManager.unregister(searchTask);
+                            ActionListener.respondAndRelease(submitListener, searchResponse.clone(null));
+                        }
                     }
                 }
 
@@ -192,19 +209,21 @@ public class TransportSubmitAsyncSearchAction extends HandledTransportAction<Sub
         ActionListener<AsyncSearchResponse> listener
     ) {
         if (shouldCancel && task.isCancelled() == false) {
-            task.cancelTask(() -> {
-                try {
-                    task.addCompletionListener(finalResponse -> taskManager.unregister(task));
-                } finally {
-                    listener.onFailure(error);
-                }
-            }, cancelReason);
+            task.cancelTask(() -> closeTaskAndFail(task, error, listener), cancelReason);
         } else {
-            try {
-                task.addCompletionListener(finalResponse -> taskManager.unregister(task));
-            } finally {
-                listener.onFailure(error);
-            }
+            closeTaskAndFail(task, error, listener);
+        }
+    }
+
+    private void closeTaskAndFail(AsyncSearchTask task, Exception error, ActionListener<AsyncSearchResponse> listener) {
+        try {
+            task.addCompletionListener(finalResponse -> {
+                try (task) {
+                    taskManager.unregister(task);
+                }
+            });
+        } finally {
+            listener.onFailure(error);
         }
     }
 
@@ -214,7 +233,9 @@ public class TransportSubmitAsyncSearchAction extends HandledTransportAction<Sub
             threadContext.getResponseHeaders(),
             response,
             ActionListener.running(() -> {
-                taskManager.unregister(searchTask);
+                try (searchTask) {
+                    taskManager.unregister(searchTask);
+                }
                 nextAction.run();
             })
         );

+ 47 - 45
x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/AsyncSearchSingleNodeTests.java

@@ -26,7 +26,6 @@ import org.elasticsearch.search.fetch.FetchSubPhase;
 import org.elasticsearch.search.fetch.FetchSubPhaseProcessor;
 import org.elasticsearch.search.fetch.StoredFieldsSpec;
 import org.elasticsearch.test.ESSingleNodeTestCase;
-import org.elasticsearch.xpack.core.search.action.AsyncSearchResponse;
 import org.elasticsearch.xpack.core.search.action.SubmitAsyncSearchAction;
 import org.elasticsearch.xpack.core.search.action.SubmitAsyncSearchRequest;
 import org.hamcrest.CoreMatchers;
@@ -35,6 +34,8 @@ import java.util.Collection;
 import java.util.Collections;
 import java.util.List;
 
+import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse;
+
 public class AsyncSearchSingleNodeTests extends ESSingleNodeTestCase {
 
     @Override
@@ -53,32 +54,33 @@ public class AsyncSearchSingleNodeTests extends ESSingleNodeTestCase {
         SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().aggregation(agg);
         SubmitAsyncSearchRequest submitAsyncSearchRequest = new SubmitAsyncSearchRequest(sourceBuilder);
         submitAsyncSearchRequest.setWaitForCompletionTimeout(TimeValue.timeValueSeconds(10));
-        AsyncSearchResponse asyncSearchResponse = client().execute(SubmitAsyncSearchAction.INSTANCE, submitAsyncSearchRequest).actionGet();
-
-        assertFalse(asyncSearchResponse.isRunning());
-        assertTrue(asyncSearchResponse.isPartial());
-        SearchResponse searchResponse = asyncSearchResponse.getSearchResponse();
-        assertEquals(10, searchResponse.getTotalShards());
-        assertEquals(10, searchResponse.getSuccessfulShards());
-        assertEquals(0, searchResponse.getFailedShards());
-        assertEquals(0, searchResponse.getShardFailures().length);
-        assertEquals(10, searchResponse.getHits().getTotalHits().value);
-        assertEquals(0, searchResponse.getHits().getHits().length);
-        StringTerms terms = searchResponse.getAggregations().get("text");
-        assertEquals(1, terms.getBuckets().size());
-        assertEquals(10, terms.getBucketByKey("value").getDocCount());
-        assertNotNull(asyncSearchResponse.getFailure());
-        assertThat(asyncSearchResponse.getFailure(), CoreMatchers.instanceOf(ElasticsearchStatusException.class));
-        ElasticsearchStatusException statusException = (ElasticsearchStatusException) asyncSearchResponse.getFailure();
-        assertEquals(RestStatus.INTERNAL_SERVER_ERROR, statusException.status());
-        assertThat(asyncSearchResponse.getFailure().getCause(), CoreMatchers.instanceOf(SearchPhaseExecutionException.class));
-        SearchPhaseExecutionException phaseExecutionException = (SearchPhaseExecutionException) asyncSearchResponse.getFailure().getCause();
-        assertEquals("fetch", phaseExecutionException.getPhaseName());
-        assertEquals("boom", phaseExecutionException.getCause().getMessage());
-        assertEquals(10, phaseExecutionException.shardFailures().length);
-        for (ShardSearchFailure shardSearchFailure : phaseExecutionException.shardFailures()) {
-            assertEquals("boom", shardSearchFailure.getCause().getMessage());
-        }
+        assertResponse(client().execute(SubmitAsyncSearchAction.INSTANCE, submitAsyncSearchRequest), asyncSearchResponse -> {
+            assertFalse(asyncSearchResponse.isRunning());
+            assertTrue(asyncSearchResponse.isPartial());
+            SearchResponse searchResponse = asyncSearchResponse.getSearchResponse();
+            assertEquals(10, searchResponse.getTotalShards());
+            assertEquals(10, searchResponse.getSuccessfulShards());
+            assertEquals(0, searchResponse.getFailedShards());
+            assertEquals(0, searchResponse.getShardFailures().length);
+            assertEquals(10, searchResponse.getHits().getTotalHits().value);
+            assertEquals(0, searchResponse.getHits().getHits().length);
+            StringTerms terms = searchResponse.getAggregations().get("text");
+            assertEquals(1, terms.getBuckets().size());
+            assertEquals(10, terms.getBucketByKey("value").getDocCount());
+            assertNotNull(asyncSearchResponse.getFailure());
+            assertThat(asyncSearchResponse.getFailure(), CoreMatchers.instanceOf(ElasticsearchStatusException.class));
+            ElasticsearchStatusException statusException = (ElasticsearchStatusException) asyncSearchResponse.getFailure();
+            assertEquals(RestStatus.INTERNAL_SERVER_ERROR, statusException.status());
+            assertThat(asyncSearchResponse.getFailure().getCause(), CoreMatchers.instanceOf(SearchPhaseExecutionException.class));
+            SearchPhaseExecutionException phaseExecutionException = (SearchPhaseExecutionException) asyncSearchResponse.getFailure()
+                .getCause();
+            assertEquals("fetch", phaseExecutionException.getPhaseName());
+            assertEquals("boom", phaseExecutionException.getCause().getMessage());
+            assertEquals(10, phaseExecutionException.shardFailures().length);
+            for (ShardSearchFailure shardSearchFailure : phaseExecutionException.shardFailures()) {
+                assertEquals("boom", shardSearchFailure.getCause().getMessage());
+            }
+        });
     }
 
     public void testFetchFailuresOnlySomeShards() throws Exception {
@@ -96,24 +98,24 @@ public class AsyncSearchSingleNodeTests extends ESSingleNodeTestCase {
         SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().aggregation(agg);
         SubmitAsyncSearchRequest submitAsyncSearchRequest = new SubmitAsyncSearchRequest(sourceBuilder);
         submitAsyncSearchRequest.setWaitForCompletionTimeout(TimeValue.timeValueSeconds(10));
-        AsyncSearchResponse asyncSearchResponse = client().execute(SubmitAsyncSearchAction.INSTANCE, submitAsyncSearchRequest).actionGet();
-
-        assertFalse(asyncSearchResponse.isRunning());
-        assertFalse(asyncSearchResponse.isPartial());
-        assertNull(asyncSearchResponse.getFailure());
-        SearchResponse searchResponse = asyncSearchResponse.getSearchResponse();
-        assertEquals(10, searchResponse.getTotalShards());
-        assertEquals(5, searchResponse.getSuccessfulShards());
-        assertEquals(5, searchResponse.getFailedShards());
-        assertEquals(10, searchResponse.getHits().getTotalHits().value);
-        assertEquals(5, searchResponse.getHits().getHits().length);
-        StringTerms terms = searchResponse.getAggregations().get("text");
-        assertEquals(1, terms.getBuckets().size());
-        assertEquals(10, terms.getBucketByKey("value").getDocCount());
-        assertEquals(5, searchResponse.getShardFailures().length);
-        for (ShardSearchFailure shardFailure : searchResponse.getShardFailures()) {
-            assertEquals("boom", shardFailure.getCause().getMessage());
-        }
+        assertResponse(client().execute(SubmitAsyncSearchAction.INSTANCE, submitAsyncSearchRequest), asyncSearchResponse -> {
+            assertFalse(asyncSearchResponse.isRunning());
+            assertFalse(asyncSearchResponse.isPartial());
+            assertNull(asyncSearchResponse.getFailure());
+            SearchResponse searchResponse = asyncSearchResponse.getSearchResponse();
+            assertEquals(10, searchResponse.getTotalShards());
+            assertEquals(5, searchResponse.getSuccessfulShards());
+            assertEquals(5, searchResponse.getFailedShards());
+            assertEquals(10, searchResponse.getHits().getTotalHits().value);
+            assertEquals(5, searchResponse.getHits().getHits().length);
+            StringTerms terms = searchResponse.getAggregations().get("text");
+            assertEquals(1, terms.getBuckets().size());
+            assertEquals(10, terms.getBucketByKey("value").getDocCount());
+            assertEquals(5, searchResponse.getShardFailures().length);
+            for (ShardSearchFailure shardFailure : searchResponse.getShardFailures()) {
+                assertEquals("boom", shardFailure.getCause().getMessage());
+            }
+        });
     }
 
     public static final class SubFetchPhasePlugin extends Plugin implements SearchPlugin {

+ 81 - 54
x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/AsyncStatusResponseTests.java

@@ -268,30 +268,37 @@ public class AsyncStatusResponseTests extends AbstractWireSerializingTestCase<As
             searchId,
             AsyncSearchResponseTests.randomSearchResponse(ccs)
         );
-
-        if (asyncSearchResponse.getSearchResponse() == null
-            && asyncSearchResponse.getFailure() == null
-            && asyncSearchResponse.isRunning() == false) {
-            // if no longer running, the search should have recorded either a failure or a search response
-            // if not an Exception should be thrown
-            expectThrows(
-                IllegalStateException.class,
-                () -> AsyncStatusResponse.getStatusFromStoredSearch(asyncSearchResponse, 100, searchId)
-            );
-        } else {
-            AsyncStatusResponse statusFromStoredSearch = AsyncStatusResponse.getStatusFromStoredSearch(asyncSearchResponse, 100, searchId);
-            assertNotNull(statusFromStoredSearch);
-            if (statusFromStoredSearch.isRunning()) {
-                assertNull(
-                    "completion_status should only be present if search is no longer running",
-                    statusFromStoredSearch.getCompletionStatus()
+        try {
+            if (asyncSearchResponse.getSearchResponse() == null
+                && asyncSearchResponse.getFailure() == null
+                && asyncSearchResponse.isRunning() == false) {
+                // if no longer running, the search should have recorded either a failure or a search response
+                // if not an Exception should be thrown
+                expectThrows(
+                    IllegalStateException.class,
+                    () -> AsyncStatusResponse.getStatusFromStoredSearch(asyncSearchResponse, 100, searchId)
                 );
             } else {
-                assertNotNull(
-                    "completion_status should be present if search is no longer running",
-                    statusFromStoredSearch.getCompletionStatus()
+                AsyncStatusResponse statusFromStoredSearch = AsyncStatusResponse.getStatusFromStoredSearch(
+                    asyncSearchResponse,
+                    100,
+                    searchId
                 );
+                assertNotNull(statusFromStoredSearch);
+                if (statusFromStoredSearch.isRunning()) {
+                    assertNull(
+                        "completion_status should only be present if search is no longer running",
+                        statusFromStoredSearch.getCompletionStatus()
+                    );
+                } else {
+                    assertNotNull(
+                        "completion_status should be present if search is no longer running",
+                        statusFromStoredSearch.getCompletionStatus()
+                    );
+                }
             }
+        } finally {
+            asyncSearchResponse.decRef();
         }
     }
 
@@ -299,14 +306,18 @@ public class AsyncStatusResponseTests extends AbstractWireSerializingTestCase<As
         String searchId = randomSearchId();
         Exception error = new IllegalArgumentException("dummy");
         AsyncSearchResponse asyncSearchResponse = new AsyncSearchResponse(searchId, null, error, true, false, 100, 200);
-        AsyncStatusResponse statusFromStoredSearch = AsyncStatusResponse.getStatusFromStoredSearch(asyncSearchResponse, 100, searchId);
-        assertNotNull(statusFromStoredSearch);
-        assertEquals(statusFromStoredSearch.getCompletionStatus(), RestStatus.BAD_REQUEST);
-        assertTrue(statusFromStoredSearch.isPartial());
-        assertNull(statusFromStoredSearch.getClusters());
-        assertEquals(0, statusFromStoredSearch.getTotalShards());
-        assertEquals(0, statusFromStoredSearch.getSuccessfulShards());
-        assertEquals(0, statusFromStoredSearch.getSkippedShards());
+        try {
+            AsyncStatusResponse statusFromStoredSearch = AsyncStatusResponse.getStatusFromStoredSearch(asyncSearchResponse, 100, searchId);
+            assertNotNull(statusFromStoredSearch);
+            assertEquals(statusFromStoredSearch.getCompletionStatus(), RestStatus.BAD_REQUEST);
+            assertTrue(statusFromStoredSearch.isPartial());
+            assertNull(statusFromStoredSearch.getClusters());
+            assertEquals(0, statusFromStoredSearch.getTotalShards());
+            assertEquals(0, statusFromStoredSearch.getSuccessfulShards());
+            assertEquals(0, statusFromStoredSearch.getSkippedShards());
+        } finally {
+            asyncSearchResponse.decRef();
+        }
     }
 
     public void testGetStatusFromStoredSearchFailedShardsScenario() {
@@ -328,10 +339,14 @@ public class AsyncStatusResponseTests extends AbstractWireSerializingTestCase<As
         );
 
         AsyncSearchResponse asyncSearchResponse = new AsyncSearchResponse(searchId, searchResponse, null, false, false, 100, 200);
-        AsyncStatusResponse statusFromStoredSearch = AsyncStatusResponse.getStatusFromStoredSearch(asyncSearchResponse, 100, searchId);
-        assertNotNull(statusFromStoredSearch);
-        assertEquals(1, statusFromStoredSearch.getFailedShards());
-        assertEquals(statusFromStoredSearch.getCompletionStatus(), RestStatus.OK);
+        try {
+            AsyncStatusResponse statusFromStoredSearch = AsyncStatusResponse.getStatusFromStoredSearch(asyncSearchResponse, 100, searchId);
+            assertNotNull(statusFromStoredSearch);
+            assertEquals(1, statusFromStoredSearch.getFailedShards());
+            assertEquals(statusFromStoredSearch.getCompletionStatus(), RestStatus.OK);
+        } finally {
+            asyncSearchResponse.decRef();
+        }
     }
 
     public void testGetStatusFromStoredSearchWithEmptyClustersSuccessfullyCompleted() {
@@ -353,10 +368,14 @@ public class AsyncStatusResponseTests extends AbstractWireSerializingTestCase<As
         );
 
         AsyncSearchResponse asyncSearchResponse = new AsyncSearchResponse(searchId, searchResponse, null, false, false, 100, 200);
-        AsyncStatusResponse statusFromStoredSearch = AsyncStatusResponse.getStatusFromStoredSearch(asyncSearchResponse, 100, searchId);
-        assertNotNull(statusFromStoredSearch);
-        assertEquals(statusFromStoredSearch.getCompletionStatus(), RestStatus.OK);
-        assertNull(statusFromStoredSearch.getClusters());
+        try {
+            AsyncStatusResponse statusFromStoredSearch = AsyncStatusResponse.getStatusFromStoredSearch(asyncSearchResponse, 100, searchId);
+            assertNotNull(statusFromStoredSearch);
+            assertEquals(statusFromStoredSearch.getCompletionStatus(), RestStatus.OK);
+            assertNull(statusFromStoredSearch.getClusters());
+        } finally {
+            asyncSearchResponse.decRef();
+        }
     }
 
     public void testGetStatusFromStoredSearchWithNonEmptyClustersSuccessfullyCompleted() {
@@ -396,16 +415,20 @@ public class AsyncStatusResponseTests extends AbstractWireSerializingTestCase<As
         );
 
         AsyncSearchResponse asyncSearchResponse = new AsyncSearchResponse(searchId, searchResponse, null, false, false, 100, 200);
-        AsyncStatusResponse statusFromStoredSearch = AsyncStatusResponse.getStatusFromStoredSearch(asyncSearchResponse, 100, searchId);
-        assertNotNull(statusFromStoredSearch);
-        assertEquals(0, statusFromStoredSearch.getFailedShards());
-        assertEquals(statusFromStoredSearch.getCompletionStatus(), RestStatus.OK);
-        assertEquals(totalClusters, statusFromStoredSearch.getClusters().getTotal());
-        assertEquals(skippedClusters, statusFromStoredSearch.getClusters().getClusterStateCount(SearchResponse.Cluster.Status.SKIPPED));
-        assertEquals(
-            successfulClusters,
-            statusFromStoredSearch.getClusters().getClusterStateCount(SearchResponse.Cluster.Status.SUCCESSFUL)
-        );
+        try {
+            AsyncStatusResponse statusFromStoredSearch = AsyncStatusResponse.getStatusFromStoredSearch(asyncSearchResponse, 100, searchId);
+            assertNotNull(statusFromStoredSearch);
+            assertEquals(0, statusFromStoredSearch.getFailedShards());
+            assertEquals(statusFromStoredSearch.getCompletionStatus(), RestStatus.OK);
+            assertEquals(totalClusters, statusFromStoredSearch.getClusters().getTotal());
+            assertEquals(skippedClusters, statusFromStoredSearch.getClusters().getClusterStateCount(SearchResponse.Cluster.Status.SKIPPED));
+            assertEquals(
+                successfulClusters,
+                statusFromStoredSearch.getClusters().getClusterStateCount(SearchResponse.Cluster.Status.SUCCESSFUL)
+            );
+        } finally {
+            asyncSearchResponse.decRef();
+        }
     }
 
     public void testGetStatusFromStoredSearchWithNonEmptyClustersStillRunning() {
@@ -442,13 +465,17 @@ public class AsyncStatusResponseTests extends AbstractWireSerializingTestCase<As
 
         boolean isRunning = true;
         AsyncSearchResponse asyncSearchResponse = new AsyncSearchResponse(searchId, searchResponse, null, false, isRunning, 100, 200);
-        AsyncStatusResponse statusFromStoredSearch = AsyncStatusResponse.getStatusFromStoredSearch(asyncSearchResponse, 100, searchId);
-        assertNotNull(statusFromStoredSearch);
-        assertEquals(0, statusFromStoredSearch.getFailedShards());
-        assertNull("completion_status should not be present if still running", statusFromStoredSearch.getCompletionStatus());
-        assertEquals(100, statusFromStoredSearch.getClusters().getTotal());
-        assertEquals(successful, statusFromStoredSearch.getClusters().getClusterStateCount(SearchResponse.Cluster.Status.SUCCESSFUL));
-        assertEquals(partial, statusFromStoredSearch.getClusters().getClusterStateCount(SearchResponse.Cluster.Status.PARTIAL));
-        assertEquals(skipped, statusFromStoredSearch.getClusters().getClusterStateCount(SearchResponse.Cluster.Status.SKIPPED));
+        try {
+            AsyncStatusResponse statusFromStoredSearch = AsyncStatusResponse.getStatusFromStoredSearch(asyncSearchResponse, 100, searchId);
+            assertNotNull(statusFromStoredSearch);
+            assertEquals(0, statusFromStoredSearch.getFailedShards());
+            assertNull("completion_status should not be present if still running", statusFromStoredSearch.getCompletionStatus());
+            assertEquals(100, statusFromStoredSearch.getClusters().getTotal());
+            assertEquals(successful, statusFromStoredSearch.getClusters().getClusterStateCount(SearchResponse.Cluster.Status.SUCCESSFUL));
+            assertEquals(partial, statusFromStoredSearch.getClusters().getClusterStateCount(SearchResponse.Cluster.Status.PARTIAL));
+            assertEquals(skipped, statusFromStoredSearch.getClusters().getClusterStateCount(SearchResponse.Cluster.Status.SKIPPED));
+        } finally {
+            asyncSearchResponse.decRef();
+        }
     }
 }

+ 2 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/async/AsyncResponse.java

@@ -8,8 +8,9 @@
 package org.elasticsearch.xpack.core.async;
 
 import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.core.RefCounted;
 
-public interface AsyncResponse<T extends AsyncResponse<?>> extends Writeable {
+public interface AsyncResponse<T extends AsyncResponse<?>> extends Writeable, RefCounted {
     /**
      * When this response will expire as a timestamp in milliseconds since epoch.
      */

+ 6 - 2
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/async/AsyncTaskIndexService.java

@@ -441,7 +441,7 @@ public final class AsyncTaskIndexService<R extends AsyncResponse<R>> {
                 listener.onFailure(e);
                 return;
             }
-            listener.onResponse(resp);
+            ActionListener.respondAndRelease(listener, resp);
         }));
     }
 
@@ -490,7 +490,11 @@ public final class AsyncTaskIndexService<R extends AsyncResponse<R>> {
             }
             Objects.requireNonNull(resp, "Get result doesn't include [" + RESULT_FIELD + "] field");
             Objects.requireNonNull(expirationTime, "Get result doesn't include [" + EXPIRATION_TIME_FIELD + "] field");
-            return resp.withExpirationTime(expirationTime);
+            try {
+                return resp.withExpirationTime(expirationTime);
+            } finally {
+                resp.decRef();
+            }
         } catch (IOException e) {
             throw new ElasticsearchParseException("Failed to parse the get result", e);
         }

+ 35 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/search/action/AsyncSearchResponse.java

@@ -15,9 +15,12 @@ import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
 import org.elasticsearch.common.xcontent.ChunkedToXContentObject;
+import org.elasticsearch.core.AbstractRefCounted;
 import org.elasticsearch.core.Nullable;
+import org.elasticsearch.core.RefCounted;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.rest.RestStatus;
+import org.elasticsearch.transport.LeakTracker;
 import org.elasticsearch.xcontent.ToXContent;
 import org.elasticsearch.xpack.core.async.AsyncResponse;
 
@@ -43,6 +46,15 @@ public class AsyncSearchResponse extends ActionResponse implements ChunkedToXCon
     private final long startTimeMillis;
     private final long expirationTimeMillis;
 
+    private final RefCounted refCounted = LeakTracker.wrap(new AbstractRefCounted() {
+        @Override
+        protected void closeInternal() {
+            if (searchResponse != null) {
+                searchResponse.decRef();
+            }
+        }
+    });
+
     /**
      * Creates an {@link AsyncSearchResponse} with meta-information only (not-modified).
      */
@@ -72,6 +84,9 @@ public class AsyncSearchResponse extends ActionResponse implements ChunkedToXCon
     ) {
         this.id = id;
         this.error = error;
+        if (searchResponse != null) {
+            searchResponse.mustIncRef();
+        }
         this.searchResponse = searchResponse;
         this.isPartial = isPartial;
         this.isRunning = isRunning;
@@ -105,6 +120,26 @@ public class AsyncSearchResponse extends ActionResponse implements ChunkedToXCon
         out.writeLong(expirationTimeMillis);
     }
 
+    @Override
+    public void incRef() {
+        refCounted.incRef();
+    }
+
+    @Override
+    public boolean tryIncRef() {
+        return refCounted.tryIncRef();
+    }
+
+    @Override
+    public boolean decRef() {
+        return refCounted.decRef();
+    }
+
+    @Override
+    public boolean hasReferences() {
+        return refCounted.hasReferences();
+    }
+
     public AsyncSearchResponse clone(String searchId) {
         return new AsyncSearchResponse(searchId, searchResponse, error, isPartial, isRunning, startTimeMillis, expirationTimeMillis);
     }

+ 18 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/async/AsyncSearchIndexServiceTests.java

@@ -111,6 +111,24 @@ public class AsyncSearchIndexServiceTests extends ESSingleNodeTestCase {
         public TestAsyncResponse convertToFailure(Exception exc) {
             return new TestAsyncResponse(test, expirationTimeMillis, exc.getMessage());
         }
+
+        @Override
+        public void incRef() {}
+
+        @Override
+        public boolean tryIncRef() {
+            return true;
+        }
+
+        @Override
+        public boolean decRef() {
+            return false;
+        }
+
+        @Override
+        public boolean hasReferences() {
+            return true;
+        }
     }
 
     @Before