Browse Source

Fail an unpromotable replica shard if its refresh fails (#95049)

In case a refresh for an unpromotable shard fails, we need to fail the search shard(s).

That's an alternative implementation of #94433 where we fail actual replicas unpromotable shards in TransportBroadcastUnpromotableAction instead of the local indexing shard in PostWriteRefresh.
Artem Prigoda 2 years ago
parent
commit
e80ccdeabe

+ 3 - 2
server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportShardRefreshAction.java

@@ -67,7 +67,7 @@ public class TransportShardRefreshAction extends TransportReplicationAction<
             ThreadPool.Names.REFRESH
         );
         // registers the unpromotable version of shard refresh action
-        new TransportUnpromotableShardRefreshAction(clusterService, transportService, actionFilters, indicesService);
+        new TransportUnpromotableShardRefreshAction(clusterService, transportService, shardStateAction, actionFilters, indicesService);
     }
 
     @Override
@@ -114,7 +114,8 @@ public class TransportShardRefreshAction extends TransportReplicationAction<
             assert replicaRequest.primaryRefreshResult.refreshed() : "primary has not refreshed";
             UnpromotableShardRefreshRequest unpromotableReplicaRequest = new UnpromotableShardRefreshRequest(
                 indexShardRoutingTable,
-                replicaRequest.primaryRefreshResult.generation()
+                replicaRequest.primaryRefreshResult.generation(),
+                false
             );
             transportService.sendRequest(
                 transportService.getLocalNode(),

+ 11 - 1
server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportUnpromotableShardRefreshAction.java

@@ -12,6 +12,7 @@ import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.ActionResponse;
 import org.elasticsearch.action.support.ActionFilters;
 import org.elasticsearch.action.support.broadcast.unpromotable.TransportBroadcastUnpromotableAction;
+import org.elasticsearch.cluster.action.shard.ShardStateAction;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.inject.Inject;
 import org.elasticsearch.index.shard.IndexShard;
@@ -30,10 +31,19 @@ public class TransportUnpromotableShardRefreshAction extends TransportBroadcastU
     public TransportUnpromotableShardRefreshAction(
         ClusterService clusterService,
         TransportService transportService,
+        ShardStateAction shardStateAction,
         ActionFilters actionFilters,
         IndicesService indicesService
     ) {
-        super(NAME, clusterService, transportService, actionFilters, UnpromotableShardRefreshRequest::new, ThreadPool.Names.REFRESH);
+        super(
+            NAME,
+            clusterService,
+            transportService,
+            shardStateAction,
+            actionFilters,
+            UnpromotableShardRefreshRequest::new,
+            ThreadPool.Names.REFRESH
+        );
         this.indicesService = indicesService;
     }
 

+ 6 - 2
server/src/main/java/org/elasticsearch/action/admin/indices/refresh/UnpromotableShardRefreshRequest.java

@@ -23,8 +23,12 @@ public class UnpromotableShardRefreshRequest extends BroadcastUnpromotableReques
 
     private final long segmentGeneration;
 
-    public UnpromotableShardRefreshRequest(IndexShardRoutingTable indexShardRoutingTable, long segmentGeneration) {
-        super(indexShardRoutingTable);
+    public UnpromotableShardRefreshRequest(
+        IndexShardRoutingTable indexShardRoutingTable,
+        long segmentGeneration,
+        boolean failShardOnError
+    ) {
+        super(indexShardRoutingTable, failShardOnError);
         this.segmentGeneration = segmentGeneration;
     }
 

+ 15 - 0
server/src/main/java/org/elasticsearch/action/support/broadcast/unpromotable/BroadcastUnpromotableRequest.java

@@ -8,6 +8,7 @@
 
 package org.elasticsearch.action.support.broadcast.unpromotable;
 
+import org.elasticsearch.TransportVersion;
 import org.elasticsearch.action.ActionRequest;
 import org.elasticsearch.action.ActionRequestValidationException;
 import org.elasticsearch.action.IndicesRequest;
@@ -38,18 +39,25 @@ public class BroadcastUnpromotableRequest extends ActionRequest implements Indic
 
     protected final ShardId shardId;
     protected final String[] indices;
+    protected final boolean failShardOnError;
 
     public BroadcastUnpromotableRequest(StreamInput in) throws IOException {
         super(in);
         indexShardRoutingTable = null;
         shardId = new ShardId(in);
         indices = new String[] { shardId.getIndex().getName() };
+        failShardOnError = in.getTransportVersion().onOrAfter(TransportVersion.V_8_9_0) && in.readBoolean();
     }
 
     public BroadcastUnpromotableRequest(IndexShardRoutingTable indexShardRoutingTable) {
+        this(indexShardRoutingTable, false);
+    }
+
+    public BroadcastUnpromotableRequest(IndexShardRoutingTable indexShardRoutingTable, boolean failShardOnError) {
         this.indexShardRoutingTable = Objects.requireNonNull(indexShardRoutingTable, "index shard routing table is null");
         this.shardId = indexShardRoutingTable.shardId();
         this.indices = new String[] { this.shardId.getIndex().getName() };
+        this.failShardOnError = failShardOnError;
     }
 
     public ShardId shardId() {
@@ -69,6 +77,9 @@ public class BroadcastUnpromotableRequest extends ActionRequest implements Indic
     public void writeTo(StreamOutput out) throws IOException {
         super.writeTo(out);
         out.writeWriteable(shardId);
+        if (out.getTransportVersion().onOrAfter(TransportVersion.V_8_8_0)) {
+            out.writeBoolean(failShardOnError);
+        }
     }
 
     @Override
@@ -86,6 +97,10 @@ public class BroadcastUnpromotableRequest extends ActionRequest implements Indic
         return indices;
     }
 
+    public boolean failShardOnError() {
+        return failShardOnError;
+    }
+
     @Override
     public IndicesOptions indicesOptions() {
         return strictSingleIndexNoExpandForbidClosed();

+ 34 - 1
server/src/main/java/org/elasticsearch/action/support/broadcast/unpromotable/TransportBroadcastUnpromotableAction.java

@@ -16,9 +16,12 @@ import org.elasticsearch.action.support.ChannelActionListener;
 import org.elasticsearch.action.support.HandledTransportAction;
 import org.elasticsearch.action.support.RefCountingListener;
 import org.elasticsearch.cluster.ClusterState;
+import org.elasticsearch.cluster.action.shard.ShardStateAction;
 import org.elasticsearch.cluster.node.DiscoveryNode;
+import org.elasticsearch.cluster.routing.ShardRouting;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.core.Strings;
 import org.elasticsearch.tasks.Task;
 import org.elasticsearch.transport.TransportChannel;
 import org.elasticsearch.transport.TransportRequestHandler;
@@ -32,6 +35,7 @@ public abstract class TransportBroadcastUnpromotableAction<Request extends Broad
 
     protected final ClusterService clusterService;
     protected final TransportService transportService;
+    protected final ShardStateAction shardStateAction;
 
     protected final String transportUnpromotableAction;
     protected final String executor;
@@ -40,12 +44,14 @@ public abstract class TransportBroadcastUnpromotableAction<Request extends Broad
         String actionName,
         ClusterService clusterService,
         TransportService transportService,
+        ShardStateAction shardStateAction,
         ActionFilters actionFilters,
         Writeable.Reader<Request> requestReader,
         String executor
     ) {
         super(actionName, transportService, actionFilters, requestReader);
         this.clusterService = clusterService;
+        this.shardStateAction = shardStateAction;
         this.transportService = transportService;
         this.transportUnpromotableAction = actionName + "[u]";
         this.executor = executor;
@@ -65,13 +71,16 @@ public abstract class TransportBroadcastUnpromotableAction<Request extends Broad
                 }
                 request.indexShardRoutingTable.unpromotableShards().forEach(shardRouting -> {
                     final DiscoveryNode node = clusterState.nodes().get(shardRouting.currentNodeId());
+                    final ActionListener<TransportResponse.Empty> acquired = listeners.acquire(ignored -> {});
                     transportService.sendRequest(
                         node,
                         transportUnpromotableAction,
                         request,
                         TransportRequestOptions.EMPTY,
                         new ActionListenerResponseHandler<>(
-                            listeners.acquire(ignored -> {}),
+                            request.failShardOnError()
+                                ? acquired.delegateResponse((l, e) -> failShard(shardRouting, clusterState, l, e))
+                                : acquired,
                             (in) -> TransportResponse.Empty.INSTANCE,
                             executor
                         )
@@ -82,6 +91,30 @@ public abstract class TransportBroadcastUnpromotableAction<Request extends Broad
         }
     }
 
+    private void failShard(ShardRouting shardRouting, ClusterState clusterState, ActionListener<TransportResponse.Empty> l, Exception e) {
+        shardStateAction.remoteShardFailed(
+            shardRouting.shardId(),
+            shardRouting.allocationId().getId(),
+            clusterState.metadata().index(shardRouting.getIndexName()).primaryTerm(shardRouting.shardId().getId()),
+            true,
+            "mark unpromotable copy as stale after refresh failure",
+            e,
+            new ActionListener<>() {
+                @Override
+                public void onResponse(Void unused) {
+                    logger.debug("Marked shard {} as failed", shardRouting.shardId());
+                    l.onResponse(TransportResponse.Empty.INSTANCE);
+                }
+
+                @Override
+                public void onFailure(Exception sfe) {
+                    logger.error(Strings.format("Unable to mark shard [%s] as failed", shardRouting.shardId()), sfe);
+                    l.onFailure(e);
+                }
+            }
+        );
+    }
+
     class UnpromotableTransportHandler implements TransportRequestHandler<Request> {
 
         @Override

+ 2 - 1
server/src/main/java/org/elasticsearch/action/support/replication/PostWriteRefresh.java

@@ -138,7 +138,8 @@ public class PostWriteRefresh {
     ) {
         UnpromotableShardRefreshRequest unpromotableReplicaRequest = new UnpromotableShardRefreshRequest(
             indexShard.getReplicationGroup().getRoutingTable(),
-            generation
+            generation,
+            true
         );
         transportService.sendRequest(
             transportService.getLocalNode(),

+ 73 - 16
server/src/test/java/org/elasticsearch/action/support/broadcast/unpromotable/TransportBroadcastUnpromotableActionTests.java

@@ -14,6 +14,7 @@ import org.elasticsearch.action.support.ActionFilters;
 import org.elasticsearch.action.support.ActionTestUtils;
 import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.cluster.ClusterState;
+import org.elasticsearch.cluster.action.shard.ShardStateAction;
 import org.elasticsearch.cluster.routing.IndexShardRoutingTable;
 import org.elasticsearch.cluster.routing.ShardRouting;
 import org.elasticsearch.cluster.routing.ShardRoutingState;
@@ -33,6 +34,7 @@ import org.junit.After;
 import org.junit.AfterClass;
 import org.junit.Before;
 import org.junit.BeforeClass;
+import org.mockito.Mockito;
 
 import java.io.IOException;
 import java.util.Collections;
@@ -51,6 +53,12 @@ import static org.elasticsearch.test.ClusterServiceUtils.setState;
 import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.is;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyBoolean;
+import static org.mockito.ArgumentMatchers.anyLong;
+import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.mock;
 
 public class TransportBroadcastUnpromotableActionTests extends ESTestCase {
 
@@ -59,6 +67,7 @@ public class TransportBroadcastUnpromotableActionTests extends ESTestCase {
     private TransportService transportService;
     private CapturingTransport transport;
     private TestTransportBroadcastUnpromotableAction broadcastUnpromotableAction;
+    private ShardStateAction shardStateAction;
 
     @BeforeClass
     public static void beforeClass() {
@@ -81,7 +90,16 @@ public class TransportBroadcastUnpromotableActionTests extends ESTestCase {
         );
         transportService.start();
         transportService.acceptIncomingRequests();
-        broadcastUnpromotableAction = new TestTransportBroadcastUnpromotableAction();
+
+        shardStateAction = mock(ShardStateAction.class);
+        Mockito.doAnswer(invocation -> {
+            ActionListener<Void> argument = invocation.getArgument(6);
+            argument.onResponse(null);
+            return null;
+        })
+            .when(shardStateAction)
+            .remoteShardFailed(any(ShardId.class), anyString(), anyLong(), anyBoolean(), anyString(), any(Exception.class), any());
+        broadcastUnpromotableAction = new TestTransportBroadcastUnpromotableAction(shardStateAction);
     }
 
     @Override
@@ -99,11 +117,12 @@ public class TransportBroadcastUnpromotableActionTests extends ESTestCase {
 
     private class TestTransportBroadcastUnpromotableAction extends TransportBroadcastUnpromotableAction<TestBroadcastUnpromotableRequest> {
 
-        TestTransportBroadcastUnpromotableAction() {
+        TestTransportBroadcastUnpromotableAction(ShardStateAction shardStateAction) {
             super(
                 "indices:admin/test",
                 TransportBroadcastUnpromotableActionTests.this.clusterService,
                 TransportBroadcastUnpromotableActionTests.this.transportService,
+                shardStateAction,
                 new ActionFilters(Set.of()),
                 TestBroadcastUnpromotableRequest::new,
                 ThreadPool.Names.SAME
@@ -131,6 +150,9 @@ public class TransportBroadcastUnpromotableActionTests extends ESTestCase {
             super(indexShardRoutingTable);
         }
 
+        TestBroadcastUnpromotableRequest(IndexShardRoutingTable indexShardRoutingTable, boolean failShardOnError) {
+            super(indexShardRoutingTable, failShardOnError);
+        }
     }
 
     private static List<ShardRouting.Role> getReplicaRoles(int numPromotableReplicas, int numSearchReplicas) {
@@ -254,7 +276,7 @@ public class TransportBroadcastUnpromotableActionTests extends ESTestCase {
         assertThat(countRequestsForIndex(state, index), is(equalTo(numReachableUnpromotables)));
     }
 
-    public void testInvalidNodes() {
+    public void testInvalidNodes() throws Exception {
         final String index = "test";
         ClusterState state = stateWithAssignedPrimariesAndReplicas(
             new String[] { index },
@@ -284,24 +306,59 @@ public class TransportBroadcastUnpromotableActionTests extends ESTestCase {
 
         PlainActionFuture<ActionResponse.Empty> response = PlainActionFuture.newFuture();
         logger.debug("--> executing for wrong shard routing table: {}", wrongRoutingTable);
+
+        // The request fails if we don't mark shards as stale
         assertThat(
-            expectThrows(
-                NodeNotConnectedException.class,
-                () -> PlainActionFuture.<ActionResponse.Empty, Exception>get(
-                    f -> ActionTestUtils.execute(
-                        broadcastUnpromotableAction,
-                        null,
-                        new TestBroadcastUnpromotableRequest(wrongRoutingTable),
-                        f
-                    ),
-                    10,
-                    TimeUnit.SECONDS
-                )
-            ).toString(),
+            expectThrows(NodeNotConnectedException.class, () -> brodcastUnpromotableRequest(wrongRoutingTable, false)).toString(),
+            containsString("discovery node must not be null")
+        );
+        Mockito.verifyNoInteractions(shardStateAction);
+
+        // We were able to mark shards as stale, so the request finishes successfully
+        assertThat(brodcastUnpromotableRequest(wrongRoutingTable, true), equalTo(ActionResponse.Empty.INSTANCE));
+        for (var shardRouting : wrongRoutingTable.unpromotableShards()) {
+            Mockito.verify(shardStateAction)
+                .remoteShardFailed(
+                    eq(shardRouting.shardId()),
+                    eq(shardRouting.allocationId().getId()),
+                    eq(state.metadata().index(index).primaryTerm(shardRouting.shardId().getId())),
+                    eq(true),
+                    eq("mark unpromotable copy as stale after refresh failure"),
+                    any(Exception.class),
+                    any()
+                );
+        }
+
+        Mockito.reset(shardStateAction);
+        // If we are unable to mark a shard as stale, then the request fails
+        Mockito.doAnswer(invocation -> {
+            Exception exception = invocation.getArgument(5);
+            ActionListener<Void> argument = invocation.getArgument(6);
+            argument.onFailure(exception);
+            return null;
+        })
+            .when(shardStateAction)
+            .remoteShardFailed(any(ShardId.class), anyString(), anyLong(), anyBoolean(), anyString(), any(Exception.class), any());
+        assertThat(
+            expectThrows(NodeNotConnectedException.class, () -> brodcastUnpromotableRequest(wrongRoutingTable, true)).toString(),
             containsString("discovery node must not be null")
         );
     }
 
+    private ActionResponse brodcastUnpromotableRequest(IndexShardRoutingTable wrongRoutingTable, boolean failShardOnError)
+        throws Exception {
+        return PlainActionFuture.<ActionResponse.Empty, Exception>get(
+            f -> ActionTestUtils.execute(
+                broadcastUnpromotableAction,
+                null,
+                new TestBroadcastUnpromotableRequest(wrongRoutingTable, failShardOnError),
+                f
+            ),
+            10,
+            TimeUnit.SECONDS
+        );
+    }
+
     public void testNullIndexShardRoutingTable() {
         PlainActionFuture<ActionResponse.Empty> response = PlainActionFuture.newFuture();
         IndexShardRoutingTable shardRoutingTable = null;