Przeglądaj źródła

Add Cross Cluster Search support for scroll searches (#25094)

To complete the cross cluster search capabilities for all search types and
function this change adds cross cluster search support for scroll searches.
Simon Willnauer 8 lat temu
rodzic
commit
bc7ec68e76

+ 16 - 14
core/src/main/java/org/elasticsearch/action/search/ClearScrollController.java

@@ -103,21 +103,23 @@ final class ClearScrollController implements Runnable {
     }
 
     void cleanScrollIds(List<ScrollIdForNode> parsedScrollIds) {
-        for (ScrollIdForNode target : parsedScrollIds) {
-            final DiscoveryNode node = nodes.get(target.getNode());
-            if (node == null) {
-                onFreedContext(false);
-            } else {
-                try {
-                    Transport.Connection connection = searchTransportService.getConnection(null, node);
-                    searchTransportService.sendFreeContext(connection, target.getScrollId(),
-                        ActionListener.wrap(freed -> onFreedContext(freed.isFreed()),
-                            e -> onFailedFreedContext(e, node)));
-                } catch (Exception e) {
-                    onFailedFreedContext(e, node);
+        SearchScrollAsyncAction.collectNodesAndRun(parsedScrollIds, nodes, searchTransportService, ActionListener.wrap(
+            lookup -> {
+                for (ScrollIdForNode target : parsedScrollIds) {
+                    final DiscoveryNode node = lookup.apply(target.getClusterAlias(), target.getNode());
+                    if (node == null) {
+                        onFreedContext(false);
+                    } else {
+                        try {
+                            Transport.Connection connection = searchTransportService.getConnection(target.getClusterAlias(), node);
+                            searchTransportService.sendFreeContext(connection, target.getScrollId(),
+                                ActionListener.wrap(freed -> onFreedContext(freed.isFreed()), e -> onFailedFreedContext(e, node)));
+                        } catch (Exception e) {
+                            onFailedFreedContext(e, node);
+                        }
+                    }
                 }
-            }
-        }
+            }, listener::onFailure));
     }
 
     private void onFreedContext(boolean freed) {

+ 18 - 1
core/src/main/java/org/elasticsearch/action/search/ScrollIdForNode.java

@@ -19,12 +19,16 @@
 
 package org.elasticsearch.action.search;
 
+import org.elasticsearch.common.inject.internal.Nullable;
+
 class ScrollIdForNode {
     private final String node;
     private final long scrollId;
+    private final String clusterAlias;
 
-    ScrollIdForNode(String node, long scrollId) {
+    ScrollIdForNode(@Nullable String clusterAlias, String node, long scrollId) {
         this.node = node;
+        this.clusterAlias = clusterAlias;
         this.scrollId = scrollId;
     }
 
@@ -32,7 +36,20 @@ class ScrollIdForNode {
         return node;
     }
 
+    public String getClusterAlias() {
+        return clusterAlias;
+    }
+
     public long getScrollId() {
         return scrollId;
     }
+
+    @Override
+    public String toString() {
+        return "ScrollIdForNode{" +
+            "node='" + node + '\'' +
+            ", scrollId=" + scrollId +
+            ", clusterAlias='" + clusterAlias + '\'' +
+            '}';
+    }
 }

+ 107 - 47
core/src/main/java/org/elasticsearch/action/search/SearchScrollAsyncAction.java

@@ -32,11 +32,17 @@ import org.elasticsearch.search.SearchPhaseResult;
 import org.elasticsearch.search.SearchShardTarget;
 import org.elasticsearch.search.internal.InternalScrollSearchRequest;
 import org.elasticsearch.search.internal.InternalSearchResponse;
+import org.elasticsearch.transport.RemoteClusterService;
+import org.elasticsearch.transport.Transport;
 
 import java.io.IOException;
 import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashSet;
 import java.util.List;
+import java.util.Set;
 import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.BiFunction;
 
 import static org.elasticsearch.action.search.TransportSearchHelper.internalScrollSearchRequest;
 
@@ -67,13 +73,15 @@ abstract class SearchScrollAsyncAction<T extends SearchPhaseResult> implements R
     protected final DiscoveryNodes nodes;
     protected final SearchPhaseController searchPhaseController;
     protected final SearchScrollRequest request;
+    protected final SearchTransportService searchTransportService;
     private final long startTime;
     private final List<ShardSearchFailure> shardFailures = new ArrayList<>();
     private final AtomicInteger successfulOps;
 
     protected SearchScrollAsyncAction(ParsedScrollId scrollId, Logger logger, DiscoveryNodes nodes,
                                       ActionListener<SearchResponse> listener, SearchPhaseController searchPhaseController,
-                                      SearchScrollRequest request) {
+                                      SearchScrollRequest request,
+                                      SearchTransportService searchTransportService) {
         this.startTime = System.currentTimeMillis();
         this.scrollId = scrollId;
         this.successfulOps = new AtomicInteger(scrollId.getContext().length);
@@ -82,6 +90,7 @@ abstract class SearchScrollAsyncAction<T extends SearchPhaseResult> implements R
         this.nodes = nodes;
         this.searchPhaseController = searchPhaseController;
         this.request = request;
+        this.searchTransportService = searchTransportService;
     }
 
     /**
@@ -97,57 +106,104 @@ abstract class SearchScrollAsyncAction<T extends SearchPhaseResult> implements R
         final ScrollIdForNode[] context = scrollId.getContext();
         if (context.length == 0) {
             listener.onFailure(new SearchPhaseExecutionException("query", "no nodes to search on", ShardSearchFailure.EMPTY_ARRAY));
-            return;
+        } else {
+            collectNodesAndRun(Arrays.asList(context), nodes, searchTransportService, ActionListener.wrap(lookup -> run(lookup, context),
+                listener::onFailure));
         }
+    }
+
+    /**
+     * This method collects nodes from the remote clusters asynchronously if any of the scroll IDs references a remote cluster.
+     * Otherwise the action listener will be invoked immediately with a function based on the given discovery nodes.
+     */
+    static void collectNodesAndRun(final Iterable<ScrollIdForNode> scrollIds, DiscoveryNodes nodes,
+                                   SearchTransportService searchTransportService,
+                                   ActionListener<BiFunction<String, String, DiscoveryNode>> listener) {
+        Set<String> clusters = new HashSet<>();
+        for (ScrollIdForNode target : scrollIds) {
+            if (target.getClusterAlias() != null) {
+                clusters.add(target.getClusterAlias());
+            }
+        }
+        if (clusters.isEmpty()) { // no remote clusters
+            listener.onResponse((cluster, node) -> nodes.get(node));
+        } else {
+            RemoteClusterService remoteClusterService = searchTransportService.getRemoteClusterService();
+            remoteClusterService.collectNodes(clusters, ActionListener.wrap(nodeFunction -> {
+                final BiFunction<String, String, DiscoveryNode> clusterNodeLookup = (clusterAlias, node) -> {
+                    if (clusterAlias == null) {
+                        return nodes.get(node);
+                    } else {
+                        return nodeFunction.apply(clusterAlias, node);
+                    }
+                };
+                listener.onResponse(clusterNodeLookup);
+            }, listener::onFailure));
+        }
+    }
+
+    private void run(BiFunction<String, String, DiscoveryNode> clusterNodeLookup, final ScrollIdForNode[] context) {
         final CountDown counter = new CountDown(scrollId.getContext().length);
         for (int i = 0; i < context.length; i++) {
             ScrollIdForNode target = context[i];
-            DiscoveryNode node = nodes.get(target.getNode());
             final int shardIndex = i;
-            if (node != null) { // it might happen that a node is going down in-between scrolls...
-                InternalScrollSearchRequest internalRequest = internalScrollSearchRequest(target.getScrollId(), request);
-                // we can't create a SearchShardTarget here since we don't know the index and shard ID we are talking to
-                // we only know the node and the search context ID. Yet, the response will contain the SearchShardTarget
-                // from the target node instead...that's why we pass null here
-                SearchActionListener<T> searchActionListener = new SearchActionListener<T>(null, shardIndex) {
-
-                    @Override
-                    protected void setSearchShardTarget(T response) {
-                        // don't do this - it's part of the response...
-                        assert response.getSearchShardTarget() != null : "search shard target must not be null";
+            final Transport.Connection connection;
+            try {
+                DiscoveryNode node = clusterNodeLookup.apply(target.getClusterAlias(), target.getNode());
+                if (node == null) {
+                    throw  new IllegalStateException("node [" + target.getNode() + "] is not available");
+                }
+                connection = getConnection(target.getClusterAlias(), node);
+            } catch (Exception ex) {
+                onShardFailure("query", counter, target.getScrollId(),
+                    ex, null, () -> SearchScrollAsyncAction.this.moveToNextPhase(clusterNodeLookup));
+                continue;
+            }
+            final InternalScrollSearchRequest internalRequest = internalScrollSearchRequest(target.getScrollId(), request);
+            // we can't create a SearchShardTarget here since we don't know the index and shard ID we are talking to
+            // we only know the node and the search context ID. Yet, the response will contain the SearchShardTarget
+            // from the target node instead...that's why we pass null here
+            SearchActionListener<T> searchActionListener = new SearchActionListener<T>(null, shardIndex) {
+
+                @Override
+                protected void setSearchShardTarget(T response) {
+                    // don't do this - it's part of the response...
+                    assert response.getSearchShardTarget() != null : "search shard target must not be null";
+                    if (target.getClusterAlias() != null) {
+                        // re-create the search target and add the cluster alias if there is any,
+                        // we need this down the road for subseq. phases
+                        SearchShardTarget searchShardTarget = response.getSearchShardTarget();
+                        response.setSearchShardTarget(new SearchShardTarget(searchShardTarget.getNodeId(), searchShardTarget.getShardId(),
+                            target.getClusterAlias(), null));
                     }
+                }
 
-                    @Override
-                    protected void innerOnResponse(T result) {
-                        assert shardIndex == result.getShardIndex() : "shard index mismatch: " + shardIndex + " but got: "
-                            + result.getShardIndex();
-                        onFirstPhaseResult(shardIndex, result);
-                        if (counter.countDown()) {
-                            SearchPhase phase = moveToNextPhase();
-                            try {
-                                phase.run();
-                            } catch (Exception e) {
-                                // we need to fail the entire request here - the entire phase just blew up
-                                // don't call onShardFailure or onFailure here since otherwise we'd countDown the counter
-                                // again which would result in an exception
-                                listener.onFailure(new SearchPhaseExecutionException(phase.getName(), "Phase failed", e,
-                                    ShardSearchFailure.EMPTY_ARRAY));
-                            }
+                @Override
+                protected void innerOnResponse(T result) {
+                    assert shardIndex == result.getShardIndex() : "shard index mismatch: " + shardIndex + " but got: "
+                        + result.getShardIndex();
+                    onFirstPhaseResult(shardIndex, result);
+                    if (counter.countDown()) {
+                        SearchPhase phase = moveToNextPhase(clusterNodeLookup);
+                        try {
+                            phase.run();
+                        } catch (Exception e) {
+                            // we need to fail the entire request here - the entire phase just blew up
+                            // don't call onShardFailure or onFailure here since otherwise we'd countDown the counter
+                            // again which would result in an exception
+                            listener.onFailure(new SearchPhaseExecutionException(phase.getName(), "Phase failed", e,
+                                ShardSearchFailure.EMPTY_ARRAY));
                         }
                     }
+                }
 
-                    @Override
-                    public void onFailure(Exception t) {
-                        onShardFailure("query", shardIndex, counter, target.getScrollId(), t, null,
-                            SearchScrollAsyncAction.this::moveToNextPhase);
-                    }
-                };
-                executeInitialPhase(node, internalRequest, searchActionListener);
-            } else { // the node is not available we treat this as a shard failure here
-                onShardFailure("query", shardIndex, counter, target.getScrollId(),
-                    new IllegalStateException("node [" + target.getNode() + "] is not available"), null,
-                    SearchScrollAsyncAction.this::moveToNextPhase);
-            }
+                @Override
+                public void onFailure(Exception t) {
+                    onShardFailure("query", counter, target.getScrollId(), t, null,
+                        () -> SearchScrollAsyncAction.this.moveToNextPhase(clusterNodeLookup));
+                }
+            };
+            executeInitialPhase(connection, internalRequest, searchActionListener);
         }
     }
 
@@ -164,10 +220,10 @@ abstract class SearchScrollAsyncAction<T extends SearchPhaseResult> implements R
         shardFailures.add(failure);
     }
 
-    protected abstract void executeInitialPhase(DiscoveryNode node, InternalScrollSearchRequest internalRequest,
+    protected abstract void executeInitialPhase(Transport.Connection connection, InternalScrollSearchRequest internalRequest,
                                                 SearchActionListener<T> searchActionListener);
 
-    protected abstract SearchPhase moveToNextPhase();
+    protected abstract SearchPhase moveToNextPhase(BiFunction<String, String, DiscoveryNode> clusterNodeLookup);
 
     protected abstract void onFirstPhaseResult(int shardId, T result);
 
@@ -199,9 +255,9 @@ abstract class SearchScrollAsyncAction<T extends SearchPhaseResult> implements R
         }
     }
 
-    protected void onShardFailure(String phaseName, final int shardIndex, final CountDown counter, final long searchId, Exception failure,
-                                @Nullable SearchShardTarget searchShardTarget,
-                                Supplier<SearchPhase> nextPhaseSupplier) {
+    protected void onShardFailure(String phaseName, final CountDown counter, final long searchId, Exception failure,
+                                  @Nullable SearchShardTarget searchShardTarget,
+                                  Supplier<SearchPhase> nextPhaseSupplier) {
         if (logger.isDebugEnabled()) {
             logger.debug((Supplier<?>) () -> new ParameterizedMessage("[{}] Failed to execute {} phase", searchId, phaseName), failure);
         }
@@ -223,4 +279,8 @@ abstract class SearchScrollAsyncAction<T extends SearchPhaseResult> implements R
             }
         }
     }
+
+    protected Transport.Connection getConnection(String clusterAlias, DiscoveryNode node) {
+        return searchTransportService.getConnection(clusterAlias, node);
+    }
 }

+ 6 - 18
core/src/main/java/org/elasticsearch/action/search/SearchScrollQueryAndFetchAsyncAction.java

@@ -20,50 +20,38 @@
 package org.elasticsearch.action.search;
 
 import org.apache.logging.log4j.Logger;
-import org.apache.logging.log4j.message.ParameterizedMessage;
-import org.apache.logging.log4j.util.Supplier;
-import org.apache.lucene.search.ScoreDoc;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.cluster.node.DiscoveryNode;
-import org.elasticsearch.cluster.node.DiscoveryNodes;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.util.concurrent.AtomicArray;
-import org.elasticsearch.common.util.concurrent.CountDown;
-import org.elasticsearch.search.SearchPhaseResult;
 import org.elasticsearch.search.fetch.QueryFetchSearchResult;
 import org.elasticsearch.search.fetch.ScrollQueryFetchSearchResult;
 import org.elasticsearch.search.internal.InternalScrollSearchRequest;
-import org.elasticsearch.search.internal.InternalSearchResponse;
-import org.elasticsearch.search.query.ScrollQuerySearchResult;
+import org.elasticsearch.transport.Transport;
 
-import java.util.List;
-import java.util.concurrent.atomic.AtomicInteger;
-
-import static org.elasticsearch.action.search.TransportSearchHelper.internalScrollSearchRequest;
+import java.util.function.BiFunction;
 
 final class SearchScrollQueryAndFetchAsyncAction extends SearchScrollAsyncAction<ScrollQueryFetchSearchResult> {
 
-    private final SearchTransportService searchTransportService;
     private final SearchTask task;
     private final AtomicArray<QueryFetchSearchResult> queryFetchResults;
 
     SearchScrollQueryAndFetchAsyncAction(Logger logger, ClusterService clusterService, SearchTransportService searchTransportService,
                                          SearchPhaseController searchPhaseController, SearchScrollRequest request, SearchTask task,
                                          ParsedScrollId scrollId, ActionListener<SearchResponse> listener) {
-        super(scrollId, logger, clusterService.state().nodes(), listener, searchPhaseController, request);
+        super(scrollId, logger, clusterService.state().nodes(), listener, searchPhaseController, request, searchTransportService);
         this.task = task;
-        this.searchTransportService = searchTransportService;
         this.queryFetchResults = new AtomicArray<>(scrollId.getContext().length);
     }
 
     @Override
-    protected void executeInitialPhase(DiscoveryNode node, InternalScrollSearchRequest internalRequest,
+    protected void executeInitialPhase(Transport.Connection connection, InternalScrollSearchRequest internalRequest,
                                        SearchActionListener<ScrollQueryFetchSearchResult> searchActionListener) {
-        searchTransportService.sendExecuteScrollFetch(node, internalRequest, task, searchActionListener);
+        searchTransportService.sendExecuteScrollFetch(connection, internalRequest, task, searchActionListener);
     }
 
     @Override
-    protected SearchPhase moveToNextPhase() {
+    protected SearchPhase moveToNextPhase(BiFunction<String, String, DiscoveryNode> clusterNodeLookup) {
         return sendResponsePhase(searchPhaseController.reducedQueryPhase(queryFetchResults.asList(), true), queryFetchResults);
     }
 

+ 14 - 11
core/src/main/java/org/elasticsearch/action/search/SearchScrollQueryThenFetchAsyncAction.java

@@ -27,28 +27,28 @@ import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.util.concurrent.AtomicArray;
 import org.elasticsearch.common.util.concurrent.CountDown;
+import org.elasticsearch.search.SearchShardTarget;
 import org.elasticsearch.search.fetch.FetchSearchResult;
 import org.elasticsearch.search.fetch.ShardFetchRequest;
 import org.elasticsearch.search.internal.InternalScrollSearchRequest;
 import org.elasticsearch.search.query.QuerySearchResult;
 import org.elasticsearch.search.query.ScrollQuerySearchResult;
+import org.elasticsearch.transport.Transport;
 
 import java.io.IOException;
-
-import static org.elasticsearch.action.search.TransportSearchHelper.internalScrollSearchRequest;
+import java.util.function.BiFunction;
 
 final class SearchScrollQueryThenFetchAsyncAction extends SearchScrollAsyncAction<ScrollQuerySearchResult> {
 
     private final SearchTask task;
-    private final SearchTransportService searchTransportService;
     private final AtomicArray<FetchSearchResult> fetchResults;
     private final AtomicArray<QuerySearchResult> queryResults;
 
     SearchScrollQueryThenFetchAsyncAction(Logger logger, ClusterService clusterService, SearchTransportService searchTransportService,
                                           SearchPhaseController searchPhaseController, SearchScrollRequest request, SearchTask task,
                                           ParsedScrollId scrollId, ActionListener<SearchResponse> listener) {
-        super(scrollId, logger, clusterService.state().nodes(), listener, searchPhaseController, request);
-        this.searchTransportService = searchTransportService;
+        super(scrollId, logger, clusterService.state().nodes(), listener, searchPhaseController, request,
+            searchTransportService);
         this.task = task;
         this.fetchResults = new AtomicArray<>(scrollId.getContext().length);
         this.queryResults = new AtomicArray<>(scrollId.getContext().length);
@@ -59,13 +59,13 @@ final class SearchScrollQueryThenFetchAsyncAction extends SearchScrollAsyncActio
     }
 
     @Override
-    protected void executeInitialPhase(DiscoveryNode node, InternalScrollSearchRequest internalRequest,
+    protected void executeInitialPhase(Transport.Connection connection, InternalScrollSearchRequest internalRequest,
                                        SearchActionListener<ScrollQuerySearchResult> searchActionListener) {
-        searchTransportService.sendExecuteScrollQuery(node, internalRequest, task, searchActionListener);
+        searchTransportService.sendExecuteScrollQuery(connection, internalRequest, task, searchActionListener);
     }
 
     @Override
-    protected SearchPhase moveToNextPhase() {
+    protected SearchPhase moveToNextPhase(BiFunction<String, String, DiscoveryNode> clusterNodeLookup) {
         return new SearchPhase("fetch") {
             @Override
             public void run() throws IOException {
@@ -89,8 +89,11 @@ final class SearchScrollQueryThenFetchAsyncAction extends SearchScrollAsyncActio
                         ScoreDoc lastEmittedDoc = lastEmittedDocPerShard[index];
                         ShardFetchRequest shardFetchRequest = new ShardFetchRequest(querySearchResult.getRequestId(), docIds,
                             lastEmittedDoc);
-                        DiscoveryNode node = nodes.get(querySearchResult.getSearchShardTarget().getNodeId());
-                        searchTransportService.sendExecuteFetchScroll(node, shardFetchRequest, task,
+                        SearchShardTarget searchShardTarget = querySearchResult.getSearchShardTarget();
+                        DiscoveryNode node = clusterNodeLookup.apply(searchShardTarget.getClusterAlias(), searchShardTarget.getNodeId());
+                        assert node != null : "target node is null in secondary phase";
+                        Transport.Connection connection = getConnection(searchShardTarget.getClusterAlias(), node);
+                        searchTransportService.sendExecuteFetchScroll(connection, shardFetchRequest, task,
                             new SearchActionListener<FetchSearchResult>(querySearchResult.getSearchShardTarget(), index) {
                                 @Override
                                 protected void innerOnResponse(FetchSearchResult response) {
@@ -102,7 +105,7 @@ final class SearchScrollQueryThenFetchAsyncAction extends SearchScrollAsyncActio
 
                                 @Override
                                 public void onFailure(Exception t) {
-                                    onShardFailure(getName(), querySearchResult.getShardIndex(), counter, querySearchResult.getRequestId(),
+                                    onShardFailure(getName(), counter, querySearchResult.getRequestId(),
                                         t, querySearchResult.getSearchShardTarget(),
                                         () -> sendResponsePhase(reducedQueryPhase, fetchResults));
                                 }

+ 6 - 6
core/src/main/java/org/elasticsearch/action/search/SearchTransportService.java

@@ -145,15 +145,15 @@ public class SearchTransportService extends AbstractComponent {
             new ActionListenerResponseHandler<>(listener, QuerySearchResult::new));
     }
 
-    public void sendExecuteScrollQuery(DiscoveryNode node, final InternalScrollSearchRequest request, SearchTask task,
+    public void sendExecuteScrollQuery(Transport.Connection connection, final InternalScrollSearchRequest request, SearchTask task,
                                        final SearchActionListener<ScrollQuerySearchResult> listener) {
-        transportService.sendChildRequest(transportService.getConnection(node), QUERY_SCROLL_ACTION_NAME, request, task,
+        transportService.sendChildRequest(connection, QUERY_SCROLL_ACTION_NAME, request, task,
             new ActionListenerResponseHandler<>(listener, ScrollQuerySearchResult::new));
     }
 
-    public void sendExecuteScrollFetch(DiscoveryNode node, final InternalScrollSearchRequest request, SearchTask task,
+    public void sendExecuteScrollFetch(Transport.Connection connection, final InternalScrollSearchRequest request, SearchTask task,
                                        final SearchActionListener<ScrollQueryFetchSearchResult> listener) {
-        transportService.sendChildRequest(transportService.getConnection(node), QUERY_FETCH_SCROLL_ACTION_NAME, request, task,
+        transportService.sendChildRequest(connection, QUERY_FETCH_SCROLL_ACTION_NAME, request, task,
             new ActionListenerResponseHandler<>(listener, ScrollQueryFetchSearchResult::new));
     }
 
@@ -162,9 +162,9 @@ public class SearchTransportService extends AbstractComponent {
         sendExecuteFetch(connection, FETCH_ID_ACTION_NAME, request, task, listener);
     }
 
-    public void sendExecuteFetchScroll(DiscoveryNode node, final ShardFetchRequest request, SearchTask task,
+    public void sendExecuteFetchScroll(Transport.Connection connection, final ShardFetchRequest request, SearchTask task,
                                        final SearchActionListener<FetchSearchResult> listener) {
-        sendExecuteFetch(transportService.getConnection(node), FETCH_ID_SCROLL_ACTION_NAME, request, task, listener);
+        sendExecuteFetch(connection, FETCH_ID_SCROLL_ACTION_NAME, request, task, listener);
     }
 
     private void sendExecuteFetch(Transport.Connection connection, String action, final ShardFetchRequest request, SearchTask task,

+ 18 - 2
core/src/main/java/org/elasticsearch/action/search/TransportSearchHelper.java

@@ -23,7 +23,9 @@ import org.apache.lucene.store.ByteArrayDataInput;
 import org.apache.lucene.store.RAMOutputStream;
 import org.elasticsearch.common.util.concurrent.AtomicArray;
 import org.elasticsearch.search.SearchPhaseResult;
+import org.elasticsearch.search.SearchShardTarget;
 import org.elasticsearch.search.internal.InternalScrollSearchRequest;
+import org.elasticsearch.transport.RemoteClusterAware;
 
 import java.io.IOException;
 import java.util.Base64;
@@ -40,7 +42,13 @@ final class TransportSearchHelper {
             out.writeVInt(searchPhaseResults.asList().size());
             for (SearchPhaseResult searchPhaseResult : searchPhaseResults.asList()) {
                 out.writeLong(searchPhaseResult.getRequestId());
-                out.writeString(searchPhaseResult.getSearchShardTarget().getNodeId());
+                SearchShardTarget searchShardTarget = searchPhaseResult.getSearchShardTarget();
+                if (searchShardTarget.getClusterAlias() != null) {
+                    out.writeString(RemoteClusterAware.buildRemoteIndexName(searchShardTarget.getClusterAlias(),
+                        searchShardTarget.getNodeId()));
+                } else {
+                    out.writeString(searchShardTarget.getNodeId());
+                }
             }
             byte[] bytes = new byte[(int) out.getFilePointer()];
             out.writeTo(bytes, 0);
@@ -57,7 +65,15 @@ final class TransportSearchHelper {
             for (int i = 0; i < context.length; ++i) {
                 long id = in.readLong();
                 String target = in.readString();
-                context[i] = new ScrollIdForNode(target, id);
+                String clusterAlias;
+                final int index = target.indexOf(RemoteClusterAware.REMOTE_CLUSTER_INDEX_SEPARATOR);
+                if (index == -1) {
+                    clusterAlias = null;
+                } else {
+                    clusterAlias = target.substring(0, index);
+                    target = target.substring(index+1);
+                }
+                context[i] = new ScrollIdForNode(clusterAlias, target, id);
             }
             if (in.getPosition() != bytes.length) {
                 throw new IllegalArgumentException("Not all bytes were read");

+ 101 - 49
core/src/test/java/org/elasticsearch/action/search/SearchScrollAsyncActionTests.java

@@ -29,6 +29,7 @@ import org.elasticsearch.search.Scroll;
 import org.elasticsearch.search.SearchShardTarget;
 import org.elasticsearch.search.internal.InternalScrollSearchRequest;
 import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.transport.Transport;
 
 import java.io.IOException;
 import java.util.Arrays;
@@ -36,17 +37,18 @@ import java.util.Collections;
 import java.util.List;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.BiFunction;
 
 public class SearchScrollAsyncActionTests extends ESTestCase {
 
     public void testSendRequestsToNodes() throws InterruptedException {
 
         ParsedScrollId scrollId = getParsedScrollId(
-            new ScrollIdForNode("node1", 1),
-            new ScrollIdForNode("node2", 2),
-            new ScrollIdForNode("node3", 17),
-            new ScrollIdForNode("node1", 0),
-            new ScrollIdForNode("node3", 0));
+            new ScrollIdForNode(null, "node1", 1),
+            new ScrollIdForNode(null, "node2", 2),
+            new ScrollIdForNode(null, "node3", 17),
+            new ScrollIdForNode(null, "node1", 0),
+            new ScrollIdForNode(null, "node3", 0));
         DiscoveryNodes discoveryNodes = DiscoveryNodes.builder()
             .add(new DiscoveryNode("node1", buildNewFakeTransportAddress(), Version.CURRENT))
             .add(new DiscoveryNode("node2", buildNewFakeTransportAddress(), Version.CURRENT))
@@ -58,22 +60,29 @@ public class SearchScrollAsyncActionTests extends ESTestCase {
         CountDownLatch latch = new CountDownLatch(1);
         AtomicInteger movedCounter = new AtomicInteger(0);
         SearchScrollAsyncAction<SearchAsyncActionTests.TestSearchPhaseResult> action =
-            new SearchScrollAsyncAction<SearchAsyncActionTests.TestSearchPhaseResult>(scrollId, logger, discoveryNodes, null, null, request)
+            new SearchScrollAsyncAction<SearchAsyncActionTests.TestSearchPhaseResult>(scrollId, logger, discoveryNodes, dummyListener(),
+                null, request, null)
             {
                 @Override
-                protected void executeInitialPhase(DiscoveryNode node, InternalScrollSearchRequest internalRequest,
+                protected void executeInitialPhase(Transport.Connection connection, InternalScrollSearchRequest internalRequest,
                                                    SearchActionListener<SearchAsyncActionTests.TestSearchPhaseResult> searchActionListener)
                 {
                     new Thread(() -> {
                         SearchAsyncActionTests.TestSearchPhaseResult testSearchPhaseResult =
-                            new SearchAsyncActionTests.TestSearchPhaseResult(internalRequest.id(), node);
-                        testSearchPhaseResult.setSearchShardTarget(new SearchShardTarget(node.getId(), new Index("test", "_na_"), 1));
+                            new SearchAsyncActionTests.TestSearchPhaseResult(internalRequest.id(), connection.getNode());
+                        testSearchPhaseResult.setSearchShardTarget(new SearchShardTarget(connection.getNode().getId(),
+                            new Index("test", "_na_"), 1));
                         searchActionListener.onResponse(testSearchPhaseResult);
                     }).start();
                 }
 
                 @Override
-                protected SearchPhase moveToNextPhase() {
+                protected Transport.Connection getConnection(String clusterAlias, DiscoveryNode node) {
+                    return new SearchAsyncActionTests.MockConnection(node);
+                }
+
+                @Override
+                protected SearchPhase moveToNextPhase(BiFunction<String, String, DiscoveryNode> clusterNodeLookup) {
                     assertEquals(1, movedCounter.incrementAndGet());
                     return new SearchPhase("test") {
                         @Override
@@ -104,11 +113,11 @@ public class SearchScrollAsyncActionTests extends ESTestCase {
     public void testFailNextPhase() throws InterruptedException {
 
         ParsedScrollId scrollId = getParsedScrollId(
-            new ScrollIdForNode("node1", 1),
-            new ScrollIdForNode("node2", 2),
-            new ScrollIdForNode("node3", 17),
-            new ScrollIdForNode("node1", 0),
-            new ScrollIdForNode("node3", 0));
+            new ScrollIdForNode(null, "node1", 1),
+            new ScrollIdForNode(null, "node2", 2),
+            new ScrollIdForNode(null, "node3", 17),
+            new ScrollIdForNode(null, "node1", 0),
+            new ScrollIdForNode(null, "node3", 0));
         DiscoveryNodes discoveryNodes = DiscoveryNodes.builder()
             .add(new DiscoveryNode("node1", buildNewFakeTransportAddress(), Version.CURRENT))
             .add(new DiscoveryNode("node2", buildNewFakeTransportAddress(), Version.CURRENT))
@@ -144,21 +153,27 @@ public class SearchScrollAsyncActionTests extends ESTestCase {
         };
         SearchScrollAsyncAction<SearchAsyncActionTests.TestSearchPhaseResult> action =
             new SearchScrollAsyncAction<SearchAsyncActionTests.TestSearchPhaseResult>(scrollId, logger, discoveryNodes, listener, null,
-                request) {
+                request, null) {
                 @Override
-                protected void executeInitialPhase(DiscoveryNode node, InternalScrollSearchRequest internalRequest,
+                protected void executeInitialPhase(Transport.Connection connection, InternalScrollSearchRequest internalRequest,
                                                    SearchActionListener<SearchAsyncActionTests.TestSearchPhaseResult> searchActionListener)
                 {
                     new Thread(() -> {
                         SearchAsyncActionTests.TestSearchPhaseResult testSearchPhaseResult =
-                            new SearchAsyncActionTests.TestSearchPhaseResult(internalRequest.id(), node);
-                        testSearchPhaseResult.setSearchShardTarget(new SearchShardTarget(node.getId(), new Index("test", "_na_"), 1));
+                            new SearchAsyncActionTests.TestSearchPhaseResult(internalRequest.id(), connection.getNode());
+                        testSearchPhaseResult.setSearchShardTarget(new SearchShardTarget(connection.getNode().getId(),
+                            new Index("test", "_na_"), 1));
                         searchActionListener.onResponse(testSearchPhaseResult);
                     }).start();
                 }
 
                 @Override
-                protected SearchPhase moveToNextPhase() {
+                protected Transport.Connection getConnection(String clusterAlias, DiscoveryNode node) {
+                    return new SearchAsyncActionTests.MockConnection(node);
+                }
+
+                @Override
+                protected SearchPhase moveToNextPhase(BiFunction<String, String, DiscoveryNode> clusterNodeLookup) {
                     assertEquals(1, movedCounter.incrementAndGet());
                     return new SearchPhase("TEST_PHASE") {
                         @Override
@@ -188,11 +203,11 @@ public class SearchScrollAsyncActionTests extends ESTestCase {
 
     public void testNodeNotAvailable() throws InterruptedException {
         ParsedScrollId scrollId = getParsedScrollId(
-            new ScrollIdForNode("node1", 1),
-            new ScrollIdForNode("node2", 2),
-            new ScrollIdForNode("node3", 17),
-            new ScrollIdForNode("node1", 0),
-            new ScrollIdForNode("node3", 0));
+            new ScrollIdForNode(null, "node1", 1),
+            new ScrollIdForNode(null, "node2", 2),
+            new ScrollIdForNode(null, "node3", 17),
+            new ScrollIdForNode(null, "node1", 0),
+            new ScrollIdForNode(null, "node3", 0));
         // node2 is not available
         DiscoveryNodes discoveryNodes = DiscoveryNodes.builder()
             .add(new DiscoveryNode("node1", buildNewFakeTransportAddress(), Version.CURRENT))
@@ -204,23 +219,34 @@ public class SearchScrollAsyncActionTests extends ESTestCase {
         CountDownLatch latch = new CountDownLatch(1);
         AtomicInteger movedCounter = new AtomicInteger(0);
         SearchScrollAsyncAction<SearchAsyncActionTests.TestSearchPhaseResult> action =
-            new SearchScrollAsyncAction<SearchAsyncActionTests.TestSearchPhaseResult>(scrollId, logger, discoveryNodes, null, null, request)
+            new SearchScrollAsyncAction<SearchAsyncActionTests.TestSearchPhaseResult>(scrollId, logger, discoveryNodes, dummyListener()
+                , null, request, null)
             {
                 @Override
-                protected void executeInitialPhase(DiscoveryNode node, InternalScrollSearchRequest internalRequest,
+                protected void executeInitialPhase(Transport.Connection connection, InternalScrollSearchRequest internalRequest,
                                                    SearchActionListener<SearchAsyncActionTests.TestSearchPhaseResult> searchActionListener)
                 {
-                    assertNotEquals("node2 is not available", "node2", node.getId());
+                    try {
+                        assertNotEquals("node2 is not available", "node2", connection.getNode().getId());
+                    } catch (NullPointerException e) {
+                        logger.warn(e);
+                    }
                     new Thread(() -> {
                         SearchAsyncActionTests.TestSearchPhaseResult testSearchPhaseResult =
-                            new SearchAsyncActionTests.TestSearchPhaseResult(internalRequest.id(), node);
-                        testSearchPhaseResult.setSearchShardTarget(new SearchShardTarget(node.getId(), new Index("test", "_na_"), 1));
+                            new SearchAsyncActionTests.TestSearchPhaseResult(internalRequest.id(), connection.getNode());
+                        testSearchPhaseResult.setSearchShardTarget(new SearchShardTarget(connection.getNode().getId(),
+                            new Index("test", "_na_"), 1));
                         searchActionListener.onResponse(testSearchPhaseResult);
                     }).start();
                 }
 
                 @Override
-                protected SearchPhase moveToNextPhase() {
+                protected Transport.Connection getConnection(String clusterAlias, DiscoveryNode node) {
+                    return new SearchAsyncActionTests.MockConnection(node);
+                }
+
+                @Override
+                protected SearchPhase moveToNextPhase(BiFunction<String, String, DiscoveryNode> clusterNodeLookup) {
                     assertEquals(1, movedCounter.incrementAndGet());
                     return new SearchPhase("test") {
                         @Override
@@ -256,11 +282,11 @@ public class SearchScrollAsyncActionTests extends ESTestCase {
 
     public void testShardFailures() throws InterruptedException {
         ParsedScrollId scrollId = getParsedScrollId(
-            new ScrollIdForNode("node1", 1),
-            new ScrollIdForNode("node2", 2),
-            new ScrollIdForNode("node3", 17),
-            new ScrollIdForNode("node1", 0),
-            new ScrollIdForNode("node3", 0));
+            new ScrollIdForNode(null, "node1", 1),
+            new ScrollIdForNode(null, "node2", 2),
+            new ScrollIdForNode(null, "node3", 17),
+            new ScrollIdForNode(null, "node1", 0),
+            new ScrollIdForNode(null, "node3", 0));
         DiscoveryNodes discoveryNodes = DiscoveryNodes.builder()
             .add(new DiscoveryNode("node1", buildNewFakeTransportAddress(), Version.CURRENT))
             .add(new DiscoveryNode("node2", buildNewFakeTransportAddress(), Version.CURRENT))
@@ -272,10 +298,11 @@ public class SearchScrollAsyncActionTests extends ESTestCase {
         CountDownLatch latch = new CountDownLatch(1);
         AtomicInteger movedCounter = new AtomicInteger(0);
         SearchScrollAsyncAction<SearchAsyncActionTests.TestSearchPhaseResult> action =
-            new SearchScrollAsyncAction<SearchAsyncActionTests.TestSearchPhaseResult>(scrollId, logger, discoveryNodes, null, null, request)
+            new SearchScrollAsyncAction<SearchAsyncActionTests.TestSearchPhaseResult>(scrollId, logger, discoveryNodes, dummyListener(),
+                null, request, null)
             {
                 @Override
-                protected void executeInitialPhase(DiscoveryNode node, InternalScrollSearchRequest internalRequest,
+                protected void executeInitialPhase(Transport.Connection connection, InternalScrollSearchRequest internalRequest,
                                                    SearchActionListener<SearchAsyncActionTests.TestSearchPhaseResult> searchActionListener)
                 {
                     new Thread(() -> {
@@ -283,15 +310,21 @@ public class SearchScrollAsyncActionTests extends ESTestCase {
                             searchActionListener.onFailure(new IllegalArgumentException("BOOM on shard"));
                         } else {
                             SearchAsyncActionTests.TestSearchPhaseResult testSearchPhaseResult =
-                                new SearchAsyncActionTests.TestSearchPhaseResult(internalRequest.id(), node);
-                            testSearchPhaseResult.setSearchShardTarget(new SearchShardTarget(node.getId(), new Index("test", "_na_"), 1));
+                                new SearchAsyncActionTests.TestSearchPhaseResult(internalRequest.id(), connection.getNode());
+                            testSearchPhaseResult.setSearchShardTarget(new SearchShardTarget(connection.getNode().getId(),
+                                new Index("test", "_na_"), 1));
                             searchActionListener.onResponse(testSearchPhaseResult);
                         }
                     }).start();
                 }
 
                 @Override
-                protected SearchPhase moveToNextPhase() {
+                protected Transport.Connection getConnection(String clusterAlias, DiscoveryNode node) {
+                    return new SearchAsyncActionTests.MockConnection(node);
+                }
+
+                @Override
+                protected SearchPhase moveToNextPhase(BiFunction<String, String, DiscoveryNode> clusterNodeLookup) {
                     assertEquals(1, movedCounter.incrementAndGet());
                     return new SearchPhase("test") {
                         @Override
@@ -327,11 +360,11 @@ public class SearchScrollAsyncActionTests extends ESTestCase {
 
     public void testAllShardsFailed() throws InterruptedException {
         ParsedScrollId scrollId = getParsedScrollId(
-            new ScrollIdForNode("node1", 1),
-            new ScrollIdForNode("node2", 2),
-            new ScrollIdForNode("node3", 17),
-            new ScrollIdForNode("node1", 0),
-            new ScrollIdForNode("node3", 0));
+            new ScrollIdForNode(null, "node1", 1),
+            new ScrollIdForNode(null, "node2", 2),
+            new ScrollIdForNode(null, "node3", 17),
+            new ScrollIdForNode(null, "node1", 0),
+            new ScrollIdForNode(null, "node3", 0));
         DiscoveryNodes discoveryNodes = DiscoveryNodes.builder()
             .add(new DiscoveryNode("node1", buildNewFakeTransportAddress(), Version.CURRENT))
             .add(new DiscoveryNode("node2", buildNewFakeTransportAddress(), Version.CURRENT))
@@ -366,16 +399,21 @@ public class SearchScrollAsyncActionTests extends ESTestCase {
         };
         SearchScrollAsyncAction<SearchAsyncActionTests.TestSearchPhaseResult> action =
             new SearchScrollAsyncAction<SearchAsyncActionTests.TestSearchPhaseResult>(scrollId, logger, discoveryNodes, listener, null,
-                request) {
+                request, null) {
                 @Override
-                protected void executeInitialPhase(DiscoveryNode node, InternalScrollSearchRequest internalRequest,
+                protected void executeInitialPhase(Transport.Connection connection, InternalScrollSearchRequest internalRequest,
                                                    SearchActionListener<SearchAsyncActionTests.TestSearchPhaseResult> searchActionListener)
                 {
                     new Thread(() -> searchActionListener.onFailure(new IllegalArgumentException("BOOM on shard"))).start();
                 }
 
                 @Override
-                protected SearchPhase moveToNextPhase() {
+                protected Transport.Connection getConnection(String clusterAlias, DiscoveryNode node) {
+                    return new SearchAsyncActionTests.MockConnection(node);
+                }
+
+                @Override
+                protected SearchPhase moveToNextPhase(BiFunction<String, String, DiscoveryNode> clusterNodeLookup) {
                    fail("don't move all shards failed");
                    return null;
                 }
@@ -404,4 +442,18 @@ public class SearchScrollAsyncActionTests extends ESTestCase {
         Collections.shuffle(scrollIdForNodes, random());
         return new ParsedScrollId("", "test", scrollIdForNodes.toArray(new ScrollIdForNode[0]));
     }
+
+    private ActionListener<SearchResponse> dummyListener() {
+        return new ActionListener<SearchResponse>() {
+            @Override
+            public void onResponse(SearchResponse response) {
+                fail("dummy");
+            }
+
+            @Override
+            public void onFailure(Exception e) {
+                throw new AssertionError(e);
+            }
+        };
+    }
 }

+ 64 - 0
core/src/test/java/org/elasticsearch/action/search/TransportSearchHelperTests.java

@@ -0,0 +1,64 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.action.search;
+
+import org.elasticsearch.Version;
+import org.elasticsearch.cluster.node.DiscoveryNode;
+import org.elasticsearch.common.util.concurrent.AtomicArray;
+import org.elasticsearch.index.shard.ShardId;
+import org.elasticsearch.search.SearchPhaseResult;
+import org.elasticsearch.search.SearchShardTarget;
+import org.elasticsearch.test.ESTestCase;
+
+import java.io.IOException;
+
+public class TransportSearchHelperTests extends ESTestCase {
+
+    public void testParseScrollId() throws IOException {
+        AtomicArray<SearchPhaseResult> array = new AtomicArray<>(3);
+        DiscoveryNode node1 = new DiscoveryNode("node_1", buildNewFakeTransportAddress(), Version.CURRENT);
+        DiscoveryNode node2 = new DiscoveryNode("node_2", buildNewFakeTransportAddress(), Version.CURRENT);
+        DiscoveryNode node3 = new DiscoveryNode("node_3", buildNewFakeTransportAddress(), Version.CURRENT);
+        SearchAsyncActionTests.TestSearchPhaseResult testSearchPhaseResult1 = new SearchAsyncActionTests.TestSearchPhaseResult(1, node1);
+        testSearchPhaseResult1.setSearchShardTarget(new SearchShardTarget("node_1", new ShardId("idx", "uuid1", 2), "cluster_x", null));
+        SearchAsyncActionTests.TestSearchPhaseResult testSearchPhaseResult2 = new SearchAsyncActionTests.TestSearchPhaseResult(12, node2);
+        testSearchPhaseResult2.setSearchShardTarget(new SearchShardTarget("node_2", new ShardId("idy", "uuid2", 42), "cluster_y", null));
+        SearchAsyncActionTests.TestSearchPhaseResult testSearchPhaseResult3 = new SearchAsyncActionTests.TestSearchPhaseResult(42, node3);
+        testSearchPhaseResult3.setSearchShardTarget(new SearchShardTarget("node_3", new ShardId("idy", "uuid2", 43), null, null));
+        array.setOnce(0, testSearchPhaseResult1);
+        array.setOnce(1, testSearchPhaseResult2);
+        array.setOnce(2, testSearchPhaseResult3);
+
+
+        String scrollId = TransportSearchHelper.buildScrollId(array);
+        ParsedScrollId parseScrollId = TransportSearchHelper.parseScrollId(scrollId);
+        assertEquals(3, parseScrollId.getContext().length);
+        assertEquals("node_1", parseScrollId.getContext()[0].getNode());
+        assertEquals("cluster_x", parseScrollId.getContext()[0].getClusterAlias());
+        assertEquals(1, parseScrollId.getContext()[0].getScrollId());
+
+        assertEquals("node_2", parseScrollId.getContext()[1].getNode());
+        assertEquals("cluster_y", parseScrollId.getContext()[1].getClusterAlias());
+        assertEquals(12, parseScrollId.getContext()[1].getScrollId());
+
+        assertEquals("node_3", parseScrollId.getContext()[2].getNode());
+        assertNull(parseScrollId.getContext()[2].getClusterAlias());
+        assertEquals(42, parseScrollId.getContext()[2].getScrollId());
+    }
+}

+ 40 - 0
qa/multi-cluster-search/src/test/resources/rest-api-spec/test/multi_cluster/40_scroll.yml

@@ -0,0 +1,40 @@
+---
+"Scroll on the mixed cluster":
+
+  - do:
+      search:
+        index: my_remote_cluster:test_index
+        size: 4
+        scroll: 1m
+        sort: filter_field
+        body:
+          query:
+            match_all: {}
+
+  - set: {_scroll_id: scroll_id}
+  - match: {hits.total:      6    }
+  - length: {hits.hits:      4    }
+  - match: {hits.hits.0._source.filter_field: 0 }
+  - match: {hits.hits.1._source.filter_field: 0 }
+  - match: {hits.hits.2._source.filter_field: 0 }
+  - match: {hits.hits.3._source.filter_field: 0 }
+
+  - do:
+      scroll:
+        body: { "scroll_id": "$scroll_id", "scroll": "1m"}
+
+  - match: {hits.total:      6    }
+  - length: {hits.hits:      2    }
+  - match: {hits.hits.0._source.filter_field: 1 }
+  - match: {hits.hits.1._source.filter_field: 1 }
+  - do:
+      scroll:
+        scroll_id: $scroll_id
+        scroll: 1m
+
+  - match: {hits.total:      6    }
+  - length: {hits.hits:      0    }
+
+  - do:
+      clear_scroll:
+        scroll_id: $scroll_id