Răsfoiți Sursa

Support canceling cross-clusters search requests (#66206)

This commit supports canceling cross-clusters search requests. Several
important changes in this commit:

- Set the parent task for CCS search requests
- Keep track of underlying connections instead of proxy connections
- Assign the parent task for proxy requests
Nhat Nguyen 4 ani în urmă
părinte
comite
273ac15af2

+ 11 - 7
server/src/internalClusterTest/java/org/elasticsearch/action/search/TransportSearchIT.java

@@ -67,6 +67,7 @@ import org.elasticsearch.search.aggregations.support.ValueType;
 import org.elasticsearch.search.builder.SearchSourceBuilder;
 import org.elasticsearch.search.fetch.FetchSubPhase;
 import org.elasticsearch.search.fetch.FetchSubPhaseProcessor;
+import org.elasticsearch.tasks.TaskId;
 import org.elasticsearch.test.ESIntegTestCase;
 
 import java.io.IOException;
@@ -131,9 +132,10 @@ public class TransportSearchIT extends ESIntegTestCase {
         indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL);
         IndexResponse indexResponse = client().index(indexRequest).actionGet();
         assertEquals(RestStatus.CREATED, indexResponse.status());
+        TaskId parentTaskId = new TaskId("node", randomNonNegativeLong());
 
         {
-            SearchRequest searchRequest = SearchRequest.subSearchRequest(new SearchRequest(), Strings.EMPTY_ARRAY,
+            SearchRequest searchRequest = SearchRequest.subSearchRequest(parentTaskId, new SearchRequest(), Strings.EMPTY_ARRAY,
                 "local", nowInMillis, randomBoolean());
             SearchResponse searchResponse = client().search(searchRequest).actionGet();
             assertEquals(1, searchResponse.getHits().getTotalHits().value);
@@ -145,7 +147,7 @@ public class TransportSearchIT extends ESIntegTestCase {
             assertEquals("1", hit.getId());
         }
         {
-            SearchRequest searchRequest = SearchRequest.subSearchRequest(new SearchRequest(), Strings.EMPTY_ARRAY,
+            SearchRequest searchRequest = SearchRequest.subSearchRequest(parentTaskId, new SearchRequest(), Strings.EMPTY_ARRAY,
                 "", nowInMillis, randomBoolean());
             SearchResponse searchResponse = client().search(searchRequest).actionGet();
             assertEquals(1, searchResponse.getHits().getTotalHits().value);
@@ -159,6 +161,7 @@ public class TransportSearchIT extends ESIntegTestCase {
     }
 
     public void testAbsoluteStartMillis() {
+        TaskId parentTaskId = new TaskId("node", randomNonNegativeLong());
         {
             IndexRequest indexRequest = new IndexRequest("test-1970.01.01");
             indexRequest.id("1");
@@ -187,13 +190,13 @@ public class TransportSearchIT extends ESIntegTestCase {
             assertEquals(0, searchResponse.getTotalShards());
         }
         {
-            SearchRequest searchRequest = SearchRequest.subSearchRequest(new SearchRequest(),
+            SearchRequest searchRequest = SearchRequest.subSearchRequest(parentTaskId, new SearchRequest(),
                 Strings.EMPTY_ARRAY, "", 0, randomBoolean());
             SearchResponse searchResponse = client().search(searchRequest).actionGet();
             assertEquals(2, searchResponse.getHits().getTotalHits().value);
         }
         {
-            SearchRequest searchRequest = SearchRequest.subSearchRequest(new SearchRequest(),
+            SearchRequest searchRequest = SearchRequest.subSearchRequest(parentTaskId, new SearchRequest(),
                 Strings.EMPTY_ARRAY, "", 0, randomBoolean());
             searchRequest.indices("<test-{now/d}>");
             SearchResponse searchResponse = client().search(searchRequest).actionGet();
@@ -201,7 +204,7 @@ public class TransportSearchIT extends ESIntegTestCase {
             assertEquals("test-1970.01.01", searchResponse.getHits().getHits()[0].getIndex());
         }
         {
-            SearchRequest searchRequest = SearchRequest.subSearchRequest(new SearchRequest(),
+            SearchRequest searchRequest = SearchRequest.subSearchRequest(parentTaskId, new SearchRequest(),
                 Strings.EMPTY_ARRAY, "", 0, randomBoolean());
             SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
             RangeQueryBuilder rangeQuery = new RangeQueryBuilder("date");
@@ -217,6 +220,7 @@ public class TransportSearchIT extends ESIntegTestCase {
 
     public void testFinalReduce()  {
         long nowInMillis = randomLongBetween(0, Long.MAX_VALUE);
+        TaskId taskId = new TaskId("node", randomNonNegativeLong());
         {
             IndexRequest indexRequest = new IndexRequest("test");
             indexRequest.id("1");
@@ -243,7 +247,7 @@ public class TransportSearchIT extends ESIntegTestCase {
         source.aggregation(terms);
 
         {
-            SearchRequest searchRequest = randomBoolean() ? originalRequest : SearchRequest.subSearchRequest(originalRequest,
+            SearchRequest searchRequest = randomBoolean() ? originalRequest : SearchRequest.subSearchRequest(taskId, originalRequest,
                 Strings.EMPTY_ARRAY, "remote", nowInMillis, true);
             SearchResponse searchResponse = client().search(searchRequest).actionGet();
             assertEquals(2, searchResponse.getHits().getTotalHits().value);
@@ -252,7 +256,7 @@ public class TransportSearchIT extends ESIntegTestCase {
             assertEquals(1, longTerms.getBuckets().size());
         }
         {
-            SearchRequest searchRequest = SearchRequest.subSearchRequest(originalRequest,
+            SearchRequest searchRequest = SearchRequest.subSearchRequest(taskId, originalRequest,
                 Strings.EMPTY_ARRAY, "remote", nowInMillis, false);
             SearchResponse searchResponse = client().search(searchRequest).actionGet();
             assertEquals(2, searchResponse.getHits().getTotalHits().value);

+ 71 - 0
server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterSearchIT.java

@@ -19,6 +19,10 @@
 
 package org.elasticsearch.search.ccs;
 
+import org.elasticsearch.action.ActionFuture;
+import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest;
+import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksResponse;
+import org.elasticsearch.action.search.SearchAction;
 import org.elasticsearch.action.search.SearchRequest;
 import org.elasticsearch.action.search.SearchResponse;
 import org.elasticsearch.action.support.PlainActionFuture;
@@ -27,6 +31,7 @@ import org.elasticsearch.cluster.metadata.IndexMetadata;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.node.DiscoveryNodeRole;
 import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.unit.TimeValue;
 import org.elasticsearch.common.util.CollectionUtils;
 import org.elasticsearch.index.IndexModule;
 import org.elasticsearch.index.query.MatchAllQueryBuilder;
@@ -36,11 +41,13 @@ import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.search.builder.SearchSourceBuilder;
 import org.elasticsearch.search.internal.SearchContext;
 import org.elasticsearch.tasks.CancellableTask;
+import org.elasticsearch.tasks.TaskInfo;
 import org.elasticsearch.test.AbstractMultiClustersTestCase;
 import org.elasticsearch.test.InternalTestCluster;
 import org.elasticsearch.test.NodeRoles;
 import org.elasticsearch.test.hamcrest.ElasticsearchAssertions;
 import org.elasticsearch.transport.TransportService;
+import org.hamcrest.Matchers;
 import org.junit.Before;
 
 import java.util.Collection;
@@ -146,6 +153,70 @@ public class CrossClusterSearchIT extends AbstractMultiClustersTestCase {
         }
     }
 
+    public void testCancel() throws Exception {
+        assertAcked(client(LOCAL_CLUSTER).admin().indices().prepareCreate("demo"));
+        indexDocs(client(LOCAL_CLUSTER), "demo");
+        final InternalTestCluster remoteCluster = cluster("cluster_a");
+        remoteCluster.ensureAtLeastNumDataNodes(1);
+        final Settings.Builder allocationFilter = Settings.builder();
+        if (randomBoolean()) {
+            remoteCluster.ensureAtLeastNumDataNodes(3);
+            List<String> remoteDataNodes = StreamSupport.stream(remoteCluster.clusterService().state().nodes().spliterator(), false)
+                .filter(DiscoveryNode::isDataNode)
+                .map(DiscoveryNode::getName)
+                .collect(Collectors.toList());
+            assertThat(remoteDataNodes.size(), Matchers.greaterThanOrEqualTo(3));
+            List<String> seedNodes = randomSubsetOf(between(1, remoteDataNodes.size() - 1), remoteDataNodes);
+            disconnectFromRemoteClusters();
+            configureRemoteCluster("cluster_a", seedNodes);
+            if (randomBoolean()) {
+                // Using proxy connections
+                allocationFilter.put("index.routing.allocation.exclude._name", String.join(",", seedNodes));
+            } else {
+                allocationFilter.put("index.routing.allocation.include._name", String.join(",", seedNodes));
+            }
+        }
+        assertAcked(client("cluster_a").admin().indices().prepareCreate("prod")
+            .setSettings(Settings.builder().put(allocationFilter.build()).put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)));
+        assertFalse(client("cluster_a").admin().cluster().prepareHealth("prod")
+            .setWaitForYellowStatus().setTimeout(TimeValue.timeValueSeconds(10)).get().isTimedOut());
+        indexDocs(client("cluster_a"), "prod");
+        SearchListenerPlugin.blockQueryPhase();
+        PlainActionFuture<SearchResponse> queryFuture = new PlainActionFuture<>();
+        SearchRequest searchRequest = new SearchRequest("demo", "cluster_a:prod");
+        searchRequest.allowPartialSearchResults(false);
+        searchRequest.setCcsMinimizeRoundtrips(false);
+        searchRequest.source(new SearchSourceBuilder().query(new MatchAllQueryBuilder()).size(1000));
+        client(LOCAL_CLUSTER).search(searchRequest, queryFuture);
+        SearchListenerPlugin.waitSearchStarted();
+        // Get the search task and cancelled
+        final TaskInfo rootTask = client().admin().cluster().prepareListTasks()
+            .setActions(SearchAction.INSTANCE.name())
+            .get().getTasks().stream().filter(t -> t.getParentTaskId().isSet() == false)
+            .findFirst().get();
+        final CancelTasksRequest cancelRequest = new CancelTasksRequest().setTaskId(rootTask.getTaskId());
+        cancelRequest.setWaitForCompletion(randomBoolean());
+        final ActionFuture<CancelTasksResponse> cancelFuture = client().admin().cluster().cancelTasks(cancelRequest);
+        assertBusy(() -> {
+            final Iterable<TransportService> transportServices = cluster("cluster_a").getInstances(TransportService.class);
+            for (TransportService transportService : transportServices) {
+                Collection<CancellableTask> cancellableTasks = transportService.getTaskManager().getCancellableTasks().values();
+                for (CancellableTask cancellableTask : cancellableTasks) {
+                    assertTrue(cancellableTask.getDescription(), cancellableTask.isCancelled());
+                }
+            }
+        });
+        SearchListenerPlugin.allowQueryPhase();
+        assertBusy(() -> assertTrue(queryFuture.isDone()));
+        assertBusy(() -> assertTrue(cancelFuture.isDone()));
+        assertBusy(() -> {
+            final Iterable<TransportService> transportServices = cluster("cluster_a").getInstances(TransportService.class);
+            for (TransportService transportService : transportServices) {
+                assertThat(transportService.getTaskManager().getBannedTaskIds(), Matchers.empty());
+            }
+        });
+    }
+
     @Override
     protected Collection<Class<? extends Plugin>> nodePlugins(String clusterAlias) {
         if (clusterAlias.equals(LOCAL_CLUSTER)) {

+ 7 - 3
server/src/main/java/org/elasticsearch/action/search/SearchRequest.java

@@ -140,21 +140,25 @@ public class SearchRequest extends ActionRequest implements IndicesRequest.Repla
      * Used when a {@link SearchRequest} is created and executed as part of a cross-cluster search request
      * performing reduction on each cluster in order to minimize network round-trips between the coordinating node and the remote clusters.
      *
+     * @param parentTaskId the parent taskId of the original search request
      * @param originalSearchRequest the original search request
      * @param indices the indices to search against
      * @param clusterAlias the alias to prefix index names with in the returned search results
      * @param absoluteStartMillis the absolute start time to be used on the remote clusters to ensure that the same value is used
      * @param finalReduce whether the reduction should be final or not
      */
-    static SearchRequest subSearchRequest(SearchRequest originalSearchRequest, String[] indices,
+    static SearchRequest subSearchRequest(TaskId parentTaskId, SearchRequest originalSearchRequest, String[] indices,
                                           String clusterAlias, long absoluteStartMillis, boolean finalReduce) {
+        Objects.requireNonNull(parentTaskId, "parentTaskId must be specified");
         Objects.requireNonNull(originalSearchRequest, "search request must not be null");
         validateIndices(indices);
         Objects.requireNonNull(clusterAlias, "cluster alias must not be null");
         if (absoluteStartMillis < 0) {
             throw new IllegalArgumentException("absoluteStartMillis must not be negative but was [" + absoluteStartMillis + "]");
         }
-        return new SearchRequest(originalSearchRequest, indices, clusterAlias, absoluteStartMillis, finalReduce);
+        final SearchRequest request = new SearchRequest(originalSearchRequest, indices, clusterAlias, absoluteStartMillis, finalReduce);
+        request.setParentTask(parentTaskId);
+        return request;
     }
 
     private SearchRequest(SearchRequest searchRequest, String[] indices, String localClusterAlias, long absoluteStartMillis,
@@ -304,7 +308,7 @@ public class SearchRequest extends ActionRequest implements IndicesRequest.Repla
     /**
      * Returns the current time in milliseconds from the time epoch, to be used for the execution of this search request. Used to
      * ensure that the same value, determined by the coordinating node, is used on all nodes involved in the execution of the search
-     * request. When created through {@link #subSearchRequest(SearchRequest, String[], String, long, boolean)}, this method returns
+     * request. When created through {@link #subSearchRequest(TaskId, SearchRequest, String[], String, long, boolean)}, this method returns
      * the provided current time, otherwise it will return {@link System#currentTimeMillis()}.
      */
     long getOrCreateAbsoluteStartMillis() {

+ 9 - 6
server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java

@@ -66,6 +66,7 @@ import org.elasticsearch.search.internal.SearchContext;
 import org.elasticsearch.search.profile.ProfileShardResult;
 import org.elasticsearch.search.profile.SearchProfileShardResults;
 import org.elasticsearch.tasks.Task;
+import org.elasticsearch.tasks.TaskId;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.RemoteClusterAware;
 import org.elasticsearch.transport.RemoteClusterService;
@@ -295,7 +296,8 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
                     task, timeProvider, searchRequest, localIndices, clusterState, listener, searchContext, searchAsyncActionProvider);
             } else {
                 if (shouldMinimizeRoundtrips(searchRequest)) {
-                    ccsRemoteReduce(searchRequest, localIndices, remoteClusterIndices, timeProvider,
+                    final TaskId parentTaskId = task.taskInfo(clusterService.localNode().getId(), false).getTaskId();
+                    ccsRemoteReduce(parentTaskId, searchRequest, localIndices, remoteClusterIndices, timeProvider,
                         searchService.aggReduceContextBuilder(searchRequest),
                         remoteClusterService, threadPool, listener,
                         (r, l) -> executeLocalSearch(
@@ -357,8 +359,9 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
             source.collapse().getInnerHits().isEmpty();
     }
 
-    static void ccsRemoteReduce(SearchRequest searchRequest, OriginalIndices localIndices, Map<String, OriginalIndices> remoteIndices,
-                                SearchTimeProvider timeProvider, InternalAggregation.ReduceContextBuilder aggReduceContextBuilder,
+    static void ccsRemoteReduce(TaskId parentTaskId, SearchRequest searchRequest, OriginalIndices localIndices,
+                                Map<String, OriginalIndices> remoteIndices, SearchTimeProvider timeProvider,
+                                InternalAggregation.ReduceContextBuilder aggReduceContextBuilder,
                                 RemoteClusterService remoteClusterService, ThreadPool threadPool, ActionListener<SearchResponse> listener,
                                 BiConsumer<SearchRequest, ActionListener<SearchResponse>> localSearchConsumer) {
 
@@ -369,7 +372,7 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
             String clusterAlias = entry.getKey();
             boolean skipUnavailable = remoteClusterService.isSkipUnavailable(clusterAlias);
             OriginalIndices indices = entry.getValue();
-            SearchRequest ccsSearchRequest = SearchRequest.subSearchRequest(searchRequest, indices.indices(),
+            SearchRequest ccsSearchRequest = SearchRequest.subSearchRequest(parentTaskId, searchRequest, indices.indices(),
                 clusterAlias, timeProvider.getAbsoluteStartMillis(), true);
             Client remoteClusterClient = remoteClusterService.getRemoteClusterClient(threadPool, clusterAlias);
             remoteClusterClient.search(ccsSearchRequest, new ActionListener<SearchResponse>() {
@@ -407,7 +410,7 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
                 String clusterAlias = entry.getKey();
                 boolean skipUnavailable = remoteClusterService.isSkipUnavailable(clusterAlias);
                 OriginalIndices indices = entry.getValue();
-                SearchRequest ccsSearchRequest = SearchRequest.subSearchRequest(searchRequest, indices.indices(),
+                SearchRequest ccsSearchRequest = SearchRequest.subSearchRequest(parentTaskId, searchRequest, indices.indices(),
                     clusterAlias, timeProvider.getAbsoluteStartMillis(), false);
                 ActionListener<SearchResponse> ccsListener = createCCSListener(clusterAlias, skipUnavailable, countDown,
                     skippedClusters, exceptions, searchResponseMerger, totalClusters,  listener);
@@ -417,7 +420,7 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
             if (localIndices != null) {
                 ActionListener<SearchResponse> ccsListener = createCCSListener(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY,
                     false, countDown, skippedClusters, exceptions, searchResponseMerger, totalClusters, listener);
-                SearchRequest ccsLocalSearchRequest = SearchRequest.subSearchRequest(searchRequest, localIndices.indices(),
+                SearchRequest ccsLocalSearchRequest = SearchRequest.subSearchRequest(parentTaskId, searchRequest, localIndices.indices(),
                     RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY, timeProvider.getAbsoluteStartMillis(), false);
                 localSearchConsumer.accept(ccsLocalSearchRequest, ccsListener);
             }

+ 2 - 0
server/src/main/java/org/elasticsearch/tasks/TaskCancellationService.java

@@ -145,6 +145,7 @@ public class TaskCancellationService {
         GroupedActionListener<Void> groupedListener = new GroupedActionListener<>(listener.map(r -> null), childConnections.size());
         final BanParentTaskRequest banRequest = BanParentTaskRequest.createSetBanParentTaskRequest(taskId, reason, waitForCompletion);
         for (Transport.Connection connection : childConnections) {
+            assert TransportService.unwrapConnection(connection) == connection : "Child connection must be unwrapped";
             transportService.sendRequest(connection, BAN_PARENT_ACTION_NAME, banRequest, TransportRequestOptions.EMPTY,
                 new EmptyTransportResponseHandler(ThreadPool.Names.SAME) {
                     @Override
@@ -167,6 +168,7 @@ public class TaskCancellationService {
         final BanParentTaskRequest request =
             BanParentTaskRequest.createRemoveBanParentTaskRequest(new TaskId(localNodeId(), task.getId()));
         for (Transport.Connection connection : childConnections) {
+            assert TransportService.unwrapConnection(connection) == connection : "Child connection must be unwrapped";
             logger.trace("Sending remove ban for tasks with the parent [{}] for connection [{}]", request.parentTaskId, connection);
             transportService.sendRequest(connection, BAN_PARENT_ACTION_NAME, request, TransportRequestOptions.EMPTY,
                 new EmptyTransportResponseHandler(ThreadPool.Names.SAME) {

+ 2 - 0
server/src/main/java/org/elasticsearch/tasks/TaskManager.java

@@ -49,6 +49,7 @@ import org.elasticsearch.common.util.concurrent.ThreadContext;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.TcpChannel;
 import org.elasticsearch.transport.Transport;
+import org.elasticsearch.transport.TransportService;
 
 import java.io.IOException;
 import java.util.ArrayList;
@@ -250,6 +251,7 @@ public class TaskManager implements ClusterStateApplier {
      * to unregister the child connection once the child task is completed or failed.
      */
     public Releasable registerChildConnection(long taskId, Transport.Connection childConnection) {
+        assert TransportService.unwrapConnection(childConnection) == childConnection : "Child connection must be unwrapped";
         final CancellableTaskHolder holder = cancellableTasks.get(taskId);
         if (holder != null) {
             logger.trace("register child connection [{}] task [{}]", childConnection, taskId);

+ 4 - 0
server/src/main/java/org/elasticsearch/transport/RemoteConnectionManager.java

@@ -203,5 +203,9 @@ public class RemoteConnectionManager implements ConnectionManager {
         public Object getCacheKey() {
             return connection.getCacheKey();
         }
+
+        Transport.Connection getConnection() {
+            return connection;
+        }
     }
 }

+ 1 - 0
server/src/main/java/org/elasticsearch/transport/TransportActionProxy.java

@@ -121,6 +121,7 @@ public final class TransportActionProxy {
             super(in);
             targetNode = new DiscoveryNode(in);
             wrapped = reader.read(in);
+            setParentTask(wrapped.getParentTask());
         }
 
         @Override

+ 23 - 1
server/src/main/java/org/elasticsearch/transport/TransportService.java

@@ -589,6 +589,17 @@ public class TransportService extends AbstractLifecycleComponent
         sendRequest(connection, action, request, options, handler);
     }
 
+    /**
+     * Unwraps and returns the actual underlying connection of the given connection.
+     */
+    public static Transport.Connection unwrapConnection(Transport.Connection connection) {
+        Transport.Connection unwrapped = connection;
+        while (unwrapped instanceof RemoteConnectionManager.ProxyConnection) {
+            unwrapped = ((RemoteConnectionManager.ProxyConnection) unwrapped).getConnection();
+        }
+        return unwrapped;
+    }
+
     /**
      * Sends a request on the specified connection. If there is a failure sending the request, the specified handler is invoked.
      *
@@ -606,7 +617,18 @@ public class TransportService extends AbstractLifecycleComponent
         try {
             final TransportResponseHandler<T> delegate;
             if (request.getParentTask().isSet()) {
-                final Releasable unregisterChildNode = taskManager.registerChildConnection(request.getParentTask().getId(), connection);
+                // If the connection is a proxy connection, then we will create a cancellable proxy task on the proxy node and an actual
+                // child task on the target node of the remote cluster.
+                //  ----> a parent task on the local cluster
+                //        |
+                //         ----> a proxy task on the proxy node on the remote cluster
+                //               |
+                //                ----> an actual child task on the target node on the remote cluster
+                // To cancel the child task on the remote cluster, we must send a cancel request to the proxy node instead of the target
+                // node as the parent task of the child task is the proxy task not the parent task on the local cluster. Hence, here we
+                // unwrap the connection and keep track of the connection to the proxy node instead of the proxy connection.
+                final Transport.Connection unwrappedConn = unwrapConnection(connection);
+                final Releasable unregisterChildNode = taskManager.registerChildConnection(request.getParentTask().getId(), unwrappedConn);
                 delegate = new TransportResponseHandler<>() {
                     @Override
                     public void handleResponse(T response) {

+ 2 - 1
server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java

@@ -69,6 +69,7 @@ import org.elasticsearch.search.suggest.Suggest;
 import org.elasticsearch.search.suggest.completion.CompletionSuggestion;
 import org.elasticsearch.search.suggest.phrase.PhraseSuggestion;
 import org.elasticsearch.search.suggest.term.TermSuggestion;
+import org.elasticsearch.tasks.TaskId;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.InternalAggregationTestCase;
 import org.elasticsearch.threadpool.TestThreadPool;
@@ -393,7 +394,7 @@ public class SearchPhaseControllerTests extends ESTestCase {
     }
 
     private static SearchRequest randomSearchRequest() {
-        return randomBoolean() ? new SearchRequest() : SearchRequest.subSearchRequest(new SearchRequest(),
+        return randomBoolean() ? new SearchRequest() : SearchRequest.subSearchRequest(new TaskId("n", 1), new SearchRequest(),
             Strings.EMPTY_ARRAY, "remote", 0, randomBoolean());
     }
 

+ 9 - 7
server/src/test/java/org/elasticsearch/action/search/SearchRequestTests.java

@@ -52,21 +52,23 @@ public class SearchRequestTests extends AbstractSearchTestCase {
             return request;
         }
         //clusterAlias and absoluteStartMillis do not have public getters/setters hence we randomize them only in this test specifically.
-        return SearchRequest.subSearchRequest(request, request.indices(),
+        return SearchRequest.subSearchRequest(new TaskId("node", 1), request, request.indices(),
             randomAlphaOfLengthBetween(5, 10), randomNonNegativeLong(), randomBoolean());
     }
 
     public void testWithLocalReduction() {
-        expectThrows(NullPointerException.class, () -> SearchRequest.subSearchRequest(null, Strings.EMPTY_ARRAY, "", 0, randomBoolean()));
+        final TaskId taskId = new TaskId("n", 1);
+        expectThrows(NullPointerException.class, () -> SearchRequest.subSearchRequest(
+            taskId, null, Strings.EMPTY_ARRAY, "", 0, randomBoolean()));
         SearchRequest request = new SearchRequest();
-        expectThrows(NullPointerException.class, () -> SearchRequest.subSearchRequest(request, null, "", 0, randomBoolean()));
-        expectThrows(NullPointerException.class, () -> SearchRequest.subSearchRequest(request,
+        expectThrows(NullPointerException.class, () -> SearchRequest.subSearchRequest(taskId, request, null, "", 0, randomBoolean()));
+        expectThrows(NullPointerException.class, () -> SearchRequest.subSearchRequest(taskId, request,
             new String[]{null}, "", 0, randomBoolean()));
-        expectThrows(NullPointerException.class, () -> SearchRequest.subSearchRequest(request,
+        expectThrows(NullPointerException.class, () -> SearchRequest.subSearchRequest(taskId, request,
             Strings.EMPTY_ARRAY, null, 0, randomBoolean()));
-        expectThrows(IllegalArgumentException.class, () -> SearchRequest.subSearchRequest(request,
+        expectThrows(IllegalArgumentException.class, () -> SearchRequest.subSearchRequest(taskId, request,
             Strings.EMPTY_ARRAY, "", -1, randomBoolean()));
-        SearchRequest searchRequest = SearchRequest.subSearchRequest(request, Strings.EMPTY_ARRAY, "", 0, randomBoolean());
+        SearchRequest searchRequest = SearchRequest.subSearchRequest(taskId, request, Strings.EMPTY_ARRAY, "", 0, randomBoolean());
         assertNull(searchRequest.validate());
     }
 

+ 7 - 6
server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java

@@ -62,6 +62,7 @@ import org.elasticsearch.search.internal.AliasFilter;
 import org.elasticsearch.search.internal.InternalSearchResponse;
 import org.elasticsearch.search.internal.SearchContext;
 import org.elasticsearch.search.sort.SortBuilders;
+import org.elasticsearch.tasks.TaskId;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.transport.MockTransportService;
 import org.elasticsearch.threadpool.TestThreadPool;
@@ -391,7 +392,7 @@ public class TransportSearchActionTests extends ESTestCase {
             AtomicReference<Exception> failure = new AtomicReference<>();
             LatchedActionListener<SearchResponse> listener = new LatchedActionListener<>(
                 ActionListener.wrap(r -> fail("no response expected"), failure::set), latch);
-            TransportSearchAction.ccsRemoteReduce(searchRequest, localIndices, remoteIndicesByCluster, timeProvider,
+            TransportSearchAction.ccsRemoteReduce(new TaskId("n", 1), searchRequest, localIndices, remoteIndicesByCluster, timeProvider,
                     emptyReduceContextBuilder(), remoteClusterService, threadPool, listener, (r, l) -> setOnce.set(Tuple.tuple(r, l)));
             if (localIndices == null) {
                 assertNull(setOnce.get());
@@ -436,7 +437,7 @@ public class TransportSearchActionTests extends ESTestCase {
                 AtomicReference<SearchResponse> response = new AtomicReference<>();
                 LatchedActionListener<SearchResponse> listener = new LatchedActionListener<>(
                     ActionListener.wrap(response::set, e -> fail("no failures expected")), latch);
-                TransportSearchAction.ccsRemoteReduce(searchRequest, localIndices, remoteIndicesByCluster, timeProvider,
+                TransportSearchAction.ccsRemoteReduce(new TaskId("n", 1), searchRequest, localIndices, remoteIndicesByCluster, timeProvider,
                         emptyReduceContextBuilder(), remoteClusterService, threadPool, listener, (r, l) -> setOnce.set(Tuple.tuple(r, l)));
                 if (localIndices == null) {
                     assertNull(setOnce.get());
@@ -462,7 +463,7 @@ public class TransportSearchActionTests extends ESTestCase {
                 AtomicReference<Exception> failure = new AtomicReference<>();
                 LatchedActionListener<SearchResponse> listener = new LatchedActionListener<>(
                     ActionListener.wrap(r -> fail("no response expected"), failure::set), latch);
-                TransportSearchAction.ccsRemoteReduce(searchRequest, localIndices, remoteIndicesByCluster, timeProvider,
+                TransportSearchAction.ccsRemoteReduce(new TaskId("n", 1), searchRequest, localIndices, remoteIndicesByCluster, timeProvider,
                         emptyReduceContextBuilder(), remoteClusterService, threadPool, listener, (r, l) -> setOnce.set(Tuple.tuple(r, l)));
                 if (localIndices == null) {
                     assertNull(setOnce.get());
@@ -509,7 +510,7 @@ public class TransportSearchActionTests extends ESTestCase {
                 AtomicReference<Exception> failure = new AtomicReference<>();
                 LatchedActionListener<SearchResponse> listener = new LatchedActionListener<>(
                     ActionListener.wrap(r -> fail("no response expected"), failure::set), latch);
-                TransportSearchAction.ccsRemoteReduce(searchRequest, localIndices, remoteIndicesByCluster, timeProvider,
+                TransportSearchAction.ccsRemoteReduce(new TaskId("n", 1), searchRequest, localIndices, remoteIndicesByCluster, timeProvider,
                         emptyReduceContextBuilder(), remoteClusterService, threadPool, listener, (r, l) -> setOnce.set(Tuple.tuple(r, l)));
                 if (localIndices == null) {
                     assertNull(setOnce.get());
@@ -538,7 +539,7 @@ public class TransportSearchActionTests extends ESTestCase {
                 AtomicReference<SearchResponse> response = new AtomicReference<>();
                 LatchedActionListener<SearchResponse> listener = new LatchedActionListener<>(
                     ActionListener.wrap(response::set, e -> fail("no failures expected")), latch);
-                TransportSearchAction.ccsRemoteReduce(searchRequest, localIndices, remoteIndicesByCluster, timeProvider,
+                TransportSearchAction.ccsRemoteReduce(new TaskId("n", 1), searchRequest, localIndices, remoteIndicesByCluster, timeProvider,
                         emptyReduceContextBuilder(), remoteClusterService, threadPool, listener, (r, l) -> setOnce.set(Tuple.tuple(r, l)));
                 if (localIndices == null) {
                     assertNull(setOnce.get());
@@ -578,7 +579,7 @@ public class TransportSearchActionTests extends ESTestCase {
                 AtomicReference<SearchResponse> response = new AtomicReference<>();
                 LatchedActionListener<SearchResponse> listener = new LatchedActionListener<>(
                     ActionListener.wrap(response::set, e -> fail("no failures expected")), latch);
-                TransportSearchAction.ccsRemoteReduce(searchRequest, localIndices, remoteIndicesByCluster, timeProvider,
+                TransportSearchAction.ccsRemoteReduce(new TaskId("n", 1), searchRequest, localIndices, remoteIndicesByCluster, timeProvider,
                         emptyReduceContextBuilder(), remoteClusterService, threadPool, listener, (r, l) -> setOnce.set(Tuple.tuple(r, l)));
                 if (localIndices == null) {
                     assertNull(setOnce.get());

+ 15 - 17
test/framework/src/main/java/org/elasticsearch/test/AbstractMultiClustersTestCase.java

@@ -52,7 +52,7 @@ import static org.elasticsearch.discovery.DiscoveryModule.DISCOVERY_SEED_PROVIDE
 import static org.elasticsearch.discovery.SettingsBasedSeedHostsProvider.DISCOVERY_SEED_HOSTS_SETTING;
 import static org.hamcrest.Matchers.empty;
 import static org.hamcrest.Matchers.hasKey;
-import static org.hamcrest.Matchers.hasSize;
+import static org.hamcrest.Matchers.not;
 
 public abstract class AbstractMultiClustersTestCase extends ESTestCase {
     public static final String LOCAL_CLUSTER = RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY;
@@ -144,34 +144,32 @@ public abstract class AbstractMultiClustersTestCase extends ESTestCase {
     }
 
     protected void configureAndConnectsToRemoteClusters() throws Exception {
-        Map<String, List<String>> seedNodes = new HashMap<>();
         for (String clusterAlias : clusterGroup.clusterAliases()) {
             if (clusterAlias.equals(LOCAL_CLUSTER) == false) {
                 final InternalTestCluster cluster = clusterGroup.getCluster(clusterAlias);
                 final String[] allNodes = cluster.getNodeNames();
-                final List<String> selectedNodes = randomSubsetOf(randomIntBetween(1, Math.min(3, allNodes.length)), allNodes);
-                seedNodes.put(clusterAlias, selectedNodes);
+                final List<String> seedNodes = randomSubsetOf(randomIntBetween(1, Math.min(3, allNodes.length)), allNodes);
+                configureRemoteCluster(clusterAlias, seedNodes);
             }
         }
-        if (seedNodes.isEmpty()) {
-            return;
-        }
+    }
+
+    protected void configureRemoteCluster(String clusterAlias, Collection<String> seedNodes) throws Exception {
         Settings.Builder settings = Settings.builder();
-        for (Map.Entry<String, List<String>> entry : seedNodes.entrySet()) {
-            final String clusterAlias = entry.getKey();
-            final String seeds = entry.getValue().stream()
-                .map(node -> cluster(clusterAlias).getInstance(TransportService.class, node).boundAddress().publishAddress().toString())
-                .collect(Collectors.joining(","));
-            settings.put("cluster.remote." + clusterAlias + ".seeds", seeds);
-        }
+        final String seed = seedNodes.stream()
+            .map(node -> {
+                final TransportService transportService = cluster(clusterAlias).getInstance(TransportService.class, node);
+                return transportService.boundAddress().publishAddress().toString();
+            })
+            .collect(Collectors.joining(","));
+        settings.put("cluster.remote." + clusterAlias + ".seeds", seed);
         client().admin().cluster().prepareUpdateSettings().setPersistentSettings(settings).get();
         assertBusy(() -> {
             List<RemoteConnectionInfo> remoteConnectionInfos = client()
                 .execute(RemoteInfoAction.INSTANCE, new RemoteInfoRequest()).actionGet().getInfos()
-                .stream().filter(RemoteConnectionInfo::isConnected)
+                .stream().filter(c -> c.isConnected() && c.getClusterAlias().equals(clusterAlias))
                 .collect(Collectors.toList());
-            final long totalConnections = seedNodes.values().stream().map(List::size).count();
-            assertThat(remoteConnectionInfos, hasSize(Math.toIntExact(totalConnections)));
+            assertThat(remoteConnectionInfos, not(empty()));
         });
     }