Browse Source

Refactor SearchResponseClusters to use CHM (#100129)

Swapped `Map<String, AtomicReference<Cluster>>` with 
`ConcurrentHashMap<String, Cluster>` in order to avoid  Compare-and-Swap
loops.

Closes #99101
Matteo Piergiovanni 2 years ago
parent
commit
608bb02cf7

+ 6 - 0
docs/changelog/100129.yaml

@@ -0,0 +1,6 @@
+pr: 100129
+summary: Refactor `SearchResponseClusters` to use CHM
+area: Search
+type: enhancement
+issues:
+ - 99101

+ 12 - 13
server/src/internalClusterTest/java/org/elasticsearch/action/search/CCSPointInTimeIT.java

@@ -26,7 +26,6 @@ import java.util.ArrayList;
 import java.util.Collection;
 import java.util.List;
 import java.util.concurrent.ExecutionException;
-import java.util.concurrent.atomic.AtomicReference;
 
 import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
 import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertHitCount;
@@ -112,14 +111,14 @@ public class CCSPointInTimeIT extends AbstractMultiClustersTestCase {
             assertThat(clusters.getClusterStateCount(SearchResponse.Cluster.Status.SKIPPED), equalTo(0));
 
             if (includeLocalIndex) {
-                AtomicReference<SearchResponse.Cluster> localClusterRef = clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY);
-                assertNotNull(localClusterRef);
-                assertOneSuccessfulShard(localClusterRef.get());
+                SearchResponse.Cluster localCluster = clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY);
+                assertNotNull(localCluster);
+                assertOneSuccessfulShard(localCluster);
             }
 
-            AtomicReference<SearchResponse.Cluster> remoteClusterRef = clusters.getCluster(REMOTE_CLUSTER);
-            assertNotNull(remoteClusterRef);
-            assertOneSuccessfulShard(remoteClusterRef.get());
+            SearchResponse.Cluster remoteCluster = clusters.getCluster(REMOTE_CLUSTER);
+            assertNotNull(remoteCluster);
+            assertOneSuccessfulShard(remoteCluster);
 
         } finally {
             closePointInTime(pitId);
@@ -168,13 +167,13 @@ public class CCSPointInTimeIT extends AbstractMultiClustersTestCase {
             assertThat(clusters.getClusterStateCount(SearchResponse.Cluster.Status.PARTIAL), equalTo(expectedNumClusters));
 
             if (includeLocalIndex) {
-                AtomicReference<SearchResponse.Cluster> localClusterRef = clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY);
-                assertNotNull(localClusterRef);
-                assertOneFailedShard(localClusterRef.get(), numShards);
+                SearchResponse.Cluster localCluster = clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY);
+                assertNotNull(localCluster);
+                assertOneFailedShard(localCluster, numShards);
             }
-            AtomicReference<SearchResponse.Cluster> remoteClusterRef = clusters.getCluster(REMOTE_CLUSTER);
-            assertNotNull(remoteClusterRef);
-            assertOneFailedShard(remoteClusterRef.get(), numShards);
+            SearchResponse.Cluster remoteCluster = clusters.getCluster(REMOTE_CLUSTER);
+            assertNotNull(remoteCluster);
+            assertOneFailedShard(remoteCluster, numShards);
 
         } finally {
             closePointInTime(pitId);

+ 13 - 13
server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterSearchIT.java

@@ -131,7 +131,7 @@ public class CrossClusterSearchIT extends AbstractMultiClustersTestCase {
         assertThat(clusters.getClusterStateCount(SearchResponse.Cluster.Status.PARTIAL), equalTo(0));
         assertThat(clusters.getClusterStateCount(SearchResponse.Cluster.Status.FAILED), equalTo(0));
 
-        SearchResponse.Cluster localClusterSearchInfo = clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY).get();
+        SearchResponse.Cluster localClusterSearchInfo = clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY);
         assertNotNull(localClusterSearchInfo);
         assertThat(localClusterSearchInfo.getStatus(), equalTo(SearchResponse.Cluster.Status.SUCCESSFUL));
         assertThat(localClusterSearchInfo.getIndexExpression(), equalTo(localIndex));
@@ -142,7 +142,7 @@ public class CrossClusterSearchIT extends AbstractMultiClustersTestCase {
         assertThat(localClusterSearchInfo.getFailures().size(), equalTo(0));
         assertThat(localClusterSearchInfo.getTook().millis(), greaterThan(0L));
 
-        SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER).get();
+        SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER);
         assertNotNull(remoteClusterSearchInfo);
         assertThat(remoteClusterSearchInfo.getStatus(), equalTo(SearchResponse.Cluster.Status.SUCCESSFUL));
         assertThat(remoteClusterSearchInfo.getIndexExpression(), equalTo(remoteIndex));
@@ -195,9 +195,9 @@ public class CrossClusterSearchIT extends AbstractMultiClustersTestCase {
         assertThat(clusters.getClusterStateCount(SearchResponse.Cluster.Status.PARTIAL), equalTo(0));
         assertThat(clusters.getClusterStateCount(SearchResponse.Cluster.Status.FAILED), equalTo(0));
 
-        SearchResponse.Cluster localClusterSearchInfo = clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY).get();
+        SearchResponse.Cluster localClusterSearchInfo = clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY);
         assertNotNull(localClusterSearchInfo);
-        SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER).get();
+        SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER);
         assertNotNull(remoteClusterSearchInfo);
 
         assertThat(localClusterSearchInfo.getStatus(), equalTo(SearchResponse.Cluster.Status.SUCCESSFUL));
@@ -262,11 +262,11 @@ public class CrossClusterSearchIT extends AbstractMultiClustersTestCase {
         assertThat(clusters.getClusterStateCount(SearchResponse.Cluster.Status.PARTIAL), equalTo(2));
         assertThat(clusters.getClusterStateCount(SearchResponse.Cluster.Status.FAILED), equalTo(0));
 
-        SearchResponse.Cluster localClusterSearchInfo = clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY).get();
+        SearchResponse.Cluster localClusterSearchInfo = clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY);
         assertNotNull(localClusterSearchInfo);
         assertOneFailedShard(localClusterSearchInfo, localNumShards);
 
-        SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER).get();
+        SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER);
         assertNotNull(remoteClusterSearchInfo);
         assertOneFailedShard(remoteClusterSearchInfo, remoteNumShards);
     }
@@ -334,7 +334,7 @@ public class CrossClusterSearchIT extends AbstractMultiClustersTestCase {
                 assertThat(clusters.getClusterStateCount(SearchResponse.Cluster.Status.FAILED), equalTo(1));
             }
 
-            SearchResponse.Cluster localClusterSearchInfo = clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY).get();
+            SearchResponse.Cluster localClusterSearchInfo = clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY);
             assertNotNull(localClusterSearchInfo);
             assertThat(localClusterSearchInfo.getStatus(), equalTo(SearchResponse.Cluster.Status.SUCCESSFUL));
             assertThat(localClusterSearchInfo.getTotalShards(), equalTo(localNumShards));
@@ -344,7 +344,7 @@ public class CrossClusterSearchIT extends AbstractMultiClustersTestCase {
             assertThat(localClusterSearchInfo.getFailures().size(), equalTo(0));
             assertThat(localClusterSearchInfo.getTook().millis(), greaterThan(0L));
 
-            SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER).get();
+            SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER);
 
             assertNotNull(remoteClusterSearchInfo);
             SearchResponse.Cluster.Status expectedStatus = skipUnavailable
@@ -409,7 +409,7 @@ public class CrossClusterSearchIT extends AbstractMultiClustersTestCase {
         assertThat(clusters.getClusterStateCount(SearchResponse.Cluster.Status.PARTIAL), equalTo(2));
         assertThat(clusters.getClusterStateCount(SearchResponse.Cluster.Status.FAILED), equalTo(0));
 
-        SearchResponse.Cluster localClusterSearchInfo = clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY).get();
+        SearchResponse.Cluster localClusterSearchInfo = clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY);
         assertNotNull(localClusterSearchInfo);
         assertThat(localClusterSearchInfo.getStatus(), equalTo(SearchResponse.Cluster.Status.PARTIAL));
         assertTrue(localClusterSearchInfo.isTimedOut());
@@ -421,7 +421,7 @@ public class CrossClusterSearchIT extends AbstractMultiClustersTestCase {
         assertThat(localClusterSearchInfo.getFailures().size(), equalTo(0));
         assertThat(localClusterSearchInfo.getTook().millis(), greaterThanOrEqualTo(0L));
 
-        SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER).get();
+        SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER);
         assertNotNull(remoteClusterSearchInfo);
         assertThat(remoteClusterSearchInfo.getStatus(), equalTo(SearchResponse.Cluster.Status.PARTIAL));
         assertTrue(remoteClusterSearchInfo.isTimedOut());
@@ -467,7 +467,7 @@ public class CrossClusterSearchIT extends AbstractMultiClustersTestCase {
 
         assertNull(clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY));
 
-        SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER).get();
+        SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER);
         assertNotNull(remoteClusterSearchInfo);
         assertThat(remoteClusterSearchInfo.getStatus(), equalTo(SearchResponse.Cluster.Status.SUCCESSFUL));
         assertThat(remoteClusterSearchInfo.getTotalShards(), equalTo(remoteNumShards));
@@ -513,7 +513,7 @@ public class CrossClusterSearchIT extends AbstractMultiClustersTestCase {
 
         assertNull(clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY));
 
-        SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER).get();
+        SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER);
         assertNotNull(remoteClusterSearchInfo);
         assertOneFailedShard(remoteClusterSearchInfo, remoteNumShards);
     }
@@ -569,7 +569,7 @@ public class CrossClusterSearchIT extends AbstractMultiClustersTestCase {
 
             assertNull(clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY));
 
-            SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER).get();
+            SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER);
             assertNotNull(remoteClusterSearchInfo);
             SearchResponse.Cluster.Status expectedStatus = skipUnavailable
                 ? SearchResponse.Cluster.Status.SKIPPED

+ 54 - 83
server/src/main/java/org/elasticsearch/action/search/CCSSingleCoordinatorSearchProgressListener.java

@@ -20,7 +20,6 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
-import java.util.concurrent.atomic.AtomicReference;
 
 /**
  * Use this progress listener for cross-cluster searches where a single
@@ -63,17 +62,15 @@ public class CCSSingleCoordinatorSearchProgressListener extends SearchProgressLi
 
         for (Map.Entry<String, Integer> entry : totalByClusterAlias.entrySet()) {
             String clusterAlias = entry.getKey();
-            AtomicReference<SearchResponse.Cluster> clusterRef = clusters.getCluster(clusterAlias);
-            assert clusterRef.get().getTotalShards() == null : "total shards should not be set on a Cluster before onListShards";
 
-            int totalCount = entry.getValue();
-            int skippedCount = skippedByClusterAlias.getOrDefault(clusterAlias, 0);
-            TimeValue took = null;
+            clusters.swapCluster(clusterAlias, (k, v) -> {
+                assert v.getTotalShards() == null : "total shards should not be set on a Cluster before onListShards";
 
-            boolean swapped;
-            do {
-                SearchResponse.Cluster curr = clusterRef.get();
-                SearchResponse.Cluster.Status status = curr.getStatus();
+                int totalCount = entry.getValue();
+                int skippedCount = skippedByClusterAlias.getOrDefault(k, 0);
+                TimeValue took = null;
+
+                SearchResponse.Cluster.Status status = v.getStatus();
                 assert status == SearchResponse.Cluster.Status.RUNNING : "should have RUNNING status during onListShards but has " + status;
 
                 // if all shards are marked as skipped, the search is done - mark as SUCCESSFUL
@@ -81,8 +78,7 @@ public class CCSSingleCoordinatorSearchProgressListener extends SearchProgressLi
                     took = new TimeValue(timeProvider.buildTookInMillis());
                     status = SearchResponse.Cluster.Status.SUCCESSFUL;
                 }
-
-                SearchResponse.Cluster updated = new SearchResponse.Cluster.Builder(curr).setStatus(status)
+                return new SearchResponse.Cluster.Builder(v).setStatus(status)
                     .setTotalShards(totalCount)
                     .setSuccessfulShards(skippedCount)
                     .setSkippedShards(skippedCount)
@@ -90,10 +86,7 @@ public class CCSSingleCoordinatorSearchProgressListener extends SearchProgressLi
                     .setTook(took)
                     .setTimedOut(false)
                     .build();
-
-                swapped = clusterRef.compareAndSet(curr, updated);
-                assert swapped : "compareAndSet in onListShards should never fail due to race condition";
-            } while (swapped == false);
+            });
         }
     }
 
@@ -115,19 +108,16 @@ public class CCSSingleCoordinatorSearchProgressListener extends SearchProgressLi
             if (clusterAlias == null) {
                 clusterAlias = RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY;
             }
-            AtomicReference<SearchResponse.Cluster> clusterRef = clusters.getCluster(clusterAlias);
-            boolean swapped;
-            do {
-                SearchResponse.Cluster curr = clusterRef.get();
-                if (curr.isTimedOut()) {
-                    break; // cluster has already been marked as timed out on some other shard
+
+            clusters.swapCluster(clusterAlias, (k, v) -> {
+                if (v.isTimedOut()) {
+                    return v; // cluster has already been marked as timed out on some other shard
                 }
-                if (curr.getStatus() == SearchResponse.Cluster.Status.FAILED || curr.getStatus() == SearchResponse.Cluster.Status.SKIPPED) {
-                    break; // safety check to make sure it hasn't hit a terminal FAILED/SKIPPED state where timeouts don't matter
+                if (v.getStatus() == SearchResponse.Cluster.Status.FAILED || v.getStatus() == SearchResponse.Cluster.Status.SKIPPED) {
+                    return v; // safety check to make sure it hasn't hit a terminal FAILED/SKIPPED state where timeouts don't matter
                 }
-                SearchResponse.Cluster updated = new SearchResponse.Cluster.Builder(curr).setTimedOut(true).build();
-                swapped = clusterRef.compareAndSet(curr, updated);
-            } while (swapped == false);
+                return new SearchResponse.Cluster.Builder(v).setTimedOut(true).build();
+            });
         }
     }
 
@@ -147,37 +137,34 @@ public class CCSSingleCoordinatorSearchProgressListener extends SearchProgressLi
         if (clusterAlias == null) {
             clusterAlias = RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY;
         }
-        AtomicReference<SearchResponse.Cluster> clusterRef = clusters.getCluster(clusterAlias);
-        boolean swapped;
-        do {
-            TimeValue took = null;
-            SearchResponse.Cluster curr = clusterRef.get();
-            SearchResponse.Cluster.Status status = SearchResponse.Cluster.Status.RUNNING;
-            int numFailedShards = curr.getFailedShards() == null ? 1 : curr.getFailedShards() + 1;
 
-            assert curr.getTotalShards() != null : "total shards should be set on the Cluster but not for " + clusterAlias;
-            if (curr.getTotalShards() == numFailedShards) {
-                if (curr.isSkipUnavailable()) {
+        clusters.swapCluster(clusterAlias, (k, v) -> {
+            TimeValue took;
+            SearchResponse.Cluster.Status status;
+            int numFailedShards = v.getFailedShards() == null ? 1 : v.getFailedShards() + 1;
+
+            assert v.getTotalShards() != null : "total shards should be set on the Cluster but not for " + k;
+            if (v.getTotalShards() == numFailedShards) {
+                took = null;
+                if (v.isSkipUnavailable()) {
                     status = SearchResponse.Cluster.Status.SKIPPED;
                 } else {
                     status = SearchResponse.Cluster.Status.FAILED;
                     // TODO in the fail-fast ticket, should we throw an exception here to stop the search?
                 }
-            } else if (curr.getTotalShards() == numFailedShards + curr.getSuccessfulShards()) {
+            } else if (v.getTotalShards() == numFailedShards + v.getSuccessfulShards()) {
                 status = SearchResponse.Cluster.Status.PARTIAL;
                 took = new TimeValue(timeProvider.buildTookInMillis());
+            } else {
+                took = null;
+                status = SearchResponse.Cluster.Status.RUNNING;
             }
-
-            // creates a new unmodifiable list
-            List<ShardSearchFailure> failures = CollectionUtils.appendToCopy(curr.getFailures(), new ShardSearchFailure(e, shardTarget));
-            SearchResponse.Cluster updated = new SearchResponse.Cluster.Builder(curr).setStatus(status)
+            return new SearchResponse.Cluster.Builder(v).setStatus(status)
                 .setFailedShards(numFailedShards)
-                .setFailures(failures)
+                .setFailures(CollectionUtils.appendToCopy(v.getFailures(), new ShardSearchFailure(e, shardTarget)))
                 .setTook(took)
                 .build();
-
-            swapped = clusterRef.compareAndSet(curr, updated);
-        } while (swapped == false);
+        });
     }
 
     /**
@@ -202,32 +189,23 @@ public class CCSSingleCoordinatorSearchProgressListener extends SearchProgressLi
             String clusterAlias = entry.getKey();
             int successfulCount = entry.getValue().intValue();
 
-            AtomicReference<SearchResponse.Cluster> clusterRef = clusters.getCluster(clusterAlias);
-            boolean swapped;
-            do {
-                SearchResponse.Cluster curr = clusterRef.get();
-                SearchResponse.Cluster.Status status = curr.getStatus();
+            clusters.swapCluster(clusterAlias, (k, v) -> {
+                SearchResponse.Cluster.Status status = v.getStatus();
                 if (status != SearchResponse.Cluster.Status.RUNNING) {
                     // don't swap in a new Cluster if the final state has already been set
-                    break;
+                    return v;
                 }
                 TimeValue took = null;
-                int successfulShards = successfulCount + curr.getSkippedShards();
-                if (successfulShards == curr.getTotalShards()) {
-                    status = curr.isTimedOut() ? SearchResponse.Cluster.Status.PARTIAL : SearchResponse.Cluster.Status.SUCCESSFUL;
+                int successfulShards = successfulCount + v.getSkippedShards();
+                if (successfulShards == v.getTotalShards()) {
+                    status = v.isTimedOut() ? SearchResponse.Cluster.Status.PARTIAL : SearchResponse.Cluster.Status.SUCCESSFUL;
                     took = new TimeValue(timeProvider.buildTookInMillis());
-                } else if (successfulShards + curr.getFailedShards() == curr.getTotalShards()) {
+                } else if (successfulShards + v.getFailedShards() == v.getTotalShards()) {
                     status = SearchResponse.Cluster.Status.PARTIAL;
                     took = new TimeValue(timeProvider.buildTookInMillis());
                 }
-
-                SearchResponse.Cluster updated = new SearchResponse.Cluster.Builder(curr).setStatus(status)
-                    .setSuccessfulShards(successfulShards)
-                    .setTook(took)
-                    .build();
-
-                swapped = clusterRef.compareAndSet(curr, updated);
-            } while (swapped == false);
+                return new SearchResponse.Cluster.Builder(v).setStatus(status).setSuccessfulShards(successfulShards).setTook(took).build();
+            });
         }
     }
 
@@ -254,38 +232,31 @@ public class CCSSingleCoordinatorSearchProgressListener extends SearchProgressLi
             String clusterAlias = entry.getKey();
             int successfulCount = entry.getValue().intValue();
 
-            AtomicReference<SearchResponse.Cluster> clusterRef = clusters.getCluster(clusterAlias);
-            boolean swapped;
-            do {
-                SearchResponse.Cluster curr = clusterRef.get();
-                SearchResponse.Cluster.Status status = curr.getStatus();
+            clusters.swapCluster(clusterAlias, (k, v) -> {
+                SearchResponse.Cluster.Status status = v.getStatus();
                 if (status != SearchResponse.Cluster.Status.RUNNING) {
                     // don't swap in a new Cluster if the final state has already been set
-                    break;
+                    return v;
                 }
                 TimeValue took = new TimeValue(timeProvider.buildTookInMillis());
-                int successfulShards = successfulCount + curr.getSkippedShards();
-                assert successfulShards + curr.getFailedShards() == curr.getTotalShards()
+                int successfulShards = successfulCount + v.getSkippedShards();
+                assert successfulShards + v.getFailedShards() == v.getTotalShards()
                     : "successfulShards("
                         + successfulShards
                         + ") + failedShards("
-                        + curr.getFailedShards()
+                        + v.getFailedShards()
                         + ") != totalShards ("
-                        + curr.getTotalShards()
+                        + v.getTotalShards()
                         + ')';
-                if (curr.isTimedOut() || successfulShards < curr.getTotalShards()) {
+                if (v.isTimedOut() || successfulShards < v.getTotalShards()) {
                     status = SearchResponse.Cluster.Status.PARTIAL;
                 } else {
-                    assert successfulShards == curr.getTotalShards()
-                        : "successful (" + successfulShards + ") should equal total(" + curr.getTotalShards() + ") if get here";
+                    assert successfulShards == v.getTotalShards()
+                        : "successful (" + successfulShards + ") should equal total(" + v.getTotalShards() + ") if get here";
                     status = SearchResponse.Cluster.Status.SUCCESSFUL;
                 }
-                SearchResponse.Cluster updated = new SearchResponse.Cluster.Builder(curr).setStatus(status)
-                    .setSuccessfulShards(successfulShards)
-                    .setTook(took)
-                    .build();
-                swapped = clusterRef.compareAndSet(curr, updated);
-            } while (swapped == false);
+                return new SearchResponse.Cluster.Builder(v).setStatus(status).setSuccessfulShards(successfulShards).setTook(took).build();
+            });
         }
     }
 

+ 44 - 21
server/src/main/java/org/elasticsearch/action/search/SearchResponse.java

@@ -17,6 +17,7 @@ import org.elasticsearch.common.collect.Iterators;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
 import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
 import org.elasticsearch.common.xcontent.ChunkedToXContentObject;
 import org.elasticsearch.core.Nullable;
@@ -42,13 +43,12 @@ import org.elasticsearch.xcontent.XContentParser.Token;
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Collections;
-import java.util.HashMap;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Locale;
 import java.util.Map;
 import java.util.Objects;
-import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.BiFunction;
 import java.util.function.Predicate;
 import java.util.function.Supplier;
 
@@ -474,8 +474,9 @@ public class SearchResponse extends ActionResponse implements ChunkedToXContentO
 
         // key to map is clusterAlias on the primary querying cluster of a CCS minimize_roundtrips=true query
         // the Map itself is immutable after construction - all Clusters will be accounted for at the start of the search
-        // updates to the Cluster occur by CAS swapping in new Cluster objects into the AtomicReference in the map.
-        private final Map<String, AtomicReference<Cluster>> clusterInfo;
+        // updates to the Cluster occur with the updateCluster method that given the key to map transforms an
+        // old Cluster Object to a new Cluster Object with the remapping function.
+        private final Map<String, Cluster> clusterInfo;
 
         // not Writeable since it is only needed on the (primary) CCS coordinator
         private transient Boolean ccsMinimizeRoundtrips;
@@ -503,19 +504,19 @@ public class SearchResponse extends ActionResponse implements ChunkedToXContentO
             this.successful = 0; // calculated from clusterInfo map for minimize_roundtrips
             this.skipped = 0;    // calculated from clusterInfo map for minimize_roundtrips
             this.ccsMinimizeRoundtrips = ccsMinimizeRoundtrips;
-            Map<String, AtomicReference<Cluster>> m = new HashMap<>();
+            Map<String, Cluster> m = ConcurrentCollections.newConcurrentMap();
             if (localIndices != null) {
                 String localKey = RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY;
                 Cluster c = new Cluster(localKey, String.join(",", localIndices.indices()), false);
-                m.put(localKey, new AtomicReference<>(c));
+                m.put(localKey, c);
             }
             for (Map.Entry<String, OriginalIndices> remote : remoteClusterIndices.entrySet()) {
                 String clusterAlias = remote.getKey();
                 boolean skipUnavailable = skipUnavailablePredicate.test(clusterAlias);
                 Cluster c = new Cluster(clusterAlias, String.join(",", remote.getValue().indices()), skipUnavailable);
-                m.put(clusterAlias, new AtomicReference<>(c));
+                m.put(clusterAlias, c);
             }
-            this.clusterInfo = Collections.unmodifiableMap(m);
+            this.clusterInfo = m;
         }
 
         /**
@@ -548,9 +549,9 @@ public class SearchResponse extends ActionResponse implements ChunkedToXContentO
                     this.successful = successfulTemp;
                     this.skipped = skippedTemp;
                 } else {
-                    Map<String, AtomicReference<Cluster>> m = new HashMap<>();
-                    clusterList.forEach(c -> m.put(c.getClusterAlias(), new AtomicReference<>(c)));
-                    this.clusterInfo = Collections.unmodifiableMap(m);
+                    Map<String, Cluster> m = ConcurrentCollections.newConcurrentMap();
+                    clusterList.forEach(c -> m.put(c.getClusterAlias(), c));
+                    this.clusterInfo = m;
                     this.successful = getClusterStateCount(Cluster.Status.SUCCESSFUL);
                     this.skipped = getClusterStateCount(Cluster.Status.SKIPPED);
                 }
@@ -579,7 +580,7 @@ public class SearchResponse extends ActionResponse implements ChunkedToXContentO
                     + failed;
         }
 
-        private Clusters(Map<String, AtomicReference<Cluster>> clusterInfoMap) {
+        private Clusters(Map<String, Cluster> clusterInfoMap) {
             assert clusterInfoMap.size() > 0 : "this constructor should not be called with an empty Cluster info map";
             this.total = clusterInfoMap.size();
             this.clusterInfo = clusterInfoMap;
@@ -596,7 +597,7 @@ public class SearchResponse extends ActionResponse implements ChunkedToXContentO
             out.writeVInt(skipped);
             if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_500_053)) {
                 if (clusterInfo != null) {
-                    List<Cluster> clusterList = clusterInfo.values().stream().map(AtomicReference::get).toList();
+                    List<Cluster> clusterList = clusterInfo.values().stream().toList();
                     out.writeCollection(clusterList);
                 } else {
                     out.writeCollection(Collections.emptyList());
@@ -616,8 +617,8 @@ public class SearchResponse extends ActionResponse implements ChunkedToXContentO
                 builder.field(FAILED_FIELD.getPreferredName(), getClusterStateCount(Cluster.Status.FAILED));
                 if (clusterInfo.size() > 0) {
                     builder.startObject("details");
-                    for (AtomicReference<Cluster> cluster : clusterInfo.values()) {
-                        cluster.get().toXContent(builder, params);
+                    for (Cluster cluster : clusterInfo.values()) {
+                        cluster.toXContent(builder, params);
                     }
                     builder.endObject();
                 }
@@ -635,7 +636,7 @@ public class SearchResponse extends ActionResponse implements ChunkedToXContentO
             int running = 0;    // 0 for BWC
             int partial = 0;    // 0 for BWC
             int failed = 0;     // 0 for BWC
-            Map<String, AtomicReference<Cluster>> clusterInfoMap = new HashMap<>();
+            Map<String, Cluster> clusterInfoMap = ConcurrentCollections.newConcurrentMap();
             String currentFieldName = null;
             while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
                 if (token == XContentParser.Token.FIELD_NAME) {
@@ -664,7 +665,7 @@ public class SearchResponse extends ActionResponse implements ChunkedToXContentO
                                 currentDetailsFieldName = parser.currentName();  // cluster alias
                             } else if (token == Token.START_OBJECT) {
                                 Cluster c = Cluster.fromXContent(currentDetailsFieldName, parser);
-                                clusterInfoMap.put(currentDetailsFieldName, new AtomicReference<>(c));
+                                clusterInfoMap.put(currentDetailsFieldName, c);
                             } else {
                                 parser.skipChildren();
                             }
@@ -716,7 +717,7 @@ public class SearchResponse extends ActionResponse implements ChunkedToXContentO
          * @return count of clusters matching the predicate
          */
         private int determineCountFromClusterInfo(Predicate<Cluster> predicate) {
-            return (int) clusterInfo.values().stream().filter(c -> predicate.test(c.get())).count();
+            return (int) clusterInfo.values().stream().filter(predicate).count();
         }
 
         /**
@@ -730,10 +731,33 @@ public class SearchResponse extends ActionResponse implements ChunkedToXContentO
          * @param clusterAlias The cluster alias as specified in the cluster collection
          * @return Cluster object associated with teh clusterAlias or null if not present
          */
-        public AtomicReference<Cluster> getCluster(String clusterAlias) {
+        public Cluster getCluster(String clusterAlias) {
             return clusterInfo.get(clusterAlias);
         }
 
+        /**
+         * Utility to swap a Cluster object. Guidelines for the remapping function:
+         * <ul>
+         * <li> The remapping function should return a new Cluster object to swap it for
+         * the existing one.</li>
+         * <li> If in the remapping function you decide to abort the swap you must return
+         * the original Cluster object to keep the map unchanged.</li>
+         * <li> Do not return {@code null}. If the remapping function returns {@code null},
+         * the mapping is removed (or remains absent if initially absent).</li>
+         * <li> If the remapping function itself throws an (unchecked) exception, the exception
+         * is rethrown, and the current mapping is left unchanged. Throwing exception therefore
+         * is OK, but it is generally discouraged.</li>
+         * <li> The remapping function may be called multiple times in a CAS fashion underneath,
+         * make sure that is safe to do so.</li>
+         * </ul>
+         * @param clusterAlias key with which the specified value is associated
+         * @param remappingFunction function to swap the oldCluster to a newCluster
+         * @return the new Cluster object
+         */
+        public Cluster swapCluster(String clusterAlias, BiFunction<String, Cluster, Cluster> remappingFunction) {
+            return clusterInfo.compute(clusterAlias, remappingFunction);
+        }
+
         @Override
         public boolean equals(Object o) {
             if (this == o) {
@@ -785,8 +809,7 @@ public class SearchResponse extends ActionResponse implements ChunkedToXContentO
          *              or any Cluster is marked as timedOut.
          */
         public boolean hasPartialResults() {
-            for (AtomicReference<Cluster> clusterRef : clusterInfo.values()) {
-                Cluster cluster = clusterRef.get();
+            for (Cluster cluster : clusterInfo.values()) {
                 switch (cluster.getStatus()) {
                     case PARTIAL, SKIPPED, FAILED, RUNNING -> {
                         return true;

+ 52 - 52
server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java

@@ -512,7 +512,7 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
                 public void onResponse(SearchResponse searchResponse) {
                     // TODO: in CCS fail fast ticket we may need to fail the query if the cluster is marked as FAILED
                     // overwrite the existing cluster entry with the updated one
-                    ccsClusterInfoUpdate(searchResponse, clusters.getCluster(clusterAlias), skipUnavailable);
+                    ccsClusterInfoUpdate(searchResponse, clusters, clusterAlias, skipUnavailable);
                     Map<String, SearchProfileShardResult> profileResults = searchResponse.getProfileResults();
                     SearchProfileResults profile = profileResults == null || profileResults.isEmpty()
                         ? null
@@ -547,7 +547,7 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
                 public void onFailure(Exception e) {
                     ShardSearchFailure failure = new ShardSearchFailure(e);
                     logCCSError(failure, clusterAlias, skipUnavailable);
-                    ccsClusterInfoUpdate(failure, clusters.getCluster(clusterAlias), skipUnavailable);
+                    ccsClusterInfoUpdate(failure, clusters, clusterAlias, skipUnavailable);
                     if (skipUnavailable) {
                         listener.onResponse(SearchResponse.empty(timeProvider::buildTookInMillis, clusters));
                     } else {
@@ -669,13 +669,13 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
                     skipUnavailable,
                     responsesCountDown,
                     exceptions,
-                    clusters.getCluster(clusterAlias),
+                    clusters,
                     listener
                 ) {
                     @Override
                     void innerOnResponse(SearchShardsResponse searchShardsResponse) {
                         assert ThreadPool.assertCurrentThreadPool(ThreadPool.Names.SEARCH_COORDINATION);
-                        ccsClusterInfoUpdate(searchShardsResponse, cluster, timeProvider);
+                        ccsClusterInfoUpdate(searchShardsResponse, clusters, clusterAlias, timeProvider);
                         searchShardsResponses.put(clusterAlias, searchShardsResponse);
                     }
 
@@ -747,13 +747,13 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
             skipUnavailable,
             countDown,
             exceptions,
-            clusters.getCluster(clusterAlias),
+            clusters,
             originalListener
         ) {
             @Override
             void innerOnResponse(SearchResponse searchResponse) {
                 // TODO: in CCS fail fast ticket we may need to fail the query if the cluster gets marked as FAILED
-                ccsClusterInfoUpdate(searchResponse, cluster, skipUnavailable);
+                ccsClusterInfoUpdate(searchResponse, clusters, clusterAlias, skipUnavailable);
                 searchResponseMerger.add(searchResponse);
             }
 
@@ -766,27 +766,25 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
 
     /**
      * Creates a new Cluster object using the {@link ShardSearchFailure} info and skip_unavailable
-     * flag to set Status. The new Cluster object is swapped into the clusterRef {@link AtomicReference}.
+     * flag to set Status. Then it swaps it in the clusters CHM at key clusterAlias
      */
     static void ccsClusterInfoUpdate(
         ShardSearchFailure failure,
-        AtomicReference<SearchResponse.Cluster> clusterRef,
+        SearchResponse.Clusters clusters,
+        String clusterAlias,
         boolean skipUnavailable
     ) {
-        SearchResponse.Cluster.Status status;
-        if (skipUnavailable) {
-            status = SearchResponse.Cluster.Status.SKIPPED;
-        } else {
-            status = SearchResponse.Cluster.Status.FAILED;
-        }
-        boolean swapped;
-        do {
-            SearchResponse.Cluster orig = clusterRef.get();
-            // returns unmodifiable list based on the original one passed plus the appended failure
-            List<ShardSearchFailure> failures = CollectionUtils.appendToCopy(orig.getFailures(), failure);
-            SearchResponse.Cluster updated = new SearchResponse.Cluster.Builder(orig).setStatus(status).setFailures(failures).build();
-            swapped = clusterRef.compareAndSet(orig, updated);
-        } while (swapped == false);
+        clusters.swapCluster(clusterAlias, (k, v) -> {
+            SearchResponse.Cluster.Status status;
+            if (skipUnavailable) {
+                status = SearchResponse.Cluster.Status.SKIPPED;
+            } else {
+                status = SearchResponse.Cluster.Status.FAILED;
+            }
+            return new SearchResponse.Cluster.Builder(v).setStatus(status)
+                .setFailures(CollectionUtils.appendToCopy(v.getFailures(), failure))
+                .build();
+        });
     }
 
     /**
@@ -794,11 +792,13 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
      * Used to update a specific SearchResponse.Cluster object state based upon
      * the SearchResponse coming from the cluster coordinator the search was performed on.
      * @param searchResponse SearchResponse from cluster sub-search
-     * @param clusterRef AtomicReference of the Cluster object to be updated
+     * @param clusters Clusters that the search was executed on
+     * @param clusterAlias Alias of the cluster to be updated
      */
     private static void ccsClusterInfoUpdate(
         SearchResponse searchResponse,
-        AtomicReference<SearchResponse.Cluster> clusterRef,
+        SearchResponse.Clusters clusters,
+        String clusterAlias,
         boolean skipUnavailable
     ) {
         /*
@@ -809,25 +809,22 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
          * 4) PARTIAL if it at least one of the shards succeeded but not all
          * 5) SUCCESSFUL if no shards failed (and did not time out)
          */
-        SearchResponse.Cluster.Status status;
-        if (searchResponse.getFailedShards() >= searchResponse.getTotalShards()) {
-            if (skipUnavailable) {
-                status = SearchResponse.Cluster.Status.SKIPPED;
+        clusters.swapCluster(clusterAlias, (k, v) -> {
+            SearchResponse.Cluster.Status status;
+            if (searchResponse.getFailedShards() >= searchResponse.getTotalShards()) {
+                if (skipUnavailable) {
+                    status = SearchResponse.Cluster.Status.SKIPPED;
+                } else {
+                    status = SearchResponse.Cluster.Status.FAILED;
+                }
+            } else if (searchResponse.isTimedOut()) {
+                status = SearchResponse.Cluster.Status.PARTIAL;
+            } else if (searchResponse.getFailedShards() > 0) {
+                status = SearchResponse.Cluster.Status.PARTIAL;
             } else {
-                status = SearchResponse.Cluster.Status.FAILED;
+                status = SearchResponse.Cluster.Status.SUCCESSFUL;
             }
-        } else if (searchResponse.isTimedOut()) {
-            status = SearchResponse.Cluster.Status.PARTIAL;
-        } else if (searchResponse.getFailedShards() > 0) {
-            status = SearchResponse.Cluster.Status.PARTIAL;
-        } else {
-            status = SearchResponse.Cluster.Status.SUCCESSFUL;
-        }
-
-        boolean swapped;
-        do {
-            SearchResponse.Cluster orig = clusterRef.get();
-            SearchResponse.Cluster updated = new SearchResponse.Cluster.Builder(orig).setStatus(status)
+            return new SearchResponse.Cluster.Builder(v).setStatus(status)
                 .setTotalShards(searchResponse.getTotalShards())
                 .setSuccessfulShards(searchResponse.getSuccessfulShards())
                 .setSkippedShards(searchResponse.getSkippedShards())
@@ -836,8 +833,7 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
                 .setTook(searchResponse.getTook())
                 .setTimedOut(searchResponse.isTimedOut())
                 .build();
-            swapped = clusterRef.compareAndSet(orig, updated);
-        } while (swapped == false);
+        });
     }
 
     /**
@@ -849,17 +845,20 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
      * the Cluster object to SUCCESSFUL status with shard counts of 0 and a filled in 'took' value.
      *
      * @param response from SearchShards API call to remote cluster
-     * @param clusterRef Reference Cluster to be updated
+     * @param clusters Clusters that the search was executed on
+     * @param clusterAlias Alias of the cluster to be updated
      * @param timeProvider search time provider (for setting took value)
      */
     private static void ccsClusterInfoUpdate(
         SearchShardsResponse response,
-        AtomicReference<SearchResponse.Cluster> clusterRef,
+        SearchResponse.Clusters clusters,
+        String clusterAlias,
         SearchTimeProvider timeProvider
     ) {
         if (response.getGroups().isEmpty()) {
-            clusterRef.updateAndGet(
-                orig -> new SearchResponse.Cluster.Builder(orig).setStatus(SearchResponse.Cluster.Status.SUCCESSFUL)
+            clusters.swapCluster(
+                clusterAlias,
+                (k, v) -> new SearchResponse.Cluster.Builder(v).setStatus(SearchResponse.Cluster.Status.SUCCESSFUL)
                     .setTotalShards(0)
                     .setSuccessfulShards(0)
                     .setSkippedShards(0)
@@ -1408,7 +1407,7 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
         protected final boolean skipUnavailable;
         private final CountDown countDown;
         private final AtomicReference<Exception> exceptions;
-        protected final AtomicReference<SearchResponse.Cluster> cluster;
+        protected final SearchResponse.Clusters clusters;
         private final ActionListener<FinalResponse> originalListener;
         protected final long startTime;
 
@@ -1420,14 +1419,14 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
             boolean skipUnavailable,
             CountDown countDown,
             AtomicReference<Exception> exceptions,
-            @Nullable AtomicReference<SearchResponse.Cluster> cluster, // null for ccs_minimize_roundtrips=false
+            SearchResponse.Clusters clusters,
             ActionListener<FinalResponse> originalListener
         ) {
             this.clusterAlias = clusterAlias;
             this.skipUnavailable = skipUnavailable;
             this.countDown = countDown;
             this.exceptions = exceptions;
-            this.cluster = cluster;
+            this.clusters = clusters;
             this.originalListener = originalListener;
             this.startTime = System.currentTimeMillis();
         }
@@ -1444,14 +1443,15 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
         public final void onFailure(Exception e) {
             ShardSearchFailure f = new ShardSearchFailure(e);
             logCCSError(f, clusterAlias, skipUnavailable);
+            SearchResponse.Cluster cluster = clusters.getCluster(clusterAlias);
             if (skipUnavailable) {
                 if (cluster != null) {
-                    ccsClusterInfoUpdate(f, cluster, skipUnavailable);
+                    ccsClusterInfoUpdate(f, clusters, clusterAlias, skipUnavailable);
                 }
                 // skippedClusters.incrementAndGet();
             } else {
                 if (cluster != null) {
-                    ccsClusterInfoUpdate(f, cluster, skipUnavailable);
+                    ccsClusterInfoUpdate(f, clusters, clusterAlias, skipUnavailable);
                 }
                 Exception exception = e;
                 if (RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY.equals(clusterAlias) == false) {

+ 15 - 16
server/src/test/java/org/elasticsearch/action/search/SearchResponseTests.java

@@ -49,7 +49,6 @@ import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
-import java.util.concurrent.atomic.AtomicReference;
 
 import static java.util.Collections.emptyList;
 import static java.util.Collections.singletonMap;
@@ -213,10 +212,11 @@ public class SearchResponseTests extends ESTestCase {
             int totalShards = 5;
             int successfulShards;
             int skippedShards;
-            int failedShards = 0;
+            int failedShards;
             List<ShardSearchFailure> failureList = Arrays.asList(failures);
             TimeValue took = new TimeValue(1000L);
             if (successful > 0) {
+                failedShards = 0;
                 status = SearchResponse.Cluster.Status.SUCCESSFUL;
                 successfulShards = 5;
                 skippedShards = 1;
@@ -241,28 +241,27 @@ public class SearchResponseTests extends ESTestCase {
                 failedShards = 5;
                 failed--;
             } else {
+                failedShards = 0;
                 throw new IllegalStateException("Test setup coding error - should not get here");
             }
             String clusterAlias = "";
             if (i >= 0) {
                 clusterAlias = "cluster_" + i;
             }
-            AtomicReference<SearchResponse.Cluster> clusterRef = clusters.getCluster(clusterAlias);
-            SearchResponse.Cluster cluster = clusterRef.get();
-            SearchResponse.Cluster update = new SearchResponse.Cluster(
+            SearchResponse.Cluster cluster = clusters.getCluster(clusterAlias);
+            List<ShardSearchFailure> finalFailureList = failureList;
+            clusters.swapCluster(
                 cluster.getClusterAlias(),
-                cluster.getIndexExpression(),
-                false,
-                status,
-                totalShards,
-                successfulShards,
-                skippedShards,
-                failedShards,
-                failureList,
-                took,
-                false
+                (k, v) -> new SearchResponse.Cluster.Builder(v).setStatus(status)
+                    .setTotalShards(totalShards)
+                    .setSuccessfulShards(successfulShards)
+                    .setSkippedShards(skippedShards)
+                    .setFailedShards(failedShards)
+                    .setFailures(finalFailureList)
+                    .setTook(took)
+                    .setTimedOut(false)
+                    .build()
             );
-            assertTrue(clusterRef.compareAndSet(cluster, update));
         }
         return clusters;
     }

+ 38 - 38
x-pack/plugin/async-search/src/internalClusterTest/java/org/elasticsearch/xpack/search/CrossClusterAsyncSearchIT.java

@@ -167,11 +167,11 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
             assertThat(clusters.getTotal(), equalTo(2));
             assertTrue("search cluster results should be marked as partial", clusters.hasPartialResults());
 
-            SearchResponse.Cluster localClusterSearchInfo = clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY).get();
+            SearchResponse.Cluster localClusterSearchInfo = clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY);
             assertNotNull(localClusterSearchInfo);
             assertThat(localClusterSearchInfo.getStatus(), equalTo(SearchResponse.Cluster.Status.RUNNING));
 
-            SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER).get();
+            SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER);
             assertNotNull(remoteClusterSearchInfo);
             assertThat(localClusterSearchInfo.getStatus(), equalTo(SearchResponse.Cluster.Status.RUNNING));
         }
@@ -193,7 +193,7 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
             assertThat(clusters.getClusterStateCount(SearchResponse.Cluster.Status.PARTIAL), equalTo(0));
             assertThat(clusters.getClusterStateCount(SearchResponse.Cluster.Status.FAILED), equalTo(0));
 
-            SearchResponse.Cluster localClusterSearchInfo = clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY).get();
+            SearchResponse.Cluster localClusterSearchInfo = clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY);
             assertNotNull(localClusterSearchInfo);
             assertThat(localClusterSearchInfo.getStatus(), equalTo(SearchResponse.Cluster.Status.SUCCESSFUL));
             assertThat(localClusterSearchInfo.getTotalShards(), equalTo(localNumShards));
@@ -203,7 +203,7 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
             assertThat(localClusterSearchInfo.getFailures().size(), equalTo(0));
             assertThat(localClusterSearchInfo.getTook().millis(), greaterThan(0L));
 
-            SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER).get();
+            SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER);
             assertNotNull(remoteClusterSearchInfo);
             assertThat(remoteClusterSearchInfo.getStatus(), equalTo(SearchResponse.Cluster.Status.SUCCESSFUL));
             assertThat(remoteClusterSearchInfo.getTotalShards(), equalTo(remoteNumShards));
@@ -228,7 +228,7 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
             assertThat(clusters.getClusterStateCount(SearchResponse.Cluster.Status.PARTIAL), equalTo(0));
             assertThat(clusters.getClusterStateCount(SearchResponse.Cluster.Status.FAILED), equalTo(0));
 
-            SearchResponse.Cluster localClusterSearchInfo = clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY).get();
+            SearchResponse.Cluster localClusterSearchInfo = clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY);
             assertNotNull(localClusterSearchInfo);
             assertThat(localClusterSearchInfo.getStatus(), equalTo(SearchResponse.Cluster.Status.SUCCESSFUL));
             assertThat(localClusterSearchInfo.getTotalShards(), equalTo(localNumShards));
@@ -238,7 +238,7 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
             assertThat(localClusterSearchInfo.getFailures().size(), equalTo(0));
             assertThat(localClusterSearchInfo.getTook().millis(), greaterThan(0L));
 
-            SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER).get();
+            SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER);
             assertNotNull(remoteClusterSearchInfo);
             assertThat(remoteClusterSearchInfo.getStatus(), equalTo(SearchResponse.Cluster.Status.SUCCESSFUL));
             assertThat(remoteClusterSearchInfo.getTotalShards(), equalTo(remoteNumShards));
@@ -286,11 +286,11 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
             assertThat(clusters.getTotal(), equalTo(2));
             assertTrue("search cluster results should be marked as partial", clusters.hasPartialResults());
 
-            SearchResponse.Cluster localClusterSearchInfo = clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY).get();
+            SearchResponse.Cluster localClusterSearchInfo = clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY);
             assertNotNull(localClusterSearchInfo);
             assertThat(localClusterSearchInfo.getStatus(), equalTo(SearchResponse.Cluster.Status.RUNNING));
 
-            SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER).get();
+            SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER);
             assertNotNull(remoteClusterSearchInfo);
             assertThat(localClusterSearchInfo.getStatus(), equalTo(SearchResponse.Cluster.Status.RUNNING));
         }
@@ -313,9 +313,9 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
             assertThat(clusters.getClusterStateCount(SearchResponse.Cluster.Status.PARTIAL), equalTo(0));
             assertThat(clusters.getClusterStateCount(SearchResponse.Cluster.Status.FAILED), equalTo(0));
 
-            SearchResponse.Cluster localClusterSearchInfo = clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY).get();
+            SearchResponse.Cluster localClusterSearchInfo = clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY);
             assertNotNull(localClusterSearchInfo);
-            SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER).get();
+            SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER);
             assertNotNull(remoteClusterSearchInfo);
 
             assertThat(localClusterSearchInfo.getStatus(), equalTo(SearchResponse.Cluster.Status.SUCCESSFUL));
@@ -357,9 +357,9 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
             assertThat(clusters.getClusterStateCount(SearchResponse.Cluster.Status.PARTIAL), equalTo(0));
             assertThat(clusters.getClusterStateCount(SearchResponse.Cluster.Status.FAILED), equalTo(0));
 
-            SearchResponse.Cluster localClusterSearchInfo = clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY).get();
+            SearchResponse.Cluster localClusterSearchInfo = clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY);
             assertNotNull(localClusterSearchInfo);
-            SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER).get();
+            SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER);
             assertNotNull(remoteClusterSearchInfo);
 
             assertThat(localClusterSearchInfo.getStatus(), equalTo(SearchResponse.Cluster.Status.SUCCESSFUL));
@@ -436,12 +436,12 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
                 assertThat(clusters.getClusterStateCount(SearchResponse.Cluster.Status.FAILED), equalTo(2));
             }
 
-            SearchResponse.Cluster localClusterSearchInfo = clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY).get();
+            SearchResponse.Cluster localClusterSearchInfo = clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY);
             assertNotNull(localClusterSearchInfo);
             assertThat(localClusterSearchInfo.getStatus(), equalTo(SearchResponse.Cluster.Status.FAILED));
             assertAllShardsFailed(minimizeRoundtrips, localClusterSearchInfo, localNumShards);
 
-            SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER).get();
+            SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER);
             assertNotNull(remoteClusterSearchInfo);
             SearchResponse.Cluster.Status expectedStatus = skipUnavailable
                 ? SearchResponse.Cluster.Status.SKIPPED
@@ -467,12 +467,12 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
                 assertThat(clusters.getClusterStateCount(SearchResponse.Cluster.Status.FAILED), equalTo(2));
             }
 
-            SearchResponse.Cluster localClusterSearchInfo = clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY).get();
+            SearchResponse.Cluster localClusterSearchInfo = clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY);
             assertNotNull(localClusterSearchInfo);
             assertThat(localClusterSearchInfo.getStatus(), equalTo(SearchResponse.Cluster.Status.FAILED));
             assertAllShardsFailed(minimizeRoundtrips, localClusterSearchInfo, localNumShards);
 
-            SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER).get();
+            SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER);
             assertNotNull(remoteClusterSearchInfo);
             SearchResponse.Cluster.Status expectedStatus = skipUnavailable
                 ? SearchResponse.Cluster.Status.SKIPPED
@@ -515,11 +515,11 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
             assertThat(clusters.getTotal(), equalTo(2));
             assertTrue("search cluster results should be marked as partial", clusters.hasPartialResults());
 
-            SearchResponse.Cluster localClusterSearchInfo = clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY).get();
+            SearchResponse.Cluster localClusterSearchInfo = clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY);
             assertNotNull(localClusterSearchInfo);
             assertThat(localClusterSearchInfo.getStatus(), equalTo(SearchResponse.Cluster.Status.RUNNING));
 
-            SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER).get();
+            SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER);
             assertNotNull(remoteClusterSearchInfo);
             assertThat(localClusterSearchInfo.getStatus(), equalTo(SearchResponse.Cluster.Status.RUNNING));
         }
@@ -541,7 +541,7 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
             assertThat(clusters.getClusterStateCount(SearchResponse.Cluster.Status.PARTIAL), equalTo(2));
             assertThat(clusters.getClusterStateCount(SearchResponse.Cluster.Status.FAILED), equalTo(0));
 
-            SearchResponse.Cluster localClusterSearchInfo = clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY).get();
+            SearchResponse.Cluster localClusterSearchInfo = clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY);
             assertNotNull(localClusterSearchInfo);
             assertThat(localClusterSearchInfo.getStatus(), equalTo(SearchResponse.Cluster.Status.PARTIAL));
             assertThat(localClusterSearchInfo.getTotalShards(), equalTo(localNumShards));
@@ -553,7 +553,7 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
             ShardSearchFailure localShardSearchFailure = localClusterSearchInfo.getFailures().get(0);
             assertTrue("should have 'index corrupted' in reason", localShardSearchFailure.reason().contains("index corrupted"));
 
-            SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER).get();
+            SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER);
             assertNotNull(remoteClusterSearchInfo);
             assertThat(remoteClusterSearchInfo.getStatus(), equalTo(SearchResponse.Cluster.Status.PARTIAL));
             assertThat(remoteClusterSearchInfo.getTotalShards(), equalTo(remoteNumShards));
@@ -578,7 +578,7 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
             assertThat(clusters.getClusterStateCount(SearchResponse.Cluster.Status.PARTIAL), equalTo(2));
             assertThat(clusters.getClusterStateCount(SearchResponse.Cluster.Status.FAILED), equalTo(0));
 
-            SearchResponse.Cluster localClusterSearchInfo = clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY).get();
+            SearchResponse.Cluster localClusterSearchInfo = clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY);
             assertNotNull(localClusterSearchInfo);
             assertThat(localClusterSearchInfo.getStatus(), equalTo(SearchResponse.Cluster.Status.PARTIAL));
             assertThat(localClusterSearchInfo.getTotalShards(), equalTo(localNumShards));
@@ -590,7 +590,7 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
             ShardSearchFailure localShardSearchFailure = localClusterSearchInfo.getFailures().get(0);
             assertTrue("should have 'index corrupted' in reason", localShardSearchFailure.reason().contains("index corrupted"));
 
-            SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER).get();
+            SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER);
             assertNotNull(remoteClusterSearchInfo);
             assertThat(remoteClusterSearchInfo.getStatus(), equalTo(SearchResponse.Cluster.Status.PARTIAL));
             assertThat(remoteClusterSearchInfo.getTotalShards(), equalTo(remoteNumShards));
@@ -644,11 +644,11 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
             assertThat(clusters.getTotal(), equalTo(2));
             assertTrue("search cluster results should be marked as partial", clusters.hasPartialResults());
 
-            SearchResponse.Cluster localClusterSearchInfo = clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY).get();
+            SearchResponse.Cluster localClusterSearchInfo = clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY);
             assertNotNull(localClusterSearchInfo);
             assertThat(localClusterSearchInfo.getStatus(), equalTo(SearchResponse.Cluster.Status.RUNNING));
 
-            SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER).get();
+            SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER);
             assertNotNull(remoteClusterSearchInfo);
             assertThat(localClusterSearchInfo.getStatus(), equalTo(SearchResponse.Cluster.Status.RUNNING));
         }
@@ -675,7 +675,7 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
                 assertThat(clusters.getClusterStateCount(SearchResponse.Cluster.Status.FAILED), equalTo(1));
             }
 
-            SearchResponse.Cluster localClusterSearchInfo = clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY).get();
+            SearchResponse.Cluster localClusterSearchInfo = clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY);
             assertNotNull(localClusterSearchInfo);
             assertThat(localClusterSearchInfo.getStatus(), equalTo(SearchResponse.Cluster.Status.SUCCESSFUL));
             assertThat(localClusterSearchInfo.getTotalShards(), equalTo(localNumShards));
@@ -685,7 +685,7 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
             assertThat(localClusterSearchInfo.getFailures().size(), equalTo(0));
             assertThat(localClusterSearchInfo.getTook().millis(), greaterThan(0L));
 
-            SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER).get();
+            SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER);
             assertNotNull(remoteClusterSearchInfo);
             SearchResponse.Cluster.Status expectedStatus = skipUnavailable
                 ? SearchResponse.Cluster.Status.SKIPPED
@@ -727,7 +727,7 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
                 assertThat(clusters.getClusterStateCount(SearchResponse.Cluster.Status.FAILED), equalTo(1));
             }
 
-            SearchResponse.Cluster localClusterSearchInfo = clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY).get();
+            SearchResponse.Cluster localClusterSearchInfo = clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY);
             assertNotNull(localClusterSearchInfo);
             assertThat(localClusterSearchInfo.getStatus(), equalTo(SearchResponse.Cluster.Status.SUCCESSFUL));
             assertThat(localClusterSearchInfo.getTotalShards(), equalTo(localNumShards));
@@ -737,7 +737,7 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
             assertThat(localClusterSearchInfo.getFailures().size(), equalTo(0));
             assertThat(localClusterSearchInfo.getTook().millis(), greaterThan(0L));
 
-            SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER).get();
+            SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER);
             assertNotNull(remoteClusterSearchInfo);
             SearchResponse.Cluster.Status expectedStatus = skipUnavailable
                 ? SearchResponse.Cluster.Status.SKIPPED
@@ -803,7 +803,7 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
             assertThat(clusters.getClusterStateCount(SearchResponse.Cluster.Status.PARTIAL), equalTo(2));
             assertThat(clusters.getClusterStateCount(SearchResponse.Cluster.Status.FAILED), equalTo(0));
 
-            SearchResponse.Cluster localClusterSearchInfo = clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY).get();
+            SearchResponse.Cluster localClusterSearchInfo = clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY);
             assertNotNull(localClusterSearchInfo);
             // PARTIAL expected since timedOut=true
             assertThat(localClusterSearchInfo.getStatus(), equalTo(SearchResponse.Cluster.Status.PARTIAL));
@@ -815,7 +815,7 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
             assertThat(localClusterSearchInfo.getTook().millis(), greaterThanOrEqualTo(0L));
             assertTrue(localClusterSearchInfo.isTimedOut());
 
-            SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER).get();
+            SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER);
             assertNotNull(remoteClusterSearchInfo);
             // PARTIAL expected since timedOut=true
             assertThat(remoteClusterSearchInfo.getStatus(), equalTo(SearchResponse.Cluster.Status.PARTIAL));
@@ -840,7 +840,7 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
             assertThat(clusters.getClusterStateCount(SearchResponse.Cluster.Status.PARTIAL), equalTo(2));
             assertThat(clusters.getClusterStateCount(SearchResponse.Cluster.Status.FAILED), equalTo(0));
 
-            SearchResponse.Cluster localClusterSearchInfo = clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY).get();
+            SearchResponse.Cluster localClusterSearchInfo = clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY);
             assertNotNull(localClusterSearchInfo);
             // PARTIAL expected since timedOut=true
             assertThat(localClusterSearchInfo.getStatus(), equalTo(SearchResponse.Cluster.Status.PARTIAL));
@@ -852,7 +852,7 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
             assertThat(localClusterSearchInfo.getTook().millis(), greaterThanOrEqualTo(0L));
             assertTrue(localClusterSearchInfo.isTimedOut());
 
-            SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER).get();
+            SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER);
             assertNotNull(remoteClusterSearchInfo);
             // PARTIAL expected since timedOut=true
             assertThat(remoteClusterSearchInfo.getStatus(), equalTo(SearchResponse.Cluster.Status.PARTIAL));
@@ -909,7 +909,7 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
 
             assertNull(clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY));
 
-            SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER).get();
+            SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER);
             assertNotNull(remoteClusterSearchInfo);
             assertThat(remoteClusterSearchInfo.getStatus(), equalTo(SearchResponse.Cluster.Status.SUCCESSFUL));
             assertThat(remoteClusterSearchInfo.getTotalShards(), equalTo(remoteNumShards));
@@ -936,7 +936,7 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
 
             assertNull(clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY));
 
-            SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER).get();
+            SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER);
             assertNotNull(remoteClusterSearchInfo);
             assertThat(remoteClusterSearchInfo.getStatus(), equalTo(SearchResponse.Cluster.Status.SUCCESSFUL));
             assertThat(remoteClusterSearchInfo.getTotalShards(), equalTo(remoteNumShards));
@@ -991,7 +991,7 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
 
             assertNull(clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY));
 
-            SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER).get();
+            SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER);
             assertNotNull(remoteClusterSearchInfo);
             assertThat(remoteClusterSearchInfo.getStatus(), equalTo(SearchResponse.Cluster.Status.PARTIAL));
             assertThat(remoteClusterSearchInfo.getTotalShards(), equalTo(remoteNumShards));
@@ -1018,7 +1018,7 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
 
             assertNull(clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY));
 
-            SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER).get();
+            SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER);
             assertNotNull(remoteClusterSearchInfo);
             assertThat(remoteClusterSearchInfo.getStatus(), equalTo(SearchResponse.Cluster.Status.PARTIAL));
             assertThat(remoteClusterSearchInfo.getTotalShards(), equalTo(remoteNumShards));
@@ -1083,7 +1083,7 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
 
             assertNull(clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY));
 
-            SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER).get();
+            SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER);
             assertNotNull(remoteClusterSearchInfo);
             SearchResponse.Cluster.Status expectedStatus = skipUnavailable
                 ? SearchResponse.Cluster.Status.SKIPPED
@@ -1110,7 +1110,7 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase {
             }
             assertNull(clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY));
 
-            SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER).get();
+            SearchResponse.Cluster remoteClusterSearchInfo = clusters.getCluster(REMOTE_CLUSTER);
             assertNotNull(remoteClusterSearchInfo);
             SearchResponse.Cluster.Status expectedStatus = skipUnavailable
                 ? SearchResponse.Cluster.Status.SKIPPED

+ 111 - 136
x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/AsyncSearchResponseTests.java

@@ -44,7 +44,6 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.UUID;
-import java.util.concurrent.atomic.AtomicReference;
 
 import static java.util.Collections.emptyList;
 import static org.elasticsearch.xpack.core.async.GetAsyncResultRequestTests.randomSearchId;
@@ -466,44 +465,36 @@ public class AsyncSearchResponseTests extends ESTestCase {
         SearchResponseSections sections = new SearchResponseSections(hits, null, null, true, null, null, 2);
         SearchResponse.Clusters clusters = createCCSClusterObjects(4, 3, true);
 
-        AtomicReference<SearchResponse.Cluster> clusterRef = clusters.getCluster(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY);
-        SearchResponse.Cluster localCluster = clusterRef.get();
-        SearchResponse.Cluster updated = new SearchResponse.Cluster(
-            localCluster.getClusterAlias(),
-            localCluster.getIndexExpression(),
-            false,
-            SearchResponse.Cluster.Status.SUCCESSFUL,
-            10,
-            10,
-            3,
-            0,
-            Collections.emptyList(),
-            new TimeValue(11111),
-            false
+        SearchResponse.Cluster updated = clusters.swapCluster(
+            RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY,
+            (k, v) -> new SearchResponse.Cluster.Builder(v).setStatus(SearchResponse.Cluster.Status.SUCCESSFUL)
+                .setTotalShards(10)
+                .setSuccessfulShards(10)
+                .setSkippedShards(3)
+                .setFailedShards(0)
+                .setFailures(Collections.emptyList())
+                .setTook(new TimeValue(11111))
+                .setTimedOut(false)
+                .build()
         );
-        boolean swapped = clusterRef.compareAndSet(localCluster, updated);
-        assertTrue("CAS swap failed for cluster " + updated, swapped);
+        assertNotNull("Set cluster failed for cluster " + RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY, updated);
 
-        clusterRef = clusters.getCluster("cluster_0");
-        SearchResponse.Cluster cluster0 = clusterRef.get();
-        updated = new SearchResponse.Cluster(
+        SearchResponse.Cluster cluster0 = clusters.getCluster("cluster_0");
+        updated = clusters.swapCluster(
             cluster0.getClusterAlias(),
-            cluster0.getIndexExpression(),
-            false,
-            SearchResponse.Cluster.Status.SUCCESSFUL,
-            8,
-            8,
-            1,
-            0,
-            Collections.emptyList(),
-            new TimeValue(7777),
-            false
+            (k, v) -> new SearchResponse.Cluster.Builder(v).setStatus(SearchResponse.Cluster.Status.SUCCESSFUL)
+                .setTotalShards(8)
+                .setSuccessfulShards(8)
+                .setSkippedShards(1)
+                .setFailedShards(0)
+                .setFailures(Collections.emptyList())
+                .setTook(new TimeValue(7777))
+                .setTimedOut(false)
+                .build()
         );
-        swapped = clusterRef.compareAndSet(cluster0, updated);
-        assertTrue("CAS swap failed for cluster " + updated, swapped);
+        assertNotNull("Set cluster failed for cluster " + cluster0.getClusterAlias(), updated);
 
-        clusterRef = clusters.getCluster("cluster_1");
-        SearchResponse.Cluster cluster1 = clusterRef.get();
+        SearchResponse.Cluster cluster1 = clusters.getCluster("cluster_1");
         ShardSearchFailure failure1 = new ShardSearchFailure(
             new NullPointerException("NPE details"),
             new SearchShardTarget("nodeId0", new ShardId("foo", UUID.randomUUID().toString(), 0), "cluster_1")
@@ -512,39 +503,34 @@ public class AsyncSearchResponseTests extends ESTestCase {
             new CorruptIndexException("abc", "123"),
             new SearchShardTarget("nodeId0", new ShardId("bar1", UUID.randomUUID().toString(), 0), "cluster_1")
         );
-        updated = new SearchResponse.Cluster(
+        updated = clusters.swapCluster(
             cluster1.getClusterAlias(),
-            cluster1.getIndexExpression(),
-            false,
-            SearchResponse.Cluster.Status.SKIPPED,
-            2,
-            0,
-            0,
-            2,
-            List.of(failure1, failure2),
-            null,
-            false
+            (k, v) -> new SearchResponse.Cluster.Builder(v).setStatus(SearchResponse.Cluster.Status.SKIPPED)
+                .setTotalShards(2)
+                .setSuccessfulShards(0)
+                .setSkippedShards(0)
+                .setFailedShards(2)
+                .setFailures(List.of(failure1, failure2))
+                .setTook(null)
+                .setTimedOut(false)
+                .build()
         );
-        swapped = clusterRef.compareAndSet(cluster1, updated);
-        assertTrue("CAS swap failed for cluster " + updated, swapped);
+        assertNotNull("Set cluster failed for cluster " + cluster1.getClusterAlias(), updated);
 
-        clusterRef = clusters.getCluster("cluster_2");
-        SearchResponse.Cluster cluster2 = clusterRef.get();
-        updated = new SearchResponse.Cluster(
+        SearchResponse.Cluster cluster2 = clusters.getCluster("cluster_2");
+        updated = clusters.swapCluster(
             cluster2.getClusterAlias(),
-            cluster2.getIndexExpression(),
-            false,
-            SearchResponse.Cluster.Status.PARTIAL,
-            8,
-            8,
-            0,
-            0,
-            Collections.emptyList(),
-            new TimeValue(17322),
-            true
+            (k, v) -> new SearchResponse.Cluster.Builder(v).setStatus(SearchResponse.Cluster.Status.PARTIAL)
+                .setTotalShards(8)
+                .setSuccessfulShards(8)
+                .setSkippedShards(0)
+                .setFailedShards(0)
+                .setFailures(Collections.emptyList())
+                .setTook(new TimeValue(17322))
+                .setTimedOut(true)
+                .build()
         );
-        swapped = clusterRef.compareAndSet(cluster2, updated);
-        assertTrue("CAS swap failed for cluster " + updated, swapped);
+        assertNotNull("Set cluster failed for cluster " + cluster2.getClusterAlias(), updated);
 
         SearchResponse searchResponse = new SearchResponse(sections, null, 10, 9, 1, took, new ShardSearchFailure[0], clusters);
 
@@ -800,114 +786,103 @@ public class AsyncSearchResponseTests extends ESTestCase {
         int partial = partialClusters;
         if (totalClusters > remoteClusters) {
             String localAlias = RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY;
-            AtomicReference<SearchResponse.Cluster> localRef = clusters.getCluster(localAlias);
-            SearchResponse.Cluster orig = localRef.get();
             SearchResponse.Cluster updated;
             if (successful > 0) {
-                updated = new SearchResponse.Cluster(
+                updated = clusters.swapCluster(
                     localAlias,
-                    localRef.get().getIndexExpression(),
-                    false,
-                    SearchResponse.Cluster.Status.SUCCESSFUL,
-                    5,
-                    5,
-                    0,
-                    0,
-                    Collections.emptyList(),
-                    new TimeValue(1000),
-                    false
+                    (k, v) -> new SearchResponse.Cluster.Builder(v).setStatus(SearchResponse.Cluster.Status.SUCCESSFUL)
+                        .setTotalShards(5)
+                        .setSuccessfulShards(5)
+                        .setSkippedShards(0)
+                        .setFailedShards(0)
+                        .setFailures(Collections.emptyList())
+                        .setTook(new TimeValue(1000))
+                        .setTimedOut(true)
+                        .build()
                 );
                 successful--;
             } else if (skipped > 0) {
-                updated = new SearchResponse.Cluster(
+                updated = clusters.swapCluster(
                     localAlias,
-                    localRef.get().getIndexExpression(),
-                    false,
-                    SearchResponse.Cluster.Status.SKIPPED,
-                    5,
-                    0,
-                    0,
-                    5,
-                    Collections.emptyList(),
-                    new TimeValue(1000),
-                    false
+                    (k, v) -> new SearchResponse.Cluster.Builder(v).setStatus(SearchResponse.Cluster.Status.SKIPPED)
+                        .setTotalShards(5)
+                        .setSuccessfulShards(0)
+                        .setSkippedShards(0)
+                        .setFailedShards(5)
+                        .setFailures(Collections.emptyList())
+                        .setTook(new TimeValue(1000))
+                        .setTimedOut(false)
+                        .build()
                 );
                 skipped--;
             } else {
-                updated = new SearchResponse.Cluster(
+                updated = clusters.swapCluster(
                     localAlias,
-                    localRef.get().getIndexExpression(),
-                    false,
-                    SearchResponse.Cluster.Status.PARTIAL,
-                    5,
-                    2,
-                    1,
-                    3,
-                    Collections.emptyList(),
-                    new TimeValue(1000),
-                    false
+                    (k, v) -> new SearchResponse.Cluster.Builder(v).setStatus(SearchResponse.Cluster.Status.PARTIAL)
+                        .setTotalShards(5)
+                        .setSuccessfulShards(2)
+                        .setSkippedShards(1)
+                        .setFailedShards(3)
+                        .setFailures(Collections.emptyList())
+                        .setTook(new TimeValue(1000))
+                        .setTimedOut(false)
+                        .build()
                 );
                 partial--;
             }
-            boolean swapped = localRef.compareAndSet(orig, updated);
-            assertTrue("CAS swap failed for cluster " + updated, swapped);
+            assertNotNull("Set cluster failed for cluster " + localAlias, updated);
         }
 
         int numClusters = successful + skipped + partial;
 
         for (int i = 0; i < numClusters; i++) {
             String clusterAlias = "cluster_" + i;
-            AtomicReference<SearchResponse.Cluster> clusterRef = clusters.getCluster(clusterAlias);
-            SearchResponse.Cluster orig = clusterRef.get();
+            SearchResponse.Cluster remote = clusters.getCluster(clusterAlias);
             SearchResponse.Cluster updated;
             if (successful > 0) {
-                updated = new SearchResponse.Cluster(
+                updated = clusters.swapCluster(
                     clusterAlias,
-                    clusterRef.get().getIndexExpression(),
-                    false,
-                    SearchResponse.Cluster.Status.SUCCESSFUL,
-                    5,
-                    5,
-                    0,
-                    0,
-                    Collections.emptyList(),
-                    new TimeValue(1000),
-                    false
+                    (k, v) -> new SearchResponse.Cluster.Builder(v).setStatus(SearchResponse.Cluster.Status.SUCCESSFUL)
+                        .setTotalShards(5)
+                        .setSuccessfulShards(5)
+                        .setSkippedShards(0)
+                        .setFailedShards(0)
+                        .setFailures(Collections.emptyList())
+                        .setTook(new TimeValue(1000))
+                        .setTimedOut(false)
+                        .build()
                 );
                 successful--;
             } else if (skipped > 0) {
-                updated = new SearchResponse.Cluster(
+                updated = clusters.swapCluster(
                     clusterAlias,
-                    clusterRef.get().getIndexExpression(),
-                    false,
-                    SearchResponse.Cluster.Status.SKIPPED,
-                    5,
-                    0,
-                    0,
-                    5,
-                    Collections.emptyList(),
-                    new TimeValue(1000),
-                    false
+                    (k, v) -> new SearchResponse.Cluster.Builder(v).setStatus(SearchResponse.Cluster.Status.SKIPPED)
+                        .setTotalShards(5)
+                        .setSuccessfulShards(0)
+                        .setSkippedShards(0)
+                        .setFailedShards(5)
+                        .setFailures(Collections.emptyList())
+                        .setTook(new TimeValue(1000))
+                        .setTimedOut(false)
+                        .build()
                 );
                 skipped--;
             } else {
-                updated = new SearchResponse.Cluster(
+                updated = clusters.swapCluster(
                     clusterAlias,
-                    clusterRef.get().getIndexExpression(),
-                    false,
-                    SearchResponse.Cluster.Status.PARTIAL,
-                    5,
-                    2,
-                    1,
-                    3,
-                    Collections.emptyList(),
-                    new TimeValue(1000),
-                    false
+                    (k, v) -> new SearchResponse.Cluster.Builder(v).setStatus(SearchResponse.Cluster.Status.PARTIAL)
+                        .setTotalShards(5)
+                        .setSuccessfulShards(2)
+                        .setSkippedShards(1)
+                        .setFailedShards(3)
+                        .setFailures(Collections.emptyList())
+                        .setTook(new TimeValue(1000))
+                        .setTimedOut(false)
+                        .build()
                 );
                 partial--;
             }
-            boolean swapped = clusterRef.compareAndSet(orig, updated);
-            assertTrue("CAS swap failed for cluster " + updated, swapped);
+            assertNotNull("Set cluster failed for cluster " + clusterAlias, updated);
         }
         return clusters;
     }