Browse Source

Simplify DataNodeRequestSender (#126664)

Ievgen Degtiarenko 6 months ago
parent
commit
b96a2f6c89

+ 0 - 1
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeComputeHandler.java

@@ -199,7 +199,6 @@ final class DataNodeComputeHandler implements TransportRequestHandler<DataNodeRe
                 );
             }
         }.startComputeOnDataNodes(
-            clusterAlias,
             concreteIndices,
             originalIndices,
             PlannerUtils.canMatchFilter(dataNodePlan),

+ 3 - 7
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSender.java

@@ -30,7 +30,6 @@ import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.search.SearchShardTarget;
 import org.elasticsearch.search.internal.AliasFilter;
 import org.elasticsearch.tasks.CancellableTask;
-import org.elasticsearch.tasks.Task;
 import org.elasticsearch.tasks.TaskCancelledException;
 import org.elasticsearch.transport.TransportException;
 import org.elasticsearch.transport.TransportRequestOptions;
@@ -104,7 +103,6 @@ abstract class DataNodeRequestSender {
     }
 
     final void startComputeOnDataNodes(
-        String clusterAlias,
         Set<String> concreteIndices,
         OriginalIndices originalIndices,
         QueryBuilder requestFilter,
@@ -112,7 +110,7 @@ abstract class DataNodeRequestSender {
         ActionListener<ComputeResponse> listener
     ) {
         final long startTimeInNanos = System.nanoTime();
-        searchShards(rootTask, clusterAlias, requestFilter, concreteIndices, originalIndices, ActionListener.wrap(targetShards -> {
+        searchShards(requestFilter, concreteIndices, originalIndices, ActionListener.wrap(targetShards -> {
             try (var computeListener = new ComputeListener(transportService.getThreadPool(), runOnTaskFailure, listener.map(profiles -> {
                 return new ComputeResponse(
                     profiles,
@@ -321,7 +319,7 @@ abstract class DataNodeRequestSender {
     }
 
     /**
-     * Result from {@link #searchShards(Task, String, QueryBuilder, Set, OriginalIndices, ActionListener)} where can_match is performed to
+     * Result from {@link #searchShards(QueryBuilder, Set, OriginalIndices, ActionListener)} where can_match is performed to
      * determine what shards can be skipped and which target nodes are needed for running the ES|QL query
      *
      * @param shards        List of target shards to perform the ES|QL query on
@@ -412,8 +410,6 @@ abstract class DataNodeRequestSender {
      * to a situation where the column structure (i.e., matched data types) differs depending on the query.
      */
     void searchShards(
-        Task parentTask,
-        String clusterAlias,
         QueryBuilder filter,
         Set<String> concreteIndices,
         OriginalIndices originalIndices,
@@ -459,7 +455,7 @@ abstract class DataNodeRequestSender {
             transportService.getLocalNode(),
             EsqlSearchShardsAction.TYPE.name(),
             searchShardsRequest,
-            parentTask,
+            rootTask,
             TransportRequestOptions.EMPTY,
             new ActionListenerResponseHandler<>(searchShardsListener, SearchShardsResponse::new, esqlExecutor)
         );

+ 50 - 46
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSenderTests.java

@@ -23,13 +23,11 @@ import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
 import org.elasticsearch.common.util.concurrent.EsExecutors;
 import org.elasticsearch.compute.test.ComputeTestCase;
-import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.index.Index;
 import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.search.internal.AliasFilter;
 import org.elasticsearch.tasks.CancellableTask;
-import org.elasticsearch.tasks.Task;
 import org.elasticsearch.tasks.TaskId;
 import org.elasticsearch.test.transport.MockTransportService;
 import org.elasticsearch.threadpool.FixedExecutorBuilder;
@@ -41,6 +39,7 @@ import org.junit.Before;
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
@@ -59,8 +58,11 @@ import static org.elasticsearch.cluster.node.DiscoveryNodeRole.DATA_COLD_NODE_RO
 import static org.elasticsearch.cluster.node.DiscoveryNodeRole.DATA_FROZEN_NODE_ROLE;
 import static org.elasticsearch.cluster.node.DiscoveryNodeRole.DATA_HOT_NODE_ROLE;
 import static org.elasticsearch.cluster.node.DiscoveryNodeRole.DATA_WARM_NODE_ROLE;
+import static org.elasticsearch.core.TimeValue.timeValueNanos;
 import static org.elasticsearch.xpack.esql.plugin.DataNodeRequestSender.NodeRequest;
 import static org.hamcrest.Matchers.anyOf;
+import static org.hamcrest.Matchers.contains;
+import static org.hamcrest.Matchers.containsInAnyOrder;
 import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.empty;
 import static org.hamcrest.Matchers.equalTo;
@@ -120,12 +122,12 @@ public class DataNodeRequestSenderTests extends ComputeTestCase {
         );
         Queue<NodeRequest> sent = ConcurrentCollections.newQueue();
         var future = sendRequests(targetShards, randomBoolean(), -1, (node, shardIds, aliasFilters, listener) -> {
-            sent.add(new NodeRequest(node, shardIds, aliasFilters));
+            sent.add(nodeRequest(node, shardIds));
             runWithDelay(() -> listener.onResponse(new DataNodeComputeResponse(List.of(), Map.of())));
         });
         safeGet(future);
         assertThat(sent.size(), equalTo(2));
-        assertThat(groupRequests(sent, 2), equalTo(Map.of(node1, List.of(shard1, shard3), node2, List.of(shard2, shard4))));
+        assertThat(sent, containsInAnyOrder(nodeRequest(node1, shard1, shard3), nodeRequest(node2, shard2, shard4)));
     }
 
     public void testMissingShards() {
@@ -163,7 +165,7 @@ public class DataNodeRequestSenderTests extends ComputeTestCase {
         );
         Queue<NodeRequest> sent = ConcurrentCollections.newQueue();
         var future = sendRequests(targetShards, randomBoolean(), -1, (node, shardIds, aliasFilters, listener) -> {
-            sent.add(new NodeRequest(node, shardIds, aliasFilters));
+            sent.add(nodeRequest(node, shardIds));
             Map<ShardId, Exception> failures = new HashMap<>();
             if (node.equals(node1) && shardIds.contains(shard5)) {
                 failures.put(shard5, new IOException("test"));
@@ -179,10 +181,11 @@ public class DataNodeRequestSenderTests extends ComputeTestCase {
             throw new AssertionError(e);
         }
         assertThat(sent, hasSize(5));
-        var firstRound = groupRequests(sent, 3);
-        assertThat(firstRound, equalTo(Map.of(node1, List.of(shard1, shard5), node4, List.of(shard2), node2, List.of(shard3, shard4))));
-        var secondRound = groupRequests(sent, 2);
-        assertThat(secondRound, equalTo(Map.of(node2, List.of(shard2), node3, List.of(shard5))));
+        assertThat(
+            take(sent, 3),
+            containsInAnyOrder(nodeRequest(node1, shard1, shard5), nodeRequest(node4, shard2), nodeRequest(node2, shard3, shard4))
+        );
+        assertThat(take(sent, 2), containsInAnyOrder(nodeRequest(node2, shard2), nodeRequest(node3, shard5)));
     }
 
     public void testRetryButFail() {
@@ -195,7 +198,7 @@ public class DataNodeRequestSenderTests extends ComputeTestCase {
         );
         Queue<NodeRequest> sent = ConcurrentCollections.newQueue();
         var future = sendRequests(targetShards, false, -1, (node, shardIds, aliasFilters, listener) -> {
-            sent.add(new NodeRequest(node, shardIds, aliasFilters));
+            sent.add(nodeRequest(node, shardIds));
             Map<ShardId, Exception> failures = new HashMap<>();
             if (shardIds.contains(shard5)) {
                 failures.put(shard5, new IOException("test failure for shard5"));
@@ -206,14 +209,12 @@ public class DataNodeRequestSenderTests extends ComputeTestCase {
         assertNotNull(ExceptionsHelper.unwrap(error, IOException.class));
         // {node-1, node-2, node-4}, {node-3}, {node-2}
         assertThat(sent.size(), equalTo(5));
-        var firstRound = groupRequests(sent, 3);
-        assertThat(firstRound, equalTo(Map.of(node1, List.of(shard1, shard5), node2, List.of(shard3, shard4), node4, List.of(shard2))));
-        NodeRequest fourth = sent.remove();
-        assertThat(fourth.node(), equalTo(node3));
-        assertThat(fourth.shardIds(), equalTo(List.of(shard5)));
-        NodeRequest fifth = sent.remove();
-        assertThat(fifth.node(), equalTo(node2));
-        assertThat(fifth.shardIds(), equalTo(List.of(shard5)));
+        assertThat(
+            take(sent, 3),
+            containsInAnyOrder(nodeRequest(node1, shard1, shard5), nodeRequest(node2, shard3, shard4), nodeRequest(node4, shard2))
+        );
+        assertThat(take(sent, 1), containsInAnyOrder(nodeRequest(node3, shard5)));
+        assertThat(take(sent, 1), containsInAnyOrder(nodeRequest(node2, shard5)));
     }
 
     public void testDoNotRetryOnRequestLevelFailure() {
@@ -221,7 +222,7 @@ public class DataNodeRequestSenderTests extends ComputeTestCase {
         Queue<NodeRequest> sent = ConcurrentCollections.newQueue();
         AtomicBoolean failed = new AtomicBoolean();
         var future = sendRequests(targetShards, false, -1, (node, shardIds, aliasFilters, listener) -> {
-            sent.add(new NodeRequest(node, shardIds, aliasFilters));
+            sent.add(nodeRequest(node, shardIds));
             if (node1.equals(node) && failed.compareAndSet(false, true)) {
                 runWithDelay(() -> listener.onFailure(new IOException("test request level failure"), true));
             } else {
@@ -232,8 +233,7 @@ public class DataNodeRequestSenderTests extends ComputeTestCase {
         assertNotNull(ExceptionsHelper.unwrap(exception, IOException.class));
         // one round: {node-1, node-2}
         assertThat(sent.size(), equalTo(2));
-        var firstRound = groupRequests(sent, 2);
-        assertThat(firstRound, equalTo(Map.of(node1, List.of(shard1, shard3), node2, List.of(shard2))));
+        assertThat(sent, containsInAnyOrder(nodeRequest(node1, shard1, shard3), nodeRequest(node2, shard2)));
     }
 
     public void testAllowPartialResults() {
@@ -241,28 +241,27 @@ public class DataNodeRequestSenderTests extends ComputeTestCase {
         Queue<NodeRequest> sent = ConcurrentCollections.newQueue();
         AtomicBoolean failed = new AtomicBoolean();
         var future = sendRequests(targetShards, true, -1, (node, shardIds, aliasFilters, listener) -> {
-            sent.add(new NodeRequest(node, shardIds, aliasFilters));
+            sent.add(nodeRequest(node, shardIds));
             if (node1.equals(node) && failed.compareAndSet(false, true)) {
                 runWithDelay(() -> listener.onFailure(new IOException("test request level failure"), true));
             } else {
                 runWithDelay(() -> listener.onResponse(new DataNodeComputeResponse(List.of(), Map.of())));
             }
         });
-        ComputeResponse resp = safeGet(future);
+        var response = safeGet(future);
+        assertThat(response.totalShards, equalTo(3));
+        assertThat(response.failedShards, equalTo(2));
+        assertThat(response.successfulShards, equalTo(1));
         // one round: {node-1, node-2}
         assertThat(sent.size(), equalTo(2));
-        var firstRound = groupRequests(sent, 2);
-        assertThat(firstRound, equalTo(Map.of(node1, List.of(shard1, shard3), node2, List.of(shard2))));
-        assertThat(resp.totalShards, equalTo(3));
-        assertThat(resp.failedShards, equalTo(2));
-        assertThat(resp.successfulShards, equalTo(1));
+        assertThat(sent, containsInAnyOrder(nodeRequest(node1, shard1, shard3), nodeRequest(node2, shard2)));
     }
 
     public void testNonFatalErrorIsRetriedOnAnotherShard() {
         var targetShards = List.of(targetShard(shard1, node1, node2));
         var sent = ConcurrentCollections.<NodeRequest>newQueue();
         var response = safeGet(sendRequests(targetShards, false, -1, (node, shardIds, aliasFilters, listener) -> {
-            sent.add(new NodeRequest(node, shardIds, aliasFilters));
+            sent.add(nodeRequest(node, shardIds));
             if (Objects.equals(node1, node)) {
                 runWithDelay(() -> listener.onFailure(new RuntimeException("test request level non fatal failure"), false));
             } else {
@@ -279,7 +278,7 @@ public class DataNodeRequestSenderTests extends ComputeTestCase {
         var targetShards = List.of(targetShard(shard1, node1, node2));
         var sent = ConcurrentCollections.<NodeRequest>newQueue();
         var future = sendRequests(targetShards, false, -1, (node, shardIds, aliasFilters, listener) -> {
-            sent.add(new NodeRequest(node, shardIds, aliasFilters));
+            sent.add(nodeRequest(node, shardIds));
             runWithDelay(() -> listener.onFailure(new RuntimeException("test request level non fatal failure"), false));
         });
         expectThrows(RuntimeException.class, equalTo("test request level non fatal failure"), future::actionGet);
@@ -290,7 +289,7 @@ public class DataNodeRequestSenderTests extends ComputeTestCase {
         var targetShards = List.of(targetShard(shard1, node1, node2));
         var sent = ConcurrentCollections.<NodeRequest>newQueue();
         var future = sendRequests(targetShards, false, -1, (node, shardIds, aliasFilters, listener) -> {
-            sent.add(new NodeRequest(node, shardIds, aliasFilters));
+            sent.add(nodeRequest(node, shardIds));
             runWithDelay(() -> listener.onFailure(new CircuitBreakingException("cbe", randomFrom(Durability.values())), false));
         });
         expectThrows(CircuitBreakingException.class, equalTo("cbe"), future::actionGet);
@@ -321,7 +320,7 @@ public class DataNodeRequestSenderTests extends ComputeTestCase {
                 }
             }
 
-            sent.add(new NodeRequest(node, shardIds, aliasFilters));
+            sent.add(nodeRequest(node, shardIds));
             runWithDelay(() -> {
                 concurrentRequests.decrementAndGet();
                 listener.onResponse(new DataNodeComputeResponse(List.of(), Map.of()));
@@ -364,7 +363,7 @@ public class DataNodeRequestSenderTests extends ComputeTestCase {
 
         var sent = ConcurrentCollections.<NodeRequest>newQueue();
         var response = safeGet(sendRequests(targetShards, randomBoolean(), 1, (node, shardIds, aliasFilters, listener) -> {
-            sent.add(new NodeRequest(node, shardIds, aliasFilters));
+            sent.add(nodeRequest(node, shardIds));
             runWithDelay(() -> {
                 if (Objects.equals(node.getId(), node1.getId()) && shardIds.equals(List.of(shard1))) {
                     listener.onFailure(new RuntimeException("test request level non fatal failure"), false);
@@ -406,29 +405,38 @@ public class DataNodeRequestSenderTests extends ComputeTestCase {
         );
         var sent = ConcurrentCollections.<NodeRequest>newQueue();
         safeGet(sendRequests(targetShards, randomBoolean(), -1, (node, shardIds, aliasFilters, listener) -> {
-            sent.add(new NodeRequest(node, shardIds, aliasFilters));
+            sent.add(nodeRequest(node, shardIds));
             runWithDelay(() -> listener.onResponse(new DataNodeComputeResponse(List.of(), Map.of())));
         }));
-        assertThat(groupRequests(sent, 1), equalTo(Map.of(node1, List.of(shard1))));
-        assertThat(groupRequests(sent, 1), anyOf(equalTo(Map.of(node2, List.of(shard2))), equalTo(Map.of(warmNode2, List.of(shard2)))));
+        assertThat(take(sent, 1), containsInAnyOrder(nodeRequest(node1, shard1)));
+        assertThat(take(sent, 1), anyOf(contains(nodeRequest(node2, shard2)), contains(nodeRequest(warmNode2, shard2))));
     }
 
     static DataNodeRequestSender.TargetShard targetShard(ShardId shardId, DiscoveryNode... nodes) {
         return new DataNodeRequestSender.TargetShard(shardId, new ArrayList<>(Arrays.asList(nodes)), null);
     }
 
-    static Map<DiscoveryNode, List<ShardId>> groupRequests(Queue<NodeRequest> sent, int limit) {
-        Map<DiscoveryNode, List<ShardId>> map = new HashMap<>();
+    static DataNodeRequestSender.NodeRequest nodeRequest(DiscoveryNode node, ShardId... shardIds) {
+        return nodeRequest(node, Arrays.asList(shardIds));
+    }
+
+    static DataNodeRequestSender.NodeRequest nodeRequest(DiscoveryNode node, List<ShardId> shardIds) {
+        var copy = new ArrayList<>(shardIds);
+        Collections.sort(copy);
+        return new NodeRequest(node, copy, Map.of());
+    }
+
+    static <T> Collection<T> take(Queue<T> queue, int limit) {
+        var result = new ArrayList<T>(limit);
         for (int i = 0; i < limit; i++) {
-            NodeRequest r = sent.remove();
-            assertNull(map.put(r.node(), r.shardIds().stream().sorted().toList()));
+            result.add(queue.remove());
         }
-        return map;
+        return result;
     }
 
     void runWithDelay(Runnable runnable) {
         if (randomBoolean()) {
-            threadPool.schedule(runnable, TimeValue.timeValueNanos(between(0, 5000)), executor);
+            threadPool.schedule(runnable, timeValueNanos(between(0, 5000)), executor);
         } else {
             executor.execute(runnable);
         }
@@ -465,8 +473,6 @@ public class DataNodeRequestSenderTests extends ComputeTestCase {
         ) {
             @Override
             void searchShards(
-                Task parentTask,
-                String clusterAlias,
                 QueryBuilder filter,
                 Set<String> concreteIndices,
                 OriginalIndices originalIndices,
@@ -477,7 +483,6 @@ public class DataNodeRequestSenderTests extends ComputeTestCase {
                     shards.size(),
                     0
                 );
-                assertSame(parentTask, task);
                 runWithDelay(() -> listener.onResponse(targetShards));
             }
 
@@ -492,7 +497,6 @@ public class DataNodeRequestSenderTests extends ComputeTestCase {
             }
         };
         requestSender.startComputeOnDataNodes(
-            "",
             Set.of(randomAlphaOfLength(10)),
             new OriginalIndices(new String[0], SearchRequest.DEFAULT_INDICES_OPTIONS),
             null,