Bläddra i källkod

Add tests for non-fatal errors in data node request sender (#124203) (#124305)

Ievgen Degtiarenko 7 månader sedan
förälder
incheckning
0a86a523bd

+ 17 - 27
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSender.java

@@ -81,15 +81,13 @@ abstract class DataNodeRequestSender {
         final long startTimeInNanos = System.nanoTime();
         searchShards(rootTask, clusterAlias, requestFilter, concreteIndices, originalIndices, ActionListener.wrap(targetShards -> {
             try (var computeListener = new ComputeListener(transportService.getThreadPool(), runOnTaskFailure, listener.map(profiles -> {
-                TimeValue took = TimeValue.timeValueNanos(System.nanoTime() - startTimeInNanos);
-                final int failedShards = shardFailures.size();
                 return new ComputeResponse(
                     profiles,
-                    took,
+                    TimeValue.timeValueNanos(System.nanoTime() - startTimeInNanos),
                     targetShards.totalShards(),
-                    targetShards.totalShards() - failedShards,
+                    targetShards.totalShards() - shardFailures.size(),
                     targetShards.skippedShards(),
-                    failedShards
+                    shardFailures.size()
                 );
             }))) {
                 for (TargetShard shard : targetShards.shards.values()) {
@@ -128,8 +126,7 @@ abstract class DataNodeRequestSender {
                         reportedFailure = true;
                         reportFailures(computeListener);
                     } else {
-                        var nodeRequests = selectNodeRequests(targetShards);
-                        for (NodeRequest request : nodeRequests) {
+                        for (NodeRequest request : selectNodeRequests(targetShards)) {
                             sendOneNodeRequest(targetShards, computeListener, request);
                         }
                     }
@@ -211,18 +208,17 @@ abstract class DataNodeRequestSender {
 
     private void trackShardLevelFailure(ShardId shardId, boolean fatal, Exception originalEx) {
         final Exception e = unwrapFailure(originalEx);
-        // Retain only one meaningful exception and avoid suppressing previous failures to minimize memory usage, especially when handling
-        // many shards.
+        final boolean isTaskCanceledException = ExceptionsHelper.unwrap(e, TaskCancelledException.class) != null;
         shardFailures.compute(shardId, (k, current) -> {
-            boolean mergedFatal = fatal || ExceptionsHelper.unwrap(e, TaskCancelledException.class) != null;
-            if (current == null) {
-                return new ShardFailure(mergedFatal, e);
-            }
-            mergedFatal |= current.fatal;
-            if (e instanceof NoShardAvailableActionException || ExceptionsHelper.unwrap(e, TaskCancelledException.class) != null) {
-                return new ShardFailure(mergedFatal, current.failure);
-            }
-            return new ShardFailure(mergedFatal, e);
+            boolean mergedFatal = fatal || isTaskCanceledException;
+            return current == null
+                ? new ShardFailure(mergedFatal, e)
+                : new ShardFailure(
+                    mergedFatal || current.fatal,
+                    // Retain only one meaningful exception and avoid suppressing previous failures to minimize memory usage,
+                    // especially when handling many shards.
+                    isTaskCanceledException || e instanceof NoShardAvailableActionException ? current.failure : e
+                );
         });
     }
 
@@ -243,17 +239,11 @@ abstract class DataNodeRequestSender {
     /**
      * (Remaining) allocated nodes of a given shard id and its alias filter
      */
-    record TargetShard(ShardId shardId, List<DiscoveryNode> remainingNodes, AliasFilter aliasFilter) {
-
-    }
+    record TargetShard(ShardId shardId, List<DiscoveryNode> remainingNodes, AliasFilter aliasFilter) {}
 
-    record NodeRequest(DiscoveryNode node, List<ShardId> shardIds, Map<Index, AliasFilter> aliasFilters) {
+    record NodeRequest(DiscoveryNode node, List<ShardId> shardIds, Map<Index, AliasFilter> aliasFilters) {}
 
-    }
-
-    private record ShardFailure(boolean fatal, Exception failure) {
-
-    }
+    private record ShardFailure(boolean fatal, Exception failure) {}
 
     /**
      * Selects the next nodes to send requests to. Limits to at most one outstanding request per node.

+ 32 - 5
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSenderTests.java

@@ -40,6 +40,7 @@ import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.Objects;
 import java.util.Queue;
 import java.util.Set;
 import java.util.concurrent.Executor;
@@ -85,7 +86,7 @@ public class DataNodeRequestSenderTests extends ComputeTestCase {
     }
 
     @After
-    public void shutdownThreadPool() throws Exception {
+    public void shutdownThreadPool() {
         terminate(threadPool);
     }
 
@@ -109,8 +110,7 @@ public class DataNodeRequestSenderTests extends ComputeTestCase {
         Queue<NodeRequest> sent = ConcurrentCollections.newQueue();
         var future = sendRequests(targetShards, randomBoolean(), (node, shardIds, aliasFilters, listener) -> {
             sent.add(new NodeRequest(node, shardIds, aliasFilters));
-            var resp = new DataNodeComputeResponse(List.of(), Map.of());
-            runWithDelay(() -> listener.onResponse(resp));
+            runWithDelay(() -> listener.onResponse(new DataNodeComputeResponse(List.of(), Map.of())));
         });
         safeGet(future);
         assertThat(sent.size(), equalTo(2));
@@ -123,8 +123,7 @@ public class DataNodeRequestSenderTests extends ComputeTestCase {
             var future = sendRequests(targetShards, false, (node, shardIds, aliasFilters, listener) -> {
                 fail("expect no data-node request is sent when target shards are missing");
             });
-            var error = expectThrows(NoShardAvailableActionException.class, future::actionGet);
-            assertThat(error.getMessage(), containsString("no shard copies found"));
+            expectThrows(NoShardAvailableActionException.class, containsString("no shard copies found"), future::actionGet);
         }
         {
             var targetShards = List.of(targetShard(shard1, node1), targetShard(shard3), targetShard(shard4, node2, node3));
@@ -244,6 +243,34 @@ public class DataNodeRequestSenderTests extends ComputeTestCase {
         assertThat(resp.successfulShards, equalTo(1));
     }
 
+    public void testNonFatalErrorIsRetriedOnAnotherShard() {
+        var targetShards = List.of(targetShard(shard1, node1, node2));
+        Queue<NodeRequest> sent = ConcurrentCollections.newQueue();
+        var response = safeGet(sendRequests(targetShards, false, (node, shardIds, aliasFilters, listener) -> {
+            sent.add(new NodeRequest(node, shardIds, aliasFilters));
+            if (Objects.equals(node1, node)) {
+                runWithDelay(() -> listener.onFailure(new RuntimeException("test request level non fatal failure"), false));
+            } else {
+                runWithDelay(() -> listener.onResponse(new DataNodeComputeResponse(List.of(), Map.of())));
+            }
+        }));
+        assertThat(response.totalShards, equalTo(1));
+        assertThat(response.successfulShards, equalTo(1));
+        assertThat(response.failedShards, equalTo(0));
+        assertThat(sent.size(), equalTo(2));
+    }
+
+    public void testNonFatalFailedOnAllNodes() {
+        var targetShards = List.of(targetShard(shard1, node1, node2));
+        Queue<NodeRequest> sent = ConcurrentCollections.newQueue();
+        var future = sendRequests(targetShards, false, (node, shardIds, aliasFilters, listener) -> {
+            sent.add(new NodeRequest(node, shardIds, aliasFilters));
+            runWithDelay(() -> listener.onFailure(new RuntimeException("test request level non fatal failure"), false));
+        });
+        expectThrows(RuntimeException.class, equalTo("test request level non fatal failure"), future::actionGet);
+        assertThat(sent.size(), equalTo(2));
+    }
+
     static DataNodeRequestSender.TargetShard targetShard(ShardId shardId, DiscoveryNode... nodes) {
         return new DataNodeRequestSender.TargetShard(shardId, new ArrayList<>(Arrays.asList(nodes)), null);
     }