Browse Source

Merge pull request #14707 from jasontedor/shard-failure-timeout

Add timeout mechanism for sending shard failures
Jason Tedor 10 years ago
parent
commit
4f3eec99bb

+ 42 - 19
core/src/main/java/org/elasticsearch/action/support/replication/TransportReplicationAction.java

@@ -36,7 +36,6 @@ import org.elasticsearch.cluster.ClusterService;
 import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.cluster.ClusterStateObserver;
 import org.elasticsearch.cluster.action.index.MappingUpdatedAction;
-import org.elasticsearch.cluster.action.shard.NoOpShardStateActionListener;
 import org.elasticsearch.cluster.action.shard.ShardStateAction;
 import org.elasticsearch.cluster.block.ClusterBlockException;
 import org.elasticsearch.cluster.block.ClusterBlockLevel;
@@ -81,6 +80,8 @@ import java.util.function.Supplier;
  */
 public abstract class TransportReplicationAction<Request extends ReplicationRequest, ReplicaRequest extends ReplicationRequest, Response extends ActionWriteResponse> extends TransportAction<Request, Response> {
 
+    public static final String SHARD_FAILURE_TIMEOUT = "action.support.replication.shard.failure_timeout";
+
     protected final TransportService transportService;
     protected final ClusterService clusterService;
     protected final IndicesService indicesService;
@@ -88,6 +89,7 @@ public abstract class TransportReplicationAction<Request extends ReplicationRequ
     protected final WriteConsistencyLevel defaultWriteConsistencyLevel;
     protected final TransportRequestOptions transportOptions;
     protected final MappingUpdatedAction mappingUpdatedAction;
+    private final TimeValue shardFailedTimeout;
 
     final String transportReplicaAction;
     final String executor;
@@ -117,6 +119,8 @@ public abstract class TransportReplicationAction<Request extends ReplicationRequ
         this.transportOptions = transportOptions();
 
         this.defaultWriteConsistencyLevel = WriteConsistencyLevel.fromString(settings.get("action.write_consistency", "quorum"));
+        // TODO: set a default timeout
+        shardFailedTimeout = settings.getAsTime(SHARD_FAILURE_TIMEOUT, null);
     }
 
     @Override
@@ -351,7 +355,6 @@ public abstract class TransportReplicationAction<Request extends ReplicationRequ
         private final AtomicBoolean finished = new AtomicBoolean(false);
         private volatile Releasable indexShardReference;
 
-
         PrimaryPhase(Request request, ActionListener<Response> listener) {
             this.internalRequest = new InternalRequest(request);
             this.listener = listener;
@@ -578,7 +581,7 @@ public abstract class TransportReplicationAction<Request extends ReplicationRequ
                 PrimaryOperationRequest por = new PrimaryOperationRequest(primary.id(), internalRequest.concreteIndex(), internalRequest.request());
                 Tuple<Response, ReplicaRequest> primaryResponse = shardOperationOnPrimary(observer.observedState(), por);
                 logger.trace("operation completed on primary [{}]", primary);
-                replicationPhase = new ReplicationPhase(shardsIt, primaryResponse.v2(), primaryResponse.v1(), observer, primary, internalRequest, listener, indexShardReference);
+                replicationPhase = new ReplicationPhase(shardsIt, primaryResponse.v2(), primaryResponse.v1(), observer, primary, internalRequest, listener, indexShardReference, shardFailedTimeout);
             } catch (Throwable e) {
                 // shard has not been allocated yet, retry it here
                 if (retryPrimaryException(e)) {
@@ -687,7 +690,7 @@ public abstract class TransportReplicationAction<Request extends ReplicationRequ
     /**
      * inner class is responsible for send the requests to all replica shards and manage the responses
      */
-    final class ReplicationPhase extends AbstractRunnable implements ShardStateAction.Listener {
+    final class ReplicationPhase extends AbstractRunnable {
 
         private final ReplicaRequest replicaRequest;
         private final Response finalResponse;
@@ -702,6 +705,7 @@ public abstract class TransportReplicationAction<Request extends ReplicationRequ
         private final int totalShards;
         private final ClusterStateObserver observer;
         private final Releasable indexShardReference;
+        private final TimeValue shardFailedTimeout;
 
         /**
          * the constructor doesn't take any action, just calculates state. Call {@link #run()} to start
@@ -709,7 +713,8 @@ public abstract class TransportReplicationAction<Request extends ReplicationRequ
          */
         public ReplicationPhase(ShardIterator originalShardIt, ReplicaRequest replicaRequest, Response finalResponse,
                                 ClusterStateObserver observer, ShardRouting originalPrimaryShard,
-                                InternalRequest internalRequest, ActionListener<Response> listener, Releasable indexShardReference) {
+                                InternalRequest internalRequest, ActionListener<Response> listener, Releasable indexShardReference,
+                                TimeValue shardFailedTimeout) {
             this.replicaRequest = replicaRequest;
             this.listener = listener;
             this.finalResponse = finalResponse;
@@ -717,6 +722,7 @@ public abstract class TransportReplicationAction<Request extends ReplicationRequ
             this.observer = observer;
             indexMetaData = observer.observedState().metaData().index(internalRequest.concreteIndex());
             this.indexShardReference = indexShardReference;
+            this.shardFailedTimeout = shardFailedTimeout;
 
             ShardRouting shard;
             // we double check on the state, if it got changed we need to make sure we take the latest one cause
@@ -822,16 +828,6 @@ public abstract class TransportReplicationAction<Request extends ReplicationRequ
             forceFinishAsFailed(t);
         }
 
-        @Override
-        public void onShardFailedNoMaster() {
-
-        }
-
-        @Override
-        public void onShardFailedFailure(DiscoveryNode master, TransportException e) {
-
-        }
-
         /**
          * start sending current requests to replicas
          */
@@ -893,14 +889,14 @@ public abstract class TransportReplicationAction<Request extends ReplicationRequ
 
                             @Override
                             public void handleException(TransportException exp) {
-                                onReplicaFailure(nodeId, exp);
                                 logger.trace("[{}] transport failure during replica request [{}] ", exp, node, replicaRequest);
-                                if (ignoreReplicaException(exp) == false) {
+                                if (ignoreReplicaException(exp)) {
+                                    onReplicaFailure(nodeId, exp);
+                                } else {
                                     logger.warn("{} failed to perform {} on node {}", exp, shardIt.shardId(), actionName, node);
-                                    shardStateAction.shardFailed(shard, indexMetaData.getIndexUUID(), "failed to perform " + actionName + " on replica on node " + node, exp, ReplicationPhase.this);
+                                    shardStateAction.shardFailed(shard, indexMetaData.getIndexUUID(), "failed to perform " + actionName + " on replica on node " + node, exp, shardFailedTimeout, new ReplicationFailedShardStateListener(nodeId, exp));
                                 }
                             }
-
                         });
             } else {
                 try {
@@ -989,6 +985,33 @@ public abstract class TransportReplicationAction<Request extends ReplicationRequ
             }
         }
 
+        public class ReplicationFailedShardStateListener implements ShardStateAction.Listener {
+            private final String nodeId;
+            private Throwable failure;
+
+            public ReplicationFailedShardStateListener(String nodeId, Throwable failure) {
+                this.nodeId = nodeId;
+                this.failure = failure;
+            }
+
+            @Override
+            public void onSuccess() {
+                onReplicaFailure(nodeId, failure);
+            }
+
+            @Override
+            public void onShardFailedNoMaster() {
+                onReplicaFailure(nodeId, failure);
+            }
+
+            @Override
+            public void onShardFailedFailure(DiscoveryNode master, TransportException e) {
+                if (e instanceof ReceiveTimeoutTransportException) {
+                    logger.trace("timeout sending shard failure to master [{}]", e, master);
+                }
+                onReplicaFailure(nodeId, failure);
+            }
+        }
     }
 
     /**

+ 2 - 1
core/src/main/java/org/elasticsearch/cluster/ClusterModule.java

@@ -21,6 +21,7 @@ package org.elasticsearch.cluster;
 
 import org.elasticsearch.action.admin.indices.close.TransportCloseIndexAction;
 import org.elasticsearch.action.support.DestructiveOperations;
+import org.elasticsearch.action.support.replication.TransportReplicationAction;
 import org.elasticsearch.cluster.action.index.MappingUpdatedAction;
 import org.elasticsearch.cluster.action.index.NodeIndexDeletedAction;
 import org.elasticsearch.cluster.action.index.NodeMappingRefreshAction;
@@ -86,7 +87,6 @@ import org.elasticsearch.indices.IndicesWarmer;
 import org.elasticsearch.indices.breaker.HierarchyCircuitBreakerService;
 import org.elasticsearch.indices.cache.request.IndicesRequestCache;
 import org.elasticsearch.indices.recovery.RecoverySettings;
-import org.elasticsearch.indices.store.IndicesStore;
 import org.elasticsearch.indices.ttl.IndicesTTLService;
 import org.elasticsearch.search.SearchService;
 import org.elasticsearch.search.internal.DefaultSearchContext;
@@ -206,6 +206,7 @@ public class ClusterModule extends AbstractModule {
         registerClusterDynamicSetting(TransportService.SETTING_TRACE_LOG_EXCLUDE + ".*", Validator.EMPTY);
         registerClusterDynamicSetting(TransportCloseIndexAction.SETTING_CLUSTER_INDICES_CLOSE_ENABLE, Validator.BOOLEAN);
         registerClusterDynamicSetting(ShardsLimitAllocationDecider.CLUSTER_TOTAL_SHARDS_PER_NODE, Validator.INTEGER);
+        registerClusterDynamicSetting(TransportReplicationAction.SHARD_FAILURE_TIMEOUT, Validator.TIME_NON_NEGATIVE);
     }
 
     private void registerBuiltinIndexSettings() {

+ 20 - 4
core/src/main/java/org/elasticsearch/cluster/action/shard/ShardStateAction.java

@@ -37,6 +37,7 @@ import org.elasticsearch.common.inject.Inject;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.unit.TimeValue;
 import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.*;
@@ -45,6 +46,7 @@ import java.io.IOException;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.CountDownLatch;
 
 import static org.elasticsearch.cluster.routing.ShardRouting.readShardRoutingEntry;
 
@@ -78,24 +80,37 @@ public class ShardStateAction extends AbstractComponent {
     }
 
     public void shardFailed(final ShardRouting shardRouting, final String indexUUID, final String message, @Nullable final Throwable failure, Listener listener) {
+        shardFailed(shardRouting, indexUUID, message, failure, null, listener);
+    }
+
+    public void shardFailed(final ShardRouting shardRouting, final String indexUUID, final String message, @Nullable final Throwable failure, TimeValue timeout, Listener listener) {
         DiscoveryNode masterNode = clusterService.state().nodes().masterNode();
         if (masterNode == null) {
             logger.warn("can't send shard failed for {}, no master known.", shardRouting);
             listener.onShardFailedNoMaster();
             return;
         }
-        innerShardFailed(shardRouting, indexUUID, masterNode, message, failure, listener);
+        innerShardFailed(shardRouting, indexUUID, masterNode, message, failure, timeout, listener);
     }
 
     public void resendShardFailed(final ShardRouting shardRouting, final String indexUUID, final DiscoveryNode masterNode, final String message, @Nullable final Throwable failure, Listener listener) {
         logger.trace("{} re-sending failed shard for {}, indexUUID [{}], reason [{}]", failure, shardRouting.shardId(), shardRouting, indexUUID, message);
-        innerShardFailed(shardRouting, indexUUID, masterNode, message, failure, listener);
+        innerShardFailed(shardRouting, indexUUID, masterNode, message, failure, null, listener);
     }
 
-    private void innerShardFailed(final ShardRouting shardRouting, final String indexUUID, final DiscoveryNode masterNode, final String message, final Throwable failure, Listener listener) {
+    private void innerShardFailed(final ShardRouting shardRouting, final String indexUUID, final DiscoveryNode masterNode, final String message, final Throwable failure, TimeValue timeout, Listener listener) {
         ShardRoutingEntry shardRoutingEntry = new ShardRoutingEntry(shardRouting, indexUUID, message, failure);
+        TransportRequestOptions options = TransportRequestOptions.EMPTY;
+        if (timeout != null) {
+            options = TransportRequestOptions.builder().withTimeout(timeout).build();
+        }
         transportService.sendRequest(masterNode,
-                SHARD_FAILED_ACTION_NAME, shardRoutingEntry, new EmptyTransportResponseHandler(ThreadPool.Names.SAME) {
+                SHARD_FAILED_ACTION_NAME, shardRoutingEntry, options, new EmptyTransportResponseHandler(ThreadPool.Names.SAME) {
+                    @Override
+                    public void handleResponse(TransportResponse.Empty response) {
+                        listener.onSuccess();
+                    }
+
                     @Override
                     public void handleException(TransportException exp) {
                         logger.warn("failed to send failed shard to {}", exp, masterNode);
@@ -288,6 +303,7 @@ public class ShardStateAction extends AbstractComponent {
     }
 
     public interface Listener {
+        default void onSuccess() {}
         default void onShardFailedNoMaster() {}
         default void onShardFailedFailure(final DiscoveryNode master, final TransportException e) {}
     }

+ 23 - 5
core/src/test/java/org/elasticsearch/action/support/replication/TransportReplicationActionTests.java

@@ -63,6 +63,7 @@ import org.junit.Before;
 import org.junit.BeforeClass;
 
 import java.io.IOException;
+import java.util.ArrayList;
 import java.util.HashSet;
 import java.util.List;
 import java.util.concurrent.CountDownLatch;
@@ -351,10 +352,11 @@ public class TransportReplicationActionTests extends ESTestCase {
         internalRequest.concreteIndex(shardId.index().name());
         Releasable reference = getOrCreateIndexShardOperationsCounter();
         assertIndexShardCounter(2);
+        // TODO: set a default timeout
         TransportReplicationAction<Request, Request, Response>.ReplicationPhase replicationPhase =
                 action.new ReplicationPhase(shardIt, request,
                         new Response(), new ClusterStateObserver(clusterService, logger),
-                        primaryShard, internalRequest, listener, reference);
+                        primaryShard, internalRequest, listener, reference, null);
 
         assertThat(replicationPhase.totalShards(), equalTo(totalShards));
         assertThat(replicationPhase.pending(), equalTo(assignedReplicas));
@@ -368,10 +370,12 @@ public class TransportReplicationActionTests extends ESTestCase {
         int pending = replicationPhase.pending();
         int criticalFailures = 0; // failures that should fail the shard
         int successful = 1;
+        List<CapturingTransport.CapturedRequest> failures = new ArrayList<>();
         for (CapturingTransport.CapturedRequest capturedRequest : capturedRequests) {
             if (randomBoolean()) {
                 Throwable t;
-                if (randomBoolean()) {
+                boolean criticalFailure = randomBoolean();
+                if (criticalFailure) {
                     t = new CorruptIndexException("simulated", (String) null);
                     criticalFailures++;
                 } else {
@@ -379,6 +383,14 @@ public class TransportReplicationActionTests extends ESTestCase {
                 }
                 logger.debug("--> simulating failure on {} with [{}]", capturedRequest.node, t.getClass().getSimpleName());
                 transport.handleResponse(capturedRequest.requestId, t);
+                if (criticalFailure) {
+                    CapturingTransport.CapturedRequest[] shardFailedRequests = transport.capturedRequests();
+                    transport.clear();
+                    assertEquals(1, shardFailedRequests.length);
+                    CapturingTransport.CapturedRequest shardFailedRequest = shardFailedRequests[0];
+                    failures.add(shardFailedRequest);
+                    transport.handleResponse(shardFailedRequest.requestId, TransportResponse.Empty.INSTANCE);
+                }
             } else {
                 successful++;
                 transport.handleResponse(capturedRequest.requestId, TransportResponse.Empty.INSTANCE);
@@ -395,7 +407,7 @@ public class TransportReplicationActionTests extends ESTestCase {
         assertThat(shardInfo.getSuccessful(), equalTo(successful));
         assertThat(shardInfo.getTotal(), equalTo(totalShards));
 
-        assertThat("failed to see enough shard failures", transport.capturedRequests().length, equalTo(criticalFailures));
+        assertThat("failed to see enough shard failures", failures.size(), equalTo(criticalFailures));
         for (CapturingTransport.CapturedRequest capturedRequest : transport.capturedRequests()) {
             assertThat(capturedRequest.action, equalTo(ShardStateAction.SHARD_FAILED_ACTION_NAME));
         }
@@ -464,9 +476,15 @@ public class TransportReplicationActionTests extends ESTestCase {
         primaryPhase = action.new PrimaryPhase(request, listener);
         primaryPhase.run();
         assertIndexShardCounter(2);
-        assertThat(transport.capturedRequests().length, equalTo(1));
+        CapturingTransport.CapturedRequest[] replicationRequests = transport.capturedRequests();
+        transport.clear();
+        assertThat(replicationRequests.length, equalTo(1));
         // try with failure response
-        transport.handleResponse(transport.capturedRequests()[0].requestId, new CorruptIndexException("simulated", (String) null));
+        transport.handleResponse(replicationRequests[0].requestId, new CorruptIndexException("simulated", (String) null));
+        CapturingTransport.CapturedRequest[] shardFailedRequests = transport.capturedRequests();
+        transport.clear();
+        assertEquals(1, shardFailedRequests.length);
+        transport.handleResponse(shardFailedRequests[0].requestId, TransportResponse.Empty.INSTANCE);
         assertIndexShardCounter(1);
     }
 

+ 35 - 0
core/src/test/java/org/elasticsearch/cluster/action/shard/ShardStateActionTests.java

@@ -27,10 +27,12 @@ import org.elasticsearch.cluster.routing.IndexRoutingTable;
 import org.elasticsearch.cluster.routing.ShardRouting;
 import org.elasticsearch.cluster.routing.ShardsIterator;
 import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.unit.TimeValue;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.cluster.TestClusterService;
 import org.elasticsearch.test.transport.CapturingTransport;
 import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.transport.ReceiveTimeoutTransportException;
 import org.elasticsearch.transport.TransportException;
 import org.elasticsearch.transport.TransportService;
 import org.junit.After;
@@ -38,6 +40,7 @@ import org.junit.AfterClass;
 import org.junit.Before;
 import org.junit.BeforeClass;
 
+import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicBoolean;
 
@@ -141,6 +144,38 @@ public class ShardStateActionTests extends ESTestCase {
         assertTrue(failure.get());
     }
 
+    public void testTimeout() throws InterruptedException {
+        final String index = "test";
+
+        clusterService.setState(stateWithStartedPrimary(index, true, randomInt(5)));
+
+        String indexUUID = clusterService.state().metaData().index(index).getIndexUUID();
+
+        AtomicBoolean progress = new AtomicBoolean();
+        AtomicBoolean timedOut = new AtomicBoolean();
+
+        TimeValue timeout = new TimeValue(1, TimeUnit.MILLISECONDS);
+        CountDownLatch latch = new CountDownLatch(1);
+        shardStateAction.shardFailed(getRandomShardRouting(index), indexUUID, "test", getSimulatedFailure(), timeout, new ShardStateAction.Listener() {
+            @Override
+            public void onShardFailedFailure(DiscoveryNode master, TransportException e) {
+                if (e instanceof ReceiveTimeoutTransportException) {
+                    assertFalse(progress.get());
+                    timedOut.set(true);
+                }
+                latch.countDown();
+            }
+        });
+
+        latch.await();
+        progress.set(true);
+        assertTrue(timedOut.get());
+
+        final CapturingTransport.CapturedRequest[] capturedRequests = transport.capturedRequests();
+        transport.clear();
+        assertThat(capturedRequests.length, equalTo(1));
+    }
+
     private ShardRouting getRandomShardRouting(String index) {
         IndexRoutingTable indexRoutingTable = clusterService.state().routingTable().index(index);
         ShardsIterator shardsIterator = indexRoutingTable.randomAllActiveShardsIt();