Explorar o código

Merge pull request #15748 from jasontedor/shard-failure-no-master-retry

Wait for new master when failing shard

Relates #14252
Jason Tedor %!s(int64=9) %!d(string=hai) anos
pai
achega
69b21feb3b

+ 23 - 41
core/src/main/java/org/elasticsearch/action/support/replication/TransportReplicationAction.java

@@ -64,7 +64,6 @@ import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.BaseTransportResponseHandler;
 import org.elasticsearch.transport.ConnectTransportException;
 import org.elasticsearch.transport.EmptyTransportResponseHandler;
-import org.elasticsearch.transport.ReceiveTimeoutTransportException;
 import org.elasticsearch.transport.TransportChannel;
 import org.elasticsearch.transport.TransportChannelResponseHandler;
 import org.elasticsearch.transport.TransportException;
@@ -76,6 +75,7 @@ import org.elasticsearch.transport.TransportService;
 import java.io.IOException;
 import java.util.Collections;
 import java.util.List;
+import java.util.Locale;
 import java.util.Map;
 import java.util.concurrent.ConcurrentMap;
 import java.util.concurrent.atomic.AtomicBoolean;
@@ -92,8 +92,6 @@ import java.util.function.Supplier;
  */
 public abstract class TransportReplicationAction<Request extends ReplicationRequest, ReplicaRequest extends ReplicationRequest, Response extends ReplicationResponse> extends TransportAction<Request, Response> {
 
-    public static final String SHARD_FAILURE_TIMEOUT = "action.support.replication.shard.failure_timeout";
-
     protected final TransportService transportService;
     protected final ClusterService clusterService;
     protected final IndicesService indicesService;
@@ -101,7 +99,6 @@ public abstract class TransportReplicationAction<Request extends ReplicationRequ
     protected final WriteConsistencyLevel defaultWriteConsistencyLevel;
     protected final TransportRequestOptions transportOptions;
     protected final MappingUpdatedAction mappingUpdatedAction;
-    private final TimeValue shardFailedTimeout;
 
     final String transportReplicaAction;
     final String transportPrimaryAction;
@@ -133,8 +130,6 @@ public abstract class TransportReplicationAction<Request extends ReplicationRequ
         this.transportOptions = transportOptions();
 
         this.defaultWriteConsistencyLevel = WriteConsistencyLevel.fromString(settings.get("action.write_consistency", "quorum"));
-        // TODO: set a default timeout
-        shardFailedTimeout = settings.getAsTime(SHARD_FAILURE_TIMEOUT, null);
     }
 
     @Override
@@ -608,7 +603,7 @@ public abstract class TransportReplicationAction<Request extends ReplicationRequ
                 if (logger.isTraceEnabled()) {
                     logger.trace("action [{}] completed on shard [{}] for request [{}] with cluster state version [{}]", transportPrimaryAction, shardId, request, state.version());
                 }
-                replicationPhase = new ReplicationPhase(primaryResponse.v2(), primaryResponse.v1(), shardId, channel, indexShardReference, shardFailedTimeout);
+                replicationPhase = new ReplicationPhase(primaryResponse.v2(), primaryResponse.v1(), shardId, channel, indexShardReference);
             } catch (Throwable e) {
                 if (ExceptionsHelper.status(e) == RestStatus.CONFLICT) {
                     if (logger.isTraceEnabled()) {
@@ -732,15 +727,13 @@ public abstract class TransportReplicationAction<Request extends ReplicationRequ
         private final AtomicInteger pending;
         private final int totalShards;
         private final Releasable indexShardReference;
-        private final TimeValue shardFailedTimeout;
 
         public ReplicationPhase(ReplicaRequest replicaRequest, Response finalResponse, ShardId shardId,
-                                TransportChannel channel, Releasable indexShardReference, TimeValue shardFailedTimeout) {
+                                TransportChannel channel, Releasable indexShardReference) {
             this.replicaRequest = replicaRequest;
             this.channel = channel;
             this.finalResponse = finalResponse;
             this.indexShardReference = indexShardReference;
-            this.shardFailedTimeout = shardFailedTimeout;
             this.shardId = shardId;
 
             // we have to get a new state after successfully indexing into the primary in order to honour recovery semantics.
@@ -882,15 +875,32 @@ public abstract class TransportReplicationAction<Request extends ReplicationRequ
                             if (ignoreReplicaException(exp)) {
                                 onReplicaFailure(nodeId, exp);
                             } else {
-                                logger.warn("{} failed to perform {} on node {}", exp, shardId, transportReplicaAction, node);
-                                shardStateAction.shardFailed(clusterService.state(), shard, indexUUID, "failed to perform " + transportReplicaAction + " on replica on node " + node, exp, shardFailedTimeout, new ReplicationFailedShardStateListener(nodeId, exp));
+                                String message = String.format(Locale.ROOT, "failed to perform %s on replica on node %s", transportReplicaAction, node);
+                                logger.warn("{} {}", exp, shardId, message);
+                                shardStateAction.shardFailed(
+                                    shard,
+                                    indexUUID,
+                                    message,
+                                    exp,
+                                    new ShardStateAction.Listener() {
+                                        @Override
+                                        public void onSuccess() {
+                                            onReplicaFailure(nodeId, exp);
+                                        }
+
+                                        @Override
+                                        public void onShardFailedFailure(Exception e) {
+                                            // TODO: handle catastrophic non-channel failures
+                                            onReplicaFailure(nodeId, exp);
+                                        }
+                                    }
+                                );
                             }
                         }
                     }
             );
         }
 
-
         void onReplicaFailure(String nodeId, @Nullable Throwable e) {
             // Only version conflict should be ignored from being put into the _shards header?
             if (e != null && ignoreReplicaException(e) == false) {
@@ -955,34 +965,6 @@ public abstract class TransportReplicationAction<Request extends ReplicationRequ
                 }
             }
         }
-
-        public class ReplicationFailedShardStateListener implements ShardStateAction.Listener {
-            private final String nodeId;
-            private Throwable failure;
-
-            public ReplicationFailedShardStateListener(String nodeId, Throwable failure) {
-                this.nodeId = nodeId;
-                this.failure = failure;
-            }
-
-            @Override
-            public void onSuccess() {
-                onReplicaFailure(nodeId, failure);
-            }
-
-            @Override
-            public void onShardFailedNoMaster() {
-                onReplicaFailure(nodeId, failure);
-            }
-
-            @Override
-            public void onShardFailedFailure(DiscoveryNode master, TransportException e) {
-                if (e instanceof ReceiveTimeoutTransportException) {
-                    logger.trace("timeout sending shard failure to master [{}]", e, master);
-                }
-                onReplicaFailure(nodeId, failure);
-            }
-        }
     }
 
     /**

+ 89 - 30
core/src/main/java/org/elasticsearch/cluster/action/shard/ShardStateAction.java

@@ -22,9 +22,11 @@ package org.elasticsearch.cluster.action.shard;
 import org.elasticsearch.ExceptionsHelper;
 import org.elasticsearch.cluster.ClusterService;
 import org.elasticsearch.cluster.ClusterState;
+import org.elasticsearch.cluster.ClusterStateObserver;
 import org.elasticsearch.cluster.ClusterStateTaskConfig;
 import org.elasticsearch.cluster.ClusterStateTaskExecutor;
 import org.elasticsearch.cluster.ClusterStateTaskListener;
+import org.elasticsearch.cluster.MasterNodeChangePredicate;
 import org.elasticsearch.cluster.NotMasterException;
 import org.elasticsearch.cluster.metadata.IndexMetaData;
 import org.elasticsearch.cluster.node.DiscoveryNode;
@@ -42,73 +44,118 @@ import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.logging.ESLogger;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.unit.TimeValue;
+import org.elasticsearch.discovery.Discovery;
+import org.elasticsearch.node.NodeClosedException;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.EmptyTransportResponseHandler;
+import org.elasticsearch.transport.NodeDisconnectedException;
 import org.elasticsearch.transport.TransportChannel;
 import org.elasticsearch.transport.TransportException;
 import org.elasticsearch.transport.TransportRequest;
 import org.elasticsearch.transport.TransportRequestHandler;
-import org.elasticsearch.transport.TransportRequestOptions;
 import org.elasticsearch.transport.TransportResponse;
 import org.elasticsearch.transport.TransportService;
 
 import java.io.IOException;
 import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Locale;
+import java.util.Set;
 
 import static org.elasticsearch.cluster.routing.ShardRouting.readShardRoutingEntry;
 
 public class ShardStateAction extends AbstractComponent {
+
     public static final String SHARD_STARTED_ACTION_NAME = "internal:cluster/shard/started";
     public static final String SHARD_FAILED_ACTION_NAME = "internal:cluster/shard/failure";
 
     private final TransportService transportService;
+    private final ClusterService clusterService;
 
     @Inject
     public ShardStateAction(Settings settings, ClusterService clusterService, TransportService transportService,
                             AllocationService allocationService, RoutingService routingService) {
         super(settings);
         this.transportService = transportService;
+        this.clusterService = clusterService;
 
         transportService.registerRequestHandler(SHARD_STARTED_ACTION_NAME, ShardRoutingEntry::new, ThreadPool.Names.SAME, new ShardStartedTransportHandler(clusterService, new ShardStartedClusterStateTaskExecutor(allocationService, logger), logger));
         transportService.registerRequestHandler(SHARD_FAILED_ACTION_NAME, ShardRoutingEntry::new, ThreadPool.Names.SAME, new ShardFailedTransportHandler(clusterService, new ShardFailedClusterStateTaskExecutor(allocationService, routingService, logger), logger));
     }
 
-    public void shardFailed(final ClusterState clusterState, final ShardRouting shardRouting, final String indexUUID, final String message, @Nullable final Throwable failure, Listener listener) {
-        shardFailed(clusterState, shardRouting, indexUUID, message, failure, null, listener);
+    public void shardFailed(final ShardRouting shardRouting, final String indexUUID, final String message, @Nullable final Throwable failure, Listener listener) {
+        ClusterStateObserver observer = new ClusterStateObserver(clusterService, null, logger);
+        ShardRoutingEntry shardRoutingEntry = new ShardRoutingEntry(shardRouting, indexUUID, message, failure);
+        sendShardFailed(observer, shardRoutingEntry, listener);
     }
 
-    public void resendShardFailed(final ClusterState clusterState, final ShardRouting shardRouting, final String indexUUID, final String message, @Nullable final Throwable failure, Listener listener) {
+    public void resendShardFailed(final ShardRouting shardRouting, final String indexUUID, final String message, @Nullable final Throwable failure, Listener listener) {
         logger.trace("{} re-sending failed shard [{}], index UUID [{}], reason [{}]", shardRouting.shardId(), failure, shardRouting, indexUUID, message);
-        shardFailed(clusterState, shardRouting, indexUUID, message, failure, listener);
+        shardFailed(shardRouting, indexUUID, message, failure, listener);
     }
 
-    public void shardFailed(final ClusterState clusterState, final ShardRouting shardRouting, final String indexUUID, final String message, @Nullable final Throwable failure, TimeValue timeout, Listener listener) {
-        DiscoveryNode masterNode = clusterState.nodes().masterNode();
+    private void sendShardFailed(ClusterStateObserver observer, ShardRoutingEntry shardRoutingEntry, Listener listener) {
+        DiscoveryNode masterNode = observer.observedState().nodes().masterNode();
         if (masterNode == null) {
-            logger.warn("{} no master known to fail shard [{}]", shardRouting.shardId(), shardRouting);
-            listener.onShardFailedNoMaster();
-            return;
-        }
-        ShardRoutingEntry shardRoutingEntry = new ShardRoutingEntry(shardRouting, indexUUID, message, failure);
-        TransportRequestOptions options = TransportRequestOptions.EMPTY;
-        if (timeout != null) {
-            options = TransportRequestOptions.builder().withTimeout(timeout).build();
+            logger.warn("{} no master known to fail shard [{}]", shardRoutingEntry.getShardRouting().shardId(), shardRoutingEntry.getShardRouting());
+            waitForNewMasterAndRetry(observer, shardRoutingEntry, listener);
+        } else {
+            transportService.sendRequest(masterNode,
+                SHARD_FAILED_ACTION_NAME, shardRoutingEntry, new EmptyTransportResponseHandler(ThreadPool.Names.SAME) {
+                    @Override
+                    public void handleResponse(TransportResponse.Empty response) {
+                        listener.onSuccess();
+                    }
+
+                    @Override
+                    public void handleException(TransportException exp) {
+                        assert exp.getCause() != null : exp;
+                        if (isMasterChannelException(exp.getCause())) {
+                            waitForNewMasterAndRetry(observer, shardRoutingEntry, listener);
+                        } else {
+                            logger.warn("{} unexpected failure while sending request to [{}] to fail shard [{}]", exp, shardRoutingEntry.getShardRouting().shardId(), masterNode, shardRoutingEntry);
+                            listener.onShardFailedFailure(exp);
+                        }
+                    }
+                });
         }
-        transportService.sendRequest(masterNode,
-            SHARD_FAILED_ACTION_NAME, shardRoutingEntry, options, new EmptyTransportResponseHandler(ThreadPool.Names.SAME) {
-                @Override
-                public void handleResponse(TransportResponse.Empty response) {
-                    listener.onSuccess();
-                }
+    }
 
-                @Override
-                public void handleException(TransportException exp) {
-                    logger.warn("{} unexpected failure while sending request to [{}] to fail shard [{}]", exp, shardRoutingEntry.shardRouting.shardId(), masterNode, shardRoutingEntry);
-                    listener.onShardFailedFailure(masterNode, exp);
+    private static Set<Class<?>> MASTER_CHANNEL_EXCEPTIONS =
+        new HashSet<>(Arrays.asList(
+            NotMasterException.class,
+            NodeDisconnectedException.class,
+            Discovery.FailedToCommitClusterStateException.class
+        ));
+    private static boolean isMasterChannelException(Throwable cause) {
+        return MASTER_CHANNEL_EXCEPTIONS.contains(cause.getClass());
+    }
+
+    // visible for testing
+    protected void waitForNewMasterAndRetry(ClusterStateObserver observer, ShardRoutingEntry shardRoutingEntry, Listener listener) {
+        observer.waitForNextChange(new ClusterStateObserver.Listener() {
+            @Override
+            public void onNewClusterState(ClusterState state) {
+                if (logger.isTraceEnabled()) {
+                    logger.trace("new cluster state [{}] after waiting for master election to fail shard [{}]", shardRoutingEntry.getShardRouting().shardId(), state.prettyPrint(), shardRoutingEntry);
                 }
-            });
+                sendShardFailed(observer, shardRoutingEntry, listener);
+            }
+
+            @Override
+            public void onClusterServiceClose() {
+                logger.warn("{} node closed while handling failed shard [{}]", shardRoutingEntry.failure, shardRoutingEntry.getShardRouting().getId(), shardRoutingEntry.getShardRouting());
+                listener.onShardFailedFailure(new NodeClosedException(clusterService.localNode()));
+            }
+
+            @Override
+            public void onTimeout(TimeValue timeout) {
+                // we wait indefinitely for a new master
+                assert false;
+            }
+        }, MasterNodeChangePredicate.INSTANCE);
     }
 
     private static class ShardFailedTransportHandler implements TransportRequestHandler<ShardRoutingEntry> {
@@ -334,10 +381,22 @@ public class ShardStateAction extends AbstractComponent {
         default void onSuccess() {
         }
 
-        default void onShardFailedNoMaster() {
-        }
-
-        default void onShardFailedFailure(final DiscoveryNode master, final TransportException e) {
+        /**
+         * Notification for non-channel exceptions that are not handled
+         * by {@link ShardStateAction}.
+         *
+         * The exceptions that are handled by {@link ShardStateAction}
+         * are:
+         *  - {@link NotMasterException}
+         *  - {@link NodeDisconnectedException}
+         *  - {@link Discovery.FailedToCommitClusterStateException}
+         *
+         * Any other exception is communicated to the requester via
+         * this notification.
+         *
+         * @param e the unexpected cause of the failure on the master
+         */
+        default void onShardFailedFailure(final Exception e) {
         }
     }
 }

+ 3 - 3
core/src/main/java/org/elasticsearch/indices/cluster/IndicesClusterStateService.java

@@ -458,7 +458,7 @@ public class IndicesClusterStateService extends AbstractLifecycleComponent<Indic
             if (!indexService.hasShard(shardId) && shardRouting.started()) {
                 if (failedShards.containsKey(shardRouting.shardId())) {
                     if (nodes.masterNode() != null) {
-                        shardStateAction.resendShardFailed(event.state(), shardRouting, indexMetaData.getIndexUUID(),
+                        shardStateAction.resendShardFailed(shardRouting, indexMetaData.getIndexUUID(),
                                 "master " + nodes.masterNode() + " marked shard as started, but shard has previous failed. resending shard failure.", null, SHARD_STATE_ACTION_LISTENER);
                     }
                 } else {
@@ -590,7 +590,7 @@ public class IndicesClusterStateService extends AbstractLifecycleComponent<Indic
         if (!indexService.hasShard(shardId)) {
             if (failedShards.containsKey(shardRouting.shardId())) {
                 if (nodes.masterNode() != null) {
-                    shardStateAction.resendShardFailed(state, shardRouting, indexMetaData.getIndexUUID(),
+                    shardStateAction.resendShardFailed(shardRouting, indexMetaData.getIndexUUID(),
                             "master " + nodes.masterNode() + " marked shard as initializing, but shard is marked as failed, resend shard failure", null, SHARD_STATE_ACTION_LISTENER);
                 }
                 return;
@@ -788,7 +788,7 @@ public class IndicesClusterStateService extends AbstractLifecycleComponent<Indic
         try {
             logger.warn("[{}] marking and sending shard failed due to [{}]", failure, shardRouting.shardId(), message);
             failedShards.put(shardRouting.shardId(), new FailedShard(shardRouting.version()));
-            shardStateAction.shardFailed(clusterService.state(), shardRouting, indexUUID, message, failure, SHARD_STATE_ACTION_LISTENER);
+            shardStateAction.shardFailed(shardRouting, indexUUID, message, failure, SHARD_STATE_ACTION_LISTENER);
         } catch (Throwable e1) {
             logger.warn("[{}][{}] failed to mark shard as failed (because of [{}])", e1, shardRouting.getIndex(), shardRouting.getId(), message);
         }

+ 19 - 2
core/src/test/java/org/elasticsearch/action/support/replication/TransportReplicationActionTests.java

@@ -28,6 +28,7 @@ import org.elasticsearch.action.support.ActionFilters;
 import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.cluster.ClusterService;
 import org.elasticsearch.cluster.ClusterState;
+import org.elasticsearch.cluster.NotMasterException;
 import org.elasticsearch.cluster.action.shard.ShardStateAction;
 import org.elasticsearch.cluster.block.ClusterBlock;
 import org.elasticsearch.cluster.block.ClusterBlockException;
@@ -488,7 +489,7 @@ public class TransportReplicationActionTests extends ESTestCase {
         TransportReplicationAction<Request, Request, Response>.ReplicationPhase replicationPhase =
                 action.new ReplicationPhase(request,
                         new Response(),
-                        request.shardId(), createTransportChannel(listener), reference, null);
+                        request.shardId(), createTransportChannel(listener), reference);
 
         assertThat(replicationPhase.totalShards(), equalTo(totalShards));
         assertThat(replicationPhase.pending(), equalTo(assignedReplicas));
@@ -557,7 +558,23 @@ public class TransportReplicationActionTests extends ESTestCase {
                     // the shard the request was sent to and the shard to be failed should be the same
                     assertEquals(shardRoutingEntry.getShardRouting(), routing);
                     failures.add(shardFailedRequest);
-                    transport.handleResponse(shardFailedRequest.requestId, TransportResponse.Empty.INSTANCE);
+                    if (randomBoolean()) {
+                        // simulate master left and test that the shard failure is retried
+                        int numberOfRetries = randomIntBetween(1, 4);
+                        CapturingTransport.CapturedRequest currentRequest = shardFailedRequest;
+                        for (int retryNumber = 0; retryNumber < numberOfRetries; retryNumber++) {
+                            // force a new cluster state to simulate a new master having been elected
+                            clusterService.setState(ClusterState.builder(clusterService.state()));
+                            transport.handleResponse(currentRequest.requestId, new NotMasterException("shard-failed-test"));
+                            CapturingTransport.CapturedRequest[] retryRequests = transport.getCapturedRequestsAndClear();
+                            assertEquals(1, retryRequests.length);
+                            currentRequest = retryRequests[0];
+                        }
+                        // now simulate that the last retry succeeded
+                        transport.handleResponse(currentRequest.requestId, TransportResponse.Empty.INSTANCE);
+                    } else {
+                        transport.handleResponse(shardFailedRequest.requestId, TransportResponse.Empty.INSTANCE);
+                    }
                 }
             } else {
                 successful++;

+ 182 - 44
core/src/test/java/org/elasticsearch/cluster/action/shard/ShardStateActionTests.java

@@ -20,41 +20,78 @@
 package org.elasticsearch.cluster.action.shard;
 
 import org.apache.lucene.index.CorruptIndexException;
+import org.elasticsearch.cluster.ClusterService;
 import org.elasticsearch.cluster.ClusterState;
-import org.elasticsearch.cluster.node.DiscoveryNode;
+import org.elasticsearch.cluster.ClusterStateObserver;
+import org.elasticsearch.cluster.NotMasterException;
 import org.elasticsearch.cluster.node.DiscoveryNodes;
 import org.elasticsearch.cluster.routing.IndexRoutingTable;
+import org.elasticsearch.cluster.routing.RoutingService;
 import org.elasticsearch.cluster.routing.ShardRouting;
 import org.elasticsearch.cluster.routing.ShardsIterator;
+import org.elasticsearch.cluster.routing.allocation.AllocationService;
 import org.elasticsearch.common.settings.Settings;
-import org.elasticsearch.common.unit.TimeValue;
+import org.elasticsearch.discovery.Discovery;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.cluster.TestClusterService;
 import org.elasticsearch.test.transport.CapturingTransport;
 import org.elasticsearch.threadpool.ThreadPool;
-import org.elasticsearch.transport.ReceiveTimeoutTransportException;
+import org.elasticsearch.transport.NodeDisconnectedException;
 import org.elasticsearch.transport.TransportException;
+import org.elasticsearch.transport.TransportResponse;
 import org.elasticsearch.transport.TransportService;
 import org.junit.After;
 import org.junit.AfterClass;
 import org.junit.Before;
 import org.junit.BeforeClass;
 
+import java.util.ArrayList;
+import java.util.List;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.LongConsumer;
 
 import static org.elasticsearch.action.support.replication.ClusterStateCreationUtils.stateWithStartedPrimary;
 import static org.hamcrest.CoreMatchers.equalTo;
+import static org.hamcrest.CoreMatchers.instanceOf;
+import static org.hamcrest.Matchers.is;
 
 public class ShardStateActionTests extends ESTestCase {
     private static ThreadPool THREAD_POOL;
 
-    private ShardStateAction shardStateAction;
+    private TestShardStateAction shardStateAction;
     private CapturingTransport transport;
     private TransportService transportService;
     private TestClusterService clusterService;
 
+    private static class TestShardStateAction extends ShardStateAction {
+        public TestShardStateAction(Settings settings, ClusterService clusterService, TransportService transportService, AllocationService allocationService, RoutingService routingService) {
+            super(settings, clusterService, transportService, allocationService, routingService);
+        }
+
+        private Runnable onBeforeWaitForNewMasterAndRetry;
+
+        public void setOnBeforeWaitForNewMasterAndRetry(Runnable onBeforeWaitForNewMasterAndRetry) {
+            this.onBeforeWaitForNewMasterAndRetry = onBeforeWaitForNewMasterAndRetry;
+        }
+
+        private Runnable onAfterWaitForNewMasterAndRetry;
+
+        public void setOnAfterWaitForNewMasterAndRetry(Runnable onAfterWaitForNewMasterAndRetry) {
+            this.onAfterWaitForNewMasterAndRetry = onAfterWaitForNewMasterAndRetry;
+        }
+
+        @Override
+        protected void waitForNewMasterAndRetry(ClusterStateObserver observer, ShardRoutingEntry shardRoutingEntry, Listener listener) {
+            onBeforeWaitForNewMasterAndRetry.run();
+            super.waitForNewMasterAndRetry(observer, shardRoutingEntry, listener);
+            onAfterWaitForNewMasterAndRetry.run();
+        }
+    }
+
     @BeforeClass
     public static void startThreadPool() {
         THREAD_POOL = new ThreadPool("ShardStateActionTest");
@@ -68,7 +105,9 @@ public class ShardStateActionTests extends ESTestCase {
         clusterService = new TestClusterService(THREAD_POOL);
         transportService = new TransportService(transport, THREAD_POOL);
         transportService.start();
-        shardStateAction = new ShardStateAction(Settings.EMPTY, clusterService, transportService, null, null);
+        shardStateAction = new TestShardStateAction(Settings.EMPTY, clusterService, transportService, null, null);
+        shardStateAction.setOnBeforeWaitForNewMasterAndRetry(() -> {});
+        shardStateAction.setOnAfterWaitForNewMasterAndRetry(() -> {});
     }
 
     @Override
@@ -84,94 +123,165 @@ public class ShardStateActionTests extends ESTestCase {
         THREAD_POOL = null;
     }
 
-    public void testNoMaster() {
+    public void testSuccess() throws InterruptedException {
         final String index = "test";
 
         clusterService.setState(stateWithStartedPrimary(index, true, randomInt(5)));
 
-        DiscoveryNodes.Builder builder = DiscoveryNodes.builder(clusterService.state().nodes());
-        builder.masterNodeId(null);
-        clusterService.setState(ClusterState.builder(clusterService.state()).nodes(builder));
-
         String indexUUID = clusterService.state().metaData().index(index).getIndexUUID();
 
-        AtomicBoolean noMaster = new AtomicBoolean();
-        assert !noMaster.get();
+        AtomicBoolean success = new AtomicBoolean();
+        CountDownLatch latch = new CountDownLatch(1);
 
-        shardStateAction.shardFailed(clusterService.state(), getRandomShardRouting(index), indexUUID, "test", getSimulatedFailure(), new ShardStateAction.Listener() {
+        ShardRouting shardRouting = getRandomShardRouting(index);
+        shardStateAction.shardFailed(shardRouting, indexUUID, "test", getSimulatedFailure(), new ShardStateAction.Listener() {
             @Override
-            public void onShardFailedNoMaster() {
-                noMaster.set(true);
+            public void onSuccess() {
+                success.set(true);
+                latch.countDown();
             }
 
             @Override
-            public void onShardFailedFailure(DiscoveryNode master, TransportException e) {
-
+            public void onShardFailedFailure(Exception e) {
+                success.set(false);
+                latch.countDown();
+                assert false;
             }
         });
 
-        assertTrue(noMaster.get());
+        CapturingTransport.CapturedRequest[] capturedRequests = transport.getCapturedRequestsAndClear();
+        assertEquals(1, capturedRequests.length);
+        // the request is a shard failed request
+        assertThat(capturedRequests[0].request, is(instanceOf(ShardStateAction.ShardRoutingEntry.class)));
+        ShardStateAction.ShardRoutingEntry shardRoutingEntry = (ShardStateAction.ShardRoutingEntry)capturedRequests[0].request;
+        // for the right shard
+        assertEquals(shardRouting, shardRoutingEntry.getShardRouting());
+        // sent to the master
+        assertEquals(clusterService.state().nodes().masterNode().getId(), capturedRequests[0].node.getId());
+
+        transport.handleResponse(capturedRequests[0].requestId, TransportResponse.Empty.INSTANCE);
+
+        latch.await();
+        assertTrue(success.get());
     }
 
-    public void testFailure() {
+    public void testNoMaster() throws InterruptedException {
         final String index = "test";
 
         clusterService.setState(stateWithStartedPrimary(index, true, randomInt(5)));
 
+        DiscoveryNodes.Builder noMasterBuilder = DiscoveryNodes.builder(clusterService.state().nodes());
+        noMasterBuilder.masterNodeId(null);
+        clusterService.setState(ClusterState.builder(clusterService.state()).nodes(noMasterBuilder));
+
         String indexUUID = clusterService.state().metaData().index(index).getIndexUUID();
 
-        AtomicBoolean failure = new AtomicBoolean();
-        assert !failure.get();
+        CountDownLatch latch = new CountDownLatch(1);
+        AtomicInteger retries = new AtomicInteger();
+        AtomicBoolean success = new AtomicBoolean();
 
-        shardStateAction.shardFailed(clusterService.state(), getRandomShardRouting(index), indexUUID, "test", getSimulatedFailure(), new ShardStateAction.Listener() {
-            @Override
-            public void onShardFailedNoMaster() {
+        setUpMasterRetryVerification(1, retries, latch, requestId -> {});
 
+        shardStateAction.shardFailed(getRandomShardRouting(index), indexUUID, "test", getSimulatedFailure(), new ShardStateAction.Listener() {
+            @Override
+            public void onSuccess() {
+                success.set(true);
+                latch.countDown();
             }
 
             @Override
-            public void onShardFailedFailure(DiscoveryNode master, TransportException e) {
-                failure.set(true);
+            public void onShardFailedFailure(Exception e) {
+                success.set(false);
+                latch.countDown();
+                assert false;
             }
         });
 
-        final CapturingTransport.CapturedRequest[] capturedRequests = transport.getCapturedRequestsAndClear();
-        assertThat(capturedRequests.length, equalTo(1));
-        assert !failure.get();
-        transport.handleResponse(capturedRequests[0].requestId, new TransportException("simulated"));
+        latch.await();
 
-        assertTrue(failure.get());
+        assertThat(retries.get(), equalTo(1));
+        assertTrue(success.get());
     }
 
-    public void testTimeout() throws InterruptedException {
+    public void testMasterChannelException() throws InterruptedException {
         final String index = "test";
 
         clusterService.setState(stateWithStartedPrimary(index, true, randomInt(5)));
 
         String indexUUID = clusterService.state().metaData().index(index).getIndexUUID();
 
-        AtomicBoolean progress = new AtomicBoolean();
-        AtomicBoolean timedOut = new AtomicBoolean();
-
-        TimeValue timeout = new TimeValue(1, TimeUnit.MILLISECONDS);
         CountDownLatch latch = new CountDownLatch(1);
-        shardStateAction.shardFailed(clusterService.state(), getRandomShardRouting(index), indexUUID, "test", getSimulatedFailure(), timeout, new ShardStateAction.Listener() {
+        AtomicInteger retries = new AtomicInteger();
+        AtomicBoolean success = new AtomicBoolean();
+        AtomicReference<Exception> exception = new AtomicReference<>();
+
+        LongConsumer retryLoop = requestId -> {
+            List<Exception> possibleExceptions = new ArrayList<>();
+            possibleExceptions.add(new NotMasterException("simulated"));
+            possibleExceptions.add(new NodeDisconnectedException(clusterService.state().nodes().masterNode(), ShardStateAction.SHARD_FAILED_ACTION_NAME));
+            possibleExceptions.add(new Discovery.FailedToCommitClusterStateException("simulated"));
+            transport.handleResponse(requestId, randomFrom(possibleExceptions));
+        };
+
+        final int numberOfRetries = randomIntBetween(1, 256);
+        setUpMasterRetryVerification(numberOfRetries, retries, latch, retryLoop);
+
+        shardStateAction.shardFailed(getRandomShardRouting(index), indexUUID, "test", getSimulatedFailure(), new ShardStateAction.Listener() {
+            @Override
+            public void onSuccess() {
+                success.set(true);
+                latch.countDown();
+            }
+
             @Override
-            public void onShardFailedFailure(DiscoveryNode master, TransportException e) {
-                if (e instanceof ReceiveTimeoutTransportException) {
-                    assertFalse(progress.get());
-                    timedOut.set(true);
-                }
+            public void onShardFailedFailure(Exception e) {
+                success.set(false);
+                exception.set(e);
                 latch.countDown();
+                assert false;
             }
         });
 
+        final CapturingTransport.CapturedRequest[] capturedRequests = transport.getCapturedRequestsAndClear();
+        assertThat(capturedRequests.length, equalTo(1));
+        assertFalse(success.get());
+        assertThat(retries.get(), equalTo(0));
+        retryLoop.accept(capturedRequests[0].requestId);
+
         latch.await();
-        progress.set(true);
-        assertTrue(timedOut.get());
+        assertNull(exception.get());
+        assertThat(retries.get(), equalTo(numberOfRetries));
+        assertTrue(success.get());
+    }
+
+    public void testUnhandledFailure() {
+        final String index = "test";
+
+        clusterService.setState(stateWithStartedPrimary(index, true, randomInt(5)));
+
+        String indexUUID = clusterService.state().metaData().index(index).getIndexUUID();
+
+        AtomicBoolean failure = new AtomicBoolean();
+
+        shardStateAction.shardFailed(getRandomShardRouting(index), indexUUID, "test", getSimulatedFailure(), new ShardStateAction.Listener() {
+            @Override
+            public void onSuccess() {
+                failure.set(false);
+                assert false;
+            }
+
+            @Override
+            public void onShardFailedFailure(Exception e) {
+                failure.set(true);
+            }
+        });
 
         final CapturingTransport.CapturedRequest[] capturedRequests = transport.getCapturedRequestsAndClear();
         assertThat(capturedRequests.length, equalTo(1));
+        assertFalse(failure.get());
+        transport.handleResponse(capturedRequests[0].requestId, new TransportException("simulated"));
+
+        assertTrue(failure.get());
     }
 
     private ShardRouting getRandomShardRouting(String index) {
@@ -182,6 +292,34 @@ public class ShardStateActionTests extends ESTestCase {
         return shardRouting;
     }
 
+    private void setUpMasterRetryVerification(int numberOfRetries, AtomicInteger retries, CountDownLatch latch, LongConsumer retryLoop) {
+        shardStateAction.setOnBeforeWaitForNewMasterAndRetry(() -> {
+            DiscoveryNodes.Builder masterBuilder = DiscoveryNodes.builder(clusterService.state().nodes());
+            masterBuilder.masterNodeId(clusterService.state().nodes().masterNodes().iterator().next().value.id());
+            clusterService.setState(ClusterState.builder(clusterService.state()).nodes(masterBuilder));
+        });
+
+        shardStateAction.setOnAfterWaitForNewMasterAndRetry(() -> verifyRetry(numberOfRetries, retries, latch, retryLoop));
+    }
+
+    private void verifyRetry(int numberOfRetries, AtomicInteger retries, CountDownLatch latch, LongConsumer retryLoop) {
+        // assert a retry request was sent
+        final CapturingTransport.CapturedRequest[] capturedRequests = transport.getCapturedRequestsAndClear();
+        if (capturedRequests.length == 1) {
+            retries.incrementAndGet();
+            if (retries.get() == numberOfRetries) {
+                // finish the request
+                transport.handleResponse(capturedRequests[0].requestId, TransportResponse.Empty.INSTANCE);
+            } else {
+                retryLoop.accept(capturedRequests[0].requestId);
+            }
+        } else {
+            // there failed to be a retry request
+            // release the driver thread to fail the test
+            latch.countDown();
+        }
+    }
+
     private Throwable getSimulatedFailure() {
         return new CorruptIndexException("simulated", (String) null);
     }