1
0
Эх сурвалжийг харах

Introduce BatchExecutionContext (#89323)

Replaces the two arguments to `ClusterStateTaskExecutor#execute` with a
parameter object called `BatchExecutionContext` so that #85525 can add a
new and rarely-used parameter without generating tons of noise.
David Turner 3 жил өмнө
parent
commit
4779893b25
36 өөрчлөгдсөн 268 нэмэгдсэн , 264 устгасан
  1. 3 4
      modules/data-streams/src/main/java/org/elasticsearch/datastreams/UpdateTimeSeriesRangeService.java
  2. 3 5
      server/src/main/java/org/elasticsearch/action/admin/cluster/desirednodes/TransportDeleteDesiredNodesAction.java
  3. 7 7
      server/src/main/java/org/elasticsearch/action/admin/cluster/desirednodes/TransportUpdateDesiredNodesAction.java
  4. 4 3
      server/src/main/java/org/elasticsearch/action/admin/indices/create/AutoCreateAction.java
  5. 5 5
      server/src/main/java/org/elasticsearch/action/admin/indices/rollover/TransportRolloverAction.java
  6. 11 5
      server/src/main/java/org/elasticsearch/cluster/ClusterStateTaskExecutor.java
  7. 1 4
      server/src/main/java/org/elasticsearch/cluster/ClusterStateTaskListener.java
  8. 5 5
      server/src/main/java/org/elasticsearch/cluster/LocalMasterServiceTask.java
  9. 16 18
      server/src/main/java/org/elasticsearch/cluster/action/shard/ShardStateAction.java
  10. 21 20
      server/src/main/java/org/elasticsearch/cluster/coordination/JoinTaskExecutor.java
  11. 9 10
      server/src/main/java/org/elasticsearch/cluster/coordination/NodeRemovalClusterStateTaskExecutor.java
  12. 18 18
      server/src/main/java/org/elasticsearch/cluster/metadata/MetadataIndexStateService.java
  13. 3 3
      server/src/main/java/org/elasticsearch/cluster/metadata/MetadataIndexTemplateService.java
  14. 3 3
      server/src/main/java/org/elasticsearch/cluster/metadata/MetadataMappingService.java
  15. 4 4
      server/src/main/java/org/elasticsearch/cluster/metadata/MetadataUpdateSettingsService.java
  16. 10 6
      server/src/main/java/org/elasticsearch/cluster/service/MasterService.java
  17. 3 4
      server/src/main/java/org/elasticsearch/health/metadata/HealthMetadataService.java
  18. 6 6
      server/src/main/java/org/elasticsearch/ingest/IngestService.java
  19. 5 6
      server/src/main/java/org/elasticsearch/reservedstate/service/ReservedStateErrorTaskExecutor.java
  20. 5 6
      server/src/main/java/org/elasticsearch/reservedstate/service/ReservedStateUpdateTaskExecutor.java
  21. 20 24
      server/src/main/java/org/elasticsearch/snapshots/SnapshotsService.java
  22. 1 3
      server/src/test/java/org/elasticsearch/cluster/ClusterStateTaskExecutorTests.java
  23. 3 3
      server/src/test/java/org/elasticsearch/cluster/metadata/MetadataIndexStateServiceBatchingTests.java
  24. 47 41
      server/src/test/java/org/elasticsearch/cluster/service/MasterServiceTests.java
  25. 2 2
      server/src/test/java/org/elasticsearch/reservedstate/service/ReservedClusterStateServiceTests.java
  26. 1 1
      test/framework/src/main/java/org/elasticsearch/cluster/service/ClusterStateTaskExecutorUtils.java
  27. 9 9
      x-pack/plugin/core/src/main/java/org/elasticsearch/license/StartBasicClusterTask.java
  28. 9 9
      x-pack/plugin/core/src/main/java/org/elasticsearch/license/StartTrialClusterTask.java
  29. 2 1
      x-pack/plugin/core/src/test/java/org/elasticsearch/license/LicenseServiceTests.java
  30. 6 5
      x-pack/plugin/ilm/src/main/java/org/elasticsearch/xpack/ilm/IndexLifecycleRunner.java
  31. 1 1
      x-pack/plugin/ilm/src/test/java/org/elasticsearch/xpack/ilm/action/ReservedLifecycleStateServiceTests.java
  32. 3 5
      x-pack/plugin/rollup/src/main/java/org/elasticsearch/xpack/rollup/v2/TransportRollupAction.java
  33. 6 7
      x-pack/plugin/shutdown/src/main/java/org/elasticsearch/xpack/shutdown/TransportDeleteShutdownNodeAction.java
  34. 8 8
      x-pack/plugin/shutdown/src/main/java/org/elasticsearch/xpack/shutdown/TransportPutShutdownNodeAction.java
  35. 3 1
      x-pack/plugin/shutdown/src/test/java/org/elasticsearch/xpack/shutdown/TransportDeleteShutdownNodeActionTests.java
  36. 5 2
      x-pack/plugin/shutdown/src/test/java/org/elasticsearch/xpack/shutdown/TransportPutShutdownNodeActionTests.java

+ 3 - 4
modules/data-streams/src/main/java/org/elasticsearch/datastreams/UpdateTimeSeriesRangeService.java

@@ -31,7 +31,6 @@ import org.elasticsearch.threadpool.ThreadPool;
 import java.io.IOException;
 import java.time.Instant;
 import java.time.temporal.ChronoUnit;
-import java.util.List;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.function.Consumer;
 
@@ -200,9 +199,9 @@ public class UpdateTimeSeriesRangeService extends AbstractLifecycleComponent imp
 
     private class UpdateTimeSeriesExecutor implements ClusterStateTaskExecutor<UpdateTimeSeriesTask> {
         @Override
-        public ClusterState execute(ClusterState currentState, List<TaskContext<UpdateTimeSeriesTask>> taskContexts) throws Exception {
-            var result = updateTimeSeriesTemporalRange(currentState, Instant.now());
-            for (final var taskContext : taskContexts) {
+        public ClusterState execute(BatchExecutionContext<UpdateTimeSeriesTask> batchExecutionContext) throws Exception {
+            var result = updateTimeSeriesTemporalRange(batchExecutionContext.initialState(), Instant.now());
+            for (final var taskContext : batchExecutionContext.taskContexts()) {
                 taskContext.success(() -> taskContext.getTask().listener().accept(null));
             }
             return result;

+ 3 - 5
server/src/main/java/org/elasticsearch/action/admin/cluster/desirednodes/TransportDeleteDesiredNodesAction.java

@@ -27,8 +27,6 @@ import org.elasticsearch.tasks.Task;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.TransportService;
 
-import java.util.List;
-
 public class TransportDeleteDesiredNodesAction extends TransportMasterNodeAction<DeleteDesiredNodesAction.Request, ActionResponse.Empty> {
 
     private final ClusterStateTaskExecutor<DeleteDesiredNodesTask> taskExecutor = new DeleteDesiredNodesExecutor();
@@ -83,11 +81,11 @@ public class TransportDeleteDesiredNodesAction extends TransportMasterNodeAction
 
     private static class DeleteDesiredNodesExecutor implements ClusterStateTaskExecutor<DeleteDesiredNodesTask> {
         @Override
-        public ClusterState execute(ClusterState currentState, List<TaskContext<DeleteDesiredNodesTask>> taskContexts) throws Exception {
-            for (final var taskContext : taskContexts) {
+        public ClusterState execute(BatchExecutionContext<DeleteDesiredNodesTask> batchExecutionContext) throws Exception {
+            for (final var taskContext : batchExecutionContext.taskContexts()) {
                 taskContext.success(() -> taskContext.getTask().listener().onResponse(ActionResponse.Empty.INSTANCE));
             }
-            return currentState.copyAndUpdateMetadata(metadata -> metadata.removeCustom(DesiredNodesMetadata.TYPE));
+            return batchExecutionContext.initialState().copyAndUpdateMetadata(metadata -> metadata.removeCustom(DesiredNodesMetadata.TYPE));
         }
     }
 }

+ 7 - 7
server/src/main/java/org/elasticsearch/action/admin/cluster/desirednodes/TransportUpdateDesiredNodesAction.java

@@ -33,7 +33,6 @@ import org.elasticsearch.tasks.Task;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.TransportService;
 
-import java.util.List;
 import java.util.Locale;
 
 import static java.lang.String.format;
@@ -177,10 +176,11 @@ public class TransportUpdateDesiredNodesAction extends TransportMasterNodeAction
         }
 
         @Override
-        public ClusterState execute(ClusterState currentState, List<TaskContext<UpdateDesiredNodesTask>> taskContexts) throws Exception {
-            final var initialDesiredNodes = DesiredNodesMetadata.fromClusterState(currentState).getLatestDesiredNodes();
+        public ClusterState execute(BatchExecutionContext<UpdateDesiredNodesTask> batchExecutionContext) throws Exception {
+            final var initialState = batchExecutionContext.initialState();
+            final var initialDesiredNodes = DesiredNodesMetadata.fromClusterState(initialState).getLatestDesiredNodes();
             var desiredNodes = initialDesiredNodes;
-            for (final var taskContext : taskContexts) {
+            for (final var taskContext : batchExecutionContext.taskContexts()) {
                 final UpdateDesiredNodesRequest request = taskContext.getTask().request();
                 if (request.isDryRun()) {
                     try {
@@ -205,12 +205,12 @@ public class TransportUpdateDesiredNodesAction extends TransportMasterNodeAction
                 );
             }
 
-            desiredNodes = DesiredNodes.updateDesiredNodesStatusIfNeeded(currentState.nodes(), desiredNodes);
+            desiredNodes = DesiredNodes.updateDesiredNodesStatusIfNeeded(initialState.nodes(), desiredNodes);
 
             if (desiredNodes == initialDesiredNodes) {
-                return currentState;
+                return initialState;
             } else {
-                final ClusterState withUpdatedDesiredNodes = replaceDesiredNodes(currentState, desiredNodes);
+                final ClusterState withUpdatedDesiredNodes = replaceDesiredNodes(initialState, desiredNodes);
                 return allocationService.adaptAutoExpandReplicas(withUpdatedDesiredNodes);
             }
         }

+ 4 - 3
server/src/main/java/org/elasticsearch/action/admin/indices/create/AutoCreateAction.java

@@ -107,9 +107,10 @@ public final class AutoCreateAction extends ActionType<CreateIndexResponse> {
             this.createIndexService = createIndexService;
             this.metadataCreateDataStreamService = metadataCreateDataStreamService;
             this.autoCreateIndex = autoCreateIndex;
-            this.executor = (currentState, taskContexts) -> {
-                ClusterState state = currentState;
+            this.executor = batchExecutionContext -> {
+                final var taskContexts = batchExecutionContext.taskContexts();
                 final Map<CreateIndexRequest, String> successfulRequests = Maps.newMapWithExpectedSize(taskContexts.size());
+                ClusterState state = batchExecutionContext.initialState();
                 for (final var taskContext : taskContexts) {
                     final var task = taskContext.getTask();
                     try {
@@ -119,7 +120,7 @@ public final class AutoCreateAction extends ActionType<CreateIndexResponse> {
                         taskContext.onFailure(e);
                     }
                 }
-                if (state != currentState) {
+                if (state != batchExecutionContext.initialState()) {
                     state = allocationService.reroute(state, "auto-create");
                 }
                 return state;

+ 5 - 5
server/src/main/java/org/elasticsearch/action/admin/indices/rollover/TransportRolloverAction.java

@@ -259,10 +259,10 @@ public class TransportRolloverAction extends TransportMasterNodeAction<RolloverR
         ActiveShardsObserver activeShardsObserver
     ) implements ClusterStateTaskExecutor<RolloverTask> {
         @Override
-        public ClusterState execute(ClusterState currentState, List<TaskContext<RolloverTask>> taskContexts) throws Exception {
-            final var results = new ArrayList<MetadataRolloverService.RolloverResult>(taskContexts.size());
-            var state = currentState;
-            for (final var taskContext : taskContexts) {
+        public ClusterState execute(BatchExecutionContext<RolloverTask> batchExecutionContext) throws Exception {
+            final var results = new ArrayList<MetadataRolloverService.RolloverResult>(batchExecutionContext.taskContexts().size());
+            var state = batchExecutionContext.initialState();
+            for (final var taskContext : batchExecutionContext.taskContexts()) {
                 try {
                     state = executeTask(state, results, taskContext);
                 } catch (Exception e) {
@@ -270,7 +270,7 @@ public class TransportRolloverAction extends TransportMasterNodeAction<RolloverR
                 }
             }
 
-            if (state != currentState) {
+            if (state != batchExecutionContext.initialState()) {
                 var reason = new StringBuilder();
                 Strings.collectionToDelimitedStringWithLimit(
                     (Iterable<String>) () -> results.stream().map(t -> t.sourceIndexName() + "->" + t.rolloverIndexName()).iterator(),

+ 11 - 5
server/src/main/java/org/elasticsearch/cluster/ClusterStateTaskExecutor.java

@@ -29,11 +29,9 @@ public interface ClusterStateTaskExecutor<T extends ClusterStateTaskListener> {
      * surprisingly many tasks to process in the batch. If it's possible to accumulate the effects of the tasks at a lower level then you
      * should do that instead.
      *
-     * @param currentState The initial cluster state on which the tasks should be executed.
-     * @param taskContexts A {@link TaskContext} for each task in the batch. Implementations must complete every context in the list.
-     * @return The resulting cluster state after executing all the tasks. If {code currentState} is returned then no update is published.
+     * @return The resulting cluster state after executing all the tasks. If {code initialState} is returned then no update is published.
      */
-    ClusterState execute(ClusterState currentState, List<TaskContext<T>> taskContexts) throws Exception;
+    ClusterState execute(BatchExecutionContext<T> batchExecutionContext) throws Exception;
 
     /**
      * @return {@code true} iff this executor should only run on the elected master.
@@ -54,7 +52,7 @@ public interface ClusterStateTaskExecutor<T extends ClusterStateTaskListener> {
     /**
      * Builds a concise description of a list of tasks (to be used in logging etc.).
      *
-     * Note that the tasks given are not necessarily the same as those that will be passed to {@link #execute(ClusterState, List)}.
+     * Note that the tasks given are not necessarily the same as those that will be passed to {@link #execute(BatchExecutionContext)}.
      * but are guaranteed to be a subset of them. This method can be called multiple times with different lists before execution.
      *
      * @param tasks the tasks to describe.
@@ -203,4 +201,12 @@ public interface ClusterStateTaskExecutor<T extends ClusterStateTaskListener> {
          */
         void onFailure(Exception failure);
     }
+
+    /**
+     * Encapsulates the context in which a batch of tasks executes.
+     *
+     * @param initialState The initial cluster state on which the tasks should be executed.
+     * @param taskContexts A {@link TaskContext} for each task in the batch. Implementations must complete every context in the list.
+     */
+    record BatchExecutionContext<T extends ClusterStateTaskListener> (ClusterState initialState, List<TaskContext<T>> taskContexts) {}
 }

+ 1 - 4
server/src/main/java/org/elasticsearch/cluster/ClusterStateTaskListener.java

@@ -11,8 +11,6 @@ import org.elasticsearch.cluster.coordination.FailedToCommitClusterStateExceptio
 import org.elasticsearch.cluster.metadata.ProcessClusterEventTimeoutException;
 import org.elasticsearch.cluster.service.MasterService;
 
-import java.util.List;
-
 public interface ClusterStateTaskListener {
 
     /**
@@ -32,8 +30,7 @@ public interface ClusterStateTaskListener {
     void onFailure(Exception e);
 
     /**
-     * Called when the result of the {@link ClusterStateTaskExecutor#execute(ClusterState, List)} method have been processed properly by all
-     * listeners.
+     * Called when the result of the {@link ClusterStateTaskExecutor#execute} method has been processed properly by all listeners.
      *
      * The {@param newState} parameter is the state that was ultimately published. This can lead to surprising behaviour if tasks are
      * batched together: a later task in the batch may undo or overwrite the changes made by an earlier task. In general you should prefer

+ 5 - 5
server/src/main/java/org/elasticsearch/cluster/LocalMasterServiceTask.java

@@ -52,14 +52,14 @@ public abstract class LocalMasterServiceTask implements ClusterStateTaskListener
                 }
 
                 @Override
-                public ClusterState execute(ClusterState currentState, List<TaskContext<LocalMasterServiceTask>> taskContexts)
-                    throws Exception {
-                    final LocalMasterServiceTask thisTask = LocalMasterServiceTask.this;
+                public ClusterState execute(BatchExecutionContext<LocalMasterServiceTask> batchExecutionContext) throws Exception {
+                    final var thisTask = LocalMasterServiceTask.this;
+                    final var taskContexts = batchExecutionContext.taskContexts();
                     assert taskContexts.size() == 1 && taskContexts.get(0).getTask() == thisTask
                         : "expected one-element task list containing current object but was " + taskContexts;
-                    thisTask.execute(currentState);
+                    thisTask.execute(batchExecutionContext.initialState());
                     taskContexts.get(0).success(() -> onPublicationComplete());
-                    return currentState;
+                    return batchExecutionContext.initialState();
                 }
             }
         );

+ 16 - 18
server/src/main/java/org/elasticsearch/cluster/action/shard/ShardStateAction.java

@@ -316,18 +316,15 @@ public class ShardStateAction {
         }
 
         @Override
-        public ClusterState execute(
-            ClusterState currentState,
-            List<ClusterStateTaskExecutor.TaskContext<FailedShardUpdateTask>> taskContexts
-        ) throws Exception {
+        public ClusterState execute(BatchExecutionContext<FailedShardUpdateTask> batchExecutionContext) throws Exception {
             List<ClusterStateTaskExecutor.TaskContext<FailedShardUpdateTask>> tasksToBeApplied = new ArrayList<>();
             List<FailedShard> failedShardsToBeApplied = new ArrayList<>();
             List<StaleShard> staleShardsToBeApplied = new ArrayList<>();
-
-            for (final var taskContext : taskContexts) {
+            final ClusterState initialState = batchExecutionContext.initialState();
+            for (final var taskContext : batchExecutionContext.taskContexts()) {
                 final var task = taskContext.getTask();
                 FailedShardEntry entry = task.entry();
-                IndexMetadata indexMetadata = currentState.metadata().index(entry.getShardId().getIndex());
+                IndexMetadata indexMetadata = initialState.metadata().index(entry.getShardId().getIndex());
                 if (indexMetadata == null) {
                     // tasks that correspond to non-existent indices are marked as successful
                     logger.debug(
@@ -377,7 +374,7 @@ public class ShardStateAction {
                         }
                     }
 
-                    ShardRouting matched = currentState.getRoutingTable().getByAllocationId(entry.getShardId(), entry.getAllocationId());
+                    ShardRouting matched = initialState.getRoutingTable().getByAllocationId(entry.getShardId(), entry.getAllocationId());
                     if (matched == null) {
                         Set<String> inSyncAllocationIds = indexMetadata.inSyncAllocationIds(entry.getShardId().id());
                         // mark shard copies without routing entries that are in in-sync allocations set only as stale if the reason why
@@ -407,9 +404,9 @@ public class ShardStateAction {
             }
             assert tasksToBeApplied.size() == failedShardsToBeApplied.size() + staleShardsToBeApplied.size();
 
-            ClusterState maybeUpdatedState = currentState;
+            ClusterState maybeUpdatedState = initialState;
             try {
-                maybeUpdatedState = applyFailedShards(currentState, failedShardsToBeApplied, staleShardsToBeApplied);
+                maybeUpdatedState = applyFailedShards(initialState, failedShardsToBeApplied, staleShardsToBeApplied);
                 for (final var taskContext : tasksToBeApplied) {
                     taskContext.success(() -> taskContext.getTask().listener().onResponse(TransportResponse.Empty.INSTANCE));
                 }
@@ -625,15 +622,16 @@ public class ShardStateAction {
         }
 
         @Override
-        public ClusterState execute(ClusterState currentState, List<TaskContext<StartedShardUpdateTask>> taskContexts) throws Exception {
+        public ClusterState execute(BatchExecutionContext<StartedShardUpdateTask> batchExecutionContext) throws Exception {
             List<TaskContext<StartedShardUpdateTask>> tasksToBeApplied = new ArrayList<>();
-            List<ShardRouting> shardRoutingsToBeApplied = new ArrayList<>(taskContexts.size());
+            List<ShardRouting> shardRoutingsToBeApplied = new ArrayList<>(batchExecutionContext.taskContexts().size());
             Set<ShardRouting> seenShardRoutings = new HashSet<>(); // to prevent duplicates
             final Map<Index, IndexLongFieldRange> updatedTimestampRanges = new HashMap<>();
-            for (var taskContext : taskContexts) {
+            final ClusterState initialState = batchExecutionContext.initialState();
+            for (var taskContext : batchExecutionContext.taskContexts()) {
                 final var task = taskContext.getTask();
                 StartedShardEntry entry = task.getEntry();
-                final ShardRouting matched = currentState.getRoutingTable().getByAllocationId(entry.shardId, entry.allocationId);
+                final ShardRouting matched = initialState.getRoutingTable().getByAllocationId(entry.shardId, entry.allocationId);
                 if (matched == null) {
                     // tasks that correspond to non-existent shards are marked as successful. The reason is that we resend shard started
                     // events on every cluster state publishing that does not contain the shard as started yet. This means that old stale
@@ -643,7 +641,7 @@ public class ShardStateAction {
                     taskContext.success(() -> task.listener().onResponse(TransportResponse.Empty.INSTANCE));
                 } else {
                     if (matched.primary() && entry.primaryTerm > 0) {
-                        final IndexMetadata indexMetadata = currentState.metadata().index(entry.shardId.getIndex());
+                        final IndexMetadata indexMetadata = initialState.metadata().index(entry.shardId.getIndex());
                         assert indexMetadata != null;
                         final long currentPrimaryTerm = indexMetadata.primaryTerm(entry.shardId.id());
                         if (currentPrimaryTerm != entry.primaryTerm) {
@@ -694,7 +692,7 @@ public class ShardStateAction {
                             // expand the timestamp range recorded in the index metadata if needed
                             final Index index = entry.shardId.getIndex();
                             IndexLongFieldRange currentTimestampMillisRange = updatedTimestampRanges.get(index);
-                            final IndexMetadata indexMetadata = currentState.metadata().index(index);
+                            final IndexMetadata indexMetadata = initialState.metadata().index(index);
                             if (currentTimestampMillisRange == null) {
                                 currentTimestampMillisRange = indexMetadata.getTimestampRange();
                             }
@@ -713,9 +711,9 @@ public class ShardStateAction {
             }
             assert tasksToBeApplied.size() >= shardRoutingsToBeApplied.size();
 
-            ClusterState maybeUpdatedState = currentState;
+            ClusterState maybeUpdatedState = initialState;
             try {
-                maybeUpdatedState = allocationService.applyStartedShards(currentState, shardRoutingsToBeApplied);
+                maybeUpdatedState = allocationService.applyStartedShards(initialState, shardRoutingsToBeApplied);
 
                 if (updatedTimestampRanges.isEmpty() == false) {
                     final Metadata.Builder metadataBuilder = Metadata.builder(maybeUpdatedState.metadata());

+ 21 - 20
server/src/main/java/org/elasticsearch/cluster/coordination/JoinTaskExecutor.java

@@ -49,45 +49,46 @@ public class JoinTaskExecutor implements ClusterStateTaskExecutor<JoinTask> {
     }
 
     @Override
-    public ClusterState execute(ClusterState currentState, List<TaskContext<JoinTask>> joinTaskContexts) throws Exception {
+    public ClusterState execute(BatchExecutionContext<JoinTask> batchExecutionContext) throws Exception {
         // The current state that MasterService uses might have been updated by a (different) master in a higher term already. If so, stop
         // processing the current cluster state update, there's no point in continuing to compute it as it will later be rejected by
         // Coordinator#publish anyhow.
-        assert joinTaskContexts.isEmpty() == false : "Expected to have non empty join tasks list";
+        assert batchExecutionContext.taskContexts().isEmpty() == false : "Expected to have non empty join tasks list";
 
-        var term = joinTaskContexts.stream().mapToLong(t -> t.getTask().term()).max().getAsLong();
+        var term = batchExecutionContext.taskContexts().stream().mapToLong(t -> t.getTask().term()).max().getAsLong();
 
-        var split = joinTaskContexts.stream().collect(Collectors.partitioningBy(t -> t.getTask().term() == term));
+        var split = batchExecutionContext.taskContexts().stream().collect(Collectors.partitioningBy(t -> t.getTask().term() == term));
         for (TaskContext<JoinTask> outdated : split.get(false)) {
             outdated.onFailure(
                 new NotMasterException("Higher term encountered (encountered: " + term + " > used: " + outdated.getTask().term() + ")")
             );
         }
 
-        joinTaskContexts = split.get(true);
+        final var joinTaskContexts = split.get(true);
+        final var initialState = batchExecutionContext.initialState();
 
-        if (currentState.term() > term) {
-            logger.trace("encountered higher term {} than current {}, there is a newer master", currentState.term(), term);
+        if (initialState.term() > term) {
+            logger.trace("encountered higher term {} than current {}, there is a newer master", initialState.term(), term);
             throw new NotMasterException(
-                "Higher term encountered (current: " + currentState.term() + " > used: " + term + "), there is a newer master"
+                "Higher term encountered (current: " + initialState.term() + " > used: " + term + "), there is a newer master"
             );
         }
 
         final boolean isBecomingMaster = joinTaskContexts.stream().anyMatch(t -> t.getTask().isBecomingMaster());
-        final DiscoveryNodes currentNodes = currentState.nodes();
+        final DiscoveryNodes currentNodes = initialState.nodes();
         boolean nodesChanged = false;
         ClusterState.Builder newState;
 
         if (currentNodes.getMasterNode() == null && isBecomingMaster) {
-            assert currentState.term() < term : "there should be at most one become master task per election (= by term)";
+            assert initialState.term() < term : "there should be at most one become master task per election (= by term)";
             // use these joins to try and become the master.
             // Note that we don't have to do any validation of the amount of joining nodes - the commit
             // during the cluster state publishing guarantees that we have enough
-            newState = becomeMasterAndTrimConflictingNodes(currentState, joinTaskContexts, term);
+            newState = becomeMasterAndTrimConflictingNodes(initialState, joinTaskContexts, term);
             nodesChanged = true;
         } else if (currentNodes.isLocalNodeElectedMaster()) {
-            assert currentState.term() == term : "term should be stable for the same master";
-            newState = ClusterState.builder(currentState);
+            assert initialState.term() == term : "term should be stable for the same master";
+            newState = ClusterState.builder(initialState);
         } else {
             logger.trace("processing node joins, but we are not the master. current master: {}", currentNodes.getMasterNode());
             throw new NotMasterException("Node [" + currentNodes.getLocalNode() + "] not master for join request");
@@ -100,7 +101,7 @@ public class JoinTaskExecutor implements ClusterStateTaskExecutor<JoinTask> {
         Version minClusterNodeVersion = newState.nodes().getMinNodeVersion();
         Version maxClusterNodeVersion = newState.nodes().getMaxNodeVersion();
         // if the cluster is not fully-formed then the min version is not meaningful
-        final boolean enforceVersionBarrier = currentState.getBlocks().hasGlobalBlock(STATE_NOT_RECOVERED_BLOCK) == false;
+        final boolean enforceVersionBarrier = initialState.getBlocks().hasGlobalBlock(STATE_NOT_RECOVERED_BLOCK) == false;
         // processing any joins
         Map<String, String> joinedNodeIdsByNodeName = new HashMap<>();
         for (final var joinTaskContext : joinTaskContexts) {
@@ -118,7 +119,7 @@ public class JoinTaskExecutor implements ClusterStateTaskExecutor<JoinTask> {
                         ensureNodesCompatibility(node.getVersion(), minClusterNodeVersion, maxClusterNodeVersion);
                         // we do this validation quite late to prevent race conditions between nodes joining and importing dangling indices
                         // we have to reject nodes that don't support all indices we have in this cluster
-                        ensureIndexCompatibility(node.getVersion(), currentState.getMetadata());
+                        ensureIndexCompatibility(node.getVersion(), initialState.getMetadata());
                         nodesBuilder.add(node);
                         nodesChanged = true;
                         minClusterNodeVersion = Version.min(minClusterNodeVersion, node.getVersion());
@@ -148,7 +149,7 @@ public class JoinTaskExecutor implements ClusterStateTaskExecutor<JoinTask> {
             );
 
             if (joinedNodeIdsByNodeName.isEmpty() == false) {
-                final var currentVotingConfigExclusions = currentState.getVotingConfigExclusions();
+                final var currentVotingConfigExclusions = initialState.getVotingConfigExclusions();
                 final var newVotingConfigExclusions = currentVotingConfigExclusions.stream().map(e -> {
                     // Update nodeId in VotingConfigExclusion when a new node with excluded node name joins
                     if (CoordinationMetadata.VotingConfigExclusion.MISSING_VALUE_MARKER.equals(e.getNodeId())
@@ -164,11 +165,11 @@ public class JoinTaskExecutor implements ClusterStateTaskExecutor<JoinTask> {
 
                 // if VotingConfigExclusions did get updated
                 if (newVotingConfigExclusions.equals(currentVotingConfigExclusions) == false) {
-                    final var coordMetadataBuilder = CoordinationMetadata.builder(currentState.coordinationMetadata())
+                    final var coordMetadataBuilder = CoordinationMetadata.builder(initialState.coordinationMetadata())
                         .term(term)
                         .clearVotingConfigExclusions();
                     newVotingConfigExclusions.forEach(coordMetadataBuilder::addVotingConfigExclusion);
-                    newState.metadata(Metadata.builder(currentState.metadata()).coordinationMetadata(coordMetadataBuilder.build()).build());
+                    newState.metadata(Metadata.builder(initialState.metadata()).coordinationMetadata(coordMetadataBuilder.build()).build());
                 }
             }
 
@@ -177,9 +178,9 @@ public class JoinTaskExecutor implements ClusterStateTaskExecutor<JoinTask> {
             );
             final ClusterState updatedState = allocationService.adaptAutoExpandReplicas(clusterStateWithNewNodesAndDesiredNodes);
             assert enforceVersionBarrier == false
-                || updatedState.nodes().getMinNodeVersion().onOrAfter(currentState.nodes().getMinNodeVersion())
+                || updatedState.nodes().getMinNodeVersion().onOrAfter(initialState.nodes().getMinNodeVersion())
                 : "min node version decreased from ["
-                    + currentState.nodes().getMinNodeVersion()
+                    + initialState.nodes().getMinNodeVersion()
                     + "] to ["
                     + updatedState.nodes().getMinNodeVersion()
                     + "]";

+ 9 - 10
server/src/main/java/org/elasticsearch/cluster/coordination/NodeRemovalClusterStateTaskExecutor.java

@@ -19,8 +19,6 @@ import org.elasticsearch.cluster.routing.allocation.AllocationService;
 import org.elasticsearch.cluster.service.MasterService;
 import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
 
-import java.util.List;
-
 public class NodeRemovalClusterStateTaskExecutor implements ClusterStateTaskExecutor<NodeRemovalClusterStateTaskExecutor.Task> {
 
     private static final Logger logger = LogManager.getLogger(NodeRemovalClusterStateTaskExecutor.class);
@@ -53,33 +51,34 @@ public class NodeRemovalClusterStateTaskExecutor implements ClusterStateTaskExec
     }
 
     @Override
-    public ClusterState execute(ClusterState currentState, List<TaskContext<Task>> taskContexts) throws Exception {
-        final DiscoveryNodes.Builder remainingNodesBuilder = DiscoveryNodes.builder(currentState.nodes());
+    public ClusterState execute(BatchExecutionContext<Task> batchExecutionContext) throws Exception {
+        final ClusterState initialState = batchExecutionContext.initialState();
+        final DiscoveryNodes.Builder remainingNodesBuilder = DiscoveryNodes.builder(initialState.nodes());
         boolean removed = false;
-        for (final var taskContext : taskContexts) {
+        for (final var taskContext : batchExecutionContext.taskContexts()) {
             final var task = taskContext.getTask();
-            if (currentState.nodes().nodeExists(task.node())) {
+            if (initialState.nodes().nodeExists(task.node())) {
                 remainingNodesBuilder.remove(task.node());
                 removed = true;
             } else {
                 logger.debug("node [{}] does not exist in cluster state, ignoring", task);
             }
-            taskContext.success(() -> task.onClusterStateProcessed.run());
+            taskContext.success(task.onClusterStateProcessed::run);
         }
 
         final ClusterState finalState;
 
         if (removed) {
-            final ClusterState remainingNodesClusterState = remainingNodesClusterState(currentState, remainingNodesBuilder);
+            final ClusterState remainingNodesClusterState = remainingNodesClusterState(initialState, remainingNodesBuilder);
             final ClusterState ptasksDisassociatedState = PersistentTasksCustomMetadata.disassociateDeadNodes(remainingNodesClusterState);
             finalState = allocationService.disassociateDeadNodes(
                 ptasksDisassociatedState,
                 true,
-                describeTasks(taskContexts.stream().map(TaskContext::getTask).toList())
+                describeTasks(batchExecutionContext.taskContexts().stream().map(TaskContext::getTask).toList())
             );
         } else {
             // no nodes to remove, keep the current cluster state
-            finalState = currentState;
+            finalState = initialState;
         }
 
         return finalState;

+ 18 - 18
server/src/main/java/org/elasticsearch/cluster/metadata/MetadataIndexStateService.java

@@ -172,9 +172,9 @@ public class MetadataIndexStateService {
     private class AddBlocksToCloseExecutor implements ClusterStateTaskExecutor<AddBlocksToCloseTask> {
 
         @Override
-        public ClusterState execute(ClusterState currentState, List<TaskContext<AddBlocksToCloseTask>> taskContexts) throws Exception {
-            ClusterState state = currentState;
-            for (final var taskContext : taskContexts) {
+        public ClusterState execute(BatchExecutionContext<AddBlocksToCloseTask> batchExecutionContext) throws Exception {
+            ClusterState state = batchExecutionContext.initialState();
+            for (final var taskContext : batchExecutionContext.taskContexts()) {
                 final var task = taskContext.getTask();
                 try {
                     final Map<Index, ClusterBlock> blockedIndices = new HashMap<>(task.request.indices().length);
@@ -227,9 +227,9 @@ public class MetadataIndexStateService {
 
         @Override
         @SuppressForbidden(reason = "consuming published cluster state for legacy reasons")
-        public ClusterState execute(ClusterState currentState, List<TaskContext<CloseIndicesTask>> taskContexts) throws Exception {
-            ClusterState state = currentState;
-            for (final var taskContext : taskContexts) {
+        public ClusterState execute(BatchExecutionContext<CloseIndicesTask> batchExecutionContext) throws Exception {
+            ClusterState state = batchExecutionContext.initialState();
+            for (final var taskContext : batchExecutionContext.taskContexts()) {
                 final var task = taskContext.getTask();
                 try {
                     final Tuple<ClusterState, List<IndexResult>> closingResult = closeRoutingTable(
@@ -489,10 +489,10 @@ public class MetadataIndexStateService {
     private class AddBlocksExecutor implements ClusterStateTaskExecutor<AddBlocksTask> {
 
         @Override
-        public ClusterState execute(ClusterState currentState, List<TaskContext<AddBlocksTask>> taskContexts) throws Exception {
-            ClusterState state = currentState;
+        public ClusterState execute(BatchExecutionContext<AddBlocksTask> batchExecutionContext) throws Exception {
+            ClusterState state = batchExecutionContext.initialState();
 
-            for (final var taskContext : taskContexts) {
+            for (final var taskContext : batchExecutionContext.taskContexts()) {
                 try {
                     final var task = taskContext.getTask();
                     final Tuple<ClusterState, Map<Index, ClusterBlock>> blockResult = addIndexBlock(
@@ -554,10 +554,10 @@ public class MetadataIndexStateService {
     private static class FinalizeBlocksExecutor implements ClusterStateTaskExecutor<FinalizeBlocksTask> {
 
         @Override
-        public ClusterState execute(ClusterState currentState, List<TaskContext<FinalizeBlocksTask>> taskContexts) throws Exception {
-            ClusterState state = currentState;
+        public ClusterState execute(BatchExecutionContext<FinalizeBlocksTask> batchExecutionContext) throws Exception {
+            ClusterState state = batchExecutionContext.initialState();
 
-            for (final var taskContext : taskContexts) {
+            for (final var taskContext : batchExecutionContext.taskContexts()) {
                 try {
                     final var task = taskContext.getTask();
                     final Tuple<ClusterState, List<AddBlockResult>> finalizeResult = finalizeBlock(
@@ -1100,13 +1100,13 @@ public class MetadataIndexStateService {
     private class OpenIndicesExecutor implements ClusterStateTaskExecutor<OpenIndicesTask> {
 
         @Override
-        public ClusterState execute(ClusterState currentState, List<TaskContext<OpenIndicesTask>> taskContexts) {
-            ClusterState state = currentState;
+        public ClusterState execute(BatchExecutionContext<OpenIndicesTask> batchExecutionContext) {
+            ClusterState state = batchExecutionContext.initialState();
 
             try {
                 // build an in-order de-duplicated array of all the indices to open
-                final Set<Index> indicesToOpen = Sets.newLinkedHashSetWithExpectedSize(taskContexts.size());
-                for (final var taskContext : taskContexts) {
+                final Set<Index> indicesToOpen = Sets.newLinkedHashSetWithExpectedSize(batchExecutionContext.taskContexts().size());
+                for (final var taskContext : batchExecutionContext.taskContexts()) {
                     Collections.addAll(indicesToOpen, taskContext.getTask().request.indices());
                 }
                 Index[] indices = indicesToOpen.toArray(Index.EMPTY_ARRAY);
@@ -1117,12 +1117,12 @@ public class MetadataIndexStateService {
                 // do a final reroute
                 state = allocationService.reroute(state, "indices opened");
 
-                for (final var taskContext : taskContexts) {
+                for (final var taskContext : batchExecutionContext.taskContexts()) {
                     final var task = taskContext.getTask();
                     taskContext.success(task);
                 }
             } catch (Exception e) {
-                for (final var taskContext : taskContexts) {
+                for (final var taskContext : batchExecutionContext.taskContexts()) {
                     taskContext.onFailure(e);
                 }
             }

+ 3 - 3
server/src/main/java/org/elasticsearch/cluster/metadata/MetadataIndexTemplateService.java

@@ -125,9 +125,9 @@ public class MetadataIndexTemplateService {
     /**
      * This is the cluster state task executor for all template-based actions.
      */
-    private static final ClusterStateTaskExecutor<TemplateClusterStateUpdateTask> TEMPLATE_TASK_EXECUTOR = (currentState, taskContexts) -> {
-        ClusterState state = currentState;
-        for (final var taskContext : taskContexts) {
+    private static final ClusterStateTaskExecutor<TemplateClusterStateUpdateTask> TEMPLATE_TASK_EXECUTOR = batchExecutionContext -> {
+        ClusterState state = batchExecutionContext.initialState();
+        for (final var taskContext : batchExecutionContext.taskContexts()) {
             try {
                 final var task = taskContext.getTask();
                 state = task.execute(state);

+ 3 - 3
server/src/main/java/org/elasticsearch/cluster/metadata/MetadataMappingService.java

@@ -94,11 +94,11 @@ public class MetadataMappingService {
 
     class PutMappingExecutor implements ClusterStateTaskExecutor<PutMappingClusterStateUpdateTask> {
         @Override
-        public ClusterState execute(ClusterState currentState, List<TaskContext<PutMappingClusterStateUpdateTask>> taskContexts)
-            throws Exception {
+        public ClusterState execute(BatchExecutionContext<PutMappingClusterStateUpdateTask> batchExecutionContext) throws Exception {
             Map<Index, MapperService> indexMapperServices = new HashMap<>();
             try {
-                for (final var taskContext : taskContexts) {
+                var currentState = batchExecutionContext.initialState();
+                for (final var taskContext : batchExecutionContext.taskContexts()) {
                     final var task = taskContext.getTask();
                     final PutMappingClusterStateUpdateRequest request = task.request;
                     try {

+ 4 - 4
server/src/main/java/org/elasticsearch/cluster/metadata/MetadataUpdateSettingsService.java

@@ -70,9 +70,9 @@ public class MetadataUpdateSettingsService {
         this.indexScopedSettings = indexScopedSettings;
         this.indicesService = indicesService;
         this.shardLimitValidator = shardLimitValidator;
-        this.executor = (currentState, taskContexts) -> {
-            ClusterState state = currentState;
-            for (final var taskContext : taskContexts) {
+        this.executor = batchExecutionContext -> {
+            ClusterState state = batchExecutionContext.initialState();
+            for (final var taskContext : batchExecutionContext.taskContexts()) {
                 try {
                     final var task = taskContext.getTask();
                     state = task.execute(state);
@@ -81,7 +81,7 @@ public class MetadataUpdateSettingsService {
                     taskContext.onFailure(e);
                 }
             }
-            if (state != currentState) {
+            if (state != batchExecutionContext.initialState()) {
                 // reroute in case things change that require it (like number of replicas)
                 state = allocationService.reroute(state, "settings update");
             }

+ 10 - 6
server/src/main/java/org/elasticsearch/cluster/service/MasterService.java

@@ -524,12 +524,16 @@ public class MasterService extends AbstractLifecycleComponent {
     private static class UnbatchedExecutor implements ClusterStateTaskExecutor<ClusterStateUpdateTask> {
         @Override
         @SuppressForbidden(reason = "consuming published cluster state for legacy reasons")
-        public ClusterState execute(ClusterState currentState, List<TaskContext<ClusterStateUpdateTask>> taskContexts) throws Exception {
-            assert taskContexts.size() == 1 : "this only supports a single task but received " + taskContexts;
-            final var taskContext = taskContexts.get(0);
+        public ClusterState execute(BatchExecutionContext<ClusterStateUpdateTask> batchExecutionContext) throws Exception {
+            assert batchExecutionContext.taskContexts().size() == 1
+                : "this only supports a single task but received " + batchExecutionContext.taskContexts();
+            final var taskContext = batchExecutionContext.taskContexts().get(0);
             final var task = taskContext.getTask();
-            final var newState = task.execute(currentState);
-            final Consumer<ClusterState> publishListener = publishedState -> task.clusterStateProcessed(currentState, publishedState);
+            final var newState = task.execute(batchExecutionContext.initialState());
+            final Consumer<ClusterState> publishListener = publishedState -> task.clusterStateProcessed(
+                batchExecutionContext.initialState(),
+                publishedState
+            );
             if (task instanceof ClusterStateAckListener ackListener) {
                 taskContext.success(publishListener, ackListener);
             } else {
@@ -972,7 +976,7 @@ public class MasterService extends AbstractLifecycleComponent {
     ) {
         final var taskContexts = castTaskContexts(executionResults);
         try {
-            return executor.execute(previousClusterState, taskContexts);
+            return executor.execute(new ClusterStateTaskExecutor.BatchExecutionContext<>(previousClusterState, taskContexts));
         } catch (Exception e) {
             logger.trace(
                 () -> format(

+ 3 - 4
server/src/main/java/org/elasticsearch/health/metadata/HealthMetadataService.java

@@ -170,10 +170,9 @@ public class HealthMetadataService {
         static class Executor implements ClusterStateTaskExecutor<UpsertHealthMetadataTask> {
 
             @Override
-            public ClusterState execute(ClusterState currentState, List<TaskContext<UpsertHealthMetadataTask>> taskContexts)
-                throws Exception {
-                ClusterState updatedState = currentState;
-                for (TaskContext<UpsertHealthMetadataTask> taskContext : taskContexts) {
+            public ClusterState execute(BatchExecutionContext<UpsertHealthMetadataTask> batchExecutionContext) throws Exception {
+                ClusterState updatedState = batchExecutionContext.initialState();
+                for (TaskContext<UpsertHealthMetadataTask> taskContext : batchExecutionContext.taskContexts()) {
                     updatedState = taskContext.getTask().execute(updatedState);
                     taskContext.success(() -> {});
                 }

+ 6 - 6
server/src/main/java/org/elasticsearch/ingest/IngestService.java

@@ -108,11 +108,11 @@ public class IngestService implements ClusterStateApplier, ReportingService<Inge
     /**
      * Cluster state task executor for ingest pipeline operations
      */
-    static final ClusterStateTaskExecutor<PipelineClusterStateUpdateTask> PIPELINE_TASK_EXECUTOR = (currentState, taskContexts) -> {
-        final var allIndexMetadata = currentState.metadata().indices().values();
-        final IngestMetadata initialIngestMetadata = currentState.metadata().custom(IngestMetadata.TYPE);
+    static final ClusterStateTaskExecutor<PipelineClusterStateUpdateTask> PIPELINE_TASK_EXECUTOR = batchExecutionContext -> {
+        final var allIndexMetadata = batchExecutionContext.initialState().metadata().indices().values();
+        final IngestMetadata initialIngestMetadata = batchExecutionContext.initialState().metadata().custom(IngestMetadata.TYPE);
         var currentIngestMetadata = initialIngestMetadata;
-        for (final var taskContext : taskContexts) {
+        for (final var taskContext : batchExecutionContext.taskContexts()) {
             try {
                 final var task = taskContext.getTask();
                 currentIngestMetadata = task.execute(currentIngestMetadata, allIndexMetadata);
@@ -123,8 +123,8 @@ public class IngestService implements ClusterStateApplier, ReportingService<Inge
         }
         final var finalIngestMetadata = currentIngestMetadata;
         return finalIngestMetadata == initialIngestMetadata
-            ? currentState
-            : currentState.copyAndUpdateMetadata(b -> b.putCustom(IngestMetadata.TYPE, finalIngestMetadata));
+            ? batchExecutionContext.initialState()
+            : batchExecutionContext.initialState().copyAndUpdateMetadata(b -> b.putCustom(IngestMetadata.TYPE, finalIngestMetadata));
     };
 
     /**

+ 5 - 6
server/src/main/java/org/elasticsearch/reservedstate/service/ReservedStateErrorTaskExecutor.java

@@ -14,8 +14,6 @@ import org.elasticsearch.action.ActionResponse;
 import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.cluster.ClusterStateTaskExecutor;
 
-import java.util.List;
-
 /**
  * Reserved cluster error state task executor
  * <p>
@@ -25,13 +23,14 @@ record ReservedStateErrorTaskExecutor() implements ClusterStateTaskExecutor<Rese
     private static final Logger logger = LogManager.getLogger(ReservedStateErrorTaskExecutor.class);
 
     @Override
-    public ClusterState execute(ClusterState currentState, List<TaskContext<ReservedStateErrorTask>> taskContexts) {
-        for (final var taskContext : taskContexts) {
+    public ClusterState execute(BatchExecutionContext<ReservedStateErrorTask> batchExecutionContext) {
+        var updatedState = batchExecutionContext.initialState();
+        for (final var taskContext : batchExecutionContext.taskContexts()) {
             final var task = taskContext.getTask();
-            currentState = task.execute(currentState);
+            updatedState = task.execute(updatedState);
             taskContext.success(() -> task.listener().onResponse(ActionResponse.Empty.INSTANCE));
         }
-        return currentState;
+        return updatedState;
     }
 
     @Override

+ 5 - 6
server/src/main/java/org/elasticsearch/reservedstate/service/ReservedStateUpdateTaskExecutor.java

@@ -17,8 +17,6 @@ import org.elasticsearch.cluster.ClusterStateTaskExecutor;
 import org.elasticsearch.cluster.routing.RerouteService;
 import org.elasticsearch.common.Priority;
 
-import java.util.List;
-
 /**
  * Reserved cluster state update task executor
  *
@@ -29,12 +27,13 @@ public record ReservedStateUpdateTaskExecutor(RerouteService rerouteService) imp
     private static final Logger logger = LogManager.getLogger(ReservedStateUpdateTaskExecutor.class);
 
     @Override
-    public ClusterState execute(ClusterState currentState, List<TaskContext<ReservedStateUpdateTask>> taskContexts) throws Exception {
-        for (final var taskContext : taskContexts) {
-            currentState = taskContext.getTask().execute(currentState);
+    public ClusterState execute(BatchExecutionContext<ReservedStateUpdateTask> batchExecutionContext) throws Exception {
+        var updatedState = batchExecutionContext.initialState();
+        for (final var taskContext : batchExecutionContext.taskContexts()) {
+            updatedState = taskContext.getTask().execute(updatedState);
             taskContext.success(() -> taskContext.getTask().listener().onResponse(ActionResponse.Empty.INSTANCE));
         }
-        return currentState;
+        return updatedState;
     }
 
     @Override

+ 20 - 24
server/src/main/java/org/elasticsearch/snapshots/SnapshotsService.java

@@ -2625,7 +2625,7 @@ public class SnapshotsService extends AbstractLifecycleComponent implements Clus
         /**
          * Computes an updated {@link SnapshotsInProgress} that takes into account an updated version of
          * {@link SnapshotDeletionsInProgress} that has a {@link SnapshotDeletionsInProgress.Entry} removed from it
-         * relative to the {@link SnapshotDeletionsInProgress} found in {@code currentState}.
+         * relative to the {@link SnapshotDeletionsInProgress} found in {@code initialState}.
          * The removal of a delete from the cluster state can trigger two possible actions on in-progress snapshots:
          * <ul>
          *     <li>Snapshots that had unfinished shard snapshots in state {@link ShardSnapshotStatus#UNASSIGNED_QUEUED} that
@@ -3037,9 +3037,8 @@ public class SnapshotsService extends AbstractLifecycleComponent implements Clus
      *
      * Package private to allow for tests.
      */
-    static final ClusterStateTaskExecutor<ShardSnapshotUpdate> SHARD_STATE_EXECUTOR = (
-        currentState,
-        taskContexts) -> new SnapshotShardsUpdateContext(currentState, taskContexts).computeUpdatedState();
+    static final ClusterStateTaskExecutor<ShardSnapshotUpdate> SHARD_STATE_EXECUTOR =
+        batchExecutionContext -> new SnapshotShardsUpdateContext(batchExecutionContext).computeUpdatedState();
 
     private static boolean isQueued(@Nullable ShardSnapshotStatus status) {
         return status != null && status.state() == ShardState.QUEUED;
@@ -3057,11 +3056,11 @@ public class SnapshotsService extends AbstractLifecycleComponent implements Clus
         // number of started tasks as a result of applying updates to the snapshot entries seen so far
         private int startedCount = 0;
 
-        // current cluster state
-        private final ClusterState currentState;
+        // batch execution context
+        private final ClusterStateTaskExecutor.BatchExecutionContext<ShardSnapshotUpdate> batchExecutionContext;
 
-        // task contexts to be completed on success
-        private final List<ClusterStateTaskExecutor.TaskContext<ShardSnapshotUpdate>> taskContexts;
+        // initial cluster state for update computation
+        private final ClusterState initialState;
 
         // updates outstanding to be applied to existing snapshot entries
         private final Map<String, List<ShardSnapshotUpdate>> updatesByRepo;
@@ -3069,21 +3068,18 @@ public class SnapshotsService extends AbstractLifecycleComponent implements Clus
         // updates that were used to update an existing in-progress shard snapshot
         private final Set<ShardSnapshotUpdate> executedUpdates = new HashSet<>();
 
-        SnapshotShardsUpdateContext(
-            ClusterState currentState,
-            List<ClusterStateTaskExecutor.TaskContext<ShardSnapshotUpdate>> taskContexts
-        ) {
-            this.currentState = currentState;
-            this.taskContexts = taskContexts;
-            updatesByRepo = new HashMap<>();
-            for (final var taskContext : taskContexts) {
+        SnapshotShardsUpdateContext(ClusterStateTaskExecutor.BatchExecutionContext<ShardSnapshotUpdate> batchExecutionContext) {
+            this.batchExecutionContext = batchExecutionContext;
+            this.initialState = batchExecutionContext.initialState();
+            this.updatesByRepo = new HashMap<>();
+            for (final var taskContext : batchExecutionContext.taskContexts()) {
                 updatesByRepo.computeIfAbsent(taskContext.getTask().snapshot.getRepository(), r -> new ArrayList<>())
                     .add(taskContext.getTask());
             }
         }
 
         ClusterState computeUpdatedState() {
-            final SnapshotsInProgress existing = currentState.custom(SnapshotsInProgress.TYPE, SnapshotsInProgress.EMPTY);
+            final SnapshotsInProgress existing = initialState.custom(SnapshotsInProgress.TYPE, SnapshotsInProgress.EMPTY);
             SnapshotsInProgress updated = existing;
             for (Map.Entry<String, List<ShardSnapshotUpdate>> updates : updatesByRepo.entrySet()) {
                 final String repoName = updates.getKey();
@@ -3098,8 +3094,8 @@ public class SnapshotsService extends AbstractLifecycleComponent implements Clus
                 updated = updated.withUpdatedEntriesForRepo(repoName, newEntries);
             }
 
-            final var result = new ShardSnapshotUpdateResult(currentState.metadata(), updated);
-            for (final var taskContext : taskContexts) {
+            final var result = new ShardSnapshotUpdateResult(initialState.metadata(), updated);
+            for (final var taskContext : batchExecutionContext.taskContexts()) {
                 taskContext.success(() -> taskContext.getTask().listener.onResponse(result));
             }
 
@@ -3109,10 +3105,10 @@ public class SnapshotsService extends AbstractLifecycleComponent implements Clus
                     changedCount,
                     startedCount
                 );
-                return ClusterState.builder(currentState).putCustom(SnapshotsInProgress.TYPE, updated).build();
+                return ClusterState.builder(initialState).putCustom(SnapshotsInProgress.TYPE, updated).build();
             }
             assert existing == updated;
-            return currentState;
+            return initialState;
         }
 
         private SnapshotsInProgress.Entry applyToEntry(SnapshotsInProgress.Entry entry, List<ShardSnapshotUpdate> updates) {
@@ -3260,7 +3256,7 @@ public class SnapshotsService extends AbstractLifecycleComponent implements Clus
                 if (entry.isClone() == false) {
                     tryStartSnapshotAfterCloneFinish(repoShardId, updatedState.generation());
                 } else if (isQueued(entry.shardsByRepoShardId().get(repoShardId))) {
-                    final String localNodeId = currentState.nodes().getLocalNodeId();
+                    final String localNodeId = initialState.nodes().getLocalNodeId();
                     assert updatedState.nodeId().equals(localNodeId)
                         : "Clone updated with node id [" + updatedState.nodeId() + "] but local node id is [" + localNodeId + "]";
                     startShardOperation(clonesBuilder(), localNodeId, updatedState.generation(), repoShardId);
@@ -3278,7 +3274,7 @@ public class SnapshotsService extends AbstractLifecycleComponent implements Clus
                             // shard snapshot was completed, we check if we can start a clone operation for the same repo shard
                             startShardOperation(
                                 clonesBuilder(),
-                                currentState.nodes().getLocalNodeId(),
+                                initialState.nodes().getLocalNodeId(),
                                 updatedState.generation(),
                                 repoShardId
                             );
@@ -3307,7 +3303,7 @@ public class SnapshotsService extends AbstractLifecycleComponent implements Clus
                         + "] because it's a normal snapshot but did not";
                 // work out the node to run the snapshot task on as it might have changed from the previous operation if it was a clone
                 // or there was a primary failover
-                final IndexRoutingTable indexRouting = currentState.routingTable().index(index);
+                final IndexRoutingTable indexRouting = initialState.routingTable().index(index);
                 final ShardRouting shardRouting;
                 if (indexRouting == null) {
                     shardRouting = null;

+ 1 - 3
server/src/test/java/org/elasticsearch/cluster/ClusterStateTaskExecutorTests.java

@@ -34,9 +34,7 @@ public class ClusterStateTaskExecutorTests extends ESTestCase {
     }
 
     public void testDescribeTasks() {
-        final ClusterStateTaskExecutor<TestTask> executor = (currentState, taskContexts) -> {
-            throw new AssertionError("should not be called");
-        };
+        final ClusterStateTaskExecutor<TestTask> executor = batchExecutionContext -> { throw new AssertionError("should not be called"); };
 
         assertThat("describes an empty list", executor.describeTasks(List.of()), equalTo(""));
         assertThat("describes a singleton list", executor.describeTasks(List.of(new TestTask("a task"))), equalTo("Task{a task}"));

+ 3 - 3
server/src/test/java/org/elasticsearch/cluster/metadata/MetadataIndexStateServiceBatchingTests.java

@@ -203,13 +203,13 @@ public class MetadataIndexStateServiceBatchingTests extends ESSingleNodeTestCase
             "block",
             new ExpectSuccessTask(),
             ClusterStateTaskConfig.build(Priority.URGENT),
-            (currentState, taskContexts) -> {
+            batchExecutionContext -> {
                 executionBarrier.await(10, TimeUnit.SECONDS); // notify test thread that the master service is blocked
                 executionBarrier.await(10, TimeUnit.SECONDS); // wait for test thread to release us
-                for (final var taskContext : taskContexts) {
+                for (final var taskContext : batchExecutionContext.taskContexts()) {
                     taskContext.success(() -> {});
                 }
-                return currentState;
+                return batchExecutionContext.initialState();
             }
         );
 

+ 47 - 41
server/src/test/java/org/elasticsearch/cluster/service/MasterServiceTests.java

@@ -222,11 +222,11 @@ public class MasterServiceTests extends ESTestCase {
                 ClusterStateTaskConfig.build(Priority.NORMAL),
                 new ClusterStateTaskExecutor<>() {
                     @Override
-                    public ClusterState execute(ClusterState currentState, List<TaskContext<ExpectSuccessTask>> taskContexts) {
-                        for (final var taskContext : taskContexts) {
+                    public ClusterState execute(BatchExecutionContext<ExpectSuccessTask> batchExecutionContext) {
+                        for (final var taskContext : batchExecutionContext.taskContexts()) {
                             taskContext.success(() -> {});
                         }
-                        return ClusterState.builder(currentState).build();
+                        return ClusterState.builder(batchExecutionContext.initialState()).build();
                     }
 
                     @Override
@@ -341,11 +341,11 @@ public class MasterServiceTests extends ESTestCase {
                 ClusterStateTaskConfig.build(Priority.NORMAL),
                 new ClusterStateTaskExecutor<>() {
                     @Override
-                    public ClusterState execute(ClusterState currentState, List<TaskContext<ExpectSuccessTask>> taskContexts) {
-                        for (final var taskContext : taskContexts) {
+                    public ClusterState execute(BatchExecutionContext<ExpectSuccessTask> batchExecutionContext) {
+                        for (final var taskContext : batchExecutionContext.taskContexts()) {
                             taskContext.success(() -> { throw new RuntimeException("testing exception handling"); });
                         }
-                        return ClusterState.builder(currentState).build();
+                        return ClusterState.builder(batchExecutionContext.initialState()).build();
                     }
 
                     @Override
@@ -542,14 +542,14 @@ public class MasterServiceTests extends ESTestCase {
             }
 
             @Override
-            public ClusterState execute(ClusterState currentState, List<TaskContext<ExpectSuccessTask>> taskContexts) {
+            public ClusterState execute(BatchExecutionContext<ExpectSuccessTask> batchExecutionContext) {
                 assertTrue("Should execute all tasks at once", executed.compareAndSet(false, true));
-                assertThat("Should execute all tasks at once", taskContexts.size(), equalTo(expectedTaskCount));
+                assertThat("Should execute all tasks at once", batchExecutionContext.taskContexts().size(), equalTo(expectedTaskCount));
                 executionCountDown.countDown();
-                for (final var taskContext : taskContexts) {
+                for (final var taskContext : batchExecutionContext.taskContexts()) {
                     taskContext.success(() -> {});
                 }
-                return currentState;
+                return batchExecutionContext.initialState();
             }
         }
 
@@ -566,13 +566,13 @@ public class MasterServiceTests extends ESTestCase {
                 "block",
                 new ExpectSuccessTask(),
                 ClusterStateTaskConfig.build(Priority.NORMAL),
-                (currentState, taskContexts) -> {
+                batchExecutionContext -> {
                     executionBarrier.await(10, TimeUnit.SECONDS); // notify test thread that the master service is blocked
                     executionBarrier.await(10, TimeUnit.SECONDS); // wait for test thread to release us
-                    for (final var taskContext : taskContexts) {
+                    for (final var taskContext : batchExecutionContext.taskContexts()) {
                         taskContext.success(() -> {});
                     }
-                    return currentState;
+                    return batchExecutionContext.initialState();
                 }
             );
 
@@ -707,19 +707,19 @@ public class MasterServiceTests extends ESTestCase {
             private final List<Task> assignments = new ArrayList<>();
 
             @Override
-            public ClusterState execute(ClusterState currentState, List<TaskContext<Task>> taskContexts) {
-                for (final var taskContext : taskContexts) {
+            public ClusterState execute(BatchExecutionContext<Task> batchExecutionContext) {
+                for (final var taskContext : batchExecutionContext.taskContexts()) {
                     assertThat("All tasks should belong to this executor", assignments, hasItem(taskContext.getTask()));
                 }
 
-                for (final var taskContext : taskContexts) {
+                for (final var taskContext : batchExecutionContext.taskContexts()) {
                     taskContext.getTask().execute();
                 }
 
-                executed.addAndGet(taskContexts.size());
-                ClusterState maybeUpdatedClusterState = currentState;
+                executed.addAndGet(batchExecutionContext.taskContexts().size());
+                ClusterState maybeUpdatedClusterState = batchExecutionContext.initialState();
                 if (randomBoolean()) {
-                    maybeUpdatedClusterState = ClusterState.builder(currentState).build();
+                    maybeUpdatedClusterState = ClusterState.builder(batchExecutionContext.initialState()).build();
                     batches.incrementAndGet();
                     assertThat(
                         "All cluster state modifications should be executed on a single thread",
@@ -728,7 +728,7 @@ public class MasterServiceTests extends ESTestCase {
                     );
                 }
 
-                for (final var taskContext : taskContexts) {
+                for (final var taskContext : batchExecutionContext.taskContexts()) {
                     taskContext.success(() -> {
                         processedStates.incrementAndGet();
                         processedStatesLatch.get().countDown();
@@ -845,14 +845,14 @@ public class MasterServiceTests extends ESTestCase {
             }
         }
 
-        final ClusterStateTaskExecutor<Task> executor = (currentState, taskContexts) -> {
+        final ClusterStateTaskExecutor<Task> executor = batchExecutionContext -> {
             if (randomBoolean()) {
                 throw new RuntimeException("simulated");
             } else {
-                for (final var taskContext : taskContexts) {
+                for (final var taskContext : batchExecutionContext.taskContexts()) {
                     taskContext.onFailure(new RuntimeException("simulated"));
                 }
-                return currentState;
+                return batchExecutionContext.initialState();
             }
         };
 
@@ -924,11 +924,11 @@ public class MasterServiceTests extends ESTestCase {
         final var executor = new ClusterStateTaskExecutor<Task>() {
             @Override
             @SuppressForbidden(reason = "consuming published cluster state for legacy reasons")
-            public ClusterState execute(ClusterState currentState, List<TaskContext<Task>> taskContexts) {
-                for (final var taskContext : taskContexts) {
+            public ClusterState execute(BatchExecutionContext<Task> batchExecutionContext) {
+                for (final var taskContext : batchExecutionContext.taskContexts()) {
                     taskContext.success(taskContext.getTask().publishListener::onResponse);
                 }
-                return ClusterState.builder(currentState).build();
+                return ClusterState.builder(batchExecutionContext.initialState()).build();
             }
         };
 
@@ -1043,8 +1043,8 @@ public class MasterServiceTests extends ESTestCase {
                 "testBlockingCallInClusterStateTaskListenerFails",
                 new ExpectSuccessTask(),
                 ClusterStateTaskConfig.build(Priority.NORMAL),
-                (currentState, taskContexts) -> {
-                    for (final var taskContext : taskContexts) {
+                batchExecutionContext -> {
+                    for (final var taskContext : batchExecutionContext.taskContexts()) {
                         taskContext.success(() -> {
                             BaseFuture<Void> future = new BaseFuture<Void>() {
                             };
@@ -1062,7 +1062,7 @@ public class MasterServiceTests extends ESTestCase {
                             }
                         });
                     }
-                    return ClusterState.builder(currentState).build();
+                    return ClusterState.builder(batchExecutionContext.initialState()).build();
                 }
             );
 
@@ -1399,11 +1399,13 @@ public class MasterServiceTests extends ESTestCase {
                     "success-test",
                     new Task(),
                     ClusterStateTaskConfig.build(Priority.NORMAL),
-                    (currentState, taskContexts) -> {
-                        for (final var taskContext : taskContexts) {
+                    batchExecutionContext -> {
+                        for (final var taskContext : batchExecutionContext.taskContexts()) {
                             taskContext.success(latch::countDown, taskContext.getTask());
                         }
-                        return randomBoolean() ? currentState : ClusterState.builder(currentState).build();
+                        return randomBoolean()
+                            ? batchExecutionContext.initialState()
+                            : ClusterState.builder(batchExecutionContext.initialState()).build();
                     }
                 );
 
@@ -1439,11 +1441,13 @@ public class MasterServiceTests extends ESTestCase {
                     "success-test",
                     new Task(),
                     ClusterStateTaskConfig.build(Priority.NORMAL),
-                    (currentState, taskContexts) -> {
-                        for (final var taskContext : taskContexts) {
+                    batchExecutionContext -> {
+                        for (final var taskContext : batchExecutionContext.taskContexts()) {
                             taskContext.success(latch::countDown, new LatchAckListener(latch));
                         }
-                        return randomBoolean() ? currentState : ClusterState.builder(currentState).build();
+                        return randomBoolean()
+                            ? batchExecutionContext.initialState()
+                            : ClusterState.builder(batchExecutionContext.initialState()).build();
                     }
                 );
 
@@ -1479,11 +1483,13 @@ public class MasterServiceTests extends ESTestCase {
                     "success-test",
                     new Task(),
                     ClusterStateTaskConfig.build(Priority.NORMAL),
-                    (currentState, taskContexts) -> {
-                        for (final var taskContext : taskContexts) {
+                    batchExecutionContext -> {
+                        for (final var taskContext : batchExecutionContext.taskContexts()) {
                             taskContext.success(new LatchAckListener(latch));
                         }
-                        return randomBoolean() ? currentState : ClusterState.builder(currentState).build();
+                        return randomBoolean()
+                            ? batchExecutionContext.initialState()
+                            : ClusterState.builder(batchExecutionContext.initialState()).build();
                     }
                 );
 
@@ -1755,11 +1761,11 @@ public class MasterServiceTests extends ESTestCase {
                 final Semaphore semaphore = new Semaphore(0);
 
                 @Override
-                public ClusterState execute(ClusterState currentState, List<TaskContext<Task>> taskContexts) {
-                    for (final var taskContext : taskContexts) {
+                public ClusterState execute(BatchExecutionContext<Task> batchExecutionContext) {
+                    for (final var taskContext : batchExecutionContext.taskContexts()) {
                         taskContext.success(() -> semaphore.release());
                     }
-                    return currentState;
+                    return batchExecutionContext.initialState();
                 }
             }
 

+ 2 - 2
server/src/test/java/org/elasticsearch/reservedstate/service/ReservedClusterStateServiceTests.java

@@ -179,7 +179,7 @@ public class ReservedClusterStateServiceTests extends ESTestCase {
             public void onFailure(Exception failure) {}
         };
 
-        ClusterState newState = taskExecutor.execute(state, List.of(taskContext));
+        ClusterState newState = taskExecutor.execute(new ClusterStateTaskExecutor.BatchExecutionContext<>(state, List.of(taskContext)));
         assertEquals(state, newState);
         assertTrue(successCalled.get());
         verify(task, times(1)).execute(any());
@@ -235,7 +235,7 @@ public class ReservedClusterStateServiceTests extends ESTestCase {
 
         ReservedStateErrorTaskExecutor executor = new ReservedStateErrorTaskExecutor();
 
-        ClusterState newState = executor.execute(state, List.of(taskContext));
+        ClusterState newState = executor.execute(new ClusterStateTaskExecutor.BatchExecutionContext<>(state, List.of(taskContext)));
 
         verify(task, times(1)).execute(any());
 

+ 1 - 1
test/framework/src/main/java/org/elasticsearch/cluster/service/ClusterStateTaskExecutorUtils.java

@@ -64,7 +64,7 @@ public class ClusterStateTaskExecutorUtils {
         final var taskContexts = StreamSupport.stream(tasks.spliterator(), false).<ClusterStateTaskExecutor.TaskContext<T>>map(
             TestTaskContext::new
         ).toList();
-        final var resultingState = executor.execute(originalState, taskContexts);
+        final var resultingState = executor.execute(new ClusterStateTaskExecutor.BatchExecutionContext<>(originalState, taskContexts));
         assertNotNull(resultingState);
         for (final var taskContext : taskContexts) {
             final var testTaskContext = (TestTaskContext<T>) taskContext;

+ 9 - 9
x-pack/plugin/core/src/main/java/org/elasticsearch/license/StartBasicClusterTask.java

@@ -18,7 +18,6 @@ import org.elasticsearch.core.Nullable;
 import org.elasticsearch.xpack.core.XPackPlugin;
 
 import java.time.Clock;
-import java.util.List;
 import java.util.Map;
 import java.util.UUID;
 
@@ -128,18 +127,19 @@ public class StartBasicClusterTask implements ClusterStateTaskListener {
 
     static class Executor implements ClusterStateTaskExecutor<StartBasicClusterTask> {
         @Override
-        public ClusterState execute(ClusterState currentState, List<TaskContext<StartBasicClusterTask>> taskContexts) throws Exception {
-            XPackPlugin.checkReadyForXPackCustomMetadata(currentState);
-            final LicensesMetadata originalLicensesMetadata = currentState.metadata().custom(LicensesMetadata.TYPE);
+        public ClusterState execute(BatchExecutionContext<StartBasicClusterTask> batchExecutionContext) throws Exception {
+            final var initialState = batchExecutionContext.initialState();
+            XPackPlugin.checkReadyForXPackCustomMetadata(initialState);
+            final LicensesMetadata originalLicensesMetadata = initialState.metadata().custom(LicensesMetadata.TYPE);
             var currentLicensesMetadata = originalLicensesMetadata;
-            for (final var taskContext : taskContexts) {
-                currentLicensesMetadata = taskContext.getTask().execute(currentLicensesMetadata, currentState.nodes(), taskContext);
+            for (final var taskContext : batchExecutionContext.taskContexts()) {
+                currentLicensesMetadata = taskContext.getTask().execute(currentLicensesMetadata, initialState.nodes(), taskContext);
             }
             if (currentLicensesMetadata == originalLicensesMetadata) {
-                return currentState;
+                return initialState;
             } else {
-                return ClusterState.builder(currentState)
-                    .metadata(Metadata.builder(currentState.metadata()).putCustom(LicensesMetadata.TYPE, currentLicensesMetadata))
+                return ClusterState.builder(initialState)
+                    .metadata(Metadata.builder(initialState.metadata()).putCustom(LicensesMetadata.TYPE, currentLicensesMetadata))
                     .build();
             }
         }

+ 9 - 9
x-pack/plugin/core/src/main/java/org/elasticsearch/license/StartTrialClusterTask.java

@@ -19,7 +19,6 @@ import org.elasticsearch.xpack.core.XPackPlugin;
 
 import java.time.Clock;
 import java.util.Collections;
-import java.util.List;
 import java.util.Map;
 import java.util.UUID;
 
@@ -115,18 +114,19 @@ public class StartTrialClusterTask implements ClusterStateTaskListener {
     static class Executor implements ClusterStateTaskExecutor<StartTrialClusterTask> {
 
         @Override
-        public ClusterState execute(ClusterState currentState, List<TaskContext<StartTrialClusterTask>> taskContexts) throws Exception {
-            XPackPlugin.checkReadyForXPackCustomMetadata(currentState);
-            final LicensesMetadata originalLicensesMetadata = currentState.metadata().custom(LicensesMetadata.TYPE);
+        public ClusterState execute(BatchExecutionContext<StartTrialClusterTask> batchExecutionContext) throws Exception {
+            final var initialState = batchExecutionContext.initialState();
+            XPackPlugin.checkReadyForXPackCustomMetadata(initialState);
+            final LicensesMetadata originalLicensesMetadata = initialState.metadata().custom(LicensesMetadata.TYPE);
             var currentLicensesMetadata = originalLicensesMetadata;
-            for (final var taskContext : taskContexts) {
-                currentLicensesMetadata = taskContext.getTask().execute(currentLicensesMetadata, currentState.nodes(), taskContext);
+            for (final var taskContext : batchExecutionContext.taskContexts()) {
+                currentLicensesMetadata = taskContext.getTask().execute(currentLicensesMetadata, initialState.nodes(), taskContext);
             }
             if (currentLicensesMetadata == originalLicensesMetadata) {
-                return currentState;
+                return initialState;
             } else {
-                return ClusterState.builder(currentState)
-                    .metadata(Metadata.builder(currentState.metadata()).putCustom(LicensesMetadata.TYPE, currentLicensesMetadata))
+                return ClusterState.builder(initialState)
+                    .metadata(Metadata.builder(initialState.metadata()).putCustom(LicensesMetadata.TYPE, currentLicensesMetadata))
                     .build();
             }
         }

+ 2 - 1
x-pack/plugin/core/src/test/java/org/elasticsearch/license/LicenseServiceTests.java

@@ -211,7 +211,8 @@ public class LicenseServiceTests extends ESTestCase {
                 m -> m.putCustom(LicensesMetadata.TYPE, new LicensesMetadata(oldLicense, null))
             );
 
-            ClusterState updatedState = taskExecutorCaptor.getValue().execute(oldState, List.of(taskContext));
+            ClusterState updatedState = taskExecutorCaptor.getValue()
+                .execute(new ClusterStateTaskExecutor.BatchExecutionContext<>(oldState, List.of(taskContext)));
             // Pass updated state to listener to trigger onResponse call to wrapped `future`
             listenerCaptor.getValue().run();
             assertion.accept(future);

+ 6 - 5
x-pack/plugin/ilm/src/main/java/org/elasticsearch/xpack/ilm/IndexLifecycleRunner.java

@@ -39,7 +39,6 @@ import org.elasticsearch.xpack.ilm.history.ILMHistoryStore;
 
 import java.util.Collections;
 import java.util.HashSet;
-import java.util.List;
 import java.util.Locale;
 import java.util.Objects;
 import java.util.Set;
@@ -61,13 +60,15 @@ class IndexLifecycleRunner {
         new ClusterStateTaskExecutor<>() {
             @Override
             @SuppressForbidden(reason = "consuming published cluster state for legacy reasons")
-            public ClusterState execute(ClusterState currentState, List<TaskContext<IndexLifecycleClusterStateUpdateTask>> taskContexts) {
-                ClusterState state = currentState;
-                for (final var taskContext : taskContexts) {
+            public ClusterState execute(BatchExecutionContext<IndexLifecycleClusterStateUpdateTask> batchExecutionContext) {
+                ClusterState state = batchExecutionContext.initialState();
+                for (final var taskContext : batchExecutionContext.taskContexts()) {
                     try {
                         final var task = taskContext.getTask();
                         state = task.execute(state);
-                        taskContext.success(new ClusterStateTaskExecutor.LegacyClusterTaskResultActionListener(task, currentState));
+                        taskContext.success(
+                            new ClusterStateTaskExecutor.LegacyClusterTaskResultActionListener(task, batchExecutionContext.initialState())
+                        );
                     } catch (Exception e) {
                         taskContext.onFailure(e);
                     }

+ 1 - 1
x-pack/plugin/ilm/src/test/java/org/elasticsearch/xpack/ilm/action/ReservedLifecycleStateServiceTests.java

@@ -268,7 +268,7 @@ public class ReservedLifecycleStateServiceTests extends ESTestCase {
                 }
             };
 
-            task.execute(state, List.of(context));
+            task.execute(new ClusterStateTaskExecutor.BatchExecutionContext<>(state, List.of(context)));
 
             return null;
         }).when(clusterService).submitStateUpdateTask(anyString(), any(), any(), any());

+ 3 - 5
x-pack/plugin/rollup/src/main/java/org/elasticsearch/xpack/rollup/v2/TransportRollupAction.java

@@ -91,11 +91,9 @@ public class TransportRollupAction extends AcknowledgedTransportMasterNodeAction
     /**
      * This is the cluster state task executor for cluster state update actions.
      */
-    private static final ClusterStateTaskExecutor<RollupClusterStateUpdateTask> STATE_UPDATE_TASK_EXECUTOR = (
-        currentState,
-        taskContexts) -> {
-        ClusterState state = currentState;
-        for (final var taskContext : taskContexts) {
+    private static final ClusterStateTaskExecutor<RollupClusterStateUpdateTask> STATE_UPDATE_TASK_EXECUTOR = batchExecutionContext -> {
+        ClusterState state = batchExecutionContext.initialState();
+        for (final var taskContext : batchExecutionContext.taskContexts()) {
             try {
                 final var task = taskContext.getTask();
                 state = task.execute(state);

+ 6 - 7
x-pack/plugin/shutdown/src/main/java/org/elasticsearch/xpack/shutdown/TransportDeleteShutdownNodeAction.java

@@ -34,7 +34,6 @@ import org.elasticsearch.transport.TransportService;
 import org.elasticsearch.xpack.shutdown.DeleteShutdownNodeAction.Request;
 
 import java.util.HashMap;
-import java.util.List;
 import java.util.Map;
 
 import static org.elasticsearch.cluster.metadata.NodesShutdownMetadata.getShutdownsOrEmpty;
@@ -80,10 +79,10 @@ public class TransportDeleteShutdownNodeAction extends AcknowledgedTransportMast
     // package private for tests
     class DeleteShutdownNodeExecutor implements ClusterStateTaskExecutor<DeleteShutdownNodeTask> {
         @Override
-        public ClusterState execute(ClusterState currentState, List<TaskContext<DeleteShutdownNodeTask>> taskContexts) throws Exception {
-            var shutdownMetadata = new HashMap<>(getShutdownsOrEmpty(currentState).getAllNodeMetadataMap());
+        public ClusterState execute(BatchExecutionContext<DeleteShutdownNodeTask> batchExecutionContext) throws Exception {
+            var shutdownMetadata = new HashMap<>(getShutdownsOrEmpty(batchExecutionContext.initialState()).getAllNodeMetadataMap());
             boolean changed = false;
-            for (final var taskContext : taskContexts) {
+            for (final var taskContext : batchExecutionContext.taskContexts()) {
                 var request = taskContext.getTask().request();
                 try {
                     changed |= deleteShutdownNodeState(shutdownMetadata, request);
@@ -95,11 +94,11 @@ public class TransportDeleteShutdownNodeAction extends AcknowledgedTransportMast
                 taskContext.success(() -> ackAndReroute(request, taskContext.getTask().listener(), reroute));
             }
             if (changed == false) {
-                return currentState;
+                return batchExecutionContext.initialState();
             }
-            return ClusterState.builder(currentState)
+            return ClusterState.builder(batchExecutionContext.initialState())
                 .metadata(
-                    Metadata.builder(currentState.metadata())
+                    Metadata.builder(batchExecutionContext.initialState().metadata())
                         .putCustom(NodesShutdownMetadata.TYPE, new NodesShutdownMetadata(shutdownMetadata))
                 )
                 .build();

+ 8 - 8
x-pack/plugin/shutdown/src/main/java/org/elasticsearch/xpack/shutdown/TransportPutShutdownNodeAction.java

@@ -33,7 +33,6 @@ import org.elasticsearch.transport.TransportService;
 import org.elasticsearch.xpack.shutdown.PutShutdownNodeAction.Request;
 
 import java.util.HashMap;
-import java.util.List;
 import java.util.Map;
 import java.util.Objects;
 import java.util.function.Predicate;
@@ -117,11 +116,12 @@ public class TransportPutShutdownNodeAction extends AcknowledgedTransportMasterN
     // package private for tests
     class PutShutdownNodeExecutor implements ClusterStateTaskExecutor<PutShutdownNodeTask> {
         @Override
-        public ClusterState execute(ClusterState currentState, List<TaskContext<PutShutdownNodeTask>> taskContexts) throws Exception {
-            var shutdownMetadata = new HashMap<>(getShutdownsOrEmpty(currentState).getAllNodeMetadataMap());
-            Predicate<String> nodeExistsPredicate = currentState.getNodes()::nodeExists;
+        public ClusterState execute(BatchExecutionContext<PutShutdownNodeTask> batchExecutionContext) throws Exception {
+            final var initialState = batchExecutionContext.initialState();
+            var shutdownMetadata = new HashMap<>(getShutdownsOrEmpty(initialState).getAllNodeMetadataMap());
+            Predicate<String> nodeExistsPredicate = batchExecutionContext.initialState().getNodes()::nodeExists;
             boolean changed = false;
-            for (final var taskContext : taskContexts) {
+            for (final var taskContext : batchExecutionContext.taskContexts()) {
                 var request = taskContext.getTask().request();
                 try {
                     changed |= putShutdownNodeState(shutdownMetadata, nodeExistsPredicate, request);
@@ -133,11 +133,11 @@ public class TransportPutShutdownNodeAction extends AcknowledgedTransportMasterN
                 taskContext.success(() -> ackAndMaybeReroute(request, taskContext.getTask().listener(), reroute));
             }
             if (changed == false) {
-                return currentState;
+                return batchExecutionContext.initialState();
             }
-            return ClusterState.builder(currentState)
+            return ClusterState.builder(batchExecutionContext.initialState())
                 .metadata(
-                    Metadata.builder(currentState.metadata())
+                    Metadata.builder(batchExecutionContext.initialState().metadata())
                         .putCustom(NodesShutdownMetadata.TYPE, new NodesShutdownMetadata(shutdownMetadata))
                 )
                 .build();

+ 3 - 1
x-pack/plugin/shutdown/src/test/java/org/elasticsearch/xpack/shutdown/TransportDeleteShutdownNodeActionTests.java

@@ -11,6 +11,7 @@ import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.support.ActionFilters;
 import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.cluster.ClusterStateTaskConfig;
+import org.elasticsearch.cluster.ClusterStateTaskExecutor;
 import org.elasticsearch.cluster.ClusterStateTaskExecutor.TaskContext;
 import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
 import org.elasticsearch.cluster.metadata.Metadata;
@@ -76,7 +77,8 @@ public class TransportDeleteShutdownNodeActionTests extends ESTestCase {
         var taskExecutor = ArgumentCaptor.forClass(DeleteShutdownNodeExecutor.class);
         verify(clusterService).submitStateUpdateTask(any(), updateTask.capture(), taskConfig.capture(), taskExecutor.capture());
         when(taskContext.getTask()).thenReturn(updateTask.getValue());
-        ClusterState gotState = taskExecutor.getValue().execute(ClusterState.EMPTY_STATE, List.of(taskContext));
+        ClusterState gotState = taskExecutor.getValue()
+            .execute(new ClusterStateTaskExecutor.BatchExecutionContext<>(ClusterState.EMPTY_STATE, List.of(taskContext)));
         assertThat(gotState, sameInstance(ClusterState.EMPTY_STATE));
     }
 }

+ 5 - 2
x-pack/plugin/shutdown/src/test/java/org/elasticsearch/xpack/shutdown/TransportPutShutdownNodeActionTests.java

@@ -11,6 +11,7 @@ import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.support.ActionFilters;
 import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.cluster.ClusterStateTaskConfig;
+import org.elasticsearch.cluster.ClusterStateTaskExecutor;
 import org.elasticsearch.cluster.ClusterStateTaskExecutor.TaskContext;
 import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
 import org.elasticsearch.cluster.metadata.SingleNodeShutdownMetadata.Type;
@@ -74,7 +75,8 @@ public class TransportPutShutdownNodeActionTests extends ESTestCase {
         var taskExecutor = ArgumentCaptor.forClass(PutShutdownNodeExecutor.class);
         verify(clusterService).submitStateUpdateTask(any(), updateTask.capture(), taskConfig.capture(), taskExecutor.capture());
         when(taskContext.getTask()).thenReturn(updateTask.getValue());
-        ClusterState stableState = taskExecutor.getValue().execute(ClusterState.EMPTY_STATE, List.of(taskContext));
+        ClusterState stableState = taskExecutor.getValue()
+            .execute(new ClusterStateTaskExecutor.BatchExecutionContext<>(ClusterState.EMPTY_STATE, List.of(taskContext)));
 
         // run the request again, there should be no call to submit an update task
         clearInvocations(clusterService);
@@ -85,7 +87,8 @@ public class TransportPutShutdownNodeActionTests extends ESTestCase {
         action.masterOperation(null, request, ClusterState.EMPTY_STATE, ActionListener.noop());
         verify(clusterService).submitStateUpdateTask(any(), updateTask.capture(), taskConfig.capture(), taskExecutor.capture());
         when(taskContext.getTask()).thenReturn(updateTask.getValue());
-        ClusterState gotState = taskExecutor.getValue().execute(stableState, List.of(taskContext));
+        ClusterState gotState = taskExecutor.getValue()
+            .execute(new ClusterStateTaskExecutor.BatchExecutionContext<>(stableState, List.of(taskContext)));
         assertThat(gotState, sameInstance(stableState));
     }
 }