Browse Source

Refresh potential lost connections at query start for `_search` (#130463)

CPS S2D9: Explicitly refresh connection(s) to remote(s) before executing query.

Previously, we'd refresh connection(s) to remote only when skip_unavailable=false.
We now do it when operating under CPS context too. However, to prevent listening for
too long, we now listen for a short time -- the duration to wait is controlled by the
setting search.ccs.force_connect_timeout that we'd eventually inject for CPS env.
Pawan Kartik 2 months ago
parent
commit
c1c72186c1

+ 5 - 0
docs/changelog/130463.yaml

@@ -0,0 +1,5 @@
+pr: 130463
+summary: Refresh potential lost connections at query start for `_search`
+area: Search
+type: enhancement
+issues: []

+ 126 - 0
server/src/internalClusterTest/java/org/elasticsearch/indices/cluster/RemoteSearchForceConnectTimeoutIT.java

@@ -0,0 +1,126 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+
+package org.elasticsearch.indices.cluster;
+
+import org.elasticsearch.action.search.SearchRequest;
+import org.elasticsearch.action.search.TransportSearchAction;
+import org.elasticsearch.common.settings.Setting;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.util.CollectionUtils;
+import org.elasticsearch.plugins.ClusterPlugin;
+import org.elasticsearch.plugins.Plugin;
+import org.elasticsearch.test.AbstractMultiClustersTestCase;
+import org.elasticsearch.test.transport.MockTransportService;
+import org.elasticsearch.transport.TransportService;
+import org.hamcrest.Matchers;
+
+import java.util.Collection;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.CountDownLatch;
+
+import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
+
+public class RemoteSearchForceConnectTimeoutIT extends AbstractMultiClustersTestCase {
+    private static final String REMOTE_CLUSTER_1 = "cluster-a";
+
+    public static class ForceConnectTimeoutPlugin extends Plugin implements ClusterPlugin {
+        @Override
+        public List<Setting<?>> getSettings() {
+            return List.of(ForceConnectTimeoutSetting);
+        }
+    }
+
+    private static final Setting<String> ForceConnectTimeoutSetting = Setting.simpleString(
+        "search.ccs.force_connect_timeout",
+        Setting.Property.NodeScope
+    );
+
+    @Override
+    protected List<String> remoteClusterAlias() {
+        return List.of(REMOTE_CLUSTER_1);
+    }
+
+    @Override
+    protected Collection<Class<? extends Plugin>> nodePlugins(String clusterAlias) {
+        return CollectionUtils.appendToCopy(super.nodePlugins(clusterAlias), ForceConnectTimeoutPlugin.class);
+    }
+
+    @Override
+    protected Settings nodeSettings() {
+        /*
+         * This is the setting that controls how long TransportSearchAction will wait for establishing a connection
+         * with a remote. At present, we set it to low 1s to prevent stalling the test for too long -- this is consistent
+         * with what we've done in other tests.
+         */
+        return Settings.builder().put(super.nodeSettings()).put("search.ccs.force_connect_timeout", "1s").build();
+    }
+
+    @Override
+    protected Map<String, Boolean> skipUnavailableForRemoteClusters() {
+        return Map.of(REMOTE_CLUSTER_1, true);
+    }
+
+    public void testTimeoutSetting() {
+        var latch = new CountDownLatch(1);
+        for (String nodeName : cluster(LOCAL_CLUSTER).getNodeNames()) {
+            MockTransportService mts = (MockTransportService) cluster(LOCAL_CLUSTER).getInstance(TransportService.class, nodeName);
+
+            mts.addConnectBehavior(
+                cluster(REMOTE_CLUSTER_1).getInstance(TransportService.class, randomFrom(cluster(REMOTE_CLUSTER_1).getNodeNames())),
+                ((transport, discoveryNode, profile, listener) -> {
+                    try {
+                        latch.await();
+                    } catch (InterruptedException e) {
+                        throw new AssertionError(e);
+                    }
+
+                    transport.openConnection(discoveryNode, profile, listener);
+                })
+            );
+        }
+
+        // Add some dummy data to prove we are communicating fine with the remote.
+        assertAcked(client(REMOTE_CLUSTER_1).admin().indices().prepareCreate("test-index"));
+        client(REMOTE_CLUSTER_1).prepareIndex("test-index").setSource("sample-field", "sample-value").get();
+        client(REMOTE_CLUSTER_1).admin().indices().prepareRefresh("test-index").get();
+
+        /*
+         * Do a full restart so that our custom connect behaviour takes effect since it does not apply to
+         * pre-existing connections -- they're already established by the time this test runs.
+         */
+        try {
+            cluster(REMOTE_CLUSTER_1).fullRestart();
+        } catch (Exception e) {
+            throw new AssertionError(e);
+        } finally {
+            var searchRequest = new SearchRequest("*", "*:*");
+            searchRequest.allowPartialSearchResults(false);
+            var result = safeGet(client().execute(TransportSearchAction.TYPE, searchRequest));
+
+            // The remote cluster should've failed.
+            var failures = result.getClusters().getCluster(REMOTE_CLUSTER_1).getFailures();
+            assertThat(failures.size(), Matchers.equalTo(1));
+
+            /*
+             * Reason should be a timed out exception. The timeout should be equal to what we've set and there should
+             * be a reference to the subscribable listener -- which is what we use to listen for a valid connection.
+             */
+            var failureReason = failures.getFirst().reason();
+            assertThat(
+                failureReason,
+                Matchers.containsString("org.elasticsearch.ElasticsearchTimeoutException: timed out after [1s/1000ms]")
+            );
+            assertThat(failureReason, Matchers.containsString("SubscribableListener"));
+            latch.countDown();
+            result.decRef();
+        }
+    }
+}

+ 142 - 56
server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java

@@ -30,6 +30,7 @@ import org.elasticsearch.action.admin.cluster.stats.CCSUsageTelemetry;
 import org.elasticsearch.action.support.ActionFilters;
 import org.elasticsearch.action.support.HandledTransportAction;
 import org.elasticsearch.action.support.IndicesOptions;
+import org.elasticsearch.action.support.SubscribableListener;
 import org.elasticsearch.action.support.master.MasterNodeRequest;
 import org.elasticsearch.client.internal.Client;
 import org.elasticsearch.cluster.ClusterState;
@@ -167,6 +168,7 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
     private final Client client;
     private final UsageService usageService;
     private final boolean collectTelemetry;
+    private final TimeValue forceConnectTimeoutSecs;
 
     @Inject
     public TransportSearchAction(
@@ -215,6 +217,7 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
         this.searchResponseMetrics = searchResponseMetrics;
         this.client = client;
         this.usageService = usageService;
+        forceConnectTimeoutSecs = settings.getAsTime("search.ccs.force_connect_timeout", null);
     }
 
     private Map<String, OriginalIndices> buildPerIndexOriginalIndices(
@@ -445,7 +448,9 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
                             projectState,
                             clusters,
                             searchPhaseProvider.apply(l)
-                        )
+                        ),
+                        transportService,
+                        forceConnectTimeoutSecs
                     );
                 } else {
                     final SearchContextId searchContext = resolvedIndices.getSearchContextId();
@@ -505,7 +510,8 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
                                 clusters,
                                 searchPhaseProvider.apply(finalDelegate)
                             );
-                        })
+                        }),
+                        forceConnectTimeoutSecs
                     );
                 }
             }
@@ -633,6 +639,40 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
             || source.collapse().getInnerHits().isEmpty();
     }
 
+    /**
+     * Return a subscribable listener with optional timeout depending on force reconnect setting is registered or
+     * not.
+     * @param forceConnectTimeoutSecs Timeout in seconds that determines how long we'll wait to establish a connection
+     *                                to a remote.
+     * @param threadPool The thread pool that'll be used for the timeout.
+     * @param timeoutExecutor The executor that should be used for the timeout.
+     * @return SubscribableListener A listener with optionally added timeout.
+     */
+    private static SubscribableListener<Transport.Connection> getListenerWithOptionalTimeout(
+        TimeValue forceConnectTimeoutSecs,
+        ThreadPool threadPool,
+        Executor timeoutExecutor
+    ) {
+        var subscribableListener = new SubscribableListener<Transport.Connection>();
+        if (forceConnectTimeoutSecs != null) {
+            subscribableListener.addTimeout(forceConnectTimeoutSecs, threadPool, timeoutExecutor);
+        }
+
+        return subscribableListener;
+    }
+
+    /**
+     * The default disconnected strategy for Elasticsearch is RECONNECT_UNLESS_SKIP_UNAVAILABLE. So we either force
+     * connect if required (like in CPS) or when skip unavailable is false for a cluster.
+     * @param forceConnectTimeoutSecs The timeout value from the force connect setting.
+     *                                If it is set, use it as it takes precedence.
+     * @param skipUnavailable The usual skip unavailable setting.
+     * @return boolean If we should always force reconnect.
+     */
+    private static boolean shouldEstablishConnection(TimeValue forceConnectTimeoutSecs, boolean skipUnavailable) {
+        return forceConnectTimeoutSecs != null || skipUnavailable == false;
+    }
+
     /**
      * Handles ccs_minimize_roundtrips=true
      */
@@ -647,7 +687,9 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
         RemoteClusterService remoteClusterService,
         ThreadPool threadPool,
         ActionListener<SearchResponse> listener,
-        BiConsumer<SearchRequest, ActionListener<SearchResponse>> localSearchConsumer
+        BiConsumer<SearchRequest, ActionListener<SearchResponse>> localSearchConsumer,
+        TransportService transportService,
+        TimeValue forceConnectTimeoutSecs
     ) {
         final var remoteClientResponseExecutor = threadPool.executor(ThreadPool.Names.SEARCH_COORDINATION);
         if (resolvedIndices.getLocalIndices() == null && resolvedIndices.getRemoteClusterIndices().size() == 1) {
@@ -665,12 +707,9 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
                 timeProvider.absoluteStartMillis(),
                 true
             );
-            var remoteClusterClient = remoteClusterService.getRemoteClusterClient(
-                clusterAlias,
-                remoteClientResponseExecutor,
-                RemoteClusterService.DisconnectedStrategy.RECONNECT_UNLESS_SKIP_UNAVAILABLE
-            );
-            remoteClusterClient.execute(TransportSearchAction.REMOTE_TYPE, ccsSearchRequest, new ActionListener<>() {
+
+            var connectionListener = getListenerWithOptionalTimeout(forceConnectTimeoutSecs, threadPool, remoteClientResponseExecutor);
+            var searchListener = new ActionListener<SearchResponse>() {
                 @Override
                 public void onResponse(SearchResponse searchResponse) {
                     // overwrite the existing cluster entry with the updated one
@@ -713,7 +752,25 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
                         listener.onFailure(wrapRemoteClusterFailure(clusterAlias, e));
                     }
                 }
-            });
+            };
+
+            connectionListener.addListener(
+                searchListener.delegateFailure(
+                    (responseListener, connection) -> transportService.sendRequest(
+                        connection,
+                        TransportSearchAction.TYPE.name(),
+                        ccsSearchRequest,
+                        TransportRequestOptions.EMPTY,
+                        new ActionListenerResponseHandler<>(responseListener, SearchResponse::new, remoteClientResponseExecutor)
+                    )
+                )
+            );
+
+            remoteClusterService.maybeEnsureConnectedAndGetConnection(
+                clusterAlias,
+                shouldEstablishConnection(forceConnectTimeoutSecs, skipUnavailable),
+                connectionListener
+            );
         } else {
             SearchResponseMerger searchResponseMerger = createSearchResponseMerger(
                 searchRequest.source(),
@@ -748,12 +805,30 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
                     task.getProgressListener(),
                     listener
                 );
-                final var remoteClusterClient = remoteClusterService.getRemoteClusterClient(
+
+                SubscribableListener<Transport.Connection> connectionListener = getListenerWithOptionalTimeout(
+                    forceConnectTimeoutSecs,
+                    threadPool,
+                    remoteClientResponseExecutor
+                );
+
+                connectionListener.addListener(
+                    ccsListener.delegateFailure(
+                        (responseListener, connection) -> transportService.sendRequest(
+                            connection,
+                            TransportSearchAction.REMOTE_TYPE.name(),
+                            ccsSearchRequest,
+                            TransportRequestOptions.EMPTY,
+                            new ActionListenerResponseHandler<>(responseListener, SearchResponse::new, remoteClientResponseExecutor)
+                        )
+                    )
+                );
+
+                remoteClusterService.maybeEnsureConnectedAndGetConnection(
                     clusterAlias,
-                    remoteClientResponseExecutor,
-                    RemoteClusterService.DisconnectedStrategy.RECONNECT_UNLESS_SKIP_UNAVAILABLE
+                    shouldEstablishConnection(forceConnectTimeoutSecs, skipUnavailable),
+                    connectionListener
                 );
-                remoteClusterClient.execute(TransportSearchAction.REMOTE_TYPE, ccsSearchRequest, ccsListener);
             }
             if (resolvedIndices.getLocalIndices() != null) {
                 ActionListener<SearchResponse> ccsListener = createCCSListener(
@@ -819,7 +894,8 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
         SearchResponse.Clusters clusters,
         SearchTimeProvider timeProvider,
         TransportService transportService,
-        ActionListener<Map<String, SearchShardsResponse>> listener
+        ActionListener<Map<String, SearchShardsResponse>> listener,
+        TimeValue forceConnectTimeoutSecs
     ) {
         RemoteClusterService remoteClusterService = transportService.getRemoteClusterService();
         final CountDown responsesCountDown = new CountDown(remoteIndicesByCluster.size());
@@ -848,49 +924,59 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
                     return searchShardsResponses;
                 }
             };
+
+            var threadPool = transportService.getThreadPool();
+            var connectionListener = getListenerWithOptionalTimeout(
+                forceConnectTimeoutSecs,
+                threadPool,
+                threadPool.executor(ThreadPool.Names.SEARCH_COORDINATION)
+            );
+
+            connectionListener.addListener(singleListener.delegateFailure((responseListener, connection) -> {
+                final String[] indices = entry.getValue().indices();
+                final Executor responseExecutor = transportService.getThreadPool().executor(ThreadPool.Names.SEARCH_COORDINATION);
+                // TODO: support point-in-time
+                if (searchContext == null && connection.getTransportVersion().onOrAfter(TransportVersions.V_8_9_X)) {
+                    SearchShardsRequest searchShardsRequest = new SearchShardsRequest(
+                        indices,
+                        indicesOptions,
+                        query,
+                        routing,
+                        preference,
+                        allowPartialResults,
+                        clusterAlias
+                    );
+                    transportService.sendRequest(
+                        connection,
+                        TransportSearchShardsAction.TYPE.name(),
+                        searchShardsRequest,
+                        TransportRequestOptions.EMPTY,
+                        new ActionListenerResponseHandler<>(responseListener, SearchShardsResponse::new, responseExecutor)
+                    );
+                } else {
+                    // does not do a can-match
+                    ClusterSearchShardsRequest searchShardsRequest = new ClusterSearchShardsRequest(
+                        MasterNodeRequest.INFINITE_MASTER_NODE_TIMEOUT,
+                        indices
+                    ).indicesOptions(indicesOptions).local(true).preference(preference).routing(routing);
+                    transportService.sendRequest(
+                        connection,
+                        TransportClusterSearchShardsAction.TYPE.name(),
+                        searchShardsRequest,
+                        TransportRequestOptions.EMPTY,
+                        new ActionListenerResponseHandler<>(
+                            singleListener.map(SearchShardsResponse::fromLegacyResponse),
+                            ClusterSearchShardsResponse::new,
+                            responseExecutor
+                        )
+                    );
+                }
+            }));
+
             remoteClusterService.maybeEnsureConnectedAndGetConnection(
                 clusterAlias,
-                skipUnavailable == false,
-                singleListener.delegateFailureAndWrap((delegate, connection) -> {
-                    final String[] indices = entry.getValue().indices();
-                    final Executor responseExecutor = transportService.getThreadPool().executor(ThreadPool.Names.SEARCH_COORDINATION);
-                    // TODO: support point-in-time
-                    if (searchContext == null && connection.getTransportVersion().onOrAfter(TransportVersions.V_8_9_X)) {
-                        SearchShardsRequest searchShardsRequest = new SearchShardsRequest(
-                            indices,
-                            indicesOptions,
-                            query,
-                            routing,
-                            preference,
-                            allowPartialResults,
-                            clusterAlias
-                        );
-                        transportService.sendRequest(
-                            connection,
-                            TransportSearchShardsAction.TYPE.name(),
-                            searchShardsRequest,
-                            TransportRequestOptions.EMPTY,
-                            new ActionListenerResponseHandler<>(delegate, SearchShardsResponse::new, responseExecutor)
-                        );
-                    } else {
-                        // does not do a can-match
-                        ClusterSearchShardsRequest searchShardsRequest = new ClusterSearchShardsRequest(
-                            MasterNodeRequest.INFINITE_MASTER_NODE_TIMEOUT,
-                            indices
-                        ).indicesOptions(indicesOptions).local(true).preference(preference).routing(routing);
-                        transportService.sendRequest(
-                            connection,
-                            TransportClusterSearchShardsAction.TYPE.name(),
-                            searchShardsRequest,
-                            TransportRequestOptions.EMPTY,
-                            new ActionListenerResponseHandler<>(
-                                delegate.map(SearchShardsResponse::fromLegacyResponse),
-                                ClusterSearchShardsResponse::new,
-                                responseExecutor
-                            )
-                        );
-                    }
-                })
+                shouldEstablishConnection(forceConnectTimeoutSecs, skipUnavailable),
+                connectionListener
             );
         }
     }

+ 31 - 12
server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java

@@ -546,7 +546,9 @@ public class TransportSearchActionTests extends ESTestCase {
                 remoteClusterService,
                 threadPool,
                 listener,
-                (r, l) -> setOnce.set(Tuple.tuple(r, l))
+                (r, l) -> setOnce.set(Tuple.tuple(r, l)),
+                service,
+                null
             );
             if (localIndices == null) {
                 assertNull(setOnce.get());
@@ -621,7 +623,9 @@ public class TransportSearchActionTests extends ESTestCase {
                     remoteClusterService,
                     threadPool,
                     listener,
-                    (r, l) -> setOnce.set(Tuple.tuple(r, l))
+                    (r, l) -> setOnce.set(Tuple.tuple(r, l)),
+                    service,
+                    null
                 );
                 if (localIndices == null) {
                     assertNull(setOnce.get());
@@ -677,7 +681,9 @@ public class TransportSearchActionTests extends ESTestCase {
                     remoteClusterService,
                     threadPool,
                     listener,
-                    (r, l) -> setOnce.set(Tuple.tuple(r, l))
+                    (r, l) -> setOnce.set(Tuple.tuple(r, l)),
+                    service,
+                    null
                 );
                 if (localIndices == null) {
                     assertNull(setOnce.get());
@@ -765,7 +771,9 @@ public class TransportSearchActionTests extends ESTestCase {
                     remoteClusterService,
                     threadPool,
                     listener,
-                    (r, l) -> setOnce.set(Tuple.tuple(r, l))
+                    (r, l) -> setOnce.set(Tuple.tuple(r, l)),
+                    service,
+                    null
                 );
                 if (localIndices == null) {
                     assertNull(setOnce.get());
@@ -864,7 +872,9 @@ public class TransportSearchActionTests extends ESTestCase {
                     remoteClusterService,
                     threadPool,
                     listener,
-                    (r, l) -> setOnce.set(Tuple.tuple(r, l))
+                    (r, l) -> setOnce.set(Tuple.tuple(r, l)),
+                    service,
+                    null
                 );
                 if (localIndices == null) {
                     assertNull(setOnce.get());
@@ -915,7 +925,9 @@ public class TransportSearchActionTests extends ESTestCase {
                     remoteClusterService,
                     threadPool,
                     listener,
-                    (r, l) -> setOnce.set(Tuple.tuple(r, l))
+                    (r, l) -> setOnce.set(Tuple.tuple(r, l)),
+                    service,
+                    null
                 );
                 if (localIndices == null) {
                     assertNull(setOnce.get());
@@ -988,7 +1000,9 @@ public class TransportSearchActionTests extends ESTestCase {
                     remoteClusterService,
                     threadPool,
                     listener,
-                    (r, l) -> setOnce.set(Tuple.tuple(r, l))
+                    (r, l) -> setOnce.set(Tuple.tuple(r, l)),
+                    service,
+                    null
                 );
                 if (localIndices == null) {
                     assertNull(setOnce.get());
@@ -1083,7 +1097,8 @@ public class TransportSearchActionTests extends ESTestCase {
                     clusters,
                     timeProvider,
                     service,
-                    new LatchedActionListener<>(ActionTestUtils.assertNoFailureListener(response::set), latch)
+                    new LatchedActionListener<>(ActionTestUtils.assertNoFailureListener(response::set), latch),
+                    null
                 );
                 awaitLatch(latch, 5, TimeUnit.SECONDS);
                 assertNotNull(response.get());
@@ -1112,7 +1127,8 @@ public class TransportSearchActionTests extends ESTestCase {
                     clusters,
                     timeProvider,
                     service,
-                    new LatchedActionListener<>(ActionListener.wrap(r -> fail("no response expected"), failure::set), latch)
+                    new LatchedActionListener<>(ActionListener.wrap(r -> fail("no response expected"), failure::set), latch),
+                    null
                 );
                 awaitLatch(latch, 5, TimeUnit.SECONDS);
                 assertEquals(numClusters, clusters.getClusterStateCount(SearchResponse.Cluster.Status.FAILED));
@@ -1160,7 +1176,8 @@ public class TransportSearchActionTests extends ESTestCase {
                     clusters,
                     timeProvider,
                     service,
-                    new LatchedActionListener<>(ActionListener.wrap(r -> fail("no response expected"), failure::set), latch)
+                    new LatchedActionListener<>(ActionListener.wrap(r -> fail("no response expected"), failure::set), latch),
+                    null
                 );
                 awaitLatch(latch, 5, TimeUnit.SECONDS);
                 assertEquals(numDisconnectedClusters, clusters.getClusterStateCount(SearchResponse.Cluster.Status.FAILED));
@@ -1190,7 +1207,8 @@ public class TransportSearchActionTests extends ESTestCase {
                     clusters,
                     timeProvider,
                     service,
-                    new LatchedActionListener<>(ActionTestUtils.assertNoFailureListener(response::set), latch)
+                    new LatchedActionListener<>(ActionTestUtils.assertNoFailureListener(response::set), latch),
+                    null
                 );
                 awaitLatch(latch, 5, TimeUnit.SECONDS);
                 assertNotNull(response.get());
@@ -1236,7 +1254,8 @@ public class TransportSearchActionTests extends ESTestCase {
                     clusters,
                     timeProvider,
                     service,
-                    new LatchedActionListener<>(ActionTestUtils.assertNoFailureListener(response::set), latch)
+                    new LatchedActionListener<>(ActionTestUtils.assertNoFailureListener(response::set), latch),
+                    null
                 );
                 awaitLatch(latch, 5, TimeUnit.SECONDS);
                 assertEquals(0, clusters.getClusterStateCount(SearchResponse.Cluster.Status.SKIPPED));