Browse Source

New TransportBroadcastUnpromotableAction action (#93600)

Introduces:

* New action that can be used to broadcast to unpromotable shards of a given IndexShardRoutingTable.
* New hook in ReplicationOperation for custom logic when the primary operation completes. If there is a failure, this increases the shard failures of the replication operation.
* Refresh action now uses the new hook to broadcast the unpromotable refresh action to all unpromotable shards.

Fixes ES-5454
Fixes ES-5212
Iraklis Psaroudakis 2 years ago
parent
commit
e144bef8e9
15 changed files with 885 additions and 93 deletions
  1. 5 0
      docs/changelog/93600.yaml
  2. 60 1
      server/src/internalClusterTest/java/org/elasticsearch/cluster/routing/ShardRoutingRoleIT.java
  3. 1 0
      server/src/main/java/module-info.java
  4. 46 0
      server/src/main/java/org/elasticsearch/action/admin/indices/refresh/ShardRefreshReplicaRequest.java
  5. 45 50
      server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportShardRefreshAction.java
  6. 14 9
      server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportUnpromotableShardRefreshAction.java
  7. 14 14
      server/src/main/java/org/elasticsearch/action/admin/indices/refresh/UnpromotableShardRefreshRequest.java
  8. 93 0
      server/src/main/java/org/elasticsearch/action/support/broadcast/unpromotable/BroadcastUnpromotableRequest.java
  9. 94 0
      server/src/main/java/org/elasticsearch/action/support/broadcast/unpromotable/TransportBroadcastUnpromotableAction.java
  10. 37 1
      server/src/main/java/org/elasticsearch/action/support/replication/ReplicationOperation.java
  11. 18 0
      server/src/main/java/org/elasticsearch/cluster/routing/IndexShardRoutingTable.java
  12. 4 1
      server/src/main/java/org/elasticsearch/transport/TransportService.java
  13. 322 0
      server/src/test/java/org/elasticsearch/action/support/broadcast/unpromotable/TransportBroadcastUnpromotableActionTests.java
  14. 71 13
      test/framework/src/main/java/org/elasticsearch/action/support/replication/ClusterStateCreationUtils.java
  15. 61 4
      test/framework/src/main/java/org/elasticsearch/cluster/routing/TestShardRouting.java

+ 5 - 0
docs/changelog/93600.yaml

@@ -0,0 +1,5 @@
+pr: 93600
+summary: New `TransportBroadcastUnpromotableAction` action
+area: CRUD
+type: feature
+issues: []

+ 60 - 1
server/src/internalClusterTest/java/org/elasticsearch/cluster/routing/ShardRoutingRoleIT.java

@@ -12,6 +12,7 @@ import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 import org.elasticsearch.action.ActionResponse;
 import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsGroup;
+import org.elasticsearch.action.admin.indices.refresh.RefreshResponse;
 import org.elasticsearch.action.admin.indices.refresh.TransportUnpromotableShardRefreshAction;
 import org.elasticsearch.action.search.ClosePointInTimeAction;
 import org.elasticsearch.action.search.ClosePointInTimeRequest;
@@ -51,6 +52,7 @@ import org.elasticsearch.snapshots.SnapshotState;
 import org.elasticsearch.test.ESIntegTestCase;
 import org.elasticsearch.test.XContentTestUtils;
 import org.elasticsearch.test.transport.MockTransportService;
+import org.elasticsearch.transport.ConnectTransportException;
 import org.elasticsearch.transport.TransportService;
 
 import java.io.IOException;
@@ -269,7 +271,7 @@ public class ShardRoutingRoleIT extends ESIntegTestCase {
                 connection.sendRequest(requestId, action, request, options);
             });
             mockTransportService.addRequestHandlingBehavior(
-                TransportUnpromotableShardRefreshAction.NAME,
+                TransportUnpromotableShardRefreshAction.NAME + "[u]",
                 (handler, request, channel, task) -> {
                     // Skip handling the request and send an immediate empty response
                     channel.sendResponse(ActionResponse.Empty.INSTANCE);
@@ -690,6 +692,63 @@ public class ShardRoutingRoleIT extends ESIntegTestCase {
         }
     }
 
+    public void testRefreshFailsIfUnpromotableDisconnects() throws Exception {
+        var routingTableWatcher = new RoutingTableWatcher();
+        var additionalNumberOfNodesWithUnpromotableShards = 1;
+        routingTableWatcher.numReplicas = routingTableWatcher.numIndexingCopies + additionalNumberOfNodesWithUnpromotableShards - 1;
+        internalCluster().ensureAtLeastNumDataNodes(routingTableWatcher.numIndexingCopies + 1);
+        final String nodeWithUnpromotableOnly = internalCluster().startDataOnlyNode(
+            Settings.builder().put("node.attr." + TestPlugin.NODE_ATTR_UNPROMOTABLE_ONLY, "true").build()
+        );
+        installMockTransportVerifications(routingTableWatcher);
+        getMasterNodePlugin().numIndexingCopies = routingTableWatcher.numIndexingCopies;
+
+        final var masterClusterService = internalCluster().getCurrentMasterNodeInstance(ClusterService.class);
+        try {
+            // verify the correct number of shard copies of each role as the routing table evolves
+            masterClusterService.addListener(routingTableWatcher);
+
+            createIndex(
+                INDEX_NAME,
+                Settings.builder()
+                    .put(routingTableWatcher.getIndexSettings())
+                    .put(IndexSettings.INDEX_CHECK_ON_STARTUP.getKey(), false)
+                    .put(IndexSettings.INDEX_REFRESH_INTERVAL_SETTING.getKey(), -1)
+                    .build()
+            );
+            ensureGreen(INDEX_NAME);
+            assertEngineTypes();
+
+            indexRandom(false, INDEX_NAME, randomIntBetween(1, 10));
+
+            for (var transportService : internalCluster().getInstances(TransportService.class)) {
+                MockTransportService mockTransportService = (MockTransportService) transportService;
+                mockTransportService.addSendBehavior((connection, requestId, action, request, options) -> {
+                    if (action.equals(TransportUnpromotableShardRefreshAction.NAME + "[u]")
+                        && nodeWithUnpromotableOnly.equals(connection.getNode().getName())) {
+                        logger.info("--> preventing {} request by throwing ConnectTransportException", action);
+                        throw new ConnectTransportException(connection.getNode(), "DISCONNECT: prevented " + action + " request");
+                    }
+                    connection.sendRequest(requestId, action, request, options);
+                });
+            }
+
+            RefreshResponse response = client().admin().indices().prepareRefresh(INDEX_NAME).execute().actionGet();
+            assertThat(
+                "each unpromotable replica shard should be added to the shard failures",
+                response.getFailedShards(),
+                equalTo((routingTableWatcher.numReplicas - (routingTableWatcher.numIndexingCopies - 1)) * routingTableWatcher.numShards)
+            );
+            assertThat(
+                "the total shards is incremented with the unpromotable shard failures",
+                response.getTotalShards(),
+                equalTo(response.getSuccessfulShards() + response.getFailedShards())
+            );
+        } finally {
+            masterClusterService.removeListener(routingTableWatcher);
+        }
+    }
+
     public void testNodesWithUnpromotableShardsNeverGetReplicationActions() throws Exception {
         var routingTableWatcher = new RoutingTableWatcher();
         var additionalNumberOfNodesWithUnpromotableShards = randomIntBetween(1, 3);

+ 1 - 0
server/src/main/java/module-info.java

@@ -139,6 +139,7 @@ module org.elasticsearch.server {
     exports org.elasticsearch.action.support;
     exports org.elasticsearch.action.support.broadcast;
     exports org.elasticsearch.action.support.broadcast.node;
+    exports org.elasticsearch.action.support.broadcast.unpromotable;
     exports org.elasticsearch.action.support.master;
     exports org.elasticsearch.action.support.master.info;
     exports org.elasticsearch.action.support.nodes;

+ 46 - 0
server/src/main/java/org/elasticsearch/action/admin/indices/refresh/ShardRefreshReplicaRequest.java

@@ -0,0 +1,46 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.action.admin.indices.refresh;
+
+import org.elasticsearch.action.support.replication.ReplicationRequest;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.index.engine.Engine;
+import org.elasticsearch.index.shard.ShardId;
+
+import java.io.IOException;
+
+/**
+ * A request that is sent to the promotable replicas of a primary shard
+ */
+public class ShardRefreshReplicaRequest extends ReplicationRequest<ShardRefreshReplicaRequest> {
+
+    /**
+     * Holds the refresh result of the primary shard. This will be used by {@link TransportShardRefreshAction} to construct a
+     * {@link UnpromotableShardRefreshRequest} to broadcast to the unpromotable replicas. The refresh result is not serialized to maintain
+     * backwards compatibility for the refresh requests to promotable replicas which do not need the refresh result. For this reason, the
+     * field is package-private.
+     */
+    final Engine.RefreshResult primaryRefreshResult;
+
+    public ShardRefreshReplicaRequest(StreamInput in) throws IOException {
+        super(in);
+        primaryRefreshResult = Engine.RefreshResult.NO_REFRESH;
+    }
+
+    public ShardRefreshReplicaRequest(ShardId shardId, Engine.RefreshResult primaryRefreshResult) {
+        super(shardId);
+        this.primaryRefreshResult = primaryRefreshResult;
+    }
+
+    @Override
+    public String toString() {
+        return "ShardRefreshReplicaRequest{" + shardId + "}";
+    }
+
+}

+ 45 - 50
server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportShardRefreshAction.java

@@ -10,16 +10,15 @@ package org.elasticsearch.action.admin.indices.refresh;
 
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.ActionListenerResponseHandler;
+import org.elasticsearch.action.ActionResponse;
 import org.elasticsearch.action.ActionType;
 import org.elasticsearch.action.support.ActionFilters;
-import org.elasticsearch.action.support.RefCountingListener;
 import org.elasticsearch.action.support.replication.BasicReplicationRequest;
+import org.elasticsearch.action.support.replication.ReplicationOperation;
 import org.elasticsearch.action.support.replication.ReplicationResponse;
 import org.elasticsearch.action.support.replication.TransportReplicationAction;
-import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.cluster.action.shard.ShardStateAction;
-import org.elasticsearch.cluster.node.DiscoveryNode;
-import org.elasticsearch.cluster.routing.ShardRouting;
+import org.elasticsearch.cluster.routing.IndexShardRoutingTable;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.inject.Inject;
 import org.elasticsearch.common.io.stream.StreamInput;
@@ -28,19 +27,14 @@ import org.elasticsearch.index.shard.IndexShard;
 import org.elasticsearch.indices.IndicesService;
 import org.elasticsearch.logging.LogManager;
 import org.elasticsearch.logging.Logger;
-import org.elasticsearch.tasks.Task;
 import org.elasticsearch.threadpool.ThreadPool;
-import org.elasticsearch.transport.TransportRequestOptions;
-import org.elasticsearch.transport.TransportResponse;
 import org.elasticsearch.transport.TransportService;
 
 import java.io.IOException;
-import java.util.function.Predicate;
-import java.util.stream.Collectors;
 
 public class TransportShardRefreshAction extends TransportReplicationAction<
     BasicReplicationRequest,
-    BasicReplicationRequest,
+    ShardRefreshReplicaRequest,
     ReplicationResponse> {
 
     private static final Logger logger = LogManager.getLogger(TransportShardRefreshAction.class);
@@ -69,10 +63,11 @@ public class TransportShardRefreshAction extends TransportReplicationAction<
             shardStateAction,
             actionFilters,
             BasicReplicationRequest::new,
-            BasicReplicationRequest::new,
+            ShardRefreshReplicaRequest::new,
             ThreadPool.Names.REFRESH
         );
-        new TransportUnpromotableShardRefreshAction(transportService, actionFilters, indicesService);
+        // registers the unpromotable version of shard refresh action
+        new TransportUnpromotableShardRefreshAction(clusterService, transportService, actionFilters, indicesService);
     }
 
     @Override
@@ -84,53 +79,53 @@ public class TransportShardRefreshAction extends TransportReplicationAction<
     protected void shardOperationOnPrimary(
         BasicReplicationRequest shardRequest,
         IndexShard primary,
-        ActionListener<PrimaryResult<BasicReplicationRequest, ReplicationResponse>> listener
+        ActionListener<PrimaryResult<ShardRefreshReplicaRequest, ReplicationResponse>> listener
     ) {
-        try (var listeners = new RefCountingListener(listener.map(v -> new PrimaryResult<>(shardRequest, new ReplicationResponse())))) {
-            var refreshResult = primary.refresh(SOURCE_API);
+        ActionListener.completeWith(listener, () -> {
+            ShardRefreshReplicaRequest replicaRequest = new ShardRefreshReplicaRequest(shardRequest.shardId(), primary.refresh(SOURCE_API));
+            replicaRequest.setParentTask(shardRequest.getParentTask());
             logger.trace("{} refresh request executed on primary", primary.shardId());
-
-            // Forward the request to all nodes that hold unpromotable replica shards
-            final ClusterState clusterState = clusterService.state();
-            final Task parentTaskId = taskManager.getTask(shardRequest.getParentTask().getId());
-            clusterState.routingTable()
-                .shardRoutingTable(shardRequest.shardId())
-                .assignedShards()
-                .stream()
-                .filter(Predicate.not(ShardRouting::isPromotableToPrimary))
-                .map(ShardRouting::currentNodeId)
-                .collect(Collectors.toUnmodifiableSet())
-                .forEach(nodeId -> {
-                    final DiscoveryNode node = clusterState.nodes().get(nodeId);
-                    UnpromotableShardRefreshRequest request = new UnpromotableShardRefreshRequest(
-                        primary.shardId(),
-                        refreshResult.generation()
-                    );
-                    logger.trace("forwarding refresh request [{}] to node [{}]", request, node);
-                    transportService.sendChildRequest(
-                        node,
-                        TransportUnpromotableShardRefreshAction.NAME,
-                        request,
-                        parentTaskId,
-                        TransportRequestOptions.EMPTY,
-                        new ActionListenerResponseHandler<>(
-                            listeners.acquire(ignored -> {}),
-                            (in) -> TransportResponse.Empty.INSTANCE,
-                            ThreadPool.Names.REFRESH
-                        )
-                    );
-                });
-        } catch (Exception e) {
-            listener.onFailure(e);
-        }
+            return new PrimaryResult<>(replicaRequest, new ReplicationResponse());
+        });
     }
 
     @Override
-    protected void shardOperationOnReplica(BasicReplicationRequest request, IndexShard replica, ActionListener<ReplicaResult> listener) {
+    protected void shardOperationOnReplica(ShardRefreshReplicaRequest request, IndexShard replica, ActionListener<ReplicaResult> listener) {
         ActionListener.completeWith(listener, () -> {
             replica.refresh(SOURCE_API);
             logger.trace("{} refresh request executed on replica", replica.shardId());
             return new ReplicaResult();
         });
     }
+
+    @Override
+    protected ReplicationOperation.Replicas<ShardRefreshReplicaRequest> newReplicasProxy() {
+        return new UnpromotableReplicasRefreshProxy();
+    }
+
+    protected class UnpromotableReplicasRefreshProxy extends ReplicasProxy {
+
+        @Override
+        public void onPrimaryOperationComplete(
+            ShardRefreshReplicaRequest replicaRequest,
+            IndexShardRoutingTable indexShardRoutingTable,
+            ActionListener<Void> listener
+        ) {
+            assert replicaRequest.primaryRefreshResult.refreshed() : "primary has not refreshed";
+            UnpromotableShardRefreshRequest unpromotableReplicaRequest = new UnpromotableShardRefreshRequest(
+                indexShardRoutingTable,
+                replicaRequest.primaryRefreshResult.generation()
+            );
+            transportService.sendRequest(
+                transportService.getLocalNode(),
+                TransportUnpromotableShardRefreshAction.NAME,
+                unpromotableReplicaRequest,
+                new ActionListenerResponseHandler<>(
+                    listener.delegateFailure((l, r) -> l.onResponse(null)),
+                    (in) -> ActionResponse.Empty.INSTANCE,
+                    ThreadPool.Names.REFRESH
+                )
+            );
+        }
+    }
 }

+ 14 - 9
server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportUnpromotableShardRefreshAction.java

@@ -11,37 +11,42 @@ package org.elasticsearch.action.admin.indices.refresh;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.ActionResponse;
 import org.elasticsearch.action.support.ActionFilters;
-import org.elasticsearch.action.support.HandledTransportAction;
+import org.elasticsearch.action.support.broadcast.unpromotable.TransportBroadcastUnpromotableAction;
+import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.inject.Inject;
-import org.elasticsearch.index.engine.Engine;
 import org.elasticsearch.index.shard.IndexShard;
 import org.elasticsearch.indices.IndicesService;
 import org.elasticsearch.tasks.Task;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.TransportService;
 
-public class TransportUnpromotableShardRefreshAction extends HandledTransportAction<UnpromotableShardRefreshRequest, ActionResponse.Empty> {
-    public static final String NAME = RefreshAction.NAME + "[u]";
+public class TransportUnpromotableShardRefreshAction extends TransportBroadcastUnpromotableAction<UnpromotableShardRefreshRequest> {
+
+    public static final String NAME = RefreshAction.NAME + "/unpromotable";
 
     private final IndicesService indicesService;
 
     @Inject
     public TransportUnpromotableShardRefreshAction(
+        ClusterService clusterService,
         TransportService transportService,
         ActionFilters actionFilters,
         IndicesService indicesService
     ) {
-        super(NAME, transportService, actionFilters, UnpromotableShardRefreshRequest::new, ThreadPool.Names.REFRESH);
+        super(NAME, clusterService, transportService, actionFilters, UnpromotableShardRefreshRequest::new, ThreadPool.Names.REFRESH);
         this.indicesService = indicesService;
     }
 
     @Override
-    protected void doExecute(Task task, UnpromotableShardRefreshRequest request, ActionListener<ActionResponse.Empty> responseListener) {
+    protected void unpromotableShardOperation(
+        Task task,
+        UnpromotableShardRefreshRequest request,
+        ActionListener<ActionResponse.Empty> responseListener
+    ) {
         ActionListener.run(responseListener, listener -> {
-            assert request.getSegmentGeneration() != Engine.RefreshResult.UNKNOWN_GENERATION
-                : "The request segment is " + request.getSegmentGeneration();
-            IndexShard shard = indicesService.indexServiceSafe(request.getShardId().getIndex()).getShard(request.getShardId().id());
+            IndexShard shard = indicesService.indexServiceSafe(request.shardId().getIndex()).getShard(request.shardId().id());
             shard.waitForSegmentGeneration(request.getSegmentGeneration(), listener.map(l -> ActionResponse.Empty.INSTANCE));
         });
     }
+
 }

+ 14 - 14
server/src/main/java/org/elasticsearch/action/admin/indices/refresh/UnpromotableShardRefreshRequest.java

@@ -8,52 +8,52 @@
 
 package org.elasticsearch.action.admin.indices.refresh;
 
-import org.elasticsearch.action.ActionRequest;
 import org.elasticsearch.action.ActionRequestValidationException;
+import org.elasticsearch.action.support.broadcast.unpromotable.BroadcastUnpromotableRequest;
+import org.elasticsearch.cluster.routing.IndexShardRoutingTable;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
-import org.elasticsearch.index.shard.ShardId;
+import org.elasticsearch.index.engine.Engine;
 
 import java.io.IOException;
 
-public class UnpromotableShardRefreshRequest extends ActionRequest {
+import static org.elasticsearch.action.ValidateActions.addValidationError;
+
+public class UnpromotableShardRefreshRequest extends BroadcastUnpromotableRequest {
 
-    private final ShardId shardId;
     private final long segmentGeneration;
 
-    public UnpromotableShardRefreshRequest(final ShardId shardId, long segmentGeneration) {
-        this.shardId = shardId;
+    public UnpromotableShardRefreshRequest(IndexShardRoutingTable indexShardRoutingTable, long segmentGeneration) {
+        super(indexShardRoutingTable);
         this.segmentGeneration = segmentGeneration;
     }
 
     public UnpromotableShardRefreshRequest(StreamInput in) throws IOException {
         super(in);
-        shardId = new ShardId(in);
         segmentGeneration = in.readVLong();
     }
 
     @Override
     public ActionRequestValidationException validate() {
-        return null;
+        ActionRequestValidationException validationException = super.validate();
+        if (segmentGeneration == Engine.RefreshResult.UNKNOWN_GENERATION) {
+            validationException = addValidationError("segment generation is unknown", validationException);
+        }
+        return validationException;
     }
 
     @Override
     public void writeTo(StreamOutput out) throws IOException {
         super.writeTo(out);
-        shardId.writeTo(out);
         out.writeVLong(segmentGeneration);
     }
 
-    public ShardId getShardId() {
-        return shardId;
-    }
-
     public long getSegmentGeneration() {
         return segmentGeneration;
     }
 
     @Override
     public String toString() {
-        return "UnpromotableShardRefreshRequest{" + "shardId=" + shardId + ", segmentGeneration=" + segmentGeneration + '}';
+        return "UnpromotableShardRefreshRequest{" + "shardId=" + shardId() + ", segmentGeneration=" + segmentGeneration + '}';
     }
 }

+ 93 - 0
server/src/main/java/org/elasticsearch/action/support/broadcast/unpromotable/BroadcastUnpromotableRequest.java

@@ -0,0 +1,93 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.action.support.broadcast.unpromotable;
+
+import org.elasticsearch.action.ActionRequest;
+import org.elasticsearch.action.ActionRequestValidationException;
+import org.elasticsearch.action.IndicesRequest;
+import org.elasticsearch.action.support.IndicesOptions;
+import org.elasticsearch.cluster.routing.IndexShardRoutingTable;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.index.shard.ShardId;
+
+import java.io.IOException;
+import java.util.Objects;
+
+import static org.elasticsearch.action.ValidateActions.addValidationError;
+import static org.elasticsearch.action.support.IndicesOptions.strictSingleIndexNoExpandForbidClosed;
+
+/**
+ * A request that is broadcast to the unpromotable assigned replicas of a primary.
+ */
+public class BroadcastUnpromotableRequest extends ActionRequest implements IndicesRequest {
+
+    /**
+     * Holds the index shard routing table that will be used by {@link TransportBroadcastUnpromotableAction} to broadcast the requests to
+     * the unpromotable replicas. The routing table is not serialized over the wire, and will be null on the other end of the wire.
+     * For this reason, the field is package-private.
+     */
+    final @Nullable IndexShardRoutingTable indexShardRoutingTable;
+
+    protected final ShardId shardId;
+    protected final String[] indices;
+
+    public BroadcastUnpromotableRequest(StreamInput in) throws IOException {
+        super(in);
+        indexShardRoutingTable = null;
+        shardId = new ShardId(in);
+        indices = new String[] { shardId.getIndex().getName() };
+    }
+
+    public BroadcastUnpromotableRequest(IndexShardRoutingTable indexShardRoutingTable) {
+        this.indexShardRoutingTable = Objects.requireNonNull(indexShardRoutingTable, "index shard routing table is null");
+        this.shardId = indexShardRoutingTable.shardId();
+        this.indices = new String[] { this.shardId.getIndex().getName() };
+    }
+
+    public ShardId shardId() {
+        return shardId;
+    }
+
+    @Override
+    public ActionRequestValidationException validate() {
+        ActionRequestValidationException validationException = null;
+        if (shardId == null) {
+            validationException = addValidationError("shard id is missing", validationException);
+        }
+        return validationException;
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        super.writeTo(out);
+        out.writeWriteable(shardId);
+    }
+
+    @Override
+    public String toString() {
+        return "BroadcastUnpromotableRequest{shardId=" + shardId() + '}';
+    }
+
+    @Override
+    public String getDescription() {
+        return toString();
+    }
+
+    @Override
+    public String[] indices() {
+        return indices;
+    }
+
+    @Override
+    public IndicesOptions indicesOptions() {
+        return strictSingleIndexNoExpandForbidClosed();
+    }
+}

+ 94 - 0
server/src/main/java/org/elasticsearch/action/support/broadcast/unpromotable/TransportBroadcastUnpromotableAction.java

@@ -0,0 +1,94 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.action.support.broadcast.unpromotable;
+
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.ActionListenerResponseHandler;
+import org.elasticsearch.action.ActionResponse;
+import org.elasticsearch.action.support.ActionFilters;
+import org.elasticsearch.action.support.ChannelActionListener;
+import org.elasticsearch.action.support.HandledTransportAction;
+import org.elasticsearch.action.support.RefCountingListener;
+import org.elasticsearch.cluster.ClusterState;
+import org.elasticsearch.cluster.node.DiscoveryNode;
+import org.elasticsearch.cluster.service.ClusterService;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.tasks.Task;
+import org.elasticsearch.transport.TransportChannel;
+import org.elasticsearch.transport.TransportRequestHandler;
+import org.elasticsearch.transport.TransportRequestOptions;
+import org.elasticsearch.transport.TransportResponse;
+import org.elasticsearch.transport.TransportService;
+
+public abstract class TransportBroadcastUnpromotableAction<Request extends BroadcastUnpromotableRequest> extends HandledTransportAction<
+    Request,
+    ActionResponse.Empty> {
+
+    protected final ClusterService clusterService;
+    protected final TransportService transportService;
+
+    protected final String transportUnpromotableAction;
+    protected final String executor;
+
+    protected TransportBroadcastUnpromotableAction(
+        String actionName,
+        ClusterService clusterService,
+        TransportService transportService,
+        ActionFilters actionFilters,
+        Writeable.Reader<Request> requestReader,
+        String executor
+    ) {
+        super(actionName, transportService, actionFilters, requestReader);
+        this.clusterService = clusterService;
+        this.transportService = transportService;
+        this.transportUnpromotableAction = actionName + "[u]";
+        this.executor = executor;
+
+        transportService.registerRequestHandler(transportUnpromotableAction, executor, requestReader, new UnpromotableTransportHandler());
+    }
+
+    protected abstract void unpromotableShardOperation(Task task, Request request, ActionListener<ActionResponse.Empty> listener);
+
+    @Override
+    protected void doExecute(Task task, Request request, ActionListener<ActionResponse.Empty> listener) {
+        try (var listeners = new RefCountingListener(listener.map(v -> ActionResponse.Empty.INSTANCE))) {
+            ActionListener.completeWith(listeners.acquire(), () -> {
+                final ClusterState clusterState = clusterService.state();
+                if (task != null) {
+                    request.setParentTask(clusterService.localNode().getId(), task.getId());
+                }
+                request.indexShardRoutingTable.unpromotableShards().forEach(shardRouting -> {
+                    final DiscoveryNode node = clusterState.nodes().get(shardRouting.currentNodeId());
+                    transportService.sendRequest(
+                        node,
+                        transportUnpromotableAction,
+                        request,
+                        TransportRequestOptions.EMPTY,
+                        new ActionListenerResponseHandler<>(
+                            listeners.acquire(ignored -> {}),
+                            (in) -> TransportResponse.Empty.INSTANCE,
+                            executor
+                        )
+                    );
+                });
+                return null;
+            });
+        }
+    }
+
+    class UnpromotableTransportHandler implements TransportRequestHandler<Request> {
+
+        @Override
+        public void messageReceived(Request request, TransportChannel channel, Task task) throws Exception {
+            final ActionListener<ActionResponse.Empty> listener = new ChannelActionListener<>(channel);
+            ActionListener.run(listener, (l) -> unpromotableShardOperation(task, request, l));
+        }
+
+    }
+}

+ 37 - 1
server/src/main/java/org/elasticsearch/action/support/replication/ReplicationOperation.java

@@ -132,6 +132,27 @@ public class ReplicationOperation<
             if (logger.isTraceEnabled()) {
                 logger.trace("[{}] op [{}] completed on primary for request [{}]", primary.routingEntry().shardId(), opType, request);
             }
+            final ReplicationGroup replicationGroup = primary.getReplicationGroup();
+
+            pendingActions.incrementAndGet();
+            replicasProxy.onPrimaryOperationComplete(
+                replicaRequest,
+                replicationGroup.getRoutingTable(),
+                ActionListener.wrap(ignored -> decPendingAndFinishIfNeeded(), exception -> {
+                    totalShards.incrementAndGet();
+                    shardReplicaFailures.add(
+                        new ReplicationResponse.ShardInfo.Failure(
+                            primary.routingEntry().shardId(),
+                            null,
+                            exception,
+                            ExceptionsHelper.status(exception),
+                            false
+                        )
+                    );
+                    decPendingAndFinishIfNeeded();
+                })
+            );
+
             // we have to get the replication group after successfully indexing into the primary in order to honour recovery semantics.
             // we have to make sure that every operation indexed into the primary after recovery start will also be replicated
             // to the recovery target. If we used an old replication group, we may miss a recovery that has started since then.
@@ -145,7 +166,6 @@ public class ReplicationOperation<
             // on.
             final long maxSeqNoOfUpdatesOrDeletes = primary.maxSeqNoOfUpdatesOrDeletes();
             assert maxSeqNoOfUpdatesOrDeletes != SequenceNumbers.UNASSIGNED_SEQ_NO : "seqno_of_updates still uninitialized";
-            final ReplicationGroup replicationGroup = primary.getReplicationGroup();
             final PendingReplicationActions pendingReplicationActions = primary.getPendingReplicationActions();
             markUnavailableShardsAsStale(replicaRequest, replicationGroup);
             performOnReplicas(replicaRequest, globalCheckpoint, maxSeqNoOfUpdatesOrDeletes, replicationGroup, pendingReplicationActions);
@@ -601,6 +621,21 @@ public class ReplicationOperation<
          * @param listener     a listener that will be notified when the failing shard has been removed from the in-sync set
          */
         void markShardCopyAsStaleIfNeeded(ShardId shardId, String allocationId, long primaryTerm, ActionListener<Void> listener);
+
+        /**
+         * Optional custom logic to execute when the primary operation is complete, before sending the replica requests.
+         *
+         * @param replicaRequest             the operation that will be performed on replicas
+         * @param indexShardRoutingTable     the replication's group index shard routing table
+         * @param listener                   callback for handling the response or failure
+         */
+        default void onPrimaryOperationComplete(
+            RequestT replicaRequest,
+            IndexShardRoutingTable indexShardRoutingTable,
+            ActionListener<Void> listener
+        ) {
+            listener.onResponse(null);
+        }
     }
 
     /**
@@ -656,4 +691,5 @@ public class ReplicationOperation<
          * */
         void runPostReplicationActions(ActionListener<Void> listener);
     }
+
 }

+ 18 - 0
server/src/main/java/org/elasticsearch/cluster/routing/IndexShardRoutingTable.java

@@ -50,6 +50,7 @@ public class IndexShardRoutingTable {
     final List<ShardRouting> replicas;
     final List<ShardRouting> activeShards;
     final List<ShardRouting> assignedShards;
+    private final List<ShardRouting> unpromotableShards;
     /**
      * The initializing list, including ones that are initializing on a target node because of relocation.
      * If we can come up with a better variable name, it would be nice...
@@ -68,6 +69,7 @@ public class IndexShardRoutingTable {
         List<ShardRouting> replicas = new ArrayList<>();
         List<ShardRouting> activeShards = new ArrayList<>();
         List<ShardRouting> assignedShards = new ArrayList<>();
+        List<ShardRouting> unpromotableShards = new ArrayList<>();
         List<ShardRouting> allInitializingShards = new ArrayList<>();
         boolean allShardsStarted = true;
         int activeSearchShardCount = 0;
@@ -97,9 +99,15 @@ public class IndexShardRoutingTable {
                 assert shard.assignedToNode() : "relocating from unassigned " + shard;
                 assert shard.getTargetRelocatingShard().assignedToNode() : "relocating to unassigned " + shard.getTargetRelocatingShard();
                 assignedShards.add(shard.getTargetRelocatingShard());
+                if (shard.getTargetRelocatingShard().isPromotableToPrimary() == false) {
+                    unpromotableShards.add(shard.getTargetRelocatingShard());
+                }
             }
             if (shard.assignedToNode()) {
                 assignedShards.add(shard);
+                if (shard.isPromotableToPrimary() == false) {
+                    unpromotableShards.add(shard);
+                }
             }
             if (shard.state() != ShardRoutingState.STARTED) {
                 allShardsStarted = false;
@@ -109,6 +117,7 @@ public class IndexShardRoutingTable {
         this.replicas = CollectionUtils.wrapUnmodifiableOrEmptySingleton(replicas);
         this.activeShards = CollectionUtils.wrapUnmodifiableOrEmptySingleton(activeShards);
         this.assignedShards = CollectionUtils.wrapUnmodifiableOrEmptySingleton(assignedShards);
+        this.unpromotableShards = CollectionUtils.wrapUnmodifiableOrEmptySingleton(unpromotableShards);
         this.allInitializingShards = CollectionUtils.wrapUnmodifiableOrEmptySingleton(allInitializingShards);
         this.allShardsStarted = allShardsStarted;
         this.activeSearchShardCount = activeSearchShardCount;
@@ -162,6 +171,15 @@ public class IndexShardRoutingTable {
         return this.assignedShards;
     }
 
+    /**
+     * Returns a {@link List} of assigned unpromotable shards, including relocation targets
+     *
+     * @return a {@link List} of shards
+     */
+    public List<ShardRouting> unpromotableShards() {
+        return this.unpromotableShards;
+    }
+
     public ShardIterator shardsRandomIt() {
         return new PlainShardIterator(shardId, shuffler.shuffle(Arrays.asList(shards)));
     }

+ 4 - 1
server/src/main/java/org/elasticsearch/transport/TransportService.java

@@ -1540,7 +1540,10 @@ public class TransportService extends AbstractLifecycleComponent
     }
 
     private boolean isLocalNode(DiscoveryNode discoveryNode) {
-        return Objects.requireNonNull(discoveryNode, "discovery node must not be null").equals(localNode);
+        if (discoveryNode == null) {
+            throw new NodeNotConnectedException(discoveryNode, "discovery node must not be null");
+        }
+        return discoveryNode.equals(localNode);
     }
 
     private static final class DelegatingTransportMessageListener implements TransportMessageListener {

+ 322 - 0
server/src/test/java/org/elasticsearch/action/support/broadcast/unpromotable/TransportBroadcastUnpromotableActionTests.java

@@ -0,0 +1,322 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.action.support.broadcast.unpromotable;
+
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.ActionResponse;
+import org.elasticsearch.action.support.ActionFilters;
+import org.elasticsearch.action.support.ActionTestUtils;
+import org.elasticsearch.action.support.PlainActionFuture;
+import org.elasticsearch.cluster.ClusterState;
+import org.elasticsearch.cluster.routing.IndexShardRoutingTable;
+import org.elasticsearch.cluster.routing.ShardRouting;
+import org.elasticsearch.cluster.routing.ShardRoutingState;
+import org.elasticsearch.cluster.service.ClusterService;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.core.IOUtils;
+import org.elasticsearch.core.Tuple;
+import org.elasticsearch.index.shard.ShardId;
+import org.elasticsearch.tasks.Task;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.test.transport.CapturingTransport;
+import org.elasticsearch.threadpool.TestThreadPool;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.transport.NodeNotConnectedException;
+import org.elasticsearch.transport.TransportService;
+import org.junit.After;
+import org.junit.AfterClass;
+import org.junit.Before;
+import org.junit.BeforeClass;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.TimeUnit;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+import static org.elasticsearch.action.support.replication.ClusterStateCreationUtils.state;
+import static org.elasticsearch.action.support.replication.ClusterStateCreationUtils.stateWithAssignedPrimariesAndReplicas;
+import static org.elasticsearch.cluster.routing.TestShardRouting.newShardRouting;
+import static org.elasticsearch.test.ClusterServiceUtils.createClusterService;
+import static org.elasticsearch.test.ClusterServiceUtils.setState;
+import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.is;
+
+public class TransportBroadcastUnpromotableActionTests extends ESTestCase {
+
+    private static ThreadPool THREAD_POOL;
+    private ClusterService clusterService;
+    private TransportService transportService;
+    private CapturingTransport transport;
+    private TestTransportBroadcastUnpromotableAction broadcastUnpromotableAction;
+
+    @BeforeClass
+    public static void beforeClass() {
+        THREAD_POOL = new TestThreadPool(TransportBroadcastUnpromotableActionTests.class.getSimpleName());
+    }
+
+    @Override
+    @Before
+    public void setUp() throws Exception {
+        super.setUp();
+        transport = new CapturingTransport();
+        clusterService = createClusterService(THREAD_POOL);
+        transportService = transport.createTransportService(
+            clusterService.getSettings(),
+            THREAD_POOL,
+            TransportService.NOOP_TRANSPORT_INTERCEPTOR,
+            x -> clusterService.localNode(),
+            null,
+            Collections.emptySet()
+        );
+        transportService.start();
+        transportService.acceptIncomingRequests();
+        broadcastUnpromotableAction = new TestTransportBroadcastUnpromotableAction();
+    }
+
+    @Override
+    @After
+    public void tearDown() throws Exception {
+        super.tearDown();
+        IOUtils.close(clusterService, transportService);
+    }
+
+    @AfterClass
+    public static void afterClass() {
+        ThreadPool.terminate(THREAD_POOL, 30, TimeUnit.SECONDS);
+        THREAD_POOL = null;
+    }
+
+    private class TestTransportBroadcastUnpromotableAction extends TransportBroadcastUnpromotableAction<TestBroadcastUnpromotableRequest> {
+
+        TestTransportBroadcastUnpromotableAction() {
+            super(
+                "indices:admin/test",
+                TransportBroadcastUnpromotableActionTests.this.clusterService,
+                TransportBroadcastUnpromotableActionTests.this.transportService,
+                new ActionFilters(Set.of()),
+                TestBroadcastUnpromotableRequest::new,
+                ThreadPool.Names.SAME
+            );
+        }
+
+        @Override
+        protected void unpromotableShardOperation(
+            Task task,
+            TestBroadcastUnpromotableRequest request,
+            ActionListener<ActionResponse.Empty> listener
+        ) {
+            assert false : "not reachable in these tests";
+        }
+
+    }
+
+    private static class TestBroadcastUnpromotableRequest extends BroadcastUnpromotableRequest {
+
+        TestBroadcastUnpromotableRequest(StreamInput in) throws IOException {
+            super(in);
+        }
+
+        TestBroadcastUnpromotableRequest(IndexShardRoutingTable indexShardRoutingTable) {
+            super(indexShardRoutingTable);
+        }
+
+    }
+
+    private static List<ShardRouting.Role> getReplicaRoles(int numPromotableReplicas, int numSearchReplicas) {
+        List<ShardRouting.Role> replicaRoles = Stream.concat(
+            Collections.nCopies(numPromotableReplicas, randomBoolean() ? ShardRouting.Role.DEFAULT : ShardRouting.Role.INDEX_ONLY).stream(),
+            Collections.nCopies(numSearchReplicas, ShardRouting.Role.SEARCH_ONLY).stream()
+        ).collect(Collectors.toList());
+        Collections.shuffle(replicaRoles, random());
+        return replicaRoles;
+    }
+
+    private static List<Tuple<ShardRoutingState, ShardRouting.Role>> getReplicaRolesWithRandomStates(
+        int numPromotableReplicas,
+        int numSearchReplicas,
+        List<ShardRoutingState> possibleStates
+    ) {
+        return getReplicaRoles(numPromotableReplicas, numSearchReplicas).stream()
+            .map(role -> new Tuple<>(randomFrom(possibleStates), role))
+            .collect(Collectors.toList());
+    }
+
+    private static List<Tuple<ShardRoutingState, ShardRouting.Role>> getReplicaRolesWithRandomStates(
+        int numPromotableReplicas,
+        int numSearchReplicas
+    ) {
+        return getReplicaRolesWithRandomStates(
+            numPromotableReplicas,
+            numSearchReplicas,
+            Arrays.stream(ShardRoutingState.values()).toList()
+        );
+    }
+
+    private static List<Tuple<ShardRoutingState, ShardRouting.Role>> getReplicaRolesWithState(
+        int numPromotableReplicas,
+        int numSearchReplicas,
+        ShardRoutingState state
+    ) {
+        return getReplicaRolesWithRandomStates(numPromotableReplicas, numSearchReplicas, List.of(state));
+    }
+
+    private int countRequestsForIndex(ClusterState state, String index) {
+        PlainActionFuture<ActionResponse.Empty> response = PlainActionFuture.newFuture();
+        state.routingTable().activePrimaryShardsGrouped(new String[] { index }, true).iterator().forEachRemaining(shardId -> {
+            logger.debug("--> executing for primary shard id: {}", shardId.shardId());
+            ActionTestUtils.execute(
+                broadcastUnpromotableAction,
+                null,
+                new TestBroadcastUnpromotableRequest(state.routingTable().shardRoutingTable(shardId.shardId())),
+                response
+            );
+        });
+
+        Map<String, List<CapturingTransport.CapturedRequest>> capturedRequests = transport.getCapturedRequestsByTargetNodeAndClear();
+        int totalRequests = 0;
+        for (Map.Entry<String, List<CapturingTransport.CapturedRequest>> entry : capturedRequests.entrySet()) {
+            logger.debug("Captured requests for node [{}] are: [{}]", entry.getKey(), entry.getValue());
+            totalRequests += entry.getValue().size();
+        }
+        return totalRequests;
+    }
+
+    public void testNotStartedPrimary() throws Exception {
+        final String index = "test";
+        final int numPromotableReplicas = randomInt(2);
+        final int numSearchReplicas = randomInt(2);
+        final ClusterState state = state(
+            index,
+            randomBoolean(),
+            randomBoolean() ? ShardRoutingState.INITIALIZING : ShardRoutingState.UNASSIGNED,
+            getReplicaRolesWithState(numPromotableReplicas, numSearchReplicas, ShardRoutingState.UNASSIGNED)
+        );
+        setState(clusterService, state);
+        logger.debug("--> using initial state:\n{}", clusterService.state());
+        assertThat(countRequestsForIndex(state, index), is(equalTo(0)));
+    }
+
+    public void testMixOfStartedPromotableAndSearchReplicas() throws Exception {
+        final String index = "test";
+        final int numShards = 1 + randomInt(3);
+        final int numPromotableReplicas = randomInt(2);
+        final int numSearchReplicas = randomInt(2);
+
+        ClusterState state = stateWithAssignedPrimariesAndReplicas(
+            new String[] { index },
+            numShards,
+            getReplicaRoles(numPromotableReplicas, numSearchReplicas)
+        );
+        setState(clusterService, state);
+        logger.debug("--> using initial state:\n{}", clusterService.state());
+        assertThat(countRequestsForIndex(state, index), is(equalTo(numShards * numSearchReplicas)));
+    }
+
+    public void testSearchReplicasWithRandomStates() throws Exception {
+        final String index = "test";
+        final int numPromotableReplicas = randomInt(2);
+        final int numSearchReplicas = randomInt(6);
+
+        List<Tuple<ShardRoutingState, ShardRouting.Role>> replicas = getReplicaRolesWithRandomStates(
+            numPromotableReplicas,
+            numSearchReplicas
+        );
+        int numReachableUnpromotables = replicas.stream().mapToInt(t -> {
+            if (t.v2() == ShardRouting.Role.SEARCH_ONLY && t.v1() != ShardRoutingState.UNASSIGNED) {
+                if (t.v1() == ShardRoutingState.RELOCATING) {
+                    return 2; // accounts for both the RELOCATING and the INITIALIZING copies
+                }
+                return 1;
+            }
+            return 0;
+        }).sum();
+
+        final ClusterState state = state(index, true, ShardRoutingState.STARTED, replicas);
+
+        setState(clusterService, state);
+        logger.debug("--> using initial state:\n{}", clusterService.state());
+        assertThat(countRequestsForIndex(state, index), is(equalTo(numReachableUnpromotables)));
+    }
+
+    public void testInvalidNodes() throws Exception {
+        final String index = "test";
+        ClusterState state = stateWithAssignedPrimariesAndReplicas(
+            new String[] { index },
+            randomIntBetween(1, 3),
+            getReplicaRoles(randomInt(2), randomIntBetween(1, 2))
+        );
+        setState(clusterService, state);
+        logger.debug("--> using initial state:\n{}", clusterService.state());
+
+        ShardId shardId = state.routingTable().activePrimaryShardsGrouped(new String[] { index }, true).get(0).shardId();
+        IndexShardRoutingTable routingTable = state.routingTable().shardRoutingTable(shardId);
+        IndexShardRoutingTable.Builder wrongRoutingTableBuilder = new IndexShardRoutingTable.Builder(shardId);
+        for (int i = 0; i < routingTable.size(); i++) {
+            ShardRouting shardRouting = routingTable.shard(i);
+            ShardRouting wrongShardRouting = newShardRouting(
+                shardId,
+                shardRouting.currentNodeId() + randomIntBetween(10, 100),
+                shardRouting.relocatingNodeId(),
+                shardRouting.primary(),
+                shardRouting.state(),
+                shardRouting.unassignedInfo(),
+                shardRouting.role()
+            );
+            wrongRoutingTableBuilder.addShard(wrongShardRouting);
+        }
+        IndexShardRoutingTable wrongRoutingTable = wrongRoutingTableBuilder.build();
+
+        PlainActionFuture<ActionResponse.Empty> response = PlainActionFuture.newFuture();
+        logger.debug("--> executing for wrong shard routing table: {}", wrongRoutingTable);
+        assertThat(
+            expectThrows(
+                NodeNotConnectedException.class,
+                () -> PlainActionFuture.<ActionResponse.Empty, Exception>get(
+                    f -> ActionTestUtils.execute(
+                        broadcastUnpromotableAction,
+                        null,
+                        new TestBroadcastUnpromotableRequest(wrongRoutingTable),
+                        f
+                    ),
+                    10,
+                    TimeUnit.SECONDS
+                )
+            ).toString(),
+            containsString("discovery node must not be null")
+        );
+    }
+
+    public void testNullIndexShardRoutingTable() throws Exception {
+        PlainActionFuture<ActionResponse.Empty> response = PlainActionFuture.newFuture();
+        IndexShardRoutingTable shardRoutingTable = null;
+        assertThat(
+            expectThrows(
+                NullPointerException.class,
+                () -> PlainActionFuture.<ActionResponse.Empty, Exception>get(
+                    f -> ActionTestUtils.execute(
+                        broadcastUnpromotableAction,
+                        null,
+                        new TestBroadcastUnpromotableRequest(shardRoutingTable),
+                        f
+                    ),
+                    10,
+                    TimeUnit.SECONDS
+                )
+            ).toString(),
+            containsString("index shard routing table is null")
+        );
+    }
+
+}

+ 71 - 13
test/framework/src/main/java/org/elasticsearch/action/support/replication/ClusterStateCreationUtils.java

@@ -26,6 +26,7 @@ import org.elasticsearch.cluster.routing.ShardRoutingState;
 import org.elasticsearch.cluster.routing.TestShardRouting;
 import org.elasticsearch.cluster.routing.UnassignedInfo;
 import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.core.Tuple;
 import org.elasticsearch.health.node.selection.HealthNode;
 import org.elasticsearch.health.node.selection.HealthNodeTaskParams;
 import org.elasticsearch.index.shard.IndexLongFieldRange;
@@ -67,20 +68,46 @@ public class ClusterStateCreationUtils {
         boolean activePrimaryLocal,
         ShardRoutingState primaryState,
         ShardRoutingState... replicaStates
+    ) {
+        return state(
+            index,
+            activePrimaryLocal,
+            primaryState,
+            Arrays.stream(replicaStates).map(shardRoutingState -> new Tuple<>(shardRoutingState, ShardRouting.Role.DEFAULT)).toList()
+        );
+    }
+
+    /**
+     * Creates cluster state with and index that has one shard and #(replicaStates) replicas with given roles
+     *
+     * @param index              name of the index
+     * @param activePrimaryLocal if active primary should coincide with the local node in the cluster state
+     * @param primaryState       state of primary
+     * @param replicaStates      states and roles of the replicas. length of this collection determines also the number of replicas
+     */
+    public static ClusterState state(
+        String index,
+        boolean activePrimaryLocal,
+        ShardRoutingState primaryState,
+        List<Tuple<ShardRoutingState, ShardRouting.Role>> replicaStates
     ) {
         assert primaryState == ShardRoutingState.STARTED
             || primaryState == ShardRoutingState.RELOCATING
-            || Arrays.stream(replicaStates).allMatch(s -> s == ShardRoutingState.UNASSIGNED)
-            : "invalid shard states [" + primaryState + "] vs [" + Arrays.toString(replicaStates) + "]";
+            || replicaStates.stream().allMatch(s -> s.v1() == ShardRoutingState.UNASSIGNED)
+            : "invalid shard states ["
+                + primaryState
+                + "] vs ["
+                + Arrays.toString(replicaStates.stream().map(t -> t.v1()).toArray(String[]::new))
+                + "]";
 
-        final int numberOfReplicas = replicaStates.length;
+        final int numberOfReplicas = replicaStates.size();
 
         int numberOfNodes = numberOfReplicas + 1;
         if (primaryState == ShardRoutingState.RELOCATING) {
             numberOfNodes++;
         }
-        for (ShardRoutingState state : replicaStates) {
-            if (state == ShardRoutingState.RELOCATING) {
+        for (var state : replicaStates) {
+            if (state.v1() == ShardRoutingState.RELOCATING) {
                 numberOfNodes++;
             }
         }
@@ -139,21 +166,30 @@ public class ClusterStateCreationUtils {
             TestShardRouting.newShardRouting(index, 0, primaryNode, relocatingNode, true, primaryState, unassignedInfo)
         );
 
-        for (ShardRoutingState replicaState : replicaStates) {
+        for (var replicaState : replicaStates) {
             String replicaNode = null;
             relocatingNode = null;
             unassignedInfo = null;
-            if (replicaState != ShardRoutingState.UNASSIGNED) {
+            if (replicaState.v1() != ShardRoutingState.UNASSIGNED) {
                 assert primaryNode != null : "a replica is assigned but the primary isn't";
                 replicaNode = selectAndRemove(unassignedNodes);
-                if (replicaState == ShardRoutingState.RELOCATING) {
+                if (replicaState.v1() == ShardRoutingState.RELOCATING) {
                     relocatingNode = selectAndRemove(unassignedNodes);
                 }
             } else {
                 unassignedInfo = new UnassignedInfo(UnassignedInfo.Reason.INDEX_CREATED, null);
             }
             indexShardRoutingBuilder.addShard(
-                TestShardRouting.newShardRouting(index, shardId.id(), replicaNode, relocatingNode, false, replicaState, unassignedInfo)
+                TestShardRouting.newShardRouting(
+                    index,
+                    shardId.id(),
+                    replicaNode,
+                    relocatingNode,
+                    false,
+                    replicaState.v1(),
+                    unassignedInfo,
+                    replicaState.v2()
+                )
             );
         }
         final IndexShardRoutingTable indexShardRoutingTable = indexShardRoutingBuilder.build();
@@ -316,8 +352,22 @@ public class ClusterStateCreationUtils {
      * Creates cluster state with several indexes, shards and replicas and all shards STARTED.
      */
     public static ClusterState stateWithAssignedPrimariesAndReplicas(String[] indices, int numberOfShards, int numberOfReplicas) {
+        return stateWithAssignedPrimariesAndReplicas(
+            indices,
+            numberOfShards,
+            Collections.nCopies(numberOfReplicas, ShardRouting.Role.DEFAULT)
+        );
+    }
 
-        int numberOfDataNodes = numberOfReplicas + 1;
+    /**
+     * Creates cluster state with several indexes, shards and replicas (with given roles) and all shards STARTED.
+     */
+    public static ClusterState stateWithAssignedPrimariesAndReplicas(
+        String[] indices,
+        int numberOfShards,
+        List<ShardRouting.Role> replicaRoles
+    ) {
+        int numberOfDataNodes = replicaRoles.size() + 1;
         DiscoveryNodes.Builder discoBuilder = DiscoveryNodes.builder();
         for (int i = 0; i < numberOfDataNodes + 1; i++) {
             final DiscoveryNode node = newNode(i);
@@ -340,7 +390,7 @@ public class ClusterStateCreationUtils {
                     Settings.builder()
                         .put(SETTING_VERSION_CREATED, Version.CURRENT)
                         .put(SETTING_NUMBER_OF_SHARDS, numberOfShards)
-                        .put(SETTING_NUMBER_OF_REPLICAS, numberOfReplicas)
+                        .put(SETTING_NUMBER_OF_REPLICAS, replicaRoles.size())
                         .put(SETTING_CREATION_DATE, System.currentTimeMillis())
                 )
                 .timestampRange(IndexLongFieldRange.UNKNOWN)
@@ -353,9 +403,17 @@ public class ClusterStateCreationUtils {
                 indexShardRoutingBuilder.addShard(
                     TestShardRouting.newShardRouting(index, i, newNode(0).getId(), null, true, ShardRoutingState.STARTED)
                 );
-                for (int replica = 0; replica < numberOfReplicas; replica++) {
+                for (int replica = 0; replica < replicaRoles.size(); replica++) {
                     indexShardRoutingBuilder.addShard(
-                        TestShardRouting.newShardRouting(index, i, newNode(replica + 1).getId(), null, false, ShardRoutingState.STARTED)
+                        TestShardRouting.newShardRouting(
+                            index,
+                            i,
+                            newNode(replica + 1).getId(),
+                            null,
+                            false,
+                            ShardRoutingState.STARTED,
+                            replicaRoles.get(replica)
+                        )
                     );
                 }
                 indexRoutingTableBuilder.addIndexShard(indexShardRoutingBuilder);

+ 61 - 4
test/framework/src/main/java/org/elasticsearch/cluster/routing/TestShardRouting.java

@@ -85,7 +85,27 @@ public class TestShardRouting {
             currentNodeId,
             relocatingNodeId,
             primary,
-            state
+            state,
+            ShardRouting.Role.DEFAULT
+        );
+    }
+
+    public static ShardRouting newShardRouting(
+        String index,
+        int shardId,
+        String currentNodeId,
+        String relocatingNodeId,
+        boolean primary,
+        ShardRoutingState state,
+        ShardRouting.Role role
+    ) {
+        return newShardRouting(
+            new ShardId(index, IndexMetadata.INDEX_UUID_NA_VALUE, shardId),
+            currentNodeId,
+            relocatingNodeId,
+            primary,
+            state,
+            role
         );
     }
 
@@ -95,6 +115,17 @@ public class TestShardRouting {
         String relocatingNodeId,
         boolean primary,
         ShardRoutingState state
+    ) {
+        return newShardRouting(shardId, currentNodeId, relocatingNodeId, primary, state, ShardRouting.Role.DEFAULT);
+    }
+
+    public static ShardRouting newShardRouting(
+        ShardId shardId,
+        String currentNodeId,
+        String relocatingNodeId,
+        boolean primary,
+        ShardRoutingState state,
+        ShardRouting.Role role
     ) {
         return new ShardRouting(
             shardId,
@@ -107,7 +138,7 @@ public class TestShardRouting {
             buildRelocationFailureInfo(state),
             buildAllocationId(state),
             -1,
-            ShardRouting.Role.DEFAULT
+            role
         );
     }
 
@@ -161,6 +192,19 @@ public class TestShardRouting {
         boolean primary,
         ShardRoutingState state,
         UnassignedInfo unassignedInfo
+    ) {
+        return newShardRouting(index, shardId, currentNodeId, relocatingNodeId, primary, state, unassignedInfo, ShardRouting.Role.DEFAULT);
+    }
+
+    public static ShardRouting newShardRouting(
+        String index,
+        int shardId,
+        String currentNodeId,
+        String relocatingNodeId,
+        boolean primary,
+        ShardRoutingState state,
+        UnassignedInfo unassignedInfo,
+        ShardRouting.Role role
     ) {
         return newShardRouting(
             new ShardId(index, IndexMetadata.INDEX_UUID_NA_VALUE, shardId),
@@ -168,7 +212,8 @@ public class TestShardRouting {
             relocatingNodeId,
             primary,
             state,
-            unassignedInfo
+            unassignedInfo,
+            role
         );
     }
 
@@ -179,6 +224,18 @@ public class TestShardRouting {
         boolean primary,
         ShardRoutingState state,
         UnassignedInfo unassignedInfo
+    ) {
+        return newShardRouting(shardId, currentNodeId, relocatingNodeId, primary, state, unassignedInfo, ShardRouting.Role.DEFAULT);
+    }
+
+    public static ShardRouting newShardRouting(
+        ShardId shardId,
+        String currentNodeId,
+        String relocatingNodeId,
+        boolean primary,
+        ShardRoutingState state,
+        UnassignedInfo unassignedInfo,
+        ShardRouting.Role role
     ) {
         return new ShardRouting(
             shardId,
@@ -191,7 +248,7 @@ public class TestShardRouting {
             buildRelocationFailureInfo(state),
             buildAllocationId(state),
             -1,
-            ShardRouting.Role.DEFAULT
+            role
         );
     }