Browse Source

Integrate CCS with new search_shards API (#95894)

This PR integrates CCS with the new search_shards API. With this change, 
we will be able to skip shards on the coordinator on remote clusters
using the timestamps stored in the cluster state.

Relates #94534
Closes #93730
Nhat Nguyen 2 years ago
parent
commit
7da55d00b4
19 changed files with 738 additions and 305 deletions
  1. 6 0
      docs/changelog/95894.yaml
  2. 6 13
      qa/ccs-unavailable-clusters/src/javaRestTest/java/org/elasticsearch/search/CrossClusterSearchUnavailableClusterIT.java
  3. 193 0
      server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CCSCanMatchIT.java
  4. 14 4
      server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java
  5. 35 12
      server/src/main/java/org/elasticsearch/action/search/SearchShardIterator.java
  6. 33 2
      server/src/main/java/org/elasticsearch/action/search/SearchShardsGroup.java
  7. 30 1
      server/src/main/java/org/elasticsearch/action/search/SearchShardsResponse.java
  8. 114 75
      server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java
  9. 31 0
      server/src/main/java/org/elasticsearch/transport/RemoteClusterService.java
  10. 2 1
      server/src/test/java/org/elasticsearch/action/search/AbstractSearchAsyncActionTests.java
  11. 3 2
      server/src/test/java/org/elasticsearch/action/search/SearchAsyncActionTests.java
  12. 15 5
      server/src/test/java/org/elasticsearch/action/search/SearchShardIteratorTests.java
  13. 64 0
      server/src/test/java/org/elasticsearch/action/search/SearchShardsResponseTests.java
  14. 133 116
      server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java
  15. 38 48
      server/src/test/java/org/elasticsearch/transport/RemoteClusterAwareClientTests.java
  16. 6 13
      server/src/test/java/org/elasticsearch/transport/RemoteClusterConnectionTests.java
  17. 5 1
      test/framework/src/main/java/org/elasticsearch/test/transport/MockTransportService.java
  18. 2 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/store/ReservedRolesStoreTests.java
  19. 8 12
      x-pack/plugin/security/qa/security-trial/src/javaRestTest/java/org/elasticsearch/xpack/security/crossclusteraccess/CrossClusterAccessHeadersForCcsRestIT.java

+ 6 - 0
docs/changelog/95894.yaml

@@ -0,0 +1,6 @@
+pr: 95894
+summary: Integrate CCS with new `search_shards` API
+area: Search
+type: enhancement
+issues:
+ - 93730

+ 6 - 13
qa/ccs-unavailable-clusters/src/javaRestTest/java/org/elasticsearch/search/CrossClusterSearchUnavailableClusterIT.java

@@ -15,10 +15,6 @@ import org.apache.lucene.search.TotalHits;
 import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.TransportVersion;
 import org.elasticsearch.Version;
-import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsAction;
-import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsGroup;
-import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsRequest;
-import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsResponse;
 import org.elasticsearch.action.admin.cluster.state.ClusterStateAction;
 import org.elasticsearch.action.admin.cluster.state.ClusterStateRequest;
 import org.elasticsearch.action.admin.cluster.state.ClusterStateResponse;
@@ -27,6 +23,9 @@ import org.elasticsearch.action.search.SearchAction;
 import org.elasticsearch.action.search.SearchRequest;
 import org.elasticsearch.action.search.SearchResponse;
 import org.elasticsearch.action.search.SearchScrollRequest;
+import org.elasticsearch.action.search.SearchShardsAction;
+import org.elasticsearch.action.search.SearchShardsRequest;
+import org.elasticsearch.action.search.SearchShardsResponse;
 import org.elasticsearch.action.search.ShardSearchFailure;
 import org.elasticsearch.client.Request;
 import org.elasticsearch.client.RequestOptions;
@@ -103,17 +102,11 @@ public class CrossClusterSearchUnavailableClusterIT extends ESRestTestCase {
         MockTransportService newService = MockTransportService.createNewService(s, version, transportVersion, threadPool, null);
         try {
             newService.registerRequestHandler(
-                ClusterSearchShardsAction.NAME,
+                SearchShardsAction.NAME,
                 ThreadPool.Names.SAME,
-                ClusterSearchShardsRequest::new,
+                SearchShardsRequest::new,
                 (request, channel, task) -> {
-                    channel.sendResponse(
-                        new ClusterSearchShardsResponse(
-                            new ClusterSearchShardsGroup[0],
-                            knownNodes.toArray(new DiscoveryNode[0]),
-                            Collections.emptyMap()
-                        )
-                    );
+                    channel.sendResponse(new SearchShardsResponse(List.of(), List.of(), Collections.emptyMap()));
                 }
             );
             newService.registerRequestHandler(SearchAction.NAME, ThreadPool.Names.SAME, SearchRequest::new, (request, channel, task) -> {

+ 193 - 0
server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CCSCanMatchIT.java

@@ -0,0 +1,193 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.search.ccs;
+
+import org.apache.lucene.document.LongPoint;
+import org.apache.lucene.index.DirectoryReader;
+import org.apache.lucene.index.PointValues;
+import org.elasticsearch.action.search.CanMatchNodeRequest;
+import org.elasticsearch.action.search.SearchRequest;
+import org.elasticsearch.action.search.SearchResponse;
+import org.elasticsearch.action.search.SearchTransportService;
+import org.elasticsearch.client.internal.Client;
+import org.elasticsearch.cluster.metadata.IndexMetadata;
+import org.elasticsearch.cluster.node.DiscoveryNode;
+import org.elasticsearch.cluster.node.DiscoveryNodes;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.util.CollectionUtils;
+import org.elasticsearch.index.IndexSettings;
+import org.elasticsearch.index.engine.EngineConfig;
+import org.elasticsearch.index.engine.EngineFactory;
+import org.elasticsearch.index.engine.InternalEngine;
+import org.elasticsearch.index.engine.InternalEngineFactory;
+import org.elasticsearch.index.query.RangeQueryBuilder;
+import org.elasticsearch.index.shard.IndexLongFieldRange;
+import org.elasticsearch.index.shard.ShardLongFieldRange;
+import org.elasticsearch.plugins.EnginePlugin;
+import org.elasticsearch.plugins.Plugin;
+import org.elasticsearch.search.builder.SearchSourceBuilder;
+import org.elasticsearch.test.AbstractMultiClustersTestCase;
+import org.elasticsearch.test.hamcrest.ElasticsearchAssertions;
+import org.elasticsearch.test.transport.MockTransportService;
+import org.elasticsearch.transport.TransportService;
+import org.hamcrest.Matchers;
+
+import java.io.IOException;
+import java.io.UncheckedIOException;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.List;
+import java.util.Optional;
+
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.in;
+
+public class CCSCanMatchIT extends AbstractMultiClustersTestCase {
+    static final String REMOTE_CLUSTER = "cluster_a";
+
+    @Override
+    protected Collection<String> remoteClusterAlias() {
+        return List.of("cluster_a");
+    }
+
+    private static class EngineWithExposingTimestamp extends InternalEngine {
+        EngineWithExposingTimestamp(EngineConfig engineConfig) {
+            super(engineConfig);
+            assert IndexMetadata.INDEX_BLOCKS_WRITE_SETTING.get(config().getIndexSettings().getSettings()) : "require read-only index";
+        }
+
+        @Override
+        public ShardLongFieldRange getRawFieldRange(String field) {
+            try (Searcher searcher = acquireSearcher("test")) {
+                final DirectoryReader directoryReader = searcher.getDirectoryReader();
+
+                final byte[] minPackedValue = PointValues.getMinPackedValue(directoryReader, field);
+                final byte[] maxPackedValue = PointValues.getMaxPackedValue(directoryReader, field);
+                if (minPackedValue == null || maxPackedValue == null) {
+                    assert minPackedValue == null && maxPackedValue == null
+                        : Arrays.toString(minPackedValue) + "-" + Arrays.toString(maxPackedValue);
+                    return ShardLongFieldRange.EMPTY;
+                }
+
+                return ShardLongFieldRange.of(LongPoint.decodeDimension(minPackedValue, 0), LongPoint.decodeDimension(maxPackedValue, 0));
+            } catch (IOException e) {
+                throw new UncheckedIOException(e);
+            }
+        }
+    }
+
+    public static class ExposingTimestampEnginePlugin extends Plugin implements EnginePlugin {
+
+        @Override
+        public Optional<EngineFactory> getEngineFactory(IndexSettings indexSettings) {
+            if (IndexMetadata.INDEX_BLOCKS_WRITE_SETTING.get(indexSettings.getSettings())) {
+                return Optional.of(EngineWithExposingTimestamp::new);
+            } else {
+                return Optional.of(new InternalEngineFactory());
+            }
+        }
+    }
+
+    @Override
+    protected Collection<Class<? extends Plugin>> nodePlugins(String clusterAlias) {
+        return CollectionUtils.appendToCopy(super.nodePlugins(clusterAlias), ExposingTimestampEnginePlugin.class);
+    }
+
+    int createIndexAndIndexDocs(String cluster, String index, int numberOfShards, long timestamp, boolean exposeTimestamp)
+        throws Exception {
+        Client client = client(cluster);
+        ElasticsearchAssertions.assertAcked(
+            client.admin()
+                .indices()
+                .prepareCreate(index)
+                .setSettings(
+                    Settings.builder()
+                        .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, numberOfShards)
+                        .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)
+                )
+                .setMapping("@timestamp", "type=date", "position", "type=long")
+        );
+        int numDocs = between(100, 500);
+        for (int i = 0; i < numDocs; i++) {
+            client.prepareIndex(index).setSource("position", i, "@timestamp", timestamp + i).get();
+        }
+        if (exposeTimestamp) {
+            client.admin().indices().prepareClose(index).get();
+            client.admin()
+                .indices()
+                .prepareUpdateSettings(index)
+                .setSettings(Settings.builder().put(IndexMetadata.INDEX_BLOCKS_WRITE_SETTING.getKey(), true).build())
+                .get();
+            client.admin().indices().prepareOpen(index).get();
+            assertBusy(() -> {
+                IndexLongFieldRange timestampRange = cluster(cluster).clusterService().state().metadata().index(index).getTimestampRange();
+                assertTrue(Strings.toString(timestampRange), timestampRange.containsAllShardRanges());
+            });
+        } else {
+            client.admin().indices().prepareRefresh(index).get();
+        }
+        return numDocs;
+    }
+
+    public void testCanMatchOnTimeRange() throws Exception {
+        long timestamp = randomLongBetween(10_000_000, 50_000_000);
+        int oldLocalNumShards = randomIntBetween(1, 5);
+        createIndexAndIndexDocs(LOCAL_CLUSTER, "local_old_index", oldLocalNumShards, timestamp - 10_000, true);
+        int oldRemoteNumShards = randomIntBetween(1, 5);
+        createIndexAndIndexDocs(REMOTE_CLUSTER, "remote_old_index", oldRemoteNumShards, timestamp - 10_000, true);
+
+        int newLocalNumShards = randomIntBetween(1, 5);
+        int localDocs = createIndexAndIndexDocs(LOCAL_CLUSTER, "local_new_index", newLocalNumShards, timestamp, randomBoolean());
+        int newRemoteNumShards = randomIntBetween(1, 5);
+        int remoteDocs = createIndexAndIndexDocs(REMOTE_CLUSTER, "remote_new_index", newRemoteNumShards, timestamp, randomBoolean());
+
+        for (String cluster : List.of(LOCAL_CLUSTER, REMOTE_CLUSTER)) {
+            for (TransportService ts : cluster(cluster).getInstances(TransportService.class)) {
+                MockTransportService mockTransportService = (MockTransportService) ts;
+                mockTransportService.addSendBehavior((connection, requestId, action, request, options) -> {
+                    if (action.equals(SearchTransportService.QUERY_CAN_MATCH_NODE_NAME)) {
+                        CanMatchNodeRequest canMatchNodeRequest = (CanMatchNodeRequest) request;
+                        List<String> indices = canMatchNodeRequest.getShardLevelRequests()
+                            .stream()
+                            .map(r -> r.shardId().getIndexName())
+                            .toList();
+                        assertThat("old indices should be prefiltered on coordinator node", "local_old_index", Matchers.not(in(indices)));
+                        assertThat("old indices should be prefiltered on coordinator node", "remote_old_index", Matchers.not(in(indices)));
+                        if (cluster.equals(LOCAL_CLUSTER)) {
+                            DiscoveryNode targetNode = connection.getNode();
+                            DiscoveryNodes remoteNodes = cluster(REMOTE_CLUSTER).clusterService().state().nodes();
+                            assertNull("No can_match requests sent across clusters", remoteNodes.get(targetNode.getId()));
+                        }
+                    }
+                    connection.sendRequest(requestId, action, request, options);
+                });
+            }
+        }
+        try {
+            for (boolean minimizeRoundTrips : List.of(true, false)) {
+                SearchSourceBuilder source = new SearchSourceBuilder().query(new RangeQueryBuilder("@timestamp").from(timestamp));
+                SearchRequest request = new SearchRequest("local_*", "*:remote_*");
+                request.source(source).setCcsMinimizeRoundtrips(minimizeRoundTrips);
+                SearchResponse searchResp = client().search(request).actionGet();
+                ElasticsearchAssertions.assertHitCount(searchResp, localDocs + remoteDocs);
+                int totalShards = oldLocalNumShards + newLocalNumShards + oldRemoteNumShards + newRemoteNumShards;
+                assertThat(searchResp.getTotalShards(), equalTo(totalShards));
+                assertThat(searchResp.getSkippedShards(), equalTo(oldLocalNumShards + oldRemoteNumShards));
+            }
+        } finally {
+            for (String cluster : List.of(LOCAL_CLUSTER, REMOTE_CLUSTER)) {
+                for (TransportService ts : cluster(cluster).getInstances(TransportService.class)) {
+                    MockTransportService mockTransportService = (MockTransportService) ts;
+                    mockTransportService.clearAllRules();
+                }
+            }
+        }
+    }
+}

+ 14 - 4
server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java

@@ -158,6 +158,12 @@ final class CanMatchPreFilterSearchPhase extends SearchPhase {
                 searchShardIterator.getClusterAlias()
             );
             final ShardSearchRequest request = canMatchNodeRequest.createShardSearchRequest(buildShardLevelRequest(searchShardIterator));
+            if (searchShardIterator.prefiltered()) {
+                CanMatchShardResponse result = new CanMatchShardResponse(searchShardIterator.skip() == false, null);
+                result.setShardIndex(request.shardRequestIndex());
+                results.consumeResult(result, () -> {});
+                continue;
+            }
             boolean canMatch = true;
             CoordinatorRewriteContext coordinatorRewriteContext = coordinatorRewriteContextProvider.getCoordinatorRewriteContext(
                 request.shardId().getIndex()
@@ -510,8 +516,10 @@ final class CanMatchPreFilterSearchPhase extends SearchPhase {
             // shards available in order to produce a valid search result.
             int shardIndexToQuery = 0;
             for (int i = 0; i < shardsIts.size(); i++) {
-                if (shardsIts.get(i).size() > 0) {
+                SearchShardIterator it = shardsIts.get(i);
+                if (it.size() > 0) {
                     shardIndexToQuery = i;
+                    it.skip(false); // un-skip which is needed when all the remote shards were skipped by the remote can_match
                     break;
                 }
             }
@@ -520,10 +528,12 @@ final class CanMatchPreFilterSearchPhase extends SearchPhase {
         SearchSourceBuilder source = request.source();
         int i = 0;
         for (SearchShardIterator iter : shardsIts) {
-            if (possibleMatches.get(i++)) {
-                iter.reset();
+            iter.reset();
+            boolean match = possibleMatches.get(i++);
+            if (match) {
+                assert iter.skip() == false;
             } else {
-                iter.resetAndSkip();
+                iter.skip(true);
             }
         }
         if (shouldSortShards(results.minAndMaxes) == false) {

+ 35 - 12
server/src/main/java/org/elasticsearch/action/search/SearchShardIterator.java

@@ -34,7 +34,8 @@ public final class SearchShardIterator implements Comparable<SearchShardIterator
     private final OriginalIndices originalIndices;
     private final String clusterAlias;
     private final ShardId shardId;
-    private boolean skip = false;
+    private boolean skip;
+    private final boolean prefiltered;
 
     private final ShardSearchContextId searchContextId;
     private final TimeValue searchContextKeepAlive;
@@ -50,16 +51,30 @@ public final class SearchShardIterator implements Comparable<SearchShardIterator
      * @param originalIndices the indices that the search request originally related to (before any rewriting happened)
      */
     public SearchShardIterator(@Nullable String clusterAlias, ShardId shardId, List<ShardRouting> shards, OriginalIndices originalIndices) {
-        this(clusterAlias, shardId, shards.stream().map(ShardRouting::currentNodeId).toList(), originalIndices, null, null);
+        this(clusterAlias, shardId, shards.stream().map(ShardRouting::currentNodeId).toList(), originalIndices, null, null, false, false);
     }
 
+    /**
+     * Creates a {@link PlainShardIterator} instance that iterates over a subset of the given shards
+     *
+     * @param clusterAlias           the alias of the cluster where the shard is located
+     * @param shardId                shard id of the group
+     * @param targetNodeIds          the list of nodes hosting shard copies
+     * @param originalIndices        the indices that the search request originally related to (before any rewriting happened)
+     * @param searchContextId        the point-in-time specified for this group if exists
+     * @param searchContextKeepAlive the time interval that data nodes should extend the keep alive of the point-in-time
+     * @param prefiltered            if true, then this group already executed the can_match phase
+     * @param skip                   if true, then this group won't have matches, and it can be safely skipped from the search
+     */
     public SearchShardIterator(
         @Nullable String clusterAlias,
         ShardId shardId,
         List<String> targetNodeIds,
         OriginalIndices originalIndices,
         ShardSearchContextId searchContextId,
-        TimeValue searchContextKeepAlive
+        TimeValue searchContextKeepAlive,
+        boolean prefiltered,
+        boolean skip
     ) {
         this.shardId = shardId;
         this.targetNodesIterator = new PlainIterator<>(targetNodeIds);
@@ -68,6 +83,9 @@ public final class SearchShardIterator implements Comparable<SearchShardIterator
         this.searchContextId = searchContextId;
         this.searchContextKeepAlive = searchContextKeepAlive;
         assert searchContextKeepAlive == null || searchContextId != null;
+        this.prefiltered = prefiltered;
+        this.skip = skip;
+        assert skip == false || prefiltered : "only prefiltered shards are skip-able";
     }
 
     /**
@@ -112,15 +130,6 @@ public final class SearchShardIterator implements Comparable<SearchShardIterator
         return targetNodesIterator.asList();
     }
 
-    /**
-     * Reset the iterator and mark it as skippable
-     * @see #skip()
-     */
-    void resetAndSkip() {
-        reset();
-        skip = true;
-    }
-
     void reset() {
         targetNodesIterator.reset();
     }
@@ -132,6 +141,20 @@ public final class SearchShardIterator implements Comparable<SearchShardIterator
         return skip;
     }
 
+    /**
+     * Specifies if the search execution should skip this shard copies
+     */
+    void skip(boolean skip) {
+        this.skip = skip;
+    }
+
+    /**
+     * Returns {@code true} if this iterator was applied pre-filtered
+     */
+    boolean prefiltered() {
+        return prefiltered;
+    }
+
     @Override
     public int size() {
         return targetNodesIterator.size();

+ 33 - 2
server/src/main/java/org/elasticsearch/action/search/SearchShardsGroup.java

@@ -8,12 +8,15 @@
 
 package org.elasticsearch.action.search;
 
+import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsGroup;
+import org.elasticsearch.cluster.routing.ShardRouting;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.index.shard.ShardId;
 
 import java.io.IOException;
+import java.util.Arrays;
 import java.util.List;
 import java.util.Objects;
 
@@ -25,21 +28,38 @@ public class SearchShardsGroup implements Writeable {
     private final ShardId shardId;
     private final List<String> allocatedNodes;
     private final boolean skipped;
+    private final transient boolean preFiltered;
 
     public SearchShardsGroup(ShardId shardId, List<String> allocatedNodes, boolean skipped) {
         this.shardId = shardId;
         this.allocatedNodes = allocatedNodes;
         this.skipped = skipped;
+        this.preFiltered = true;
+    }
+
+    /**
+     * Create a new response from a legacy response from the cluster_search_shards API
+     */
+    SearchShardsGroup(ClusterSearchShardsGroup oldGroup) {
+        this.shardId = oldGroup.getShardId();
+        this.allocatedNodes = Arrays.stream(oldGroup.getShards()).map(ShardRouting::currentNodeId).toList();
+        this.skipped = false;
+        this.preFiltered = false;
     }
 
     public SearchShardsGroup(StreamInput in) throws IOException {
         this.shardId = new ShardId(in);
         this.allocatedNodes = in.readStringList();
         this.skipped = in.readBoolean();
+        this.preFiltered = true;
     }
 
     @Override
     public void writeTo(StreamOutput out) throws IOException {
+        if (preFiltered == false) {
+            assert false : "Serializing a response created from a legacy response is not allowed";
+            throw new IllegalStateException("Serializing a response created from a legacy response is not allowed");
+        }
         shardId.writeTo(out);
         out.writeStringCollection(allocatedNodes);
         out.writeBoolean(skipped);
@@ -56,6 +76,14 @@ public class SearchShardsGroup implements Writeable {
         return skipped;
     }
 
+    /**
+     * Returns true if the can_match was performed against this group. This flag is for BWC purpose. It's always
+     * true for a response from the new search_shards API; but always false for a response from the old API.
+     */
+    boolean preFiltered() {
+        return preFiltered;
+    }
+
     /**
      * The list of node ids that shard copies on this group are allocated on.
      */
@@ -68,11 +96,14 @@ public class SearchShardsGroup implements Writeable {
         if (this == o) return true;
         if (o == null || getClass() != o.getClass()) return false;
         SearchShardsGroup group = (SearchShardsGroup) o;
-        return skipped == group.skipped && shardId.equals(group.shardId) && allocatedNodes.equals(group.allocatedNodes);
+        return skipped == group.skipped
+            && preFiltered == group.preFiltered
+            && shardId.equals(group.shardId)
+            && allocatedNodes.equals(group.allocatedNodes);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(shardId, allocatedNodes, skipped);
+        return Objects.hash(shardId, allocatedNodes, skipped, preFiltered);
     }
 }

+ 30 - 1
server/src/main/java/org/elasticsearch/action/search/SearchShardsResponse.java

@@ -9,13 +9,21 @@
 package org.elasticsearch.action.search;
 
 import org.elasticsearch.action.ActionResponse;
+import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsGroup;
+import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsResponse;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.util.Maps;
+import org.elasticsearch.index.Index;
+import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.search.internal.AliasFilter;
 
 import java.io.IOException;
+import java.util.Arrays;
 import java.util.Collection;
+import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
 import java.util.Objects;
 
@@ -27,7 +35,11 @@ public final class SearchShardsResponse extends ActionResponse {
     private final Collection<DiscoveryNode> nodes;
     private final Map<String, AliasFilter> aliasFilters;
 
-    SearchShardsResponse(Collection<SearchShardsGroup> groups, Collection<DiscoveryNode> nodes, Map<String, AliasFilter> aliasFilters) {
+    public SearchShardsResponse(
+        Collection<SearchShardsGroup> groups,
+        Collection<DiscoveryNode> nodes,
+        Map<String, AliasFilter> aliasFilters
+    ) {
         this.groups = groups;
         this.nodes = nodes;
         this.aliasFilters = aliasFilters;
@@ -80,4 +92,21 @@ public final class SearchShardsResponse extends ActionResponse {
     public int hashCode() {
         return Objects.hash(groups, nodes, aliasFilters);
     }
+
+    static SearchShardsResponse fromLegacyResponse(ClusterSearchShardsResponse oldResp) {
+        Map<String, Index> indexByNames = new HashMap<>();
+        for (ClusterSearchShardsGroup oldGroup : oldResp.getGroups()) {
+            ShardId shardId = oldGroup.getShardId();
+            indexByNames.put(shardId.getIndexName(), shardId.getIndex());
+        }
+        // convert index_name -> alias_filters to index_uuid -> alias_filters
+        Map<String, AliasFilter> aliasFilters = Maps.newMapWithExpectedSize(oldResp.getIndicesAndFilters().size());
+        for (Map.Entry<String, AliasFilter> e : oldResp.getIndicesAndFilters().entrySet()) {
+            Index index = indexByNames.get(e.getKey());
+            aliasFilters.put(index.getUUID(), e.getValue());
+        }
+        List<SearchShardsGroup> groups = Arrays.stream(oldResp.getGroups()).map(SearchShardsGroup::new).toList();
+        assert groups.stream().noneMatch(SearchShardsGroup::preFiltered) : "legacy responses must not have preFiltered set";
+        return new SearchShardsResponse(groups, Arrays.asList(oldResp.getNodes()), aliasFilters);
+    }
 }

+ 114 - 75
server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java

@@ -8,10 +8,12 @@
 
 package org.elasticsearch.action.search;
 
+import org.elasticsearch.TransportVersion;
 import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.ActionListenerResponseHandler;
 import org.elasticsearch.action.IndicesRequest;
 import org.elasticsearch.action.OriginalIndices;
-import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsGroup;
+import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsAction;
 import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsRequest;
 import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsResponse;
 import org.elasticsearch.action.support.ActionFilters;
@@ -46,6 +48,7 @@ import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.index.Index;
 import org.elasticsearch.index.IndexNotFoundException;
+import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.index.query.Rewriteable;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.index.shard.ShardNotFoundException;
@@ -67,6 +70,7 @@ import org.elasticsearch.transport.RemoteClusterAware;
 import org.elasticsearch.transport.RemoteClusterService;
 import org.elasticsearch.transport.RemoteTransportException;
 import org.elasticsearch.transport.Transport;
+import org.elasticsearch.transport.TransportRequestOptions;
 import org.elasticsearch.transport.TransportService;
 
 import java.util.ArrayList;
@@ -122,6 +126,7 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
 
     private final ThreadPool threadPool;
     private final ClusterService clusterService;
+    private final TransportService transportService;
     private final SearchTransportService searchTransportService;
     private final RemoteClusterService remoteClusterService;
     private final SearchPhaseController searchPhaseController;
@@ -155,6 +160,7 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
         this.remoteClusterService = searchTransportService.getRemoteClusterService();
         SearchTransportService.registerRequestHandler(transportService, searchService);
         this.clusterService = clusterService;
+        this.transportService = transportService;
         this.searchService = searchService;
         this.indexNameExpressionResolver = indexNameExpressionResolver;
         this.namedWriteableRegistry = namedWriteableRegistry;
@@ -308,8 +314,8 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
                     searchPhaseProvider.apply(listener)
                 );
             } else {
+                final TaskId parentTaskId = task.taskInfo(clusterService.localNode().getId(), false).taskId();
                 if (shouldMinimizeRoundtrips(rewritten)) {
-                    final TaskId parentTaskId = task.taskInfo(clusterService.localNode().getId(), false).taskId();
                     ccsRemoteReduce(
                         parentTaskId,
                         rewritten,
@@ -332,14 +338,17 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
                     );
                 } else {
                     AtomicInteger skippedClusters = new AtomicInteger(0);
+                    // TODO: pass parentTaskId
                     collectSearchShards(
                         rewritten.indicesOptions(),
                         rewritten.preference(),
                         rewritten.routing(),
+                        rewritten.source() != null ? rewritten.source().query() : null,
+                        Objects.requireNonNullElse(rewritten.allowPartialSearchResults(), searchService.defaultAllowPartialSearchResults()),
+                        searchContext,
                         skippedClusters,
                         remoteClusterIndices,
-                        remoteClusterService,
-                        threadPool,
+                        transportService,
                         ActionListener.wrap(searchShardsResponses -> {
                             final BiFunction<String, String, DiscoveryNode> clusterNodeLookup = getRemoteClusterNodeLookup(
                                 searchShardsResponses
@@ -355,7 +364,10 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
                                     remoteClusterIndices
                                 );
                             } else {
-                                remoteAliasFilters = getRemoteAliasFilters(searchShardsResponses);
+                                remoteAliasFilters = new HashMap<>();
+                                for (SearchShardsResponse searchShardsResponse : searchShardsResponses.values()) {
+                                    remoteAliasFilters.putAll(searchShardsResponse.getAliasFilters());
+                                }
                                 remoteShardIterators = getRemoteShardsIterator(
                                     searchShardsResponses,
                                     remoteClusterIndices,
@@ -585,47 +597,80 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
         IndicesOptions indicesOptions,
         String preference,
         String routing,
+        QueryBuilder query,
+        boolean allowPartialResults,
+        SearchContextId searchContext,
         AtomicInteger skippedClusters,
         Map<String, OriginalIndices> remoteIndicesByCluster,
-        RemoteClusterService remoteClusterService,
-        ThreadPool threadPool,
-        ActionListener<Map<String, ClusterSearchShardsResponse>> listener
+        TransportService transportService,
+        ActionListener<Map<String, SearchShardsResponse>> listener
     ) {
+        RemoteClusterService remoteClusterService = transportService.getRemoteClusterService();
         final CountDown responsesCountDown = new CountDown(remoteIndicesByCluster.size());
-        final Map<String, ClusterSearchShardsResponse> searchShardsResponses = new ConcurrentHashMap<>();
+        final Map<String, SearchShardsResponse> searchShardsResponses = new ConcurrentHashMap<>();
         final AtomicReference<Exception> exceptions = new AtomicReference<>();
         for (Map.Entry<String, OriginalIndices> entry : remoteIndicesByCluster.entrySet()) {
             final String clusterAlias = entry.getKey();
             boolean skipUnavailable = remoteClusterService.isSkipUnavailable(clusterAlias);
-            Client clusterClient = remoteClusterService.getRemoteClusterClient(threadPool, clusterAlias);
-            final String[] indices = entry.getValue().indices();
-            ClusterSearchShardsRequest searchShardsRequest = new ClusterSearchShardsRequest(indices).indicesOptions(indicesOptions)
-                .local(true)
-                .preference(preference)
-                .routing(routing);
-            clusterClient.admin()
-                .cluster()
-                .searchShards(
-                    searchShardsRequest,
-                    new CCSActionListener<ClusterSearchShardsResponse, Map<String, ClusterSearchShardsResponse>>(
-                        clusterAlias,
-                        skipUnavailable,
-                        responsesCountDown,
-                        skippedClusters,
-                        exceptions,
-                        listener
-                    ) {
-                        @Override
-                        void innerOnResponse(ClusterSearchShardsResponse clusterSearchShardsResponse) {
-                            searchShardsResponses.put(clusterAlias, clusterSearchShardsResponse);
-                        }
+            TransportSearchAction.CCSActionListener<SearchShardsResponse, Map<String, SearchShardsResponse>> singleListener =
+                new TransportSearchAction.CCSActionListener<>(
+                    clusterAlias,
+                    skipUnavailable,
+                    responsesCountDown,
+                    skippedClusters,
+                    exceptions,
+                    listener
+                ) {
+                    @Override
+                    void innerOnResponse(SearchShardsResponse searchShardsResponse) {
+                        searchShardsResponses.put(clusterAlias, searchShardsResponse);
+                    }
 
-                        @Override
-                        Map<String, ClusterSearchShardsResponse> createFinalResponse() {
-                            return searchShardsResponses;
-                        }
+                    @Override
+                    Map<String, SearchShardsResponse> createFinalResponse() {
+                        return searchShardsResponses;
                     }
-                );
+                };
+            remoteClusterService.maybeEnsureConnectedAndGetConnection(
+                clusterAlias,
+                skipUnavailable == false,
+                ActionListener.wrap(connection -> {
+                    final String[] indices = entry.getValue().indices();
+                    // TODO: support point-in-time
+                    if (searchContext == null && connection.getTransportVersion().onOrAfter(TransportVersion.V_8_500_000)) {
+                        SearchShardsRequest searchShardsRequest = new SearchShardsRequest(
+                            indices,
+                            indicesOptions,
+                            query,
+                            routing,
+                            preference,
+                            allowPartialResults,
+                            clusterAlias
+                        );
+                        transportService.sendRequest(
+                            connection,
+                            SearchShardsAction.NAME,
+                            searchShardsRequest,
+                            TransportRequestOptions.EMPTY,
+                            new ActionListenerResponseHandler<>(singleListener, SearchShardsResponse::new)
+                        );
+                    } else {
+                        ClusterSearchShardsRequest searchShardsRequest = new ClusterSearchShardsRequest(indices).indicesOptions(
+                            indicesOptions
+                        ).local(true).preference(preference).routing(routing);
+                        transportService.sendRequest(
+                            connection,
+                            ClusterSearchShardsAction.NAME,
+                            searchShardsRequest,
+                            TransportRequestOptions.EMPTY,
+                            new ActionListenerResponseHandler<>(
+                                singleListener.map(SearchShardsResponse::fromLegacyResponse),
+                                ClusterSearchShardsResponse::new
+                            )
+                        );
+                    }
+                }, singleListener::onFailure)
+            );
         }
     }
 
@@ -688,9 +733,9 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
         );
     }
 
-    static BiFunction<String, String, DiscoveryNode> getRemoteClusterNodeLookup(Map<String, ClusterSearchShardsResponse> searchShardsResp) {
+    static BiFunction<String, String, DiscoveryNode> getRemoteClusterNodeLookup(Map<String, SearchShardsResponse> searchShardsResp) {
         Map<String, Map<String, DiscoveryNode>> clusterToNode = new HashMap<>();
-        for (Map.Entry<String, ClusterSearchShardsResponse> entry : searchShardsResp.entrySet()) {
+        for (Map.Entry<String, SearchShardsResponse> entry : searchShardsResp.entrySet()) {
             String clusterAlias = entry.getKey();
             for (DiscoveryNode remoteNode : entry.getValue().getNodes()) {
                 clusterToNode.computeIfAbsent(clusterAlias, k -> new HashMap<>()).put(remoteNode.getId(), remoteNode);
@@ -705,38 +750,17 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
         };
     }
 
-    static Map<String, AliasFilter> getRemoteAliasFilters(Map<String, ClusterSearchShardsResponse> searchShardsResp) {
-        final Map<String, AliasFilter> aliasFilterMap = new HashMap<>();
-        for (Map.Entry<String, ClusterSearchShardsResponse> entry : searchShardsResp.entrySet()) {
-            ClusterSearchShardsResponse searchShardsResponse = entry.getValue();
-            final Map<String, AliasFilter> indicesAndFilters = searchShardsResponse.getIndicesAndFilters();
-            for (ClusterSearchShardsGroup clusterSearchShardsGroup : searchShardsResponse.getGroups()) {
-                ShardId shardId = clusterSearchShardsGroup.getShardId();
-                final AliasFilter aliasFilter;
-                if (indicesAndFilters == null) {
-                    aliasFilter = AliasFilter.EMPTY;
-                } else {
-                    aliasFilter = indicesAndFilters.get(shardId.getIndexName());
-                    assert aliasFilter != null : "alias filter must not be null for index: " + shardId.getIndex();
-                }
-                // here we have to map the filters to the UUID since from now on we use the uuid for the lookup
-                aliasFilterMap.put(shardId.getIndex().getUUID(), aliasFilter);
-            }
-        }
-        return aliasFilterMap;
-    }
-
     static List<SearchShardIterator> getRemoteShardsIterator(
-        Map<String, ClusterSearchShardsResponse> searchShardsResponses,
+        Map<String, SearchShardsResponse> searchShardsResponses,
         Map<String, OriginalIndices> remoteIndicesByCluster,
         Map<String, AliasFilter> aliasFilterMap
     ) {
         final List<SearchShardIterator> remoteShardIterators = new ArrayList<>();
-        for (Map.Entry<String, ClusterSearchShardsResponse> entry : searchShardsResponses.entrySet()) {
-            for (ClusterSearchShardsGroup clusterSearchShardsGroup : entry.getValue().getGroups()) {
+        for (Map.Entry<String, SearchShardsResponse> entry : searchShardsResponses.entrySet()) {
+            for (SearchShardsGroup searchShardsGroup : entry.getValue().getGroups()) {
                 // add the cluster name to the remote index names for indices disambiguation
                 // this ends up in the hits returned with the search response
-                ShardId shardId = clusterSearchShardsGroup.getShardId();
+                ShardId shardId = searchShardsGroup.shardId();
                 AliasFilter aliasFilter = aliasFilterMap.get(shardId.getIndex().getUUID());
                 String[] aliases = aliasFilter.getAliases();
                 String clusterAlias = entry.getKey();
@@ -746,8 +770,12 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
                 SearchShardIterator shardIterator = new SearchShardIterator(
                     clusterAlias,
                     shardId,
-                    Arrays.asList(clusterSearchShardsGroup.getShards()),
-                    new OriginalIndices(finalIndices, originalIndices.indicesOptions())
+                    searchShardsGroup.allocatedNodes(),
+                    new OriginalIndices(finalIndices, originalIndices.indicesOptions()),
+                    null,
+                    null,
+                    searchShardsGroup.preFiltered(),
+                    searchShardsGroup.skipped()
                 );
                 remoteShardIterators.add(shardIterator);
             }
@@ -756,24 +784,24 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
     }
 
     static List<SearchShardIterator> getRemoteShardsIteratorFromPointInTime(
-        Map<String, ClusterSearchShardsResponse> searchShardsResponses,
+        Map<String, SearchShardsResponse> searchShardsResponses,
         SearchContextId searchContextId,
         TimeValue searchContextKeepAlive,
         Map<String, OriginalIndices> remoteClusterIndices
     ) {
         final List<SearchShardIterator> remoteShardIterators = new ArrayList<>();
-        for (Map.Entry<String, ClusterSearchShardsResponse> entry : searchShardsResponses.entrySet()) {
-            for (ClusterSearchShardsGroup group : entry.getValue().getGroups()) {
-                final ShardId shardId = group.getShardId();
+        for (Map.Entry<String, SearchShardsResponse> entry : searchShardsResponses.entrySet()) {
+            for (SearchShardsGroup group : entry.getValue().getGroups()) {
+                final ShardId shardId = group.shardId();
                 final String clusterAlias = entry.getKey();
                 final SearchContextIdForNode perNode = searchContextId.shards().get(shardId);
                 assert clusterAlias.equals(perNode.getClusterAlias()) : clusterAlias + " != " + perNode.getClusterAlias();
-                final List<String> targetNodes = new ArrayList<>(group.getShards().length);
+                final List<String> targetNodes = new ArrayList<>(group.allocatedNodes().size());
                 targetNodes.add(perNode.getNode());
                 if (perNode.getSearchContextId().getSearcherId() != null) {
-                    for (ShardRouting shard : group.getShards()) {
-                        if (shard.currentNodeId().equals(perNode.getNode()) == false) {
-                            targetNodes.add(shard.currentNodeId());
+                    for (String node : group.allocatedNodes()) {
+                        if (node.equals(perNode.getNode()) == false) {
+                            targetNodes.add(node);
                         }
                     }
                 }
@@ -788,7 +816,9 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
                     targetNodes,
                     finalIndices,
                     perNode.getSearchContextId(),
-                    searchContextKeepAlive
+                    searchContextKeepAlive,
+                    false,
+                    false
                 );
                 remoteShardIterators.add(shardIterator);
             }
@@ -1331,7 +1361,16 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
                     originalIndices.indicesOptions()
                 );
                 iterators.add(
-                    new SearchShardIterator(localClusterAlias, shardId, targetNodes, finalIndices, perNode.getSearchContextId(), keepAlive)
+                    new SearchShardIterator(
+                        localClusterAlias,
+                        shardId,
+                        targetNodes,
+                        finalIndices,
+                        perNode.getSearchContextId(),
+                        keepAlive,
+                        false,
+                        false
+                    )
                 );
             }
         }

+ 31 - 0
server/src/main/java/org/elasticsearch/transport/RemoteClusterService.java

@@ -240,6 +240,37 @@ public final class RemoteClusterService extends RemoteClusterAware implements Cl
         return getRemoteClusterConnection(cluster).getConnection();
     }
 
+    /**
+     * Unlike {@link #getConnection(String)} this method might attempt to re-establish a remote connection if there is no connection
+     * available before returning a connection to the remote cluster.
+     *
+     * @param clusterAlias    the remote cluster
+     * @param ensureConnected whether requests should wait for a connection attempt when there isn't available connection
+     * @param listener        a listener that will be notified the connection or failure
+     */
+    public void maybeEnsureConnectedAndGetConnection(
+        String clusterAlias,
+        boolean ensureConnected,
+        ActionListener<Transport.Connection> listener
+    ) {
+        ActionListener<Void> ensureConnectedListener = ActionListener.wrap(nullValue -> ActionListener.completeWith(listener, () -> {
+            try {
+                return getConnection(clusterAlias);
+            } catch (NoSuchRemoteClusterException e) {
+                if (ensureConnected == false) {
+                    // trigger another connection attempt, but don't wait for it to complete
+                    ensureConnected(clusterAlias, ActionListener.noop());
+                }
+                throw e;
+            }
+        }), listener::onFailure);
+        if (ensureConnected) {
+            ensureConnected(clusterAlias, ensureConnectedListener);
+        } else {
+            ensureConnectedListener.onResponse(null);
+        }
+    }
+
     RemoteClusterConnection getRemoteClusterConnection(String cluster) {
         if (enabled == false) {
             throw new IllegalArgumentException(

+ 2 - 1
server/src/test/java/org/elasticsearch/action/search/AbstractSearchAsyncActionTests.java

@@ -241,7 +241,8 @@ public class AbstractSearchAsyncActionTests extends ESTestCase {
         AbstractSearchAsyncAction<SearchPhaseResult> action = createAction(searchRequest, phaseResults, listener, false, new AtomicLong());
         // skip one to avoid the "all shards failed" failure.
         SearchShardIterator skipIterator = new SearchShardIterator(null, null, Collections.emptyList(), null);
-        skipIterator.resetAndSkip();
+        skipIterator.skip(true);
+        skipIterator.reset();
         action.skipShard(skipIterator);
         assertThat(exception.get(), instanceOf(SearchPhaseExecutionException.class));
         SearchPhaseExecutionException searchPhaseExecutionException = (SearchPhaseExecutionException) exception.get();

+ 3 - 2
server/src/test/java/org/elasticsearch/action/search/SearchAsyncActionTests.java

@@ -81,7 +81,7 @@ public class SearchAsyncActionTests extends ESTestCase {
         int numSkipped = 0;
         for (SearchShardIterator iter : shardsIter) {
             if (iter.shardId().id() % 2 == 0) {
-                iter.resetAndSkip();
+                iter.skip(true);
                 numSkipped++;
             }
         }
@@ -633,7 +633,8 @@ public class SearchAsyncActionTests extends ESTestCase {
                 originalIndices
             );
             // Skip all the shards
-            searchShardIterator.resetAndSkip();
+            searchShardIterator.skip(true);
+            searchShardIterator.reset();
             searchShardIterators.add(searchShardIterator);
         }
         GroupShardsIterator<SearchShardIterator> shardsIter = new GroupShardsIterator<>(searchShardIterators);

+ 15 - 5
server/src/test/java/org/elasticsearch/action/search/SearchShardIteratorTests.java

@@ -70,7 +70,9 @@ public class SearchShardIteratorTests extends ESTestCase {
             List.of(nodeId),
             originalIndices,
             null,
-            null
+            null,
+            false,
+            false
         );
         final SearchShardTarget searchShardTarget = searchShardIterator.nextOrNull();
         assertNotNull(searchShardTarget);
@@ -89,7 +91,9 @@ public class SearchShardIteratorTests extends ESTestCase {
                 s.getTargetNodeIds(),
                 s.getOriginalIndices(),
                 s.getSearchContextId(),
-                s.getSearchContextKeepAlive()
+                s.getSearchContextKeepAlive(),
+                s.prefiltered(),
+                s.skip()
             ),
             s -> {
                 if (randomBoolean()) {
@@ -105,7 +109,9 @@ public class SearchShardIteratorTests extends ESTestCase {
                         s.getTargetNodeIds(),
                         s.getOriginalIndices(),
                         s.getSearchContextId(),
-                        s.getSearchContextKeepAlive()
+                        s.getSearchContextKeepAlive(),
+                        s.prefiltered(),
+                        s.skip()
                     );
                 } else {
                     ShardId shardId = new ShardId(
@@ -119,7 +125,9 @@ public class SearchShardIteratorTests extends ESTestCase {
                         s.getTargetNodeIds(),
                         s.getOriginalIndices(),
                         s.getSearchContextId(),
-                        s.getSearchContextKeepAlive()
+                        s.getSearchContextKeepAlive(),
+                        s.prefiltered(),
+                        s.skip()
                     );
                 }
             }
@@ -186,7 +194,9 @@ public class SearchShardIteratorTests extends ESTestCase {
             shardIterator1.getTargetNodeIds(),
             shardIterator1.getOriginalIndices(),
             shardIterator1.getSearchContextId(),
-            shardIterator1.getSearchContextKeepAlive()
+            shardIterator1.getSearchContextKeepAlive(),
+            shardIterator1.prefiltered(),
+            shardIterator1.skip()
         );
         assertEquals(shardIterator1, shardIterator2);
         assertEquals(0, shardIterator1.compareTo(shardIterator2));

+ 64 - 0
server/src/test/java/org/elasticsearch/action/search/SearchShardsResponseTests.java

@@ -8,17 +8,29 @@
 
 package org.elasticsearch.action.search;
 
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsGroup;
+import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsResponse;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.node.TestDiscoveryNode;
+import org.elasticsearch.cluster.routing.ShardRouting;
+import org.elasticsearch.cluster.routing.ShardRoutingState;
+import org.elasticsearch.cluster.routing.TestShardRouting;
 import org.elasticsearch.common.UUIDs;
+import org.elasticsearch.common.io.stream.BytesStreamOutput;
 import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.transport.TransportAddress;
+import org.elasticsearch.common.util.iterable.Iterables;
 import org.elasticsearch.index.query.RandomQueryBuilder;
+import org.elasticsearch.index.query.TermQueryBuilder;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.search.SearchModule;
 import org.elasticsearch.search.internal.AliasFilter;
 import org.elasticsearch.test.AbstractWireSerializingTestCase;
+import org.elasticsearch.test.TransportVersionUtils;
+import org.elasticsearch.test.VersionUtils;
 
 import java.io.IOException;
 import java.util.ArrayList;
@@ -27,6 +39,9 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.hasSize;
+
 public class SearchShardsResponseTests extends AbstractWireSerializingTestCase<SearchShardsResponse> {
 
     @Override
@@ -92,4 +107,53 @@ public class SearchShardsResponseTests extends AbstractWireSerializingTestCase<S
             }
         }
     }
+
+    public void testLegacyResponse() {
+        DiscoveryNode node1 = TestDiscoveryNode.create(
+            "node-1",
+            new TransportAddress(TransportAddress.META_ADDRESS, randomInt(0xFFFF)),
+            VersionUtils.randomVersion(random())
+        );
+        DiscoveryNode node2 = TestDiscoveryNode.create(
+            "node-2",
+            new TransportAddress(TransportAddress.META_ADDRESS, randomInt(0xFFFF)),
+            VersionUtils.randomVersion(random())
+        );
+        final ClusterSearchShardsGroup[] groups = new ClusterSearchShardsGroup[2];
+        {
+            ShardId shardId = new ShardId("index-1", "uuid-1", 0);
+            var shard1 = TestShardRouting.newShardRouting(shardId, node1.getId(), randomBoolean(), ShardRoutingState.STARTED);
+            var shard2 = TestShardRouting.newShardRouting(shardId, node2.getId(), randomBoolean(), ShardRoutingState.STARTED);
+            groups[0] = new ClusterSearchShardsGroup(shardId, new ShardRouting[] { shard1, shard2 });
+        }
+        {
+            ShardId shardId = new ShardId("index-2", "uuid-2", 7);
+            var shard1 = TestShardRouting.newShardRouting(shardId, node1.getId(), randomBoolean(), ShardRoutingState.STARTED);
+            groups[1] = new ClusterSearchShardsGroup(shardId, new ShardRouting[] { shard1 });
+        }
+        AliasFilter aliasFilter = AliasFilter.of(new TermQueryBuilder("t", "v"), "alias-1");
+        var legacyResponse = new ClusterSearchShardsResponse(groups, new DiscoveryNode[] { node1, node2 }, Map.of("index-1", aliasFilter));
+        SearchShardsResponse newResponse = SearchShardsResponse.fromLegacyResponse(legacyResponse);
+        assertThat(newResponse.getNodes(), equalTo(List.of(node1, node2)));
+        assertThat(newResponse.getAliasFilters(), equalTo(Map.of("uuid-1", aliasFilter)));
+        assertThat(newResponse.getGroups(), hasSize(2));
+        SearchShardsGroup group1 = Iterables.get(newResponse.getGroups(), 0);
+        assertThat(group1.shardId(), equalTo(new ShardId("index-1", "uuid-1", 0)));
+        assertThat(group1.allocatedNodes(), equalTo(List.of("node-1", "node-2")));
+        assertFalse(group1.skipped());
+        assertFalse(group1.preFiltered());
+
+        SearchShardsGroup group2 = Iterables.get(newResponse.getGroups(), 1);
+        assertThat(group2.shardId(), equalTo(new ShardId("index-2", "uuid-2", 7)));
+        assertThat(group2.allocatedNodes(), equalTo(List.of("node-1")));
+        assertFalse(group2.skipped());
+        assertFalse(group2.preFiltered());
+
+        TransportVersion version = TransportVersionUtils.randomCompatibleVersion(random());
+        try (BytesStreamOutput out = new BytesStreamOutput()) {
+            out.setTransportVersion(version);
+            AssertionError error = expectThrows(AssertionError.class, () -> newResponse.writeTo(out));
+            assertThat(error.getMessage(), equalTo("Serializing a response created from a legacy response is not allowed"));
+        }
+    }
 }

+ 133 - 116
server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java

@@ -99,6 +99,7 @@ import java.util.Collections;
 import java.util.Comparator;
 import java.util.HashMap;
 import java.util.HashSet;
+import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Optional;
@@ -245,111 +246,116 @@ public class TransportSearchActionTests extends ESTestCase {
     }
 
     public void testProcessRemoteShards() {
-        try (
-            TransportService transportService = MockTransportService.createNewService(
-                Settings.EMPTY,
-                Version.CURRENT,
-                TransportVersion.CURRENT,
-                threadPool,
-                null
-            )
-        ) {
-            RemoteClusterService service = transportService.getRemoteClusterService();
-            assertFalse(service.isCrossClusterSearchEnabled());
-            Map<String, ClusterSearchShardsResponse> searchShardsResponseMap = new HashMap<>();
-            DiscoveryNode[] nodes = new DiscoveryNode[] { TestDiscoveryNode.create("node1"), TestDiscoveryNode.create("node2") };
-            Map<String, AliasFilter> indicesAndAliases = new HashMap<>();
-            indicesAndAliases.put("foo", AliasFilter.of(new TermsQueryBuilder("foo", "bar"), "some_alias_for_foo", "some_other_foo_alias"));
-            indicesAndAliases.put("bar", AliasFilter.of(new MatchAllQueryBuilder(), Strings.EMPTY_ARRAY));
-            ClusterSearchShardsGroup[] groups = new ClusterSearchShardsGroup[] {
-                new ClusterSearchShardsGroup(
-                    new ShardId("foo", "foo_id", 0),
-                    new ShardRouting[] {
-                        TestShardRouting.newShardRouting("foo", 0, "node1", true, ShardRoutingState.STARTED),
-                        TestShardRouting.newShardRouting("foo", 0, "node2", false, ShardRoutingState.STARTED) }
-                ),
-                new ClusterSearchShardsGroup(
-                    new ShardId("foo", "foo_id", 1),
-                    new ShardRouting[] {
-                        TestShardRouting.newShardRouting("foo", 0, "node1", true, ShardRoutingState.STARTED),
-                        TestShardRouting.newShardRouting("foo", 1, "node2", false, ShardRoutingState.STARTED) }
-                ),
-                new ClusterSearchShardsGroup(
-                    new ShardId("bar", "bar_id", 0),
-                    new ShardRouting[] {
-                        TestShardRouting.newShardRouting("bar", 0, "node2", true, ShardRoutingState.STARTED),
-                        TestShardRouting.newShardRouting("bar", 0, "node1", false, ShardRoutingState.STARTED) }
-                ) };
-            searchShardsResponseMap.put("test_cluster_1", new ClusterSearchShardsResponse(groups, nodes, indicesAndAliases));
+        Map<String, SearchShardsResponse> searchShardsResponseMap = new LinkedHashMap<>();
+        // first cluster - new response
+        {
+            List<DiscoveryNode> nodes = List.of(TestDiscoveryNode.create("node1"), TestDiscoveryNode.create("node2"));
+            Map<String, AliasFilter> aliasFilters1 = Map.of(
+                "foo_id",
+                AliasFilter.of(new TermsQueryBuilder("foo", "bar"), "some_alias_for_foo", "some_other_foo_alias"),
+                "bar_id",
+                AliasFilter.of(new MatchAllQueryBuilder(), Strings.EMPTY_ARRAY)
+            );
+            List<SearchShardsGroup> groups = List.of(
+                new SearchShardsGroup(new ShardId("foo", "foo_id", 0), List.of("node1", "node2"), false),
+                new SearchShardsGroup(new ShardId("foo", "foo_id", 1), List.of("node2", "node1"), true),
+                new SearchShardsGroup(new ShardId("bar", "bar_id", 0), List.of("node2", "node1"), false)
+            );
+            searchShardsResponseMap.put("test_cluster_1", new SearchShardsResponse(groups, nodes, aliasFilters1));
+        }
+        // second cluster - legacy response
+        {
             DiscoveryNode[] nodes2 = new DiscoveryNode[] { TestDiscoveryNode.create("node3") };
             ClusterSearchShardsGroup[] groups2 = new ClusterSearchShardsGroup[] {
                 new ClusterSearchShardsGroup(
                     new ShardId("xyz", "xyz_id", 0),
                     new ShardRouting[] { TestShardRouting.newShardRouting("xyz", 0, "node3", true, ShardRoutingState.STARTED) }
                 ) };
-            Map<String, AliasFilter> filter = new HashMap<>();
-            filter.put("xyz", AliasFilter.of(null, "some_alias_for_xyz"));
-            searchShardsResponseMap.put("test_cluster_2", new ClusterSearchShardsResponse(groups2, nodes2, filter));
-
-            Map<String, OriginalIndices> remoteIndicesByCluster = new HashMap<>();
-            remoteIndicesByCluster.put(
-                "test_cluster_1",
-                new OriginalIndices(new String[] { "fo*", "ba*" }, SearchRequest.DEFAULT_INDICES_OPTIONS)
-            );
-            remoteIndicesByCluster.put("test_cluster_2", new OriginalIndices(new String[] { "x*" }, SearchRequest.DEFAULT_INDICES_OPTIONS));
-            Map<String, AliasFilter> remoteAliases = TransportSearchAction.getRemoteAliasFilters(searchShardsResponseMap);
-            List<SearchShardIterator> iteratorList = TransportSearchAction.getRemoteShardsIterator(
-                searchShardsResponseMap,
-                remoteIndicesByCluster,
-                remoteAliases
+            Map<String, AliasFilter> aliasFilters2 = Map.of("xyz", AliasFilter.of(null, "some_alias_for_xyz"));
+            searchShardsResponseMap.put(
+                "test_cluster_2",
+                SearchShardsResponse.fromLegacyResponse(new ClusterSearchShardsResponse(groups2, nodes2, aliasFilters2))
             );
-            assertEquals(4, iteratorList.size());
-            for (SearchShardIterator iterator : iteratorList) {
-                if (iterator.shardId().getIndexName().endsWith("foo")) {
-                    assertArrayEquals(
-                        new String[] { "some_alias_for_foo", "some_other_foo_alias" },
-                        iterator.getOriginalIndices().indices()
-                    );
-                    assertTrue(iterator.shardId().getId() == 0 || iterator.shardId().getId() == 1);
-                    assertEquals("test_cluster_1", iterator.getClusterAlias());
-                    assertEquals("foo", iterator.shardId().getIndexName());
-                    SearchShardTarget shard = iterator.nextOrNull();
-                    assertNotNull(shard);
-                    assertEquals(shard.getShardId().getIndexName(), "foo");
-                    shard = iterator.nextOrNull();
-                    assertNotNull(shard);
-                    assertEquals(shard.getShardId().getIndexName(), "foo");
-                    assertNull(iterator.nextOrNull());
-                } else if (iterator.shardId().getIndexName().endsWith("bar")) {
-                    assertArrayEquals(new String[] { "bar" }, iterator.getOriginalIndices().indices());
-                    assertEquals(0, iterator.shardId().getId());
-                    assertEquals("test_cluster_1", iterator.getClusterAlias());
-                    assertEquals("bar", iterator.shardId().getIndexName());
-                    SearchShardTarget shard = iterator.nextOrNull();
-                    assertNotNull(shard);
-                    assertEquals(shard.getShardId().getIndexName(), "bar");
-                    shard = iterator.nextOrNull();
-                    assertNotNull(shard);
-                    assertEquals(shard.getShardId().getIndexName(), "bar");
-                    assertNull(iterator.nextOrNull());
-                } else if (iterator.shardId().getIndexName().endsWith("xyz")) {
-                    assertArrayEquals(new String[] { "some_alias_for_xyz" }, iterator.getOriginalIndices().indices());
-                    assertEquals(0, iterator.shardId().getId());
-                    assertEquals("xyz", iterator.shardId().getIndexName());
-                    assertEquals("test_cluster_2", iterator.getClusterAlias());
-                    SearchShardTarget shard = iterator.nextOrNull();
-                    assertNotNull(shard);
-                    assertEquals(shard.getShardId().getIndexName(), "xyz");
-                    assertNull(iterator.nextOrNull());
-                }
-            }
-            assertEquals(3, remoteAliases.size());
-            assertTrue(remoteAliases.toString(), remoteAliases.containsKey("foo_id"));
-            assertTrue(remoteAliases.toString(), remoteAliases.containsKey("bar_id"));
-            assertTrue(remoteAliases.toString(), remoteAliases.containsKey("xyz_id"));
-            assertEquals(new TermsQueryBuilder("foo", "bar"), remoteAliases.get("foo_id").getQueryBuilder());
-            assertEquals(new MatchAllQueryBuilder(), remoteAliases.get("bar_id").getQueryBuilder());
-            assertNull(remoteAliases.get("xyz_id").getQueryBuilder());
+        }
+        Map<String, OriginalIndices> remoteIndicesByCluster = Map.of(
+            "test_cluster_1",
+            new OriginalIndices(new String[] { "fo*", "ba*" }, SearchRequest.DEFAULT_INDICES_OPTIONS),
+            "test_cluster_2",
+            new OriginalIndices(new String[] { "x*" }, SearchRequest.DEFAULT_INDICES_OPTIONS)
+        );
+        Map<String, AliasFilter> aliasFilters = new HashMap<>();
+        searchShardsResponseMap.values().forEach(r -> aliasFilters.putAll(r.getAliasFilters()));
+        List<SearchShardIterator> iteratorList = TransportSearchAction.getRemoteShardsIterator(
+            searchShardsResponseMap,
+            remoteIndicesByCluster,
+            aliasFilters
+        );
+        assertThat(iteratorList, hasSize(4));
+        {
+            SearchShardIterator shardIt = iteratorList.get(0);
+            assertTrue(shardIt.prefiltered());
+            assertFalse(shardIt.skip());
+            assertThat(shardIt.shardId(), equalTo(new ShardId("foo", "foo_id", 0)));
+            assertArrayEquals(new String[] { "some_alias_for_foo", "some_other_foo_alias" }, shardIt.getOriginalIndices().indices());
+            assertEquals("test_cluster_1", shardIt.getClusterAlias());
+            assertEquals("foo", shardIt.shardId().getIndexName());
+            SearchShardTarget shard = shardIt.nextOrNull();
+            assertNotNull(shard);
+            assertEquals(shard.getShardId().getIndexName(), "foo");
+            assertThat(shard.getNodeId(), equalTo("node1"));
+            shard = shardIt.nextOrNull();
+            assertNotNull(shard);
+            assertEquals(shard.getShardId().getIndexName(), "foo");
+            assertThat(shard.getNodeId(), equalTo("node2"));
+            assertNull(shardIt.nextOrNull());
+        }
+        {
+            SearchShardIterator shardIt = iteratorList.get(1);
+            assertTrue(shardIt.prefiltered());
+            assertTrue(shardIt.skip());
+            assertThat(shardIt.shardId(), equalTo(new ShardId("foo", "foo_id", 1)));
+            assertArrayEquals(new String[] { "some_alias_for_foo", "some_other_foo_alias" }, shardIt.getOriginalIndices().indices());
+            assertEquals("test_cluster_1", shardIt.getClusterAlias());
+            assertEquals("foo", shardIt.shardId().getIndexName());
+            SearchShardTarget shard = shardIt.nextOrNull();
+            assertNotNull(shard);
+            assertEquals(shard.getShardId().getIndexName(), "foo");
+            assertThat(shard.getNodeId(), equalTo("node2"));
+            shard = shardIt.nextOrNull();
+            assertNotNull(shard);
+            assertEquals(shard.getShardId().getIndexName(), "foo");
+            assertThat(shard.getNodeId(), equalTo("node1"));
+            assertNull(shardIt.nextOrNull());
+        }
+        {
+            SearchShardIterator shardIt = iteratorList.get(2);
+            assertTrue(shardIt.prefiltered());
+            assertFalse(shardIt.skip());
+            assertThat(shardIt.shardId(), equalTo(new ShardId("bar", "bar_id", 0)));
+            assertArrayEquals(new String[] { "bar" }, shardIt.getOriginalIndices().indices());
+            assertEquals("test_cluster_1", shardIt.getClusterAlias());
+            SearchShardTarget shard = shardIt.nextOrNull();
+            assertNotNull(shard);
+            assertEquals(shard.getShardId().getIndexName(), "bar");
+            assertThat(shard.getNodeId(), equalTo("node2"));
+            shard = shardIt.nextOrNull();
+            assertNotNull(shard);
+            assertEquals(shard.getShardId().getIndexName(), "bar");
+            assertThat(shard.getNodeId(), equalTo("node1"));
+            assertNull(shardIt.nextOrNull());
+        }
+        {
+            SearchShardIterator shardIt = iteratorList.get(3);
+            assertFalse(shardIt.prefiltered());
+            assertFalse(shardIt.skip());
+            assertArrayEquals(new String[] { "some_alias_for_xyz" }, shardIt.getOriginalIndices().indices());
+            assertThat(shardIt.shardId(), equalTo(new ShardId("xyz", "xyz_id", 0)));
+            assertEquals("test_cluster_2", shardIt.getClusterAlias());
+            SearchShardTarget shard = shardIt.nextOrNull();
+            assertNotNull(shard);
+            assertEquals(shard.getShardId().getIndexName(), "xyz");
+            assertThat(shard.getNodeId(), equalTo("node3"));
+            assertNull(shardIt.nextOrNull());
         }
     }
 
@@ -827,31 +833,34 @@ public class TransportSearchActionTests extends ESTestCase {
         ) {
             service.start();
             service.acceptIncomingRequests();
+
             RemoteClusterService remoteClusterService = service.getRemoteClusterService();
             {
                 final CountDownLatch latch = new CountDownLatch(1);
-                AtomicReference<Map<String, ClusterSearchShardsResponse>> response = new AtomicReference<>();
+                AtomicReference<Map<String, SearchShardsResponse>> response = new AtomicReference<>();
                 AtomicInteger skippedClusters = new AtomicInteger();
                 TransportSearchAction.collectSearchShards(
                     IndicesOptions.lenientExpandOpen(),
                     null,
                     null,
+                    new MatchAllQueryBuilder(),
+                    randomBoolean(),
+                    null,
                     skippedClusters,
                     remoteIndicesByCluster,
-                    remoteClusterService,
-                    threadPool,
+                    service,
                     new LatchedActionListener<>(ActionListener.wrap(response::set, e -> fail("no failures expected")), latch)
                 );
                 awaitLatch(latch, 5, TimeUnit.SECONDS);
                 assertEquals(0, skippedClusters.get());
                 assertNotNull(response.get());
-                Map<String, ClusterSearchShardsResponse> map = response.get();
+                Map<String, SearchShardsResponse> map = response.get();
                 assertEquals(numClusters, map.size());
                 for (int i = 0; i < numClusters; i++) {
                     String clusterAlias = "remote" + i;
                     assertTrue(map.containsKey(clusterAlias));
-                    ClusterSearchShardsResponse shardsResponse = map.get(clusterAlias);
-                    assertEquals(1, shardsResponse.getNodes().length);
+                    SearchShardsResponse shardsResponse = map.get(clusterAlias);
+                    assertThat(shardsResponse.getNodes(), hasSize(1));
                 }
             }
             {
@@ -862,10 +871,12 @@ public class TransportSearchActionTests extends ESTestCase {
                     IndicesOptions.lenientExpandOpen(),
                     "index_not_found",
                     null,
+                    new MatchAllQueryBuilder(),
+                    randomBoolean(),
+                    null,
                     skippedClusters,
                     remoteIndicesByCluster,
-                    remoteClusterService,
-                    threadPool,
+                    service,
                     new LatchedActionListener<>(ActionListener.wrap(r -> fail("no response expected"), failure::set), latch)
                 );
                 awaitLatch(latch, 5, TimeUnit.SECONDS);
@@ -907,10 +918,12 @@ public class TransportSearchActionTests extends ESTestCase {
                     IndicesOptions.lenientExpandOpen(),
                     null,
                     null,
+                    new MatchAllQueryBuilder(),
+                    randomBoolean(),
+                    null,
                     skippedClusters,
                     remoteIndicesByCluster,
-                    remoteClusterService,
-                    threadPool,
+                    service,
                     new LatchedActionListener<>(ActionListener.wrap(r -> fail("no response expected"), failure::set), latch)
                 );
                 awaitLatch(latch, 5, TimeUnit.SECONDS);
@@ -929,20 +942,22 @@ public class TransportSearchActionTests extends ESTestCase {
             {
                 final CountDownLatch latch = new CountDownLatch(1);
                 AtomicInteger skippedClusters = new AtomicInteger(0);
-                AtomicReference<Map<String, ClusterSearchShardsResponse>> response = new AtomicReference<>();
+                AtomicReference<Map<String, SearchShardsResponse>> response = new AtomicReference<>();
                 TransportSearchAction.collectSearchShards(
                     IndicesOptions.lenientExpandOpen(),
                     null,
                     null,
+                    new MatchAllQueryBuilder(),
+                    randomBoolean(),
+                    null,
                     skippedClusters,
                     remoteIndicesByCluster,
-                    remoteClusterService,
-                    threadPool,
+                    service,
                     new LatchedActionListener<>(ActionListener.wrap(response::set, e -> fail("no failures expected")), latch)
                 );
                 awaitLatch(latch, 5, TimeUnit.SECONDS);
                 assertNotNull(response.get());
-                Map<String, ClusterSearchShardsResponse> map = response.get();
+                Map<String, SearchShardsResponse> map = response.get();
                 assertEquals(numClusters - disconnectedNodesIndices.size(), map.size());
                 assertEquals(skippedClusters.get(), disconnectedNodesIndices.size());
                 for (int i = 0; i < numClusters; i++) {
@@ -972,21 +987,23 @@ public class TransportSearchActionTests extends ESTestCase {
             assertBusy(() -> {
                 final CountDownLatch latch = new CountDownLatch(1);
                 AtomicInteger skippedClusters = new AtomicInteger(0);
-                AtomicReference<Map<String, ClusterSearchShardsResponse>> response = new AtomicReference<>();
+                AtomicReference<Map<String, SearchShardsResponse>> response = new AtomicReference<>();
                 TransportSearchAction.collectSearchShards(
                     IndicesOptions.lenientExpandOpen(),
                     null,
                     null,
+                    new MatchAllQueryBuilder(),
+                    randomBoolean(),
+                    null,
                     skippedClusters,
                     remoteIndicesByCluster,
-                    remoteClusterService,
-                    threadPool,
+                    service,
                     new LatchedActionListener<>(ActionListener.wrap(response::set, e -> fail("no failures expected")), latch)
                 );
                 awaitLatch(latch, 5, TimeUnit.SECONDS);
                 assertEquals(0, skippedClusters.get());
                 assertNotNull(response.get());
-                Map<String, ClusterSearchShardsResponse> map = response.get();
+                Map<String, SearchShardsResponse> map = response.get();
                 assertEquals(numClusters, map.size());
                 for (int i = 0; i < numClusters; i++) {
                     String clusterAlias = "remote" + i;

+ 38 - 48
server/src/test/java/org/elasticsearch/transport/RemoteClusterAwareClientTests.java

@@ -11,27 +11,28 @@ package org.elasticsearch.transport;
 import org.elasticsearch.TransportVersion;
 import org.elasticsearch.Version;
 import org.elasticsearch.action.ActionListener;
-import org.elasticsearch.action.LatchedActionListener;
-import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsRequest;
-import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsResponse;
-import org.elasticsearch.action.search.SearchRequest;
+import org.elasticsearch.action.search.SearchShardsAction;
+import org.elasticsearch.action.search.SearchShardsRequest;
+import org.elasticsearch.action.search.SearchShardsResponse;
+import org.elasticsearch.action.support.IndicesOptions;
+import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
+import org.elasticsearch.index.query.MatchAllQueryBuilder;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.transport.MockTransportService;
 import org.elasticsearch.threadpool.TestThreadPool;
 import org.elasticsearch.threadpool.ThreadPool;
 
-import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
 import java.util.concurrent.CopyOnWriteArrayList;
-import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.TimeUnit;
-import java.util.concurrent.atomic.AtomicReference;
+
+import static org.hamcrest.Matchers.equalTo;
 
 public class RemoteClusterAwareClientTests extends ESTestCase {
 
@@ -79,25 +80,17 @@ public class RemoteClusterAwareClientTests extends ESTestCase {
                         randomBoolean()
                     )
                 ) {
-                    SearchRequest request = new SearchRequest("test-index");
-                    CountDownLatch responseLatch = new CountDownLatch(1);
-                    AtomicReference<ClusterSearchShardsResponse> reference = new AtomicReference<>();
-                    ClusterSearchShardsRequest searchShardsRequest = new ClusterSearchShardsRequest("test-index").indicesOptions(
-                        request.indicesOptions()
-                    ).local(true).preference(request.preference()).routing(request.routing());
-                    client.admin()
-                        .cluster()
-                        .searchShards(
-                            searchShardsRequest,
-                            new LatchedActionListener<>(
-                                ActionListener.wrap(reference::set, e -> fail("no failures expected")),
-                                responseLatch
-                            )
-                        );
-                    responseLatch.await();
-                    assertNotNull(reference.get());
-                    ClusterSearchShardsResponse clusterSearchShardsResponse = reference.get();
-                    assertEquals(knownNodes, Arrays.asList(clusterSearchShardsResponse.getNodes()));
+                    SearchShardsRequest searchShardsRequest = new SearchShardsRequest(
+                        new String[] { "test-index" },
+                        IndicesOptions.strictExpandOpen(),
+                        new MatchAllQueryBuilder(),
+                        null,
+                        null,
+                        randomBoolean(),
+                        null
+                    );
+                    var searchShardsResponse = client.execute(SearchShardsAction.INSTANCE, searchShardsRequest).actionGet();
+                    assertThat(searchShardsResponse.getNodes(), equalTo(knownNodes));
                 }
             }
         }
@@ -135,35 +128,32 @@ public class RemoteClusterAwareClientTests extends ESTestCase {
                         randomBoolean()
                     )
                 ) {
-                    SearchRequest request = new SearchRequest("test-index");
                     int numThreads = 10;
                     ExecutorService executorService = Executors.newFixedThreadPool(numThreads);
                     for (int i = 0; i < numThreads; i++) {
                         final String threadId = Integer.toString(i);
+                        PlainActionFuture<SearchShardsResponse> future = new PlainActionFuture<>();
                         executorService.submit(() -> {
                             ThreadContext threadContext = seedTransport.threadPool.getThreadContext();
                             threadContext.putHeader("threadId", threadId);
-                            AtomicReference<ClusterSearchShardsResponse> reference = new AtomicReference<>();
-                            final ClusterSearchShardsRequest searchShardsRequest = new ClusterSearchShardsRequest("test-index")
-                                .indicesOptions(request.indicesOptions())
-                                .local(true)
-                                .preference(request.preference())
-                                .routing(request.routing());
-                            CountDownLatch responseLatch = new CountDownLatch(1);
-                            client.admin()
-                                .cluster()
-                                .searchShards(searchShardsRequest, new LatchedActionListener<>(ActionListener.wrap(resp -> {
-                                    reference.set(resp);
-                                    assertEquals(threadId, seedTransport.threadPool.getThreadContext().getHeader("threadId"));
-                                }, e -> fail("no failures expected")), responseLatch));
-                            try {
-                                responseLatch.await();
-                            } catch (InterruptedException e) {
-                                throw new RuntimeException(e);
-                            }
-                            assertNotNull(reference.get());
-                            ClusterSearchShardsResponse clusterSearchShardsResponse = reference.get();
-                            assertEquals(knownNodes, Arrays.asList(clusterSearchShardsResponse.getNodes()));
+                            var searchShardsRequest = new SearchShardsRequest(
+                                new String[] { "test-index" },
+                                IndicesOptions.strictExpandOpen(),
+                                new MatchAllQueryBuilder(),
+                                null,
+                                null,
+                                randomBoolean(),
+                                null
+                            );
+                            client.execute(
+                                SearchShardsAction.INSTANCE,
+                                searchShardsRequest,
+                                ActionListener.runBefore(
+                                    future,
+                                    () -> assertThat(seedTransport.threadPool.getThreadContext().getHeader("threadId"), equalTo(threadId))
+                                )
+                            );
+                            assertThat(future.actionGet().getNodes(), equalTo(knownNodes));
                         });
                     }
                     ThreadPool.terminate(executorService, 5, TimeUnit.SECONDS);

+ 6 - 13
server/src/test/java/org/elasticsearch/transport/RemoteClusterConnectionTests.java

@@ -13,16 +13,15 @@ import org.elasticsearch.TransportVersion;
 import org.elasticsearch.Version;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.admin.cluster.remote.RemoteClusterNodesAction;
-import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsAction;
-import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsGroup;
-import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsRequest;
-import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsResponse;
 import org.elasticsearch.action.admin.cluster.state.ClusterStateAction;
 import org.elasticsearch.action.admin.cluster.state.ClusterStateRequest;
 import org.elasticsearch.action.admin.cluster.state.ClusterStateResponse;
 import org.elasticsearch.action.search.SearchAction;
 import org.elasticsearch.action.search.SearchRequest;
 import org.elasticsearch.action.search.SearchResponse;
+import org.elasticsearch.action.search.SearchShardsAction;
+import org.elasticsearch.action.search.SearchShardsRequest;
+import org.elasticsearch.action.search.SearchShardsResponse;
 import org.elasticsearch.action.search.ShardSearchFailure;
 import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.cluster.ClusterName;
@@ -125,20 +124,14 @@ public class RemoteClusterConnectionTests extends ESTestCase {
         MockTransportService newService = MockTransportService.createNewService(s, version, transportVersion, threadPool, null);
         try {
             newService.registerRequestHandler(
-                ClusterSearchShardsAction.NAME,
+                SearchShardsAction.NAME,
                 ThreadPool.Names.SAME,
-                ClusterSearchShardsRequest::new,
+                SearchShardsRequest::new,
                 (request, channel, task) -> {
                     if ("index_not_found".equals(request.preference())) {
                         channel.sendResponse(new IndexNotFoundException("index"));
                     } else {
-                        channel.sendResponse(
-                            new ClusterSearchShardsResponse(
-                                new ClusterSearchShardsGroup[0],
-                                knownNodes.toArray(new DiscoveryNode[0]),
-                                Collections.emptyMap()
-                            )
-                        );
+                        channel.sendResponse(new SearchShardsResponse(List.of(), knownNodes, Collections.emptyMap()));
                     }
                 }
             );

+ 5 - 1
test/framework/src/main/java/org/elasticsearch/test/transport/MockTransportService.java

@@ -24,6 +24,7 @@ import org.elasticsearch.common.settings.Setting;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.transport.BoundTransportAddress;
 import org.elasticsearch.common.transport.TransportAddress;
+import org.elasticsearch.common.util.CollectionUtils;
 import org.elasticsearch.common.util.MockPageCacheRecycler;
 import org.elasticsearch.common.util.concurrent.AbstractRunnable;
 import org.elasticsearch.common.util.concurrent.RunOnce;
@@ -33,6 +34,7 @@ import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.indices.breaker.NoneCircuitBreakerService;
 import org.elasticsearch.node.Node;
 import org.elasticsearch.plugins.Plugin;
+import org.elasticsearch.search.SearchModule;
 import org.elasticsearch.tasks.TaskManager;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.tasks.MockTaskManager;
@@ -123,7 +125,9 @@ public class MockTransportService extends TransportService {
 
     public static TcpTransport newMockTransport(Settings settings, TransportVersion version, ThreadPool threadPool) {
         settings = Settings.builder().put(TransportSettings.PORT.getKey(), ESTestCase.getPortRange()).put(settings).build();
-        NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(ClusterModule.getNamedWriteables());
+        SearchModule searchModule = new SearchModule(Settings.EMPTY, List.of());
+        var namedWriteables = CollectionUtils.concatLists(searchModule.getNamedWriteables(), ClusterModule.getNamedWriteables());
+        NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(namedWriteables);
         return new Netty4Transport(
             settings,
             version,

+ 2 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/store/ReservedRolesStoreTests.java

@@ -54,6 +54,7 @@ import org.elasticsearch.action.ingest.PutPipelineAction;
 import org.elasticsearch.action.ingest.SimulatePipelineAction;
 import org.elasticsearch.action.search.MultiSearchAction;
 import org.elasticsearch.action.search.SearchAction;
+import org.elasticsearch.action.search.SearchShardsAction;
 import org.elasticsearch.action.support.WriteRequest;
 import org.elasticsearch.action.update.UpdateAction;
 import org.elasticsearch.cluster.metadata.AliasMetadata;
@@ -3194,6 +3195,7 @@ public class ReservedRolesStoreTests extends ESTestCase {
             GetFieldMappingsAction.NAME + "*",
             GetMappingsAction.NAME,
             ClusterSearchShardsAction.NAME,
+            SearchShardsAction.NAME,
             ValidateQueryAction.NAME + "*",
             GetSettingsAction.NAME,
             ExplainLifecycleAction.NAME,

+ 8 - 12
x-pack/plugin/security/qa/security-trial/src/javaRestTest/java/org/elasticsearch/xpack/security/crossclusteraccess/CrossClusterAccessHeadersForCcsRestIT.java

@@ -14,22 +14,20 @@ import org.apache.lucene.search.TotalHits;
 import org.elasticsearch.TransportVersion;
 import org.elasticsearch.Version;
 import org.elasticsearch.action.admin.cluster.remote.RemoteClusterNodesAction;
-import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsAction;
-import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsGroup;
-import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsRequest;
-import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsResponse;
 import org.elasticsearch.action.admin.cluster.state.ClusterStateAction;
 import org.elasticsearch.action.admin.cluster.state.ClusterStateRequest;
 import org.elasticsearch.action.admin.cluster.state.ClusterStateResponse;
 import org.elasticsearch.action.search.SearchAction;
 import org.elasticsearch.action.search.SearchRequest;
 import org.elasticsearch.action.search.SearchResponse;
+import org.elasticsearch.action.search.SearchShardsAction;
+import org.elasticsearch.action.search.SearchShardsRequest;
+import org.elasticsearch.action.search.SearchShardsResponse;
 import org.elasticsearch.action.search.ShardSearchFailure;
 import org.elasticsearch.client.Request;
 import org.elasticsearch.client.Response;
 import org.elasticsearch.cluster.ClusterName;
 import org.elasticsearch.cluster.ClusterState;
-import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.settings.SecureString;
 import org.elasticsearch.common.settings.Settings;
@@ -965,7 +963,7 @@ public class CrossClusterAccessHeadersForCcsRestIT extends SecurityOnTrialLicens
         if (minimizeRoundtrips) {
             expectedActions.add(SearchAction.NAME);
         } else {
-            expectedActions.add(ClusterSearchShardsAction.NAME);
+            expectedActions.add(SearchShardsAction.NAME);
         }
         if (false == useProxyMode) {
             expectedActions.add(RemoteClusterNodesAction.NAME);
@@ -994,7 +992,7 @@ public class CrossClusterAccessHeadersForCcsRestIT extends SecurityOnTrialLicens
                     );
                     assertThat(actualCrossClusterAccessSubjectInfo, equalTo(expectedCrossClusterAccessSubjectInfo));
                 }
-                case SearchAction.NAME, ClusterSearchShardsAction.NAME -> {
+                case SearchAction.NAME, SearchShardsAction.NAME -> {
                     assertContainsHeadersExpectedForCrossClusterAccess(actual.headers());
                     assertContainsCrossClusterAccessCredentialsHeader(encodedCredential, actual);
                     final var actualCrossClusterAccessSubjectInfo = CrossClusterAccessSubjectInfo.decode(
@@ -1060,16 +1058,14 @@ public class CrossClusterAccessHeadersForCcsRestIT extends SecurityOnTrialLicens
                 }
             );
             service.registerRequestHandler(
-                ClusterSearchShardsAction.NAME,
+                SearchShardsAction.NAME,
                 ThreadPool.Names.SAME,
-                ClusterSearchShardsRequest::new,
+                SearchShardsRequest::new,
                 (request, channel, task) -> {
                     capturedHeaders.add(
                         new CapturedActionWithHeaders(task.getAction(), Map.copyOf(threadPool.getThreadContext().getHeaders()))
                     );
-                    channel.sendResponse(
-                        new ClusterSearchShardsResponse(new ClusterSearchShardsGroup[0], new DiscoveryNode[0], Collections.emptyMap())
-                    );
+                    channel.sendResponse(new SearchShardsResponse(List.of(), List.of(), Collections.emptyMap()));
                 }
             );
             service.registerRequestHandler(SearchAction.NAME, ThreadPool.Names.SAME, SearchRequest::new, (request, channel, task) -> {