Преглед изворни кода

Revert "Revert "[RCI] Check blocks while having index shard permit in TransportReplicationAction (#35332)""

This reverts commit d3d7c01
Tanguy Leroux пре 7 година
родитељ
комит
f9f7261d60

+ 75 - 62
server/src/main/java/org/elasticsearch/action/support/replication/TransportReplicationAction.java

@@ -235,9 +235,39 @@ public abstract class TransportReplicationAction<
         return TransportRequestOptions.EMPTY;
     }
 
+    private String concreteIndex(final ClusterState state, final ReplicationRequest request) {
+        return resolveIndex() ? indexNameExpressionResolver.concreteSingleIndex(state, request).getName() : request.index();
+    }
+
+    private ClusterBlockException blockExceptions(final ClusterState state, final String indexName) {
+        ClusterBlockLevel globalBlockLevel = globalBlockLevel();
+        if (globalBlockLevel != null) {
+            ClusterBlockException blockException = state.blocks().globalBlockedException(globalBlockLevel);
+            if (blockException != null) {
+                return blockException;
+            }
+        }
+        ClusterBlockLevel indexBlockLevel = indexBlockLevel();
+        if (indexBlockLevel != null) {
+            ClusterBlockException blockException = state.blocks().indexBlockedException(indexBlockLevel, indexName);
+            if (blockException != null) {
+                return blockException;
+            }
+        }
+        return null;
+    }
+
     protected boolean retryPrimaryException(final Throwable e) {
         return e.getClass() == ReplicationOperation.RetryOnPrimaryException.class
-                || TransportActions.isShardNotAvailableException(e);
+                || TransportActions.isShardNotAvailableException(e)
+                || isRetryableClusterBlockException(e);
+    }
+
+    boolean isRetryableClusterBlockException(final Throwable e) {
+        if (e instanceof ClusterBlockException) {
+            return ((ClusterBlockException) e).retryable();
+        }
+        return false;
     }
 
     protected class OperationTransportHandler implements TransportRequestHandler<Request> {
@@ -310,6 +340,15 @@ public abstract class TransportReplicationAction<
         @Override
         public void onResponse(PrimaryShardReference primaryShardReference) {
             try {
+                final ClusterState clusterState = clusterService.state();
+                final IndexMetaData indexMetaData = clusterState.metaData().getIndexSafe(primaryShardReference.routingEntry().index());
+
+                final ClusterBlockException blockException = blockExceptions(clusterState, indexMetaData.getIndex().getName());
+                if (blockException != null) {
+                    logger.trace("cluster is blocked, action failed on primary", blockException);
+                    throw blockException;
+                }
+
                 if (primaryShardReference.isRelocated()) {
                     primaryShardReference.close(); // release shard operation lock as soon as possible
                     setPhase(replicationTask, "primary_delegation");
@@ -323,7 +362,7 @@ public abstract class TransportReplicationAction<
                         response.readFrom(in);
                         return response;
                     };
-                    DiscoveryNode relocatingNode = clusterService.state().nodes().get(primary.relocatingNodeId());
+                    DiscoveryNode relocatingNode = clusterState.nodes().get(primary.relocatingNodeId());
                     transportService.sendRequest(relocatingNode, transportPrimaryAction,
                         new ConcreteShardRequest<>(request, primary.allocationId().getRelocationId(), primaryTerm),
                         transportOptions,
@@ -696,35 +735,42 @@ public abstract class TransportReplicationAction<
         protected void doRun() {
             setPhase(task, "routing");
             final ClusterState state = observer.setAndGetObservedState();
-            if (handleBlockExceptions(state)) {
-                return;
-            }
-
-            // request does not have a shardId yet, we need to pass the concrete index to resolve shardId
-            final String concreteIndex = concreteIndex(state);
-            final IndexMetaData indexMetaData = state.metaData().index(concreteIndex);
-            if (indexMetaData == null) {
-                retry(new IndexNotFoundException(concreteIndex));
-                return;
-            }
-            if (indexMetaData.getState() == IndexMetaData.State.CLOSE) {
-                throw new IndexClosedException(indexMetaData.getIndex());
-            }
+            final String concreteIndex = concreteIndex(state, request);
+            final ClusterBlockException blockException = blockExceptions(state, concreteIndex);
+            if (blockException != null) {
+                if (blockException.retryable()) {
+                    logger.trace("cluster is blocked, scheduling a retry", blockException);
+                    retry(blockException);
+                } else {
+                    finishAsFailed(blockException);
+                }
+            } else {
+                // request does not have a shardId yet, we need to pass the concrete index to resolve shardId
+                final IndexMetaData indexMetaData = state.metaData().index(concreteIndex);
+                if (indexMetaData == null) {
+                    retry(new IndexNotFoundException(concreteIndex));
+                    return;
+                }
+                if (indexMetaData.getState() == IndexMetaData.State.CLOSE) {
+                    throw new IndexClosedException(indexMetaData.getIndex());
+                }
 
-            // resolve all derived request fields, so we can route and apply it
-            resolveRequest(indexMetaData, request);
-            assert request.shardId() != null : "request shardId must be set in resolveRequest";
-            assert request.waitForActiveShards() != ActiveShardCount.DEFAULT : "request waitForActiveShards must be set in resolveRequest";
+                // resolve all derived request fields, so we can route and apply it
+                resolveRequest(indexMetaData, request);
+                assert request.shardId() != null : "request shardId must be set in resolveRequest";
+                assert request.waitForActiveShards() != ActiveShardCount.DEFAULT :
+                    "request waitForActiveShards must be set in resolveRequest";
 
-            final ShardRouting primary = primary(state);
-            if (retryIfUnavailable(state, primary)) {
-                return;
-            }
-            final DiscoveryNode node = state.nodes().get(primary.currentNodeId());
-            if (primary.currentNodeId().equals(state.nodes().getLocalNodeId())) {
-                performLocalAction(state, primary, node, indexMetaData);
-            } else {
-                performRemoteAction(state, primary, node);
+                final ShardRouting primary = primary(state);
+                if (retryIfUnavailable(state, primary)) {
+                    return;
+                }
+                final DiscoveryNode node = state.nodes().get(primary.currentNodeId());
+                if (primary.currentNodeId().equals(state.nodes().getLocalNodeId())) {
+                    performLocalAction(state, primary, node, indexMetaData);
+                } else {
+                    performRemoteAction(state, primary, node);
+                }
             }
         }
 
@@ -776,44 +822,11 @@ public abstract class TransportReplicationAction<
             return false;
         }
 
-        private String concreteIndex(ClusterState state) {
-            return resolveIndex() ? indexNameExpressionResolver.concreteSingleIndex(state, request).getName() : request.index();
-        }
-
         private ShardRouting primary(ClusterState state) {
             IndexShardRoutingTable indexShard = state.getRoutingTable().shardRoutingTable(request.shardId());
             return indexShard.primaryShard();
         }
 
-        private boolean handleBlockExceptions(ClusterState state) {
-            ClusterBlockLevel globalBlockLevel = globalBlockLevel();
-            if (globalBlockLevel != null) {
-                ClusterBlockException blockException = state.blocks().globalBlockedException(globalBlockLevel);
-                if (blockException != null) {
-                    handleBlockException(blockException);
-                    return true;
-                }
-            }
-            ClusterBlockLevel indexBlockLevel = indexBlockLevel();
-            if (indexBlockLevel != null) {
-                ClusterBlockException blockException = state.blocks().indexBlockedException(indexBlockLevel, concreteIndex(state));
-                if (blockException != null) {
-                    handleBlockException(blockException);
-                    return true;
-                }
-            }
-            return false;
-        }
-
-        private void handleBlockException(ClusterBlockException blockException) {
-            if (blockException.retryable()) {
-                logger.trace("cluster is blocked, scheduling a retry", blockException);
-                retry(blockException);
-            } else {
-                finishAsFailed(blockException);
-            }
-        }
-
         private void performAction(final DiscoveryNode node, final String action, final boolean isPrimaryAction,
                                    final TransportRequest requestToPerform) {
             transportService.sendRequest(node, action, requestToPerform, transportOptions, new TransportResponseHandler<Response>() {

+ 159 - 74
server/src/test/java/org/elasticsearch/action/support/replication/TransportReplicationActionTests.java

@@ -89,6 +89,7 @@ import org.junit.BeforeClass;
 
 import java.io.IOException;
 import java.util.Collections;
+import java.util.EnumSet;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Locale;
@@ -100,6 +101,7 @@ import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicReference;
 import java.util.stream.Collectors;
 
+import static java.util.Collections.singleton;
 import static org.elasticsearch.action.support.replication.ClusterStateCreationUtils.state;
 import static org.elasticsearch.action.support.replication.ClusterStateCreationUtils.stateWithActivePrimary;
 import static org.elasticsearch.cluster.metadata.IndexMetaData.SETTING_WAIT_FOR_ACTIVE_SHARDS;
@@ -108,9 +110,11 @@ import static org.elasticsearch.test.ClusterServiceUtils.setState;
 import static org.hamcrest.CoreMatchers.containsString;
 import static org.hamcrest.Matchers.arrayWithSize;
 import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.hasToString;
 import static org.hamcrest.Matchers.instanceOf;
 import static org.hamcrest.Matchers.notNullValue;
 import static org.hamcrest.Matchers.nullValue;
+import static org.hamcrest.core.Is.is;
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.anyInt;
 import static org.mockito.Matchers.anyLong;
@@ -182,70 +186,157 @@ public class TransportReplicationActionTests extends ESTestCase {
         threadPool = null;
     }
 
-    <T> void assertListenerThrows(String msg, PlainActionFuture<T> listener, Class<?> klass) throws InterruptedException {
-        try {
-            listener.get();
-            fail(msg);
-        } catch (ExecutionException ex) {
-            assertThat(ex.getCause(), instanceOf(klass));
+    private <T> T assertListenerThrows(String msg, PlainActionFuture<?> listener, Class<T> klass) {
+        ExecutionException exception = expectThrows(ExecutionException.class, msg, listener::get);
+        assertThat(exception.getCause(), instanceOf(klass));
+        @SuppressWarnings("unchecked")
+        final T cause = (T) exception.getCause();
+        return cause;
+    }
+
+    private void setStateWithBlock(final ClusterService clusterService, final ClusterBlock block, final boolean globalBlock) {
+        final ClusterBlocks.Builder blocks = ClusterBlocks.builder();
+        if (globalBlock) {
+            blocks.addGlobalBlock(block);
+        } else {
+            blocks.addIndexBlock("index", block);
         }
+        setState(clusterService, ClusterState.builder(clusterService.state()).blocks(blocks).build());
     }
 
-    public void testBlocks() throws ExecutionException, InterruptedException {
-        Request request = new Request();
-        PlainActionFuture<TestResponse> listener = new PlainActionFuture<>();
-        ReplicationTask task = maybeTask();
-        TestAction action = new TestAction(Settings.EMPTY, "internal:testActionWithBlocks",
-                transportService, clusterService, shardStateAction, threadPool) {
+    public void testBlocksInReroutePhase() throws Exception {
+        final ClusterBlock nonRetryableBlock =
+            new ClusterBlock(1, "non retryable", false, true, false, RestStatus.SERVICE_UNAVAILABLE, ClusterBlockLevel.ALL);
+        final ClusterBlock retryableBlock =
+            new ClusterBlock(1, "retryable", true, true, false, RestStatus.SERVICE_UNAVAILABLE, ClusterBlockLevel.ALL);
+
+        final boolean globalBlock = randomBoolean();
+        final TestAction action = new TestAction(Settings.EMPTY, "internal:testActionWithBlocks",
+            transportService, clusterService, shardStateAction, threadPool) {
             @Override
             protected ClusterBlockLevel globalBlockLevel() {
-                return ClusterBlockLevel.WRITE;
+                return globalBlock ? ClusterBlockLevel.WRITE : null;
+            }
+
+            @Override
+            protected ClusterBlockLevel indexBlockLevel() {
+                return globalBlock == false ? ClusterBlockLevel.WRITE : null;
             }
         };
 
-        ClusterBlocks.Builder block = ClusterBlocks.builder().addGlobalBlock(new ClusterBlock(1, "non retryable", false, true,
-            false, RestStatus.SERVICE_UNAVAILABLE, ClusterBlockLevel.ALL));
-        setState(clusterService, ClusterState.builder(clusterService.state()).blocks(block));
-        TestAction.ReroutePhase reroutePhase = action.new ReroutePhase(task, request, listener);
-        reroutePhase.run();
-        assertListenerThrows("primary phase should fail operation", listener, ClusterBlockException.class);
-        assertPhase(task, "failed");
+        setState(clusterService, ClusterStateCreationUtils.stateWithActivePrimary("index", true, 0));
 
-        block = ClusterBlocks.builder()
-            .addGlobalBlock(new ClusterBlock(1, "retryable", true, true, false, RestStatus.SERVICE_UNAVAILABLE, ClusterBlockLevel.ALL));
-        setState(clusterService, ClusterState.builder(clusterService.state()).blocks(block));
-        listener = new PlainActionFuture<>();
-        reroutePhase = action.new ReroutePhase(task, new Request().timeout("5ms"), listener);
-        reroutePhase.run();
-        assertListenerThrows("failed to timeout on retryable block", listener, ClusterBlockException.class);
-        assertPhase(task, "failed");
-        assertFalse(request.isRetrySet.get());
+        {
+            setStateWithBlock(clusterService, nonRetryableBlock, globalBlock);
 
-        listener = new PlainActionFuture<>();
-        reroutePhase = action.new ReroutePhase(task, request = new Request(), listener);
-        reroutePhase.run();
-        assertFalse("primary phase should wait on retryable block", listener.isDone());
-        assertPhase(task, "waiting_for_retry");
-        assertTrue(request.isRetrySet.get());
+            Request request = globalBlock ? new Request() : new Request().index("index");
+            PlainActionFuture<TestResponse> listener = new PlainActionFuture<>();
+            ReplicationTask task = maybeTask();
+
+            TestAction.ReroutePhase reroutePhase = action.new ReroutePhase(task, request, listener);
+            reroutePhase.run();
+
+            ClusterBlockException exception =
+                assertListenerThrows("primary action should fail operation", listener, ClusterBlockException.class);
+            assertThat(((ClusterBlockException) exception.unwrapCause()).blocks().iterator().next(), is(nonRetryableBlock));
+            assertPhase(task, "failed");
+        }
+        {
+            setStateWithBlock(clusterService, retryableBlock, globalBlock);
+
+            Request requestWithTimeout = (globalBlock ? new Request() : new Request().index("index")).timeout("5ms");
+            PlainActionFuture<TestResponse> listener = new PlainActionFuture<>();
+            ReplicationTask task = maybeTask();
+
+            TestAction.ReroutePhase reroutePhase = action.new ReroutePhase(task, requestWithTimeout, listener);
+            reroutePhase.run();
+
+            ClusterBlockException exception =
+                assertListenerThrows("failed to timeout on retryable block", listener, ClusterBlockException.class);
+            assertThat(((ClusterBlockException) exception.unwrapCause()).blocks().iterator().next(), is(retryableBlock));
+            assertPhase(task, "failed");
+            assertTrue(requestWithTimeout.isRetrySet.get());
+        }
+        {
+            setStateWithBlock(clusterService, retryableBlock, globalBlock);
+
+            Request request = globalBlock ? new Request() : new Request().index("index");
+            PlainActionFuture<TestResponse> listener = new PlainActionFuture<>();
+            ReplicationTask task = maybeTask();
+
+            TestAction.ReroutePhase reroutePhase = action.new ReroutePhase(task, request, listener);
+            reroutePhase.run();
+
+            assertFalse("primary phase should wait on retryable block", listener.isDone());
+            assertPhase(task, "waiting_for_retry");
+            assertTrue(request.isRetrySet.get());
+
+            setStateWithBlock(clusterService, nonRetryableBlock, globalBlock);
+
+            ClusterBlockException exception = assertListenerThrows("primary phase should fail operation when moving from a retryable " +
+                    "block to a non-retryable one", listener, ClusterBlockException.class);
+            assertThat(((ClusterBlockException) exception.unwrapCause()).blocks().iterator().next(), is(nonRetryableBlock));
+            assertIndexShardUninitialized();
+        }
+        {
+            Request requestWithTimeout = new Request().index("unknown").setShardId(new ShardId("unknown", "_na_", 0)).timeout("5ms");
+            PlainActionFuture<TestResponse> listener = new PlainActionFuture<>();
+            ReplicationTask task = maybeTask();
+
+            TestAction testActionWithNoBlocks = new TestAction(Settings.EMPTY, "internal:testActionWithNoBlocks", transportService,
+                clusterService, shardStateAction, threadPool);
+            listener = new PlainActionFuture<>();
+            TestAction.ReroutePhase reroutePhase = testActionWithNoBlocks.new ReroutePhase(task, requestWithTimeout, listener);
+            reroutePhase.run();
+            assertListenerThrows("should fail with an IndexNotFoundException when no blocks", listener, IndexNotFoundException.class);
+        }
+    }
+
+    public void testBlocksInPrimaryAction() {
+        final boolean globalBlock = randomBoolean();
 
-        block = ClusterBlocks.builder().addGlobalBlock(new ClusterBlock(1, "non retryable", false, true, false,
-            RestStatus.SERVICE_UNAVAILABLE, ClusterBlockLevel.ALL));
+        final TestAction actionWithBlocks =
+            new TestAction(Settings.EMPTY, "internal:actionWithBlocks", transportService, clusterService, shardStateAction, threadPool) {
+                @Override
+                protected ClusterBlockLevel globalBlockLevel() {
+                    return globalBlock ? ClusterBlockLevel.WRITE : null;
+                }
+
+                @Override
+                protected ClusterBlockLevel indexBlockLevel() {
+                    return globalBlock == false ? ClusterBlockLevel.WRITE : null;
+                }
+            };
+
+        final String index = "index";
+        final ShardId shardId = new ShardId(index, "_na_", 0);
+        setState(clusterService, stateWithActivePrimary(index, true, randomInt(5)));
+
+        final ClusterBlocks.Builder block = ClusterBlocks.builder();
+        if (globalBlock) {
+            block.addGlobalBlock(new ClusterBlock(randomIntBetween(1, 16), "test global block", randomBoolean(), randomBoolean(),
+                randomBoolean(), RestStatus.BAD_REQUEST, ClusterBlockLevel.ALL));
+        } else {
+            block.addIndexBlock(index, new ClusterBlock(randomIntBetween(1, 16), "test index block", randomBoolean(), randomBoolean(),
+                randomBoolean(), RestStatus.FORBIDDEN, ClusterBlockLevel.READ_WRITE));
+        }
         setState(clusterService, ClusterState.builder(clusterService.state()).blocks(block));
-        assertListenerThrows("primary phase should fail operation when moving from a retryable block to a non-retryable one", listener,
-            ClusterBlockException.class);
-        assertIndexShardUninitialized();
 
-        action = new TestAction(Settings.EMPTY, "internal:testActionWithNoBlocks", transportService, clusterService, shardStateAction,
-            threadPool) {
-            @Override
-            protected ClusterBlockLevel globalBlockLevel() {
-                return null;
-            }
-        };
-        listener = new PlainActionFuture<>();
-        reroutePhase = action.new ReroutePhase(task, new Request().timeout("5ms"), listener);
-        reroutePhase.run();
-        assertListenerThrows("should fail with an IndexNotFoundException when no blocks checked", listener, IndexNotFoundException.class);
+        final ClusterState clusterState = clusterService.state();
+        final String targetAllocationID = clusterState.getRoutingTable().shardRoutingTable(shardId).primaryShard().allocationId().getId();
+        final long primaryTerm = clusterState.metaData().index(index).primaryTerm(shardId.id());
+        final Request request = new Request(shardId);
+        final ReplicationTask task = maybeTask();
+        final PlainActionFuture<TestResponse> listener = new PlainActionFuture<>();
+
+        final TransportReplicationAction.AsyncPrimaryAction asyncPrimaryActionWithBlocks =
+            actionWithBlocks.new AsyncPrimaryAction(request, targetAllocationID, primaryTerm, createTransportChannel(listener), task);
+        asyncPrimaryActionWithBlocks.run();
+
+        final ExecutionException exception = expectThrows(ExecutionException.class, listener::get);
+        assertThat(exception.getCause(), instanceOf(ClusterBlockException.class));
+        assertThat(exception.getCause(), hasToString(containsString("test " + (globalBlock ? "global" : "index") + " block")));
+        assertPhase(task, "finished");
     }
 
     public void assertIndexShardUninitialized() {
@@ -377,21 +468,12 @@ public class TransportReplicationActionTests extends ESTestCase {
         PlainActionFuture<TestResponse> listener = new PlainActionFuture<>();
         ReplicationTask task = maybeTask();
 
-        ClusterBlockLevel indexBlockLevel = randomBoolean() ? ClusterBlockLevel.WRITE : null;
         TestAction action = new TestAction(Settings.EMPTY, "internal:testActionWithBlocks", transportService,
-                clusterService, shardStateAction, threadPool) {
-            @Override
-            protected ClusterBlockLevel indexBlockLevel() {
-                return indexBlockLevel;
-            }
-        };
+                clusterService, shardStateAction, threadPool);
         TestAction.ReroutePhase reroutePhase = action.new ReroutePhase(task, request, listener);
         reroutePhase.run();
-        if (indexBlockLevel == ClusterBlockLevel.WRITE) {
-            assertListenerThrows("must throw block exception", listener, ClusterBlockException.class);
-        } else {
-            assertListenerThrows("must throw index closed exception", listener, IndexClosedException.class);
-        }
+        assertListenerThrows("must throw index closed exception", listener, IndexClosedException.class);
+
         assertPhase(task, "failed");
         assertFalse(request.isRetrySet.get());
     }
@@ -682,12 +764,12 @@ public class TransportReplicationActionTests extends ESTestCase {
         PlainActionFuture<TestResponse> listener = new PlainActionFuture<>();
 
 
-        final IndexShard shard = mock(IndexShard.class);
+        final IndexShard shard = mockIndexShard(shardId, clusterService);
         when(shard.getPendingPrimaryTerm()).thenReturn(primaryTerm);
         when(shard.routingEntry()).thenReturn(routingEntry);
         when(shard.isRelocatedPrimary()).thenReturn(false);
         IndexShardRoutingTable shardRoutingTable = clusterService.state().routingTable().shardRoutingTable(shardId);
-        Set<String> inSyncIds = randomBoolean() ? Collections.singleton(routingEntry.allocationId().getId()) :
+        Set<String> inSyncIds = randomBoolean() ? singleton(routingEntry.allocationId().getId()) :
             clusterService.state().metaData().index(index).inSyncAllocationIds(0);
         when(shard.getReplicationGroup()).thenReturn(
             new ReplicationGroup(shardRoutingTable,
@@ -1022,6 +1104,17 @@ public class TransportReplicationActionTests extends ESTestCase {
         transportService.stop();
     }
 
+    public void testIsRetryableClusterBlockException() {
+        final TestAction action = new TestAction(Settings.EMPTY, "internal:testIsRetryableClusterBlockException", transportService,
+            clusterService, shardStateAction, threadPool);
+        assertFalse(action.isRetryableClusterBlockException(randomRetryPrimaryException(new ShardId("index", "_na_", 0))));
+
+        final boolean retryable = randomBoolean();
+        ClusterBlock randomBlock = new ClusterBlock(randomIntBetween(1, 16), "test", retryable, randomBoolean(),
+            randomBoolean(), randomFrom(RestStatus.values()), EnumSet.of(randomFrom(ClusterBlockLevel.values())));
+        assertEquals(retryable, action.isRetryableClusterBlockException(new ClusterBlockException(singleton(randomBlock))));
+    }
+
     private void assertConcreteShardRequest(TransportRequest capturedRequest, Request expectedRequest, AllocationId expectedAllocationId) {
         final TransportReplicationAction.ConcreteShardRequest<?> concreteShardRequest =
             (TransportReplicationAction.ConcreteShardRequest<?>) capturedRequest;
@@ -1115,15 +1208,6 @@ public class TransportReplicationActionTests extends ESTestCase {
                 Request::new, Request::new, ThreadPool.Names.SAME);
         }
 
-        TestAction(Settings settings, String actionName, TransportService transportService,
-                   ClusterService clusterService, ShardStateAction shardStateAction,
-                   ThreadPool threadPool, boolean withDocumentFailureOnPrimary, boolean withDocumentFailureOnReplica) {
-            super(settings, actionName, transportService, clusterService, mockIndicesService(clusterService), threadPool,
-                shardStateAction,
-                new ActionFilters(new HashSet<>()), new IndexNameExpressionResolver(),
-                Request::new, Request::new, ThreadPool.Names.SAME);
-        }
-
         @Override
         protected TestResponse newResponseInstance() {
             return new TestResponse();
@@ -1183,6 +1267,7 @@ public class TransportReplicationActionTests extends ESTestCase {
 
     private IndexShard mockIndexShard(ShardId shardId, ClusterService clusterService) {
         final IndexShard indexShard = mock(IndexShard.class);
+        when(indexShard.shardId()).thenReturn(shardId);
         doAnswer(invocation -> {
             ActionListener<Releasable> callback = (ActionListener<Releasable>) invocation.getArguments()[0];
             count.incrementAndGet();