Browse Source

Catch and handle disconnect exceptions in search (#115836) (#117373)

Getting a connection can throw an exception for a disconnected node.
We failed to handle these in the adjusted spots, leading to a phase failure
(and possible memory leaks for outstanding operations) instead of correctly
recording a per-shard failure.
Armin Braun 10 months ago
parent
commit
bde7828eb7

+ 5 - 0
docs/changelog/115836.yaml

@@ -0,0 +1,5 @@
+pr: 115836
+summary: Catch and handle disconnect exceptions in search
+area: Search
+type: bug
+issues: []

+ 23 - 9
server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java

@@ -84,15 +84,20 @@ final class DfsQueryPhase extends SearchPhase {
 
         for (final DfsSearchResult dfsResult : searchResults) {
             final SearchShardTarget shardTarget = dfsResult.getSearchShardTarget();
-            Transport.Connection connection = context.getConnection(shardTarget.getClusterAlias(), shardTarget.getNodeId());
-            ShardSearchRequest shardRequest = rewriteShardSearchRequest(dfsResult.getShardSearchRequest());
+            final int shardIndex = dfsResult.getShardIndex();
             QuerySearchRequest querySearchRequest = new QuerySearchRequest(
-                context.getOriginalIndices(dfsResult.getShardIndex()),
+                context.getOriginalIndices(shardIndex),
                 dfsResult.getContextId(),
-                shardRequest,
+                rewriteShardSearchRequest(dfsResult.getShardSearchRequest()),
                 dfs
             );
-            final int shardIndex = dfsResult.getShardIndex();
+            final Transport.Connection connection;
+            try {
+                connection = context.getConnection(shardTarget.getClusterAlias(), shardTarget.getNodeId());
+            } catch (Exception e) {
+                shardFailure(e, querySearchRequest, shardIndex, shardTarget, counter);
+                return;
+            }
             searchTransportService.sendExecuteQuery(
                 connection,
                 querySearchRequest,
@@ -112,10 +117,7 @@ final class DfsQueryPhase extends SearchPhase {
                     @Override
                     public void onFailure(Exception exception) {
                         try {
-                            context.getLogger()
-                                .debug(() -> "[" + querySearchRequest.contextId() + "] Failed to execute query phase", exception);
-                            progressListener.notifyQueryFailure(shardIndex, shardTarget, exception);
-                            counter.onFailure(shardIndex, shardTarget, exception);
+                            shardFailure(exception, querySearchRequest, shardIndex, shardTarget, counter);
                         } finally {
                             if (context.isPartOfPointInTime(querySearchRequest.contextId()) == false) {
                                 // the query might not have been executed at all (for example because thread pool rejected
@@ -134,6 +136,18 @@ final class DfsQueryPhase extends SearchPhase {
         }
     }
 
+    private void shardFailure(
+        Exception exception,
+        QuerySearchRequest querySearchRequest,
+        int shardIndex,
+        SearchShardTarget shardTarget,
+        CountedCollector<SearchPhaseResult> counter
+    ) {
+        context.getLogger().debug(() -> "[" + querySearchRequest.contextId() + "] Failed to execute query phase", exception);
+        progressListener.notifyQueryFailure(shardIndex, shardTarget, exception);
+        counter.onFailure(shardIndex, shardTarget, exception);
+    }
+
     // package private for testing
     ShardSearchRequest rewriteShardSearchRequest(ShardSearchRequest request) {
         SearchSourceBuilder source = request.source();

+ 35 - 26
server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java

@@ -21,6 +21,7 @@ import org.elasticsearch.search.fetch.ShardFetchSearchRequest;
 import org.elasticsearch.search.internal.ShardSearchContextId;
 import org.elasticsearch.search.rank.RankDoc;
 import org.elasticsearch.search.rank.RankDocShardInfo;
+import org.elasticsearch.transport.Transport;
 
 import java.util.ArrayList;
 import java.util.HashMap;
@@ -214,9 +215,41 @@ final class FetchSearchPhase extends SearchPhase {
         final ShardSearchContextId contextId = shardPhaseResult.queryResult() != null
             ? shardPhaseResult.queryResult().getContextId()
             : shardPhaseResult.rankFeatureResult().getContextId();
+        var listener = new SearchActionListener<FetchSearchResult>(shardTarget, shardIndex) {
+            @Override
+            public void innerOnResponse(FetchSearchResult result) {
+                try {
+                    progressListener.notifyFetchResult(shardIndex);
+                    counter.onResult(result);
+                } catch (Exception e) {
+                    context.onPhaseFailure(FetchSearchPhase.this, "", e);
+                }
+            }
+
+            @Override
+            public void onFailure(Exception e) {
+                try {
+                    logger.debug(() -> "[" + contextId + "] Failed to execute fetch phase", e);
+                    progressListener.notifyFetchFailure(shardIndex, shardTarget, e);
+                    counter.onFailure(shardIndex, shardTarget, e);
+                } finally {
+                    // the search context might not be cleared on the node where the fetch was executed for example
+                    // because the action was rejected by the thread pool. in this case we need to send a dedicated
+                    // request to clear the search context.
+                    releaseIrrelevantSearchContext(shardPhaseResult, context);
+                }
+            }
+        };
+        final Transport.Connection connection;
+        try {
+            connection = context.getConnection(shardTarget.getClusterAlias(), shardTarget.getNodeId());
+        } catch (Exception e) {
+            listener.onFailure(e);
+            return;
+        }
         context.getSearchTransport()
             .sendExecuteFetch(
-                context.getConnection(shardTarget.getClusterAlias(), shardTarget.getNodeId()),
+                connection,
                 new ShardFetchSearchRequest(
                     context.getOriginalIndices(shardPhaseResult.getShardIndex()),
                     contextId,
@@ -228,31 +261,7 @@ final class FetchSearchPhase extends SearchPhase {
                     aggregatedDfs
                 ),
                 context.getTask(),
-                new SearchActionListener<>(shardTarget, shardIndex) {
-                    @Override
-                    public void innerOnResponse(FetchSearchResult result) {
-                        try {
-                            progressListener.notifyFetchResult(shardIndex);
-                            counter.onResult(result);
-                        } catch (Exception e) {
-                            context.onPhaseFailure(FetchSearchPhase.this, "", e);
-                        }
-                    }
-
-                    @Override
-                    public void onFailure(Exception e) {
-                        try {
-                            logger.debug(() -> "[" + contextId + "] Failed to execute fetch phase", e);
-                            progressListener.notifyFetchFailure(shardIndex, shardTarget, e);
-                            counter.onFailure(shardIndex, shardTarget, e);
-                        } finally {
-                            // the search context might not be cleared on the node where the fetch was executed for example
-                            // because the action was rejected by the thread pool. in this case we need to send a dedicated
-                            // request to clear the search context.
-                            releaseIrrelevantSearchContext(shardPhaseResult, context);
-                        }
-                    }
-                }
+                listener
             );
     }
 

+ 32 - 23
server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java

@@ -24,6 +24,7 @@ import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorCont
 import org.elasticsearch.search.rank.feature.RankFeatureDoc;
 import org.elasticsearch.search.rank.feature.RankFeatureResult;
 import org.elasticsearch.search.rank.feature.RankFeatureShardRequest;
+import org.elasticsearch.transport.Transport;
 
 import java.util.List;
 
@@ -136,9 +137,38 @@ public class RankFeaturePhase extends SearchPhase {
         final SearchShardTarget shardTarget = queryResult.queryResult().getSearchShardTarget();
         final ShardSearchContextId contextId = queryResult.queryResult().getContextId();
         final int shardIndex = queryResult.getShardIndex();
+        var listener = new SearchActionListener<RankFeatureResult>(shardTarget, shardIndex) {
+            @Override
+            protected void innerOnResponse(RankFeatureResult response) {
+                try {
+                    progressListener.notifyRankFeatureResult(shardIndex);
+                    rankRequestCounter.onResult(response);
+                } catch (Exception e) {
+                    context.onPhaseFailure(RankFeaturePhase.this, "", e);
+                }
+            }
+
+            @Override
+            public void onFailure(Exception e) {
+                try {
+                    logger.debug(() -> "[" + contextId + "] Failed to execute rank phase", e);
+                    progressListener.notifyRankFeatureFailure(shardIndex, shardTarget, e);
+                    rankRequestCounter.onFailure(shardIndex, shardTarget, e);
+                } finally {
+                    releaseIrrelevantSearchContext(queryResult, context);
+                }
+            }
+        };
+        final Transport.Connection connection;
+        try {
+            connection = context.getConnection(shardTarget.getClusterAlias(), shardTarget.getNodeId());
+        } catch (Exception e) {
+            listener.onFailure(e);
+            return;
+        }
         context.getSearchTransport()
             .sendExecuteRankFeature(
-                context.getConnection(shardTarget.getClusterAlias(), shardTarget.getNodeId()),
+                connection,
                 new RankFeatureShardRequest(
                     context.getOriginalIndices(queryResult.getShardIndex()),
                     queryResult.getContextId(),
@@ -146,28 +176,7 @@ public class RankFeaturePhase extends SearchPhase {
                     entry
                 ),
                 context.getTask(),
-                new SearchActionListener<>(shardTarget, shardIndex) {
-                    @Override
-                    protected void innerOnResponse(RankFeatureResult response) {
-                        try {
-                            progressListener.notifyRankFeatureResult(shardIndex);
-                            rankRequestCounter.onResult(response);
-                        } catch (Exception e) {
-                            context.onPhaseFailure(RankFeaturePhase.this, "", e);
-                        }
-                    }
-
-                    @Override
-                    public void onFailure(Exception e) {
-                        try {
-                            logger.debug(() -> "[" + contextId + "] Failed to execute rank phase", e);
-                            progressListener.notifyRankFeatureFailure(shardIndex, shardTarget, e);
-                            rankRequestCounter.onFailure(shardIndex, shardTarget, e);
-                        } finally {
-                            releaseIrrelevantSearchContext(queryResult, context);
-                        }
-                    }
-                }
+                listener
             );
     }
 

+ 8 - 6
server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java

@@ -87,12 +87,14 @@ final class SearchDfsQueryThenFetchAsyncAction extends AbstractSearchAsyncAction
         final SearchShardTarget shard,
         final SearchActionListener<DfsSearchResult> listener
     ) {
-        getSearchTransport().sendExecuteDfs(
-            getConnection(shard.getClusterAlias(), shard.getNodeId()),
-            buildShardSearchRequest(shardIt, listener.requestIndex),
-            getTask(),
-            listener
-        );
+        final Transport.Connection connection;
+        try {
+            connection = getConnection(shard.getClusterAlias(), shard.getNodeId());
+        } catch (Exception e) {
+            listener.onFailure(e);
+            return;
+        }
+        getSearchTransport().sendExecuteDfs(connection, buildShardSearchRequest(shardIt, listener.requestIndex), getTask(), listener);
     }
 
     @Override

+ 8 - 1
server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java

@@ -93,8 +93,15 @@ class SearchQueryThenFetchAsyncAction extends AbstractSearchAsyncAction<SearchPh
         final SearchShardTarget shard,
         final SearchActionListener<SearchPhaseResult> listener
     ) {
+        final Transport.Connection connection;
+        try {
+            connection = getConnection(shard.getClusterAlias(), shard.getNodeId());
+        } catch (Exception e) {
+            listener.onFailure(e);
+            return;
+        }
         ShardSearchRequest request = rewriteShardSearchRequest(super.buildShardSearchRequest(shardIt, listener.requestIndex));
-        getSearchTransport().sendExecuteQuery(getConnection(shard.getClusterAlias(), shard.getNodeId()), request, getTask(), listener);
+        getSearchTransport().sendExecuteQuery(connection, request, getTask(), listener);
     }
 
     @Override

+ 10 - 6
server/src/test/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncActionTests.java

@@ -16,6 +16,7 @@ import org.apache.lucene.search.TotalHits;
 import org.elasticsearch.Version;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.OriginalIndices;
+import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.cluster.ClusterName;
 import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.cluster.node.DiscoveryNode;
@@ -733,17 +734,20 @@ public class SearchQueryThenFetchAsyncActionTests extends ESTestCase {
             assertThat(phase.totalHits().relation, equalTo(TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO));
 
             SearchShardTarget searchShardTarget = new SearchShardTarget("node3", shardIt.shardId(), null);
+            final PlainActionFuture<Void> f = new PlainActionFuture<>();
             SearchActionListener<SearchPhaseResult> listener = new SearchActionListener<SearchPhaseResult>(searchShardTarget, 0) {
                 @Override
-                public void onFailure(Exception e) {}
+                public void onFailure(Exception e) {
+                    f.onFailure(e);
+                }
 
                 @Override
-                protected void innerOnResponse(SearchPhaseResult response) {}
+                protected void innerOnResponse(SearchPhaseResult response) {
+                    fail("should not be called");
+                }
             };
-            Exception e = expectThrows(
-                VersionMismatchException.class,
-                () -> action.executePhaseOnShard(shardIt, searchShardTarget, listener)
-            );
+            action.executePhaseOnShard(shardIt, searchShardTarget, listener);
+            Exception e = expectThrows(VersionMismatchException.class, f::actionGet);
             assertThat(e.getMessage(), equalTo("One of the shards is incompatible with the required minimum version [" + minVersion + "]"));
         }
     }