فهرست منبع

Fix bulk NPE when retrying failure redirect after cluster block (#107598)

This PR fixes a bug in the bulk operation when retrying blocked cluster states before 
executing a failure store write by correctly wrapping the retry runnable to keep it from 
prematurely returning a null response.
James Baiera 1 سال پیش
والد
کامیت
a912cb0371

+ 5 - 0
docs/changelog/107598.yaml

@@ -0,0 +1,5 @@
+pr: 107598
+summary: Fix bulk NPE when retrying failure redirect after cluster block
+area: Data streams
+type: bug
+issues: []

+ 5 - 1
server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java

@@ -161,7 +161,11 @@ final class BulkOperation extends ActionRunnable<BulkResponse> {
         assert failureStoreRedirects.isEmpty() != true : "Attempting to redirect failures, but none were present in the queue";
         final ClusterState clusterState = observer.setAndGetObservedState();
         // If the cluster is blocked at this point, discard the failure store redirects and complete the response with the original failures
-        if (handleBlockExceptions(clusterState, ActionRunnable.run(listener, this::doRedirectFailures), this::discardRedirectsAndFinish)) {
+        if (handleBlockExceptions(
+            clusterState,
+            ActionRunnable.wrap(listener, (l) -> this.doRedirectFailures()),
+            this::discardRedirectsAndFinish
+        )) {
             return;
         }
         Map<ShardId, List<BulkItemRequest>> requestsByShard = drainAndGroupRedirectsByShards(clusterState);

+ 193 - 54
server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java

@@ -38,7 +38,6 @@ import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.util.concurrent.AtomicArray;
 import org.elasticsearch.common.util.concurrent.EsExecutors;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
-import org.elasticsearch.core.CheckedFunction;
 import org.elasticsearch.index.IndexNotFoundException;
 import org.elasticsearch.index.IndexVersion;
 import org.elasticsearch.index.mapper.MapperException;
@@ -49,6 +48,7 @@ import org.elasticsearch.tasks.Task;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.client.NoOpNodeClient;
 import org.elasticsearch.threadpool.TestThreadPool;
+import org.elasticsearch.threadpool.ThreadPool;
 import org.junit.After;
 import org.junit.Assume;
 import org.junit.Before;
@@ -58,8 +58,12 @@ import java.util.Arrays;
 import java.util.List;
 import java.util.Map;
 import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Executor;
+import java.util.concurrent.ExecutorService;
 import java.util.concurrent.TimeUnit;
+import java.util.function.BiConsumer;
 import java.util.function.Supplier;
 
 import static org.hamcrest.CoreMatchers.equalTo;
@@ -168,10 +172,7 @@ public class BulkOperationTests extends ESTestCase {
      * If a bulk operation begins and the cluster is experiencing a non-retryable block, the bulk operation should fail
      */
     public void testClusterBlockedFailsBulk() {
-        NodeClient client = getNodeClient((r) -> {
-            fail("Should not have executed shard action on blocked cluster");
-            return null;
-        });
+        NodeClient client = getNodeClient(assertNoClientInteraction());
 
         CompletableFuture<BulkResponse> future = new CompletableFuture<>();
         ActionListener<BulkResponse> listener = ActionListener.wrap(future::complete, future::completeExceptionally);
@@ -196,10 +197,7 @@ public class BulkOperationTests extends ESTestCase {
      * If a bulk operation times out while waiting for cluster blocks to be cleared, it should fail the request.
      */
     public void testTimeoutOnRetryableClusterBlockedFailsBulk() {
-        NodeClient client = getNodeClient((r) -> {
-            fail("Should not have executed shard action on blocked cluster");
-            return null;
-        });
+        NodeClient client = getNodeClient(assertNoClientInteraction());
 
         CompletableFuture<BulkResponse> future = new CompletableFuture<>();
         ActionListener<BulkResponse> listener = ActionListener.wrap(future::complete, future::completeExceptionally);
@@ -234,10 +232,7 @@ public class BulkOperationTests extends ESTestCase {
      * If the cluster service closes while a bulk operation is waiting for cluster blocks to be cleared, it should fail the request.
      */
     public void testNodeClosedOnRetryableClusterBlockedFailsBulk() {
-        NodeClient client = getNodeClient((r) -> {
-            fail("Should not have executed shard action on blocked cluster");
-            return null;
-        });
+        NodeClient client = getNodeClient(assertNoClientInteraction());
 
         CompletableFuture<BulkResponse> future = new CompletableFuture<>();
         ActionListener<BulkResponse> listener = ActionListener.wrap(future::complete, future::completeExceptionally);
@@ -272,7 +267,7 @@ public class BulkOperationTests extends ESTestCase {
         bulkRequest.add(new IndexRequest(indexName).id("1").source(Map.of("key", "val")));
         bulkRequest.add(new IndexRequest(indexName).id("3").source(Map.of("key", "val")));
 
-        NodeClient client = getNodeClient(this::acceptAllShardWrites);
+        NodeClient client = getNodeClient(acceptAllShardWrites());
 
         CompletableFuture<BulkResponse> future = new CompletableFuture<>();
         ActionListener<BulkResponse> listener = ActionListener.wrap(future::complete, future::completeExceptionally);
@@ -293,7 +288,7 @@ public class BulkOperationTests extends ESTestCase {
         bulkRequest.add(new IndexRequest(indexName).id("3").source(Map.of("key", "val")));
 
         NodeClient client = getNodeClient(
-            failingShards(Map.of(new ShardId(indexMetadata.getIndex(), 0), () -> new MapperException("test")))
+            shardSpecificResponse(Map.of(new ShardId(indexMetadata.getIndex(), 0), failWithException(() -> new MapperException("test"))))
         );
 
         CompletableFuture<BulkResponse> future = new CompletableFuture<>();
@@ -320,7 +315,7 @@ public class BulkOperationTests extends ESTestCase {
         bulkRequest.add(new IndexRequest(dataStreamName).id("1").source(Map.of("key", "val")).opType(DocWriteRequest.OpType.CREATE));
         bulkRequest.add(new IndexRequest(dataStreamName).id("3").source(Map.of("key", "val")).opType(DocWriteRequest.OpType.CREATE));
 
-        NodeClient client = getNodeClient(this::acceptAllShardWrites);
+        NodeClient client = getNodeClient(acceptAllShardWrites());
 
         CompletableFuture<BulkResponse> future = new CompletableFuture<>();
         ActionListener<BulkResponse> listener = ActionListener.wrap(future::complete, future::completeExceptionally);
@@ -341,7 +336,7 @@ public class BulkOperationTests extends ESTestCase {
         bulkRequest.add(new IndexRequest(dataStreamName).id("3").source(Map.of("key", "val")).opType(DocWriteRequest.OpType.CREATE));
 
         NodeClient client = getNodeClient(
-            failingShards(Map.of(new ShardId(ds1BackingIndex2.getIndex(), 0), () -> new MapperException("test")))
+            shardSpecificResponse(Map.of(new ShardId(ds1BackingIndex2.getIndex(), 0), failWithException(() -> new MapperException("test"))))
         );
 
         CompletableFuture<BulkResponse> future = new CompletableFuture<>();
@@ -371,7 +366,7 @@ public class BulkOperationTests extends ESTestCase {
         bulkRequest.add(new IndexRequest(fsDataStreamName).id("3").source(Map.of("key", "val")).opType(DocWriteRequest.OpType.CREATE));
 
         NodeClient client = getNodeClient(
-            failingShards(Map.of(new ShardId(ds2BackingIndex1.getIndex(), 0), () -> new MapperException("test")))
+            shardSpecificResponse(Map.of(new ShardId(ds2BackingIndex1.getIndex(), 0), failWithException(() -> new MapperException("test"))))
         );
 
         CompletableFuture<BulkResponse> future = new CompletableFuture<>();
@@ -433,12 +428,12 @@ public class BulkOperationTests extends ESTestCase {
         // Mock client that rejects all shard requests on the first shard in the backing index, and all requests to the only shard of
         // the failure store index.
         NodeClient client = getNodeClient(
-            failingShards(
+            shardSpecificResponse(
                 Map.of(
                     new ShardId(ds2BackingIndex1.getIndex(), 0),
-                    () -> new MapperException("root cause"),
+                    failWithException(() -> new MapperException("root cause")),
                     new ShardId(ds2FailureStore1.getIndex(), 0),
-                    () -> new MapperException("failure store test failure")
+                    failWithException(() -> new MapperException("failure store test failure"))
                 )
             )
         );
@@ -500,6 +495,101 @@ public class BulkOperationTests extends ESTestCase {
         assertThat(failedItem.getFailure().getCause().getSuppressed()[0].getMessage(), is(equalTo("Could not serialize json")));
     }
 
+    /**
+     * A bulk operation to a data stream with a failure store enabled could still succeed if the cluster is experiencing a
+     * retryable block when the redirected documents would be sent to the shard-level action. If the cluster state observer
+     * returns an unblocked cluster, the redirection of failure documents should proceed and not return early.
+     */
+    public void testRetryableBlockAcceptsFailureStoreDocument() throws Exception {
+        Assume.assumeTrue(DataStream.isFailureStoreFeatureFlagEnabled());
+
+        // Requests that go to two separate shards
+        BulkRequest bulkRequest = new BulkRequest();
+        bulkRequest.add(new IndexRequest(fsDataStreamName).id("1").source(Map.of("key", "val")).opType(DocWriteRequest.OpType.CREATE));
+        bulkRequest.add(new IndexRequest(fsDataStreamName).id("3").source(Map.of("key", "val")).opType(DocWriteRequest.OpType.CREATE));
+
+        // We want to make sure that going async during the write operation won't cause correctness
+        // issues, so use a real executor for the test
+        ExecutorService writeExecutor = threadPool.executor(ThreadPool.Names.WRITE);
+
+        // Create a pair of countdown latches to synchronize our test code and the write operation we're testing:
+        // One to notify the test that the write operation has been reached, and one for the test to signal that
+        // the write operation should proceed
+        CountDownLatch readyToPerformFailureStoreWrite = new CountDownLatch(1);
+        CountDownLatch beginFailureStoreWrite = new CountDownLatch(1);
+
+        // A mock client that:
+        // 1) Rejects an entire shard level request for the backing index and
+        // 2) When the followup write is submitted for the failure store, will go async and wait until the above latch is counted down
+        // before accepting the request.
+        NodeClient client = getNodeClient(
+            shardSpecificResponse(
+                Map.of(
+                    new ShardId(ds2BackingIndex1.getIndex(), 0),
+                    failWithException(() -> new MapperException("root cause")),
+                    new ShardId(ds2FailureStore1.getIndex(), 0),
+                    goAsyncAndWait(writeExecutor, readyToPerformFailureStoreWrite, beginFailureStoreWrite, acceptAllShardWrites())
+                )
+            )
+        );
+
+        // Create a new cluster state that has a retryable cluster block on it
+        ClusterState blockedState = ClusterState.builder(DEFAULT_STATE)
+            .blocks(ClusterBlocks.builder().addGlobalBlock(NoMasterBlockService.NO_MASTER_BLOCK_WRITES).build())
+            .build();
+
+        // Cluster state observer logic:
+        // First time we will return the normal cluster state (before normal writes) which skips any further interactions,
+        // Second time we will return a blocked cluster state (before the redirects) causing us to start observing the cluster
+        // Then, when waiting for next state change, we will emulate the observer receiving an unblocked state to continue the processing
+        // Finally, third time we will return the normal cluster state again since the cluster will be "unblocked" after waiting
+        ClusterStateObserver observer = mock(ClusterStateObserver.class);
+        when(observer.setAndGetObservedState()).thenReturn(DEFAULT_STATE).thenReturn(blockedState).thenReturn(DEFAULT_STATE);
+        when(observer.isTimedOut()).thenReturn(false);
+        doAnswer(invocation -> {
+            ClusterStateObserver.Listener l = invocation.getArgument(0);
+            l.onNewClusterState(DEFAULT_STATE);
+            return null;
+        }).when(observer).waitForNextChange(any());
+
+        CompletableFuture<BulkResponse> future = new CompletableFuture<>();
+        ActionListener<BulkResponse> listener = ActionListener.notifyOnce(
+            ActionListener.wrap(future::complete, future::completeExceptionally)
+        );
+
+        newBulkOperation(client, bulkRequest, DEFAULT_STATE, observer, listener).run();
+
+        // The operation will attempt to write the documents in the request, receive a failure, wait for a stable cluster state, and then
+        // redirect the failed documents to the failure store. Wait for that failure store write to start:
+        if (readyToPerformFailureStoreWrite.await(30, TimeUnit.SECONDS) == false) {
+            // we're going to fail the test, but be a good citizen and unblock the other thread first
+            beginFailureStoreWrite.countDown();
+            fail("timed out waiting for failure store write operation to begin");
+        }
+
+        // Check to make sure there is no response yet
+        if (future.isDone()) {
+            // we're going to fail the test, but be a good citizen and unblock the other thread first
+            beginFailureStoreWrite.countDown();
+            fail("bulk operation completed prematurely");
+        }
+
+        // Operation is still correctly in flight. Allow the write operation to continue
+        beginFailureStoreWrite.countDown();
+
+        // Await final result and verify
+        BulkResponse bulkItemResponses = future.get();
+        assertThat(bulkItemResponses.hasFailures(), is(false));
+        BulkItemResponse failedItem = Arrays.stream(bulkItemResponses.getItems())
+            .filter(item -> item.getIndex().equals(ds2FailureStore1.getIndex().getName()))
+            .findFirst()
+            .orElseThrow(() -> new AssertionError("Could not find redirected item"));
+        assertThat(failedItem, is(notNullValue()));
+
+        verify(observer, times(1)).isTimedOut();
+        verify(observer, times(1)).waitForNextChange(any());
+    }
+
     /**
      * A bulk operation to a data stream with a failure store enabled may still partially fail if the cluster is experiencing a
      * non-retryable block when the redirected documents would be sent to the shard-level action.
@@ -515,7 +605,9 @@ public class BulkOperationTests extends ESTestCase {
         // Mock client that rejects all shard requests on the first shard in the backing index, and all requests to the only shard of
         // the failure store index.
         NodeClient client = getNodeClient(
-            failingShards(Map.of(new ShardId(ds2BackingIndex1.getIndex(), 0), () -> new MapperException("root cause")))
+            shardSpecificResponse(
+                Map.of(new ShardId(ds2BackingIndex1.getIndex(), 0), failWithException(() -> new MapperException("root cause")))
+            )
         );
 
         // Create a new cluster state that has a non-retryable cluster block on it
@@ -570,7 +662,9 @@ public class BulkOperationTests extends ESTestCase {
         // Mock client that rejects all shard requests on the first shard in the backing index, and all requests to the only shard of
         // the failure store index.
         NodeClient client = getNodeClient(
-            failingShards(Map.of(new ShardId(ds2BackingIndex1.getIndex(), 0), () -> new MapperException("root cause")))
+            shardSpecificResponse(
+                Map.of(new ShardId(ds2BackingIndex1.getIndex(), 0), failWithException(() -> new MapperException("root cause")))
+            )
         );
 
         // Create a new cluster state that has a retryable cluster block on it
@@ -633,7 +727,9 @@ public class BulkOperationTests extends ESTestCase {
         // Mock client that rejects all shard requests on the first shard in the backing index, and all requests to the only shard of
         // the failure store index.
         NodeClient client = getNodeClient(
-            failingShards(Map.of(new ShardId(ds2BackingIndex1.getIndex(), 0), () -> new MapperException("root cause")))
+            shardSpecificResponse(
+                Map.of(new ShardId(ds2BackingIndex1.getIndex(), 0), failWithException(() -> new MapperException("root cause")))
+            )
         );
 
         // Create a new cluster state that has a retryable cluster block on it
@@ -663,29 +759,74 @@ public class BulkOperationTests extends ESTestCase {
         verify(observer, times(1)).waitForNextChange(any());
     }
 
+    /**
+     * Throws an assertion error with the given message if the client operation executes
+     */
+    private static BiConsumer<BulkShardRequest, ActionListener<BulkShardResponse>> assertNoClientInteraction() {
+        return (r, l) -> fail("Should not have executed shard action on blocked cluster");
+    }
+
     /**
      * Accepts all write operations from the given request object when it is encountered in the mock shard bulk action
      */
-    private BulkShardResponse acceptAllShardWrites(BulkShardRequest request) {
-        return new BulkShardResponse(
-            request.shardId(),
-            Arrays.stream(request.items()).map(item -> requestToResponse(request.shardId(), item)).toArray(BulkItemResponse[]::new)
-        );
+    private static BiConsumer<BulkShardRequest, ActionListener<BulkShardResponse>> acceptAllShardWrites() {
+        return (BulkShardRequest request, ActionListener<BulkShardResponse> listener) -> {
+            listener.onResponse(
+                new BulkShardResponse(
+                    request.shardId(),
+                    Arrays.stream(request.items()).map(item -> requestToResponse(request.shardId(), item)).toArray(BulkItemResponse[]::new)
+                )
+            );
+        };
+    }
+
+    /**
+     * When the request is received, it is marked as failed with an exception created by the supplier
+     */
+    private BiConsumer<BulkShardRequest, ActionListener<BulkShardResponse>> failWithException(Supplier<Exception> exceptionSupplier) {
+        return (BulkShardRequest request, ActionListener<BulkShardResponse> listener) -> { listener.onFailure(exceptionSupplier.get()); };
     }
 
     /**
-     * Maps an entire shard id to an exception to throw when it is encountered in the mock shard bulk action
+     * Maps an entire shard id to a consumer when it is encountered in the mock shard bulk action
      */
-    private CheckedFunction<BulkShardRequest, BulkShardResponse, Exception> failingShards(Map<ShardId, Supplier<Exception>> shardsToFail) {
-        return (BulkShardRequest request) -> {
-            if (shardsToFail.containsKey(request.shardId())) {
-                throw shardsToFail.get(request.shardId()).get();
+    private BiConsumer<BulkShardRequest, ActionListener<BulkShardResponse>> shardSpecificResponse(
+        Map<ShardId, BiConsumer<BulkShardRequest, ActionListener<BulkShardResponse>>> shardsToResponders
+    ) {
+        return (BulkShardRequest request, ActionListener<BulkShardResponse> listener) -> {
+            if (shardsToResponders.containsKey(request.shardId())) {
+                shardsToResponders.get(request.shardId()).accept(request, listener);
             } else {
-                return acceptAllShardWrites(request);
+                acceptAllShardWrites().accept(request, listener);
             }
         };
     }
 
+    /**
+     * When the consumer is called, it goes async on the given executor. It will signal that it has reached the operation by counting down
+     * the readyLatch, then wait on the provided continueLatch before executing the delegate consumer.
+     */
+    private BiConsumer<BulkShardRequest, ActionListener<BulkShardResponse>> goAsyncAndWait(
+        Executor executor,
+        CountDownLatch readyLatch,
+        CountDownLatch continueLatch,
+        BiConsumer<BulkShardRequest, ActionListener<BulkShardResponse>> delegate
+    ) {
+        return (final BulkShardRequest request, final ActionListener<BulkShardResponse> listener) -> {
+            executor.execute(() -> {
+                try {
+                    readyLatch.countDown();
+                    if (continueLatch.await(30, TimeUnit.SECONDS) == false) {
+                        listener.onFailure(new RuntimeException("Timeout in client operation waiting for test to signal a continuation"));
+                    }
+                } catch (InterruptedException e) {
+                    listener.onFailure(new RuntimeException(e));
+                }
+                delegate.accept(request, listener);
+            });
+        };
+    }
+
     /**
      * Index name / id tuple
      */
@@ -694,17 +835,19 @@ public class BulkOperationTests extends ESTestCase {
     /**
      * Maps a document to an exception to thrown when it is encountered in the mock shard bulk action
      */
-    private CheckedFunction<BulkShardRequest, BulkShardResponse, Exception> thatFailsDocuments(
+    private BiConsumer<BulkShardRequest, ActionListener<BulkShardResponse>> thatFailsDocuments(
         Map<IndexAndId, Supplier<Exception>> documentsToFail
     ) {
-        return (BulkShardRequest request) -> new BulkShardResponse(request.shardId(), Arrays.stream(request.items()).map(item -> {
-            IndexAndId key = new IndexAndId(request.index(), item.request().id());
-            if (documentsToFail.containsKey(key)) {
-                return requestToFailedResponse(item, documentsToFail.get(key).get());
-            } else {
-                return requestToResponse(request.shardId(), item);
-            }
-        }).toArray(BulkItemResponse[]::new));
+        return (BulkShardRequest request, ActionListener<BulkShardResponse> listener) -> {
+            listener.onResponse(new BulkShardResponse(request.shardId(), Arrays.stream(request.items()).map(item -> {
+                IndexAndId key = new IndexAndId(request.index(), item.request().id());
+                if (documentsToFail.containsKey(key)) {
+                    return requestToFailedResponse(item, documentsToFail.get(key).get());
+                } else {
+                    return requestToResponse(request.shardId(), item);
+                }
+            }).toArray(BulkItemResponse[]::new)));
+        };
     }
 
     /**
@@ -734,7 +877,7 @@ public class BulkOperationTests extends ESTestCase {
      * @param onShardAction Called when TransportShardBulkAction is executed.
      * @return A node client for the test.
      */
-    private NodeClient getNodeClient(CheckedFunction<BulkShardRequest, BulkShardResponse, Exception> onShardAction) {
+    private NodeClient getNodeClient(BiConsumer<BulkShardRequest, ActionListener<BulkShardResponse>> onShardAction) {
         return new NoOpNodeClient(threadPool) {
             @Override
             @SuppressWarnings("unchecked")
@@ -744,17 +887,13 @@ public class BulkOperationTests extends ESTestCase {
                 ActionListener<Response> listener
             ) {
                 if (TransportShardBulkAction.TYPE.equals(action)) {
-                    Response response = null;
-                    Exception exception = null;
+                    ActionListener<BulkShardResponse> notifyOnceListener = ActionListener.notifyOnce(
+                        (ActionListener<BulkShardResponse>) listener
+                    );
                     try {
-                        response = (Response) onShardAction.apply((BulkShardRequest) request);
+                        onShardAction.accept((BulkShardRequest) request, notifyOnceListener);
                     } catch (Exception responseException) {
-                        exception = responseException;
-                    }
-                    if (response != null) {
-                        listener.onResponse(response);
-                    } else {
-                        listener.onFailure(exception);
+                        notifyOnceListener.onFailure(responseException);
                     }
                 } else {
                     fail("Unexpected client call to " + action.name());