Browse Source

Capture deprecation warnings in batched master tasks (#85525)

It's possible for a cluster state update task to emit deprecation
warnings, but if the task is executed in a batch then these warnings
will be exposed to the listener for every item in the batch. With this
commit we introduce a mechanism for tasks to capture just the warnings
relevant to them, along with assertions that warnings are not
inadvertently leaked back to the master service.

Closes #85506
David Turner 3 years ago
parent
commit
745947e854
32 changed files with 483 additions and 82 deletions
  1. 6 0
      docs/changelog/85525.yaml
  2. 4 1
      modules/data-streams/src/main/java/org/elasticsearch/datastreams/UpdateTimeSeriesRangeService.java
  3. 111 0
      server/src/internalClusterTest/java/org/elasticsearch/action/support/AutoCreateIndexIT.java
  4. 1 1
      server/src/main/java/org/elasticsearch/action/admin/cluster/desirednodes/TransportUpdateDesiredNodesAction.java
  5. 1 1
      server/src/main/java/org/elasticsearch/action/admin/indices/create/AutoCreateAction.java
  6. 4 2
      server/src/main/java/org/elasticsearch/action/admin/indices/rollover/TransportRolloverAction.java
  7. 25 3
      server/src/main/java/org/elasticsearch/cluster/ClusterStateTaskExecutor.java
  8. 5 3
      server/src/main/java/org/elasticsearch/cluster/LocalMasterServiceTask.java
  9. 2 1
      server/src/main/java/org/elasticsearch/cluster/action/shard/ShardStateAction.java
  10. 10 10
      server/src/main/java/org/elasticsearch/cluster/coordination/NodeRemovalClusterStateTaskExecutor.java
  11. 8 2
      server/src/main/java/org/elasticsearch/cluster/metadata/MetadataIndexStateService.java
  12. 3 1
      server/src/main/java/org/elasticsearch/cluster/metadata/MetadataIndexTemplateService.java
  13. 1 1
      server/src/main/java/org/elasticsearch/cluster/metadata/MetadataMappingService.java
  14. 6 2
      server/src/main/java/org/elasticsearch/cluster/metadata/MetadataUpdateSettingsService.java
  15. 120 31
      server/src/main/java/org/elasticsearch/cluster/service/MasterService.java
  16. 3 1
      server/src/main/java/org/elasticsearch/health/metadata/HealthMetadataService.java
  17. 3 1
      server/src/main/java/org/elasticsearch/ingest/IngestService.java
  18. 3 1
      server/src/main/java/org/elasticsearch/reservedstate/service/ReservedStateErrorTaskExecutor.java
  19. 3 1
      server/src/main/java/org/elasticsearch/reservedstate/service/ReservedStateUpdateTaskExecutor.java
  20. 113 5
      server/src/test/java/org/elasticsearch/cluster/service/MasterServiceTests.java
  21. 17 2
      server/src/test/java/org/elasticsearch/reservedstate/service/ReservedClusterStateServiceTests.java
  22. 9 1
      test/framework/src/main/java/org/elasticsearch/cluster/service/ClusterStateTaskExecutorUtils.java
  23. 3 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/license/StartBasicClusterTask.java
  24. 3 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/license/StartTrialClusterTask.java
  25. 1 1
      x-pack/plugin/core/src/test/java/org/elasticsearch/license/LicenseServiceTests.java
  26. 3 1
      x-pack/plugin/ilm/src/main/java/org/elasticsearch/xpack/ilm/IndexLifecycleRunner.java
  27. 7 1
      x-pack/plugin/ilm/src/test/java/org/elasticsearch/xpack/ilm/action/ReservedLifecycleStateServiceTests.java
  28. 3 1
      x-pack/plugin/rollup/src/main/java/org/elasticsearch/xpack/rollup/v2/TransportRollupAction.java
  29. 1 1
      x-pack/plugin/shutdown/src/main/java/org/elasticsearch/xpack/shutdown/TransportDeleteShutdownNodeAction.java
  30. 1 1
      x-pack/plugin/shutdown/src/main/java/org/elasticsearch/xpack/shutdown/TransportPutShutdownNodeAction.java
  31. 1 1
      x-pack/plugin/shutdown/src/test/java/org/elasticsearch/xpack/shutdown/TransportDeleteShutdownNodeActionTests.java
  32. 2 2
      x-pack/plugin/shutdown/src/test/java/org/elasticsearch/xpack/shutdown/TransportPutShutdownNodeActionTests.java

+ 6 - 0
docs/changelog/85525.yaml

@@ -0,0 +1,6 @@
+pr: 85525
+summary: Capture deprecation warnings in batched master tasks
+area: Cluster Coordination
+type: bug
+issues:
+ - 85506

+ 4 - 1
modules/data-streams/src/main/java/org/elasticsearch/datastreams/UpdateTimeSeriesRangeService.java

@@ -200,7 +200,10 @@ public class UpdateTimeSeriesRangeService extends AbstractLifecycleComponent imp
     private class UpdateTimeSeriesExecutor implements ClusterStateTaskExecutor<UpdateTimeSeriesTask> {
         @Override
         public ClusterState execute(BatchExecutionContext<UpdateTimeSeriesTask> batchExecutionContext) throws Exception {
-            var result = updateTimeSeriesTemporalRange(batchExecutionContext.initialState(), Instant.now());
+            final ClusterState result;
+            try (var ignored = batchExecutionContext.dropHeadersContext()) {
+                result = updateTimeSeriesTemporalRange(batchExecutionContext.initialState(), Instant.now());
+            }
             for (final var taskContext : batchExecutionContext.taskContexts()) {
                 taskContext.success(() -> taskContext.getTask().listener().accept(null));
             }

+ 111 - 0
server/src/internalClusterTest/java/org/elasticsearch/action/support/AutoCreateIndexIT.java

@@ -0,0 +1,111 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.action.support;
+
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.index.IndexResponse;
+import org.elasticsearch.cluster.ClusterStateTaskConfig;
+import org.elasticsearch.cluster.service.ClusterService;
+import org.elasticsearch.common.Priority;
+import org.elasticsearch.test.ESIntegTestCase;
+import org.elasticsearch.xcontent.XContentType;
+
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.CyclicBarrier;
+import java.util.concurrent.TimeUnit;
+
+import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.hasItems;
+import static org.hamcrest.Matchers.not;
+
+public class AutoCreateIndexIT extends ESIntegTestCase {
+    public void testBatchingWithDeprecationWarnings() throws Exception {
+        final var masterNodeClusterService = internalCluster().getCurrentMasterNodeInstance(ClusterService.class);
+        final var barrier = new CyclicBarrier(2);
+        masterNodeClusterService.submitStateUpdateTask(
+            "block",
+            e -> { assert false : e; },
+            ClusterStateTaskConfig.build(Priority.NORMAL),
+            batchExecutionContext -> {
+                barrier.await(10, TimeUnit.SECONDS);
+                barrier.await(10, TimeUnit.SECONDS);
+                batchExecutionContext.taskContexts().forEach(c -> c.success(() -> {}));
+                return batchExecutionContext.initialState();
+            }
+        );
+
+        barrier.await(10, TimeUnit.SECONDS);
+
+        final var countDownLatch = new CountDownLatch(2);
+
+        final var client = client();
+        client.prepareIndex("no-dot").setSource("{}", XContentType.JSON).execute(new ActionListener<>() {
+            @Override
+            public void onResponse(IndexResponse indexResponse) {
+                try {
+                    final var warningHeaders = client.threadPool().getThreadContext().getResponseHeaders().get("Warning");
+                    if (warningHeaders != null) {
+                        assertThat(
+                            warningHeaders,
+                            not(
+                                hasItems(
+                                    containsString("index names starting with a dot are reserved for hidden indices and system indices")
+                                )
+                            )
+                        );
+                    }
+                } finally {
+                    countDownLatch.countDown();
+                }
+            }
+
+            @Override
+            public void onFailure(Exception e) {
+                countDownLatch.countDown();
+                assert false : e;
+            }
+        });
+
+        client.prepareIndex(".has-dot").setSource("{}", XContentType.JSON).execute(new ActionListener<>() {
+            @Override
+            public void onResponse(IndexResponse indexResponse) {
+                try {
+                    final var warningHeaders = client.threadPool().getThreadContext().getResponseHeaders().get("Warning");
+                    assertNotNull(warningHeaders);
+                    assertThat(
+                        warningHeaders,
+                        hasItems(containsString("index names starting with a dot are reserved for hidden indices and system indices"))
+                    );
+                } finally {
+                    countDownLatch.countDown();
+                }
+            }
+
+            @Override
+            public void onFailure(Exception e) {
+                countDownLatch.countDown();
+                assert false : e;
+            }
+        });
+
+        assertBusy(
+            () -> assertThat(
+                masterNodeClusterService.getMasterService()
+                    .pendingTasks()
+                    .stream()
+                    .map(pendingClusterTask -> pendingClusterTask.getSource().string())
+                    .toList(),
+                hasItems("auto create [no-dot]", "auto create [.has-dot]")
+            )
+        );
+
+        barrier.await(10, TimeUnit.SECONDS);
+        assertTrue(countDownLatch.await(10, TimeUnit.SECONDS));
+    }
+}

+ 1 - 1
server/src/main/java/org/elasticsearch/action/admin/cluster/desirednodes/TransportUpdateDesiredNodesAction.java

@@ -192,7 +192,7 @@ public class TransportUpdateDesiredNodesAction extends TransportMasterNodeAction
                     continue;
                 }
                 final var previousDesiredNodes = desiredNodes;
-                try {
+                try (var ignored = taskContext.captureResponseHeaders()) {
                     desiredNodes = updateDesiredNodes(desiredNodes, request);
                 } catch (Exception e) {
                     taskContext.onFailure(e);

+ 1 - 1
server/src/main/java/org/elasticsearch/action/admin/indices/create/AutoCreateAction.java

@@ -113,7 +113,7 @@ public final class AutoCreateAction extends ActionType<CreateIndexResponse> {
                 ClusterState state = batchExecutionContext.initialState();
                 for (final var taskContext : taskContexts) {
                     final var task = taskContext.getTask();
-                    try {
+                    try (var ignored = taskContext.captureResponseHeaders()) {
                         state = task.execute(state, successfulRequests, taskContext);
                         assert successfulRequests.containsKey(task.request);
                     } catch (Exception e) {

+ 4 - 2
server/src/main/java/org/elasticsearch/action/admin/indices/rollover/TransportRolloverAction.java

@@ -263,7 +263,7 @@ public class TransportRolloverAction extends TransportMasterNodeAction<RolloverR
             final var results = new ArrayList<MetadataRolloverService.RolloverResult>(batchExecutionContext.taskContexts().size());
             var state = batchExecutionContext.initialState();
             for (final var taskContext : batchExecutionContext.taskContexts()) {
-                try {
+                try (var ignored = taskContext.captureResponseHeaders()) {
                     state = executeTask(state, results, taskContext);
                 } catch (Exception e) {
                     taskContext.onFailure(e);
@@ -280,7 +280,9 @@ public class TransportRolloverAction extends TransportMasterNodeAction<RolloverR
                     1024,
                     reason
                 );
-                state = allocationService.reroute(state, reason.toString());
+                try (var ignored = batchExecutionContext.dropHeadersContext()) {
+                    state = allocationService.reroute(state, reason.toString());
+                }
             }
             return state;
         }

+ 25 - 3
server/src/main/java/org/elasticsearch/cluster/ClusterStateTaskExecutor.java

@@ -8,9 +8,11 @@
 package org.elasticsearch.cluster;
 
 import org.elasticsearch.common.Strings;
+import org.elasticsearch.core.Releasable;
 
 import java.util.List;
 import java.util.function.Consumer;
+import java.util.function.Supplier;
 
 /**
  * An executor for batches of cluster state update tasks.
@@ -52,8 +54,8 @@ public interface ClusterStateTaskExecutor<T extends ClusterStateTaskListener> {
     /**
      * Builds a concise description of a list of tasks (to be used in logging etc.).
      *
-     * Note that the tasks given are not necessarily the same as those that will be passed to {@link #execute(BatchExecutionContext)}.
-     * but are guaranteed to be a subset of them. This method can be called multiple times with different lists before execution.
+     * Note that the tasks given are not necessarily the same as those that will be passed to {@link #execute} but are guaranteed to be a
+     * subset of them. This method can be called multiple times with different lists before execution.
      *
      * @param tasks the tasks to describe.
      * @return A string which describes the batch of tasks.
@@ -200,6 +202,11 @@ public interface ClusterStateTaskExecutor<T extends ClusterStateTaskListener> {
          * @param failure The exception with which the task failed.
          */
         void onFailure(Exception failure);
+
+        /**
+         * Creates a context which captures any response headers (e.g. deprecation warnings) to be fed to the task's listener on completion.
+         */
+        Releasable captureResponseHeaders();
     }
 
     /**
@@ -207,6 +214,21 @@ public interface ClusterStateTaskExecutor<T extends ClusterStateTaskListener> {
      *
      * @param initialState The initial cluster state on which the tasks should be executed.
      * @param taskContexts A {@link TaskContext} for each task in the batch. Implementations must complete every context in the list.
+     * @param dropHeadersContextSupplier Supplies a context (a resource for use in a try-with-resources block) which captures and drops any
+     *                                   emitted response headers, for cases where things like deprecation warnings may be emitted but
+     *                                   cannot be associated with any specific task.
      */
-    record BatchExecutionContext<T extends ClusterStateTaskListener> (ClusterState initialState, List<TaskContext<T>> taskContexts) {}
+    record BatchExecutionContext<T extends ClusterStateTaskListener> (
+        ClusterState initialState,
+        List<TaskContext<T>> taskContexts,
+        Supplier<Releasable> dropHeadersContextSupplier
+    ) {
+        /**
+         * Creates a context (a resource for use in a try-with-resources block) which captures and drops any emitted response headers, for
+         * cases where things like deprecation warnings may be emitted but cannot be associated with any specific task.
+         */
+        public Releasable dropHeadersContext() {
+            return dropHeadersContextSupplier.get();
+        }
+    }
 }

+ 5 - 3
server/src/main/java/org/elasticsearch/cluster/LocalMasterServiceTask.java

@@ -23,7 +23,7 @@ public abstract class LocalMasterServiceTask implements ClusterStateTaskListener
         this.priority = priority;
     }
 
-    protected void execute(ClusterState currentState) throws Exception {}
+    protected void execute(ClusterState currentState) {}
 
     @Override
     public final void clusterStateProcessed(ClusterState oldState, ClusterState newState) {
@@ -52,12 +52,14 @@ public abstract class LocalMasterServiceTask implements ClusterStateTaskListener
                 }
 
                 @Override
-                public ClusterState execute(BatchExecutionContext<LocalMasterServiceTask> batchExecutionContext) throws Exception {
+                public ClusterState execute(BatchExecutionContext<LocalMasterServiceTask> batchExecutionContext) {
                     final var thisTask = LocalMasterServiceTask.this;
                     final var taskContexts = batchExecutionContext.taskContexts();
                     assert taskContexts.size() == 1 && taskContexts.get(0).getTask() == thisTask
                         : "expected one-element task list containing current object but was " + taskContexts;
-                    thisTask.execute(batchExecutionContext.initialState());
+                    try (var ignored = taskContexts.get(0).captureResponseHeaders()) {
+                        thisTask.execute(batchExecutionContext.initialState());
+                    }
                     taskContexts.get(0).success(() -> onPublicationComplete());
                     return batchExecutionContext.initialState();
                 }

+ 2 - 1
server/src/main/java/org/elasticsearch/cluster/action/shard/ShardStateAction.java

@@ -405,7 +405,8 @@ public class ShardStateAction {
             assert tasksToBeApplied.size() == failedShardsToBeApplied.size() + staleShardsToBeApplied.size();
 
             ClusterState maybeUpdatedState = initialState;
-            try {
+            try (var ignored = batchExecutionContext.dropHeadersContext()) {
+                // drop deprecation warnings arising from the computation (reroute etc).
                 maybeUpdatedState = applyFailedShards(initialState, failedShardsToBeApplied, staleShardsToBeApplied);
                 for (final var taskContext : tasksToBeApplied) {
                     taskContext.success(() -> taskContext.getTask().listener().onResponse(TransportResponse.Empty.INSTANCE));

+ 10 - 10
server/src/main/java/org/elasticsearch/cluster/coordination/NodeRemovalClusterStateTaskExecutor.java

@@ -66,22 +66,22 @@ public class NodeRemovalClusterStateTaskExecutor implements ClusterStateTaskExec
             taskContext.success(task.onClusterStateProcessed::run);
         }
 
-        final ClusterState finalState;
+        if (removed == false) {
+            // no nodes to remove, keep the current cluster state
+            return initialState;
+        }
+
+        try (var ignored = batchExecutionContext.dropHeadersContext()) {
+            // suppress deprecation warnings e.g. from reroute()
 
-        if (removed) {
-            final ClusterState remainingNodesClusterState = remainingNodesClusterState(initialState, remainingNodesBuilder);
-            final ClusterState ptasksDisassociatedState = PersistentTasksCustomMetadata.disassociateDeadNodes(remainingNodesClusterState);
-            finalState = allocationService.disassociateDeadNodes(
+            final var remainingNodesClusterState = remainingNodesClusterState(initialState, remainingNodesBuilder);
+            final var ptasksDisassociatedState = PersistentTasksCustomMetadata.disassociateDeadNodes(remainingNodesClusterState);
+            return allocationService.disassociateDeadNodes(
                 ptasksDisassociatedState,
                 true,
                 describeTasks(batchExecutionContext.taskContexts().stream().map(TaskContext::getTask).toList())
             );
-        } else {
-            // no nodes to remove, keep the current cluster state
-            finalState = initialState;
         }
-
-        return finalState;
     }
 
     // visible for testing

+ 8 - 2
server/src/main/java/org/elasticsearch/cluster/metadata/MetadataIndexStateService.java

@@ -279,7 +279,10 @@ public class MetadataIndexStateService {
                 }
             }
 
-            return allocationService.reroute(state, "indices closed");
+            try (var ignored = batchExecutionContext.dropHeadersContext()) {
+                // reroute may encounter deprecated features but the resulting warnings are not associated with any particular task
+                return allocationService.reroute(state, "indices closed");
+            }
         }
     }
 
@@ -1103,7 +1106,10 @@ public class MetadataIndexStateService {
         public ClusterState execute(BatchExecutionContext<OpenIndicesTask> batchExecutionContext) {
             ClusterState state = batchExecutionContext.initialState();
 
-            try {
+            try (var ignored = batchExecutionContext.dropHeadersContext()) {
+                // we may encounter deprecated settings but they are not directly related to opening the indices, nor are they really
+                // associated with any particular tasks, so we drop them
+
                 // build an in-order de-duplicated array of all the indices to open
                 final Set<Index> indicesToOpen = Sets.newLinkedHashSetWithExpectedSize(batchExecutionContext.taskContexts().size());
                 for (final var taskContext : batchExecutionContext.taskContexts()) {

+ 3 - 1
server/src/main/java/org/elasticsearch/cluster/metadata/MetadataIndexTemplateService.java

@@ -130,7 +130,9 @@ public class MetadataIndexTemplateService {
         for (final var taskContext : batchExecutionContext.taskContexts()) {
             try {
                 final var task = taskContext.getTask();
-                state = task.execute(state);
+                try (var ignored = taskContext.captureResponseHeaders()) {
+                    state = task.execute(state);
+                }
                 taskContext.success(() -> task.listener.onResponse(AcknowledgedResponse.TRUE));
             } catch (Exception e) {
                 taskContext.onFailure(e);

+ 1 - 1
server/src/main/java/org/elasticsearch/cluster/metadata/MetadataMappingService.java

@@ -101,7 +101,7 @@ public class MetadataMappingService {
                 for (final var taskContext : batchExecutionContext.taskContexts()) {
                     final var task = taskContext.getTask();
                     final PutMappingClusterStateUpdateRequest request = task.request;
-                    try {
+                    try (var ignored = taskContext.captureResponseHeaders()) {
                         for (Index index : request.indices()) {
                             final IndexMetadata indexMetadata = currentState.metadata().getIndexSafe(index);
                             if (indexMapperServices.containsKey(indexMetadata.getIndex()) == false) {

+ 6 - 2
server/src/main/java/org/elasticsearch/cluster/metadata/MetadataUpdateSettingsService.java

@@ -75,7 +75,9 @@ public class MetadataUpdateSettingsService {
             for (final var taskContext : batchExecutionContext.taskContexts()) {
                 try {
                     final var task = taskContext.getTask();
-                    state = task.execute(state);
+                    try (var ignored = taskContext.captureResponseHeaders()) {
+                        state = task.execute(state);
+                    }
                     taskContext.success(task.getAckListener());
                 } catch (Exception e) {
                     taskContext.onFailure(e);
@@ -83,7 +85,9 @@ public class MetadataUpdateSettingsService {
             }
             if (state != batchExecutionContext.initialState()) {
                 // reroute in case things change that require it (like number of replicas)
-                state = allocationService.reroute(state, "settings update");
+                try (var ignored = batchExecutionContext.dropHeadersContext()) {
+                    state = allocationService.reroute(state, "settings update");
+                }
             }
             return state;
         };

+ 120 - 31
server/src/main/java/org/elasticsearch/cluster/service/MasterService.java

@@ -32,6 +32,7 @@ import org.elasticsearch.common.settings.ClusterSettings;
 import org.elasticsearch.common.settings.Setting;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.text.Text;
+import org.elasticsearch.common.util.CollectionUtils;
 import org.elasticsearch.common.util.concurrent.CountDown;
 import org.elasticsearch.common.util.concurrent.EsExecutors;
 import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
@@ -39,6 +40,8 @@ import org.elasticsearch.common.util.concurrent.FutureUtils;
 import org.elasticsearch.common.util.concurrent.PrioritizedEsThreadPoolExecutor;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
 import org.elasticsearch.core.Nullable;
+import org.elasticsearch.core.Releasable;
+import org.elasticsearch.core.Releasables;
 import org.elasticsearch.core.SuppressForbidden;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.node.Node;
@@ -50,7 +53,9 @@ import org.elasticsearch.threadpool.Scheduler;
 import org.elasticsearch.threadpool.ThreadPool;
 
 import java.util.Arrays;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
 import java.util.Objects;
 import java.util.concurrent.TimeUnit;
 import java.util.function.Consumer;
@@ -157,7 +162,7 @@ public class MasterService extends AbstractLifecycleComponent {
         @Override
         protected void onTimeout(BatchedTask task, TimeValue timeout) {
             threadPool.generic()
-                .execute(() -> ((UpdateTask) task).onFailure(new ProcessClusterEventTimeoutException(timeout, task.source)));
+                .execute(() -> ((UpdateTask) task).onFailure(new ProcessClusterEventTimeoutException(timeout, task.source), () -> {}));
         }
 
         @Override
@@ -188,8 +193,9 @@ public class MasterService extends AbstractLifecycleComponent {
                 );
             }
 
-            public void onFailure(Exception e) {
+            public void onFailure(Exception e, Runnable restoreResponseHeaders) {
                 try (ThreadContext.StoredContext ignore = threadContextSupplier.get()) {
+                    restoreResponseHeaders.run();
                     listener.onFailure(e);
                 } catch (Exception inner) {
                     inner.addSuppressed(e);
@@ -198,10 +204,21 @@ public class MasterService extends AbstractLifecycleComponent {
             }
 
             @Nullable
-            public ContextPreservingAckListener wrapInTaskContext(@Nullable ClusterStateAckListener clusterStateAckListener) {
+            public ContextPreservingAckListener wrapInTaskContext(
+                @Nullable ClusterStateAckListener clusterStateAckListener,
+                Runnable restoreResponseHeaders
+            ) {
                 return clusterStateAckListener == null
                     ? null
-                    : new ContextPreservingAckListener(Objects.requireNonNull(clusterStateAckListener), threadContextSupplier);
+                    : new ContextPreservingAckListener(
+                        Objects.requireNonNull(clusterStateAckListener),
+                        threadContextSupplier,
+                        restoreResponseHeaders
+                    );
+            }
+
+            ThreadContext getThreadContext() {
+                return threadPool.getThreadContext();
             }
         }
     }
@@ -250,7 +267,7 @@ public class MasterService extends AbstractLifecycleComponent {
 
         if (previousClusterState.nodes().isLocalNodeElectedMaster() == false && executor.runOnlyOnMaster()) {
             logger.debug("failing [{}]: local node is no longer master", summary);
-            updateTasks.forEach(t -> t.onFailure(new NotMasterException("no longer master, failing [" + t.source() + "]")));
+            updateTasks.forEach(t -> t.onFailure(new NotMasterException("no longer master, failing [" + t.source() + "]"), () -> {}));
             return;
         }
 
@@ -258,12 +275,12 @@ public class MasterService extends AbstractLifecycleComponent {
         final var executionResults = updateTasks.stream().map(ExecutionResult::new).toList();
         final var newClusterState = patchVersions(
             previousClusterState,
-            executeTasks(previousClusterState, executionResults, executor, summary)
+            executeTasks(previousClusterState, executionResults, executor, summary, threadPool.getThreadContext())
         );
         // fail all tasks that have failed
         for (final var executionResult : executionResults) {
             if (executionResult.failure != null) {
-                executionResult.updateTask.onFailure(executionResult.failure);
+                executionResult.updateTask.onFailure(executionResult.failure, executionResult::restoreResponseHeaders);
             }
         }
         final TimeValue computationTime = getTimeSince(computationStartTime);
@@ -529,7 +546,10 @@ public class MasterService extends AbstractLifecycleComponent {
                 : "this only supports a single task but received " + batchExecutionContext.taskContexts();
             final var taskContext = batchExecutionContext.taskContexts().get(0);
             final var task = taskContext.getTask();
-            final var newState = task.execute(batchExecutionContext.initialState());
+            final ClusterState newState;
+            try (var ignored = taskContext.captureResponseHeaders()) {
+                newState = task.execute(batchExecutionContext.initialState());
+            }
             final Consumer<ClusterState> publishListener = publishedState -> task.clusterStateProcessed(
                 batchExecutionContext.initialState(),
                 publishedState
@@ -644,7 +664,11 @@ public class MasterService extends AbstractLifecycleComponent {
      * callbacks, and also logs and swallows any exceptions thrown. One of these is created for each task in the batch that passes a
      * {@link ClusterStateAckListener} to {@link ClusterStateTaskExecutor.TaskContext#success}.
      */
-    private record ContextPreservingAckListener(ClusterStateAckListener listener, Supplier<ThreadContext.StoredContext> context) {
+    private record ContextPreservingAckListener(
+        ClusterStateAckListener listener,
+        Supplier<ThreadContext.StoredContext> context,
+        Runnable restoreResponseHeaders
+    ) {
 
         public boolean mustAck(DiscoveryNode discoveryNode) {
             return listener.mustAck(discoveryNode);
@@ -652,6 +676,7 @@ public class MasterService extends AbstractLifecycleComponent {
 
         public void onAckSuccess() {
             try (ThreadContext.StoredContext ignore = context.get()) {
+                restoreResponseHeaders.run();
                 listener.onAllNodesAcked();
             } catch (Exception inner) {
                 logger.error("exception thrown by listener while notifying on all nodes acked", inner);
@@ -660,6 +685,7 @@ public class MasterService extends AbstractLifecycleComponent {
 
         public void onAckFailure(@Nullable Exception e) {
             try (ThreadContext.StoredContext ignore = context.get()) {
+                restoreResponseHeaders.run();
                 listener.onAckFailure(e);
             } catch (Exception inner) {
                 inner.addSuppressed(e);
@@ -669,6 +695,7 @@ public class MasterService extends AbstractLifecycleComponent {
 
         public void onAckTimeout() {
             try (ThreadContext.StoredContext ignore = context.get()) {
+                restoreResponseHeaders.run();
                 listener.onAckTimeout();
             } catch (Exception e) {
                 logger.error("exception thrown by listener while notifying on ack timeout", e);
@@ -807,6 +834,9 @@ public class MasterService extends AbstractLifecycleComponent {
         @Nullable // if the task is incomplete or succeeded
         Exception failure;
 
+        @Nullable
+        Map<String, List<String>> responseHeaders;
+
         ExecutionResult(Batcher.UpdateTask updateTask) {
             this.updateTask = updateTask;
         }
@@ -877,6 +907,40 @@ public class MasterService extends AbstractLifecycleComponent {
             this.failure = Objects.requireNonNull(failure);
         }
 
+        @Override
+        public Releasable captureResponseHeaders() {
+            final var threadContext = updateTask.getThreadContext();
+            final var storedContext = threadContext.newStoredContext(false);
+            return Releasables.wrap(() -> {
+                final var newResponseHeaders = threadContext.getResponseHeaders();
+                if (newResponseHeaders.isEmpty()) {
+                    return;
+                }
+                if (responseHeaders == null) {
+                    responseHeaders = new HashMap<>(newResponseHeaders);
+                } else {
+                    for (final var newResponseHeader : newResponseHeaders.entrySet()) {
+                        responseHeaders.compute(newResponseHeader.getKey(), (ignored, oldValue) -> {
+                            if (oldValue == null) {
+                                return newResponseHeader.getValue();
+                            }
+                            return CollectionUtils.concatLists(oldValue, newResponseHeader.getValue());
+                        });
+                    }
+                }
+            }, storedContext);
+        }
+
+        private void restoreResponseHeaders() {
+            if (responseHeaders != null) {
+                for (final var responseHeader : responseHeaders.entrySet()) {
+                    for (final var value : responseHeader.getValue()) {
+                        updateTask.getThreadContext().addResponseHeader(responseHeader.getKey(), value);
+                    }
+                }
+            }
+        }
+
         void onBatchFailure(Exception failure) {
             // if the whole batch resulted in an exception then this overrides any task-level results whether successful or not
             this.failure = Objects.requireNonNull(failure);
@@ -890,6 +954,7 @@ public class MasterService extends AbstractLifecycleComponent {
                 return;
             }
             try (ThreadContext.StoredContext ignored = updateTask.threadContextSupplier.get()) {
+                restoreResponseHeaders();
                 if (onPublicationSuccess == null) {
                     publishedStateConsumer.accept(newClusterState);
                 } else {
@@ -906,6 +971,7 @@ public class MasterService extends AbstractLifecycleComponent {
                 return;
             }
             try (ThreadContext.StoredContext ignored = updateTask.threadContextSupplier.get()) {
+                restoreResponseHeaders();
                 if (onPublicationSuccess == null) {
                     publishedStateConsumer.accept(clusterState);
                 } else {
@@ -922,6 +988,7 @@ public class MasterService extends AbstractLifecycleComponent {
                 return;
             }
             try (ThreadContext.StoredContext ignored = updateTask.threadContextSupplier.get()) {
+                restoreResponseHeaders();
                 getTask().onFailure(e);
             } catch (Exception inner) {
                 inner.addSuppressed(e);
@@ -931,7 +998,7 @@ public class MasterService extends AbstractLifecycleComponent {
 
         ContextPreservingAckListener getContextPreservingAckListener() {
             assert incomplete() == false;
-            return updateTask.wrapInTaskContext(clusterStateAckListener);
+            return updateTask.wrapInTaskContext(clusterStateAckListener, this::restoreResponseHeaders);
         }
 
         @Override
@@ -944,9 +1011,10 @@ public class MasterService extends AbstractLifecycleComponent {
         ClusterState previousClusterState,
         List<ExecutionResult<ClusterStateTaskListener>> executionResults,
         ClusterStateTaskExecutor<ClusterStateTaskListener> executor,
-        BatchSummary summary
+        BatchSummary summary,
+        ThreadContext threadContext
     ) {
-        final var resultingState = innerExecuteTasks(previousClusterState, executionResults, executor, summary);
+        final var resultingState = innerExecuteTasks(previousClusterState, executionResults, executor, summary, threadContext);
         if (previousClusterState != resultingState
             && previousClusterState.nodes().isLocalNodeElectedMaster()
             && (resultingState.nodes().isLocalNodeElectedMaster() == false)) {
@@ -972,28 +1040,49 @@ public class MasterService extends AbstractLifecycleComponent {
         ClusterState previousClusterState,
         List<ExecutionResult<ClusterStateTaskListener>> executionResults,
         ClusterStateTaskExecutor<ClusterStateTaskListener> executor,
-        BatchSummary summary
+        BatchSummary summary,
+        ThreadContext threadContext
     ) {
         final var taskContexts = castTaskContexts(executionResults);
-        try {
-            return executor.execute(new ClusterStateTaskExecutor.BatchExecutionContext<>(previousClusterState, taskContexts));
-        } catch (Exception e) {
-            logger.trace(
-                () -> format(
-                    "failed to execute cluster state update (on version: [%s], uuid: [%s]) for [%s]\n%s%s%s",
-                    previousClusterState.version(),
-                    previousClusterState.stateUUID(),
-                    summary,
-                    previousClusterState.nodes(),
-                    previousClusterState.routingTable(),
-                    previousClusterState.getRoutingNodes()
-                ),
-                e
-            );
-            for (final var executionResult : executionResults) {
-                executionResult.onBatchFailure(e);
+        try (var ignored = threadContext.newStoredContext(false)) {
+            // if the executor leaks a response header then this will cause a test failure, but we also store the context here to be sure
+            // to avoid leaking headers in production that were missed by tests
+
+            try {
+                return executor.execute(
+                    new ClusterStateTaskExecutor.BatchExecutionContext<>(
+                        previousClusterState,
+                        taskContexts,
+                        () -> threadContext.newStoredContext(false)
+                    )
+                );
+            } catch (Exception e) {
+                logger.trace(
+                    () -> format(
+                        "failed to execute cluster state update (on version: [%s], uuid: [%s]) for [%s]\n%s%s%s",
+                        previousClusterState.version(),
+                        previousClusterState.stateUUID(),
+                        summary,
+                        previousClusterState.nodes(),
+                        previousClusterState.routingTable(),
+                        previousClusterState.getRoutingNodes()
+                    ),
+                    e
+                );
+                for (final var executionResult : executionResults) {
+                    executionResult.onBatchFailure(e);
+                }
+                return previousClusterState;
+            } finally {
+                assert threadContext.getResponseHeaders().isEmpty()
+                    : """
+                        Batched task executors must marshal response headers to the appropriate task context (e.g. using \
+                        TaskContext#captureResponseHeaders) or suppress them (e.g. using BatchExecutionContext#dropHeadersContext) and \
+                        must not leak them to the master service, but executor ["""
+                        + executor
+                        + "] leaked the following headers: "
+                        + threadContext.getResponseHeaders();
             }
-            return previousClusterState;
         }
     }
 

+ 3 - 1
server/src/main/java/org/elasticsearch/health/metadata/HealthMetadataService.java

@@ -173,7 +173,9 @@ public class HealthMetadataService {
             public ClusterState execute(BatchExecutionContext<UpsertHealthMetadataTask> batchExecutionContext) throws Exception {
                 ClusterState updatedState = batchExecutionContext.initialState();
                 for (TaskContext<UpsertHealthMetadataTask> taskContext : batchExecutionContext.taskContexts()) {
-                    updatedState = taskContext.getTask().execute(updatedState);
+                    try (var ignored = taskContext.captureResponseHeaders()) {
+                        updatedState = taskContext.getTask().execute(updatedState);
+                    }
                     taskContext.success(() -> {});
                 }
                 return updatedState;

+ 3 - 1
server/src/main/java/org/elasticsearch/ingest/IngestService.java

@@ -115,7 +115,9 @@ public class IngestService implements ClusterStateApplier, ReportingService<Inge
         for (final var taskContext : batchExecutionContext.taskContexts()) {
             try {
                 final var task = taskContext.getTask();
-                currentIngestMetadata = task.execute(currentIngestMetadata, allIndexMetadata);
+                try (var ignored = taskContext.captureResponseHeaders()) {
+                    currentIngestMetadata = task.execute(currentIngestMetadata, allIndexMetadata);
+                }
                 taskContext.success(() -> task.listener.onResponse(AcknowledgedResponse.TRUE));
             } catch (Exception e) {
                 taskContext.onFailure(e);

+ 3 - 1
server/src/main/java/org/elasticsearch/reservedstate/service/ReservedStateErrorTaskExecutor.java

@@ -27,7 +27,9 @@ record ReservedStateErrorTaskExecutor() implements ClusterStateTaskExecutor<Rese
         var updatedState = batchExecutionContext.initialState();
         for (final var taskContext : batchExecutionContext.taskContexts()) {
             final var task = taskContext.getTask();
-            updatedState = task.execute(updatedState);
+            try (var ignored = taskContext.captureResponseHeaders()) {
+                updatedState = task.execute(updatedState);
+            }
             taskContext.success(() -> task.listener().onResponse(ActionResponse.Empty.INSTANCE));
         }
         return updatedState;

+ 3 - 1
server/src/main/java/org/elasticsearch/reservedstate/service/ReservedStateUpdateTaskExecutor.java

@@ -30,7 +30,9 @@ public record ReservedStateUpdateTaskExecutor(RerouteService rerouteService) imp
     public ClusterState execute(BatchExecutionContext<ReservedStateUpdateTask> batchExecutionContext) throws Exception {
         var updatedState = batchExecutionContext.initialState();
         for (final var taskContext : batchExecutionContext.taskContexts()) {
-            updatedState = taskContext.getTask().execute(updatedState);
+            try (var ignored = taskContext.captureResponseHeaders()) {
+                updatedState = taskContext.getTask().execute(updatedState);
+            }
             taskContext.success(() -> taskContext.getTask().listener().onResponse(ActionResponse.Empty.INSTANCE));
         }
         return updatedState;

+ 113 - 5
server/src/test/java/org/elasticsearch/cluster/service/MasterServiceTests.java

@@ -643,6 +643,7 @@ public class MasterServiceTests extends ESTestCase {
         AtomicInteger submittedTasks = new AtomicInteger();
         AtomicInteger processedStates = new AtomicInteger();
         SetOnce<CountDownLatch> processedStatesLatch = new SetOnce<>();
+        final String responseHeaderName = randomAlphaOfLength(10);
 
         class Task implements ClusterStateTaskListener {
             private final AtomicBoolean executed = new AtomicBoolean();
@@ -653,6 +654,7 @@ public class MasterServiceTests extends ESTestCase {
             }
 
             public void execute() {
+                threadPool.getThreadContext().addResponseHeader(responseHeaderName, toString());
                 if (executed.compareAndSet(false, true) == false) {
                     throw new AssertionError("Task [" + id + "] should only be executed once");
                 } else {
@@ -713,7 +715,19 @@ public class MasterServiceTests extends ESTestCase {
                 }
 
                 for (final var taskContext : batchExecutionContext.taskContexts()) {
-                    taskContext.getTask().execute();
+                    if (randomBoolean()) {
+                        try (var ignored = taskContext.captureResponseHeaders()) {
+                            threadPool.getThreadContext().addResponseHeader(responseHeaderName, randomAlphaOfLength(10));
+                        }
+                    }
+                    try (var ignored = taskContext.captureResponseHeaders()) {
+                        taskContext.getTask().execute();
+                    }
+                    if (randomBoolean()) {
+                        try (var ignored = taskContext.captureResponseHeaders()) {
+                            threadPool.getThreadContext().addResponseHeader(responseHeaderName, randomAlphaOfLength(10));
+                        }
+                    }
                 }
 
                 executed.addAndGet(batchExecutionContext.taskContexts().size());
@@ -730,6 +744,10 @@ public class MasterServiceTests extends ESTestCase {
 
                 for (final var taskContext : batchExecutionContext.taskContexts()) {
                     taskContext.success(() -> {
+                        assertThat(
+                            threadPool.getThreadContext().getResponseHeaders().get(responseHeaderName),
+                            hasItem(taskContext.getTask().toString())
+                        );
                         processedStates.incrementAndGet();
                         processedStatesLatch.get().countDown();
                     });
@@ -902,8 +920,10 @@ public class MasterServiceTests extends ESTestCase {
         class Task implements ClusterStateTaskListener {
 
             final ActionListener<ClusterState> publishListener;
+            final String responseHeaderValue;
 
-            Task(ActionListener<ClusterState> publishListener) {
+            Task(String responseHeaderValue, ActionListener<ClusterState> publishListener) {
+                this.responseHeaderValue = responseHeaderValue;
                 this.publishListener = publishListener;
             }
 
@@ -921,11 +941,16 @@ public class MasterServiceTests extends ESTestCase {
         final String testContextHeaderName = "test-context-header";
         final ThreadContext threadContext = threadPool.getThreadContext();
 
+        final var testResponseHeaderName = "test-response-header";
+
         final var executor = new ClusterStateTaskExecutor<Task>() {
             @Override
             @SuppressForbidden(reason = "consuming published cluster state for legacy reasons")
             public ClusterState execute(BatchExecutionContext<Task> batchExecutionContext) {
                 for (final var taskContext : batchExecutionContext.taskContexts()) {
+                    try (var ignored = taskContext.captureResponseHeaders()) {
+                        threadPool.getThreadContext().addResponseHeader(testResponseHeaderName, taskContext.getTask().responseHeaderValue);
+                    }
                     taskContext.success(taskContext.getTask().publishListener::onResponse);
                 }
                 return ClusterState.builder(batchExecutionContext.initialState()).build();
@@ -967,11 +992,13 @@ public class MasterServiceTests extends ESTestCase {
             for (int i = 0; i < toSubmit; i++) {
                 try (ThreadContext.StoredContext ignored = threadContext.newStoredContext(false)) {
                     final var testContextHeaderValue = randomAlphaOfLength(10);
+                    final var testResponseHeaderValue = randomAlphaOfLength(10);
                     threadContext.putHeader(testContextHeaderName, testContextHeaderValue);
-                    final var task = new Task(new ActionListener<>() {
+                    final var task = new Task(testResponseHeaderValue, new ActionListener<>() {
                         @Override
                         public void onResponse(ClusterState clusterState) {
                             assertEquals(testContextHeaderValue, threadContext.getHeader(testContextHeaderName));
+                            assertEquals(List.of(testResponseHeaderValue), threadContext.getResponseHeaders().get(testResponseHeaderName));
                             assertSame(publishedState.get(), clusterState);
                             publishSuccessCountdown.countDown();
                         }
@@ -1007,8 +1034,9 @@ public class MasterServiceTests extends ESTestCase {
             for (int i = 0; i < toSubmit; i++) {
                 try (ThreadContext.StoredContext ignored = threadContext.newStoredContext(false)) {
                     final String testContextHeaderValue = randomAlphaOfLength(10);
+                    final String testResponseHeaderValue = randomAlphaOfLength(10);
                     threadContext.putHeader(testContextHeaderName, testContextHeaderValue);
-                    final var task = new Task(new ActionListener<>() {
+                    final var task = new Task(testResponseHeaderValue, new ActionListener<>() {
                         @Override
                         public void onResponse(ClusterState clusterState) {
                             throw new AssertionError("should not succeed");
@@ -1017,6 +1045,7 @@ public class MasterServiceTests extends ESTestCase {
                         @Override
                         public void onFailure(Exception e) {
                             assertEquals(testContextHeaderValue, threadContext.getHeader(testContextHeaderName));
+                            assertEquals(List.of(testResponseHeaderValue), threadContext.getResponseHeaders().get(testResponseHeaderName));
                             assertThat(e, instanceOf(FailedToCommitClusterStateException.class));
                             assertThat(e.getMessage(), equalTo(exceptionMessage));
                             publishFailureCountdown.countDown();
@@ -1322,6 +1351,8 @@ public class MasterServiceTests extends ESTestCase {
             )
         ) {
 
+            final var responseHeaderName = "test-response-header";
+
             final ClusterState initialClusterState = ClusterState.builder(new ClusterName(MasterServiceTests.class.getSimpleName()))
                 .nodes(DiscoveryNodes.builder().add(node1).add(node2).add(node3).localNodeId(node1.getId()).masterNodeId(node1.getId()))
                 .blocks(ClusterBlocks.EMPTY_CLUSTER_BLOCK)
@@ -1401,7 +1432,17 @@ public class MasterServiceTests extends ESTestCase {
                     ClusterStateTaskConfig.build(Priority.NORMAL),
                     batchExecutionContext -> {
                         for (final var taskContext : batchExecutionContext.taskContexts()) {
-                            taskContext.success(latch::countDown, taskContext.getTask());
+                            final var responseHeaderValue = randomAlphaOfLength(10);
+                            try (var ignored = taskContext.captureResponseHeaders()) {
+                                threadPool.getThreadContext().addResponseHeader(responseHeaderName, responseHeaderValue);
+                            }
+                            taskContext.success(() -> {
+                                assertThat(
+                                    threadPool.getThreadContext().getResponseHeaders().get(responseHeaderName),
+                                    equalTo(List.of(responseHeaderValue))
+                                );
+                                latch.countDown();
+                            }, taskContext.getTask());
                         }
                         return randomBoolean()
                             ? batchExecutionContext.initialState()
@@ -1496,6 +1537,66 @@ public class MasterServiceTests extends ESTestCase {
                 assertTrue(latch.await(10, TimeUnit.SECONDS));
             }
 
+            // check that exception from acking is passed to listener
+            {
+                final CountDownLatch latch = new CountDownLatch(1);
+
+                publisherRef.set((clusterChangedEvent, publishListener, ackListener) -> {
+                    publishListener.onResponse(null);
+                    ackListener.onCommit(TimeValue.ZERO);
+                    ackListener.onNodeAck(node1, null);
+                    ackListener.onNodeAck(node2, new ElasticsearchException("simulated"));
+                    ackListener.onNodeAck(node3, null);
+                });
+
+                class Task implements ClusterStateTaskListener {
+
+                    @Override
+                    public void onFailure(Exception e) {
+                        throw new AssertionError(e);
+                    }
+
+                    @Override
+                    public void clusterStateProcessed(ClusterState oldState, ClusterState newState) {
+                        fail();
+                    }
+                }
+
+                masterService.submitStateUpdateTask(
+                    "node-ack-fail-test",
+                    new Task(),
+                    ClusterStateTaskConfig.build(Priority.NORMAL),
+                    batchExecutionContext -> {
+                        for (final var taskContext : batchExecutionContext.taskContexts()) {
+                            final var responseHeaderValue = randomAlphaOfLength(10);
+                            try (var ignored = taskContext.captureResponseHeaders()) {
+                                threadPool.getThreadContext().addResponseHeader(responseHeaderName, responseHeaderValue);
+                            }
+                            taskContext.success(new LatchAckListener(latch) {
+                                @Override
+                                public void onAllNodesAcked() {
+                                    fail();
+                                }
+
+                                @Override
+                                public void onAckFailure(Exception e) {
+                                    assertThat(
+                                        threadPool.getThreadContext().getResponseHeaders().get(responseHeaderName),
+                                        equalTo(List.of(responseHeaderValue))
+                                    );
+                                    assertThat(e, instanceOf(ElasticsearchException.class));
+                                    assertThat(e.getMessage(), equalTo("simulated"));
+                                    latch.countDown();
+                                }
+                            });
+                        }
+                        return ClusterState.builder(batchExecutionContext.initialState()).build();
+                    }
+                );
+
+                assertTrue(latch.await(10, TimeUnit.SECONDS));
+            }
+
             // check that we don't time out before even committing the cluster state
             {
                 final CountDownLatch latch = new CountDownLatch(1);
@@ -1554,11 +1655,14 @@ public class MasterServiceTests extends ESTestCase {
                     ackListener.onNodeAck(node3, null);
                 });
 
+                final var responseHeaderValue = randomAlphaOfLength(10);
+
                 masterService.submitUnbatchedStateUpdateTask(
                     "test2",
                     new AckedClusterStateUpdateTask(ackedRequest(ackTimeout, null), null) {
                         @Override
                         public ClusterState execute(ClusterState currentState) {
+                            threadPool.getThreadContext().addResponseHeader(responseHeaderName, responseHeaderValue);
                             return ClusterState.builder(currentState).build();
                         }
 
@@ -1580,6 +1684,10 @@ public class MasterServiceTests extends ESTestCase {
 
                         @Override
                         public void onAckTimeout() {
+                            assertThat(
+                                threadPool.getThreadContext().getResponseHeaders().get(responseHeaderName),
+                                equalTo(List.of(responseHeaderValue))
+                            );
                             latch.countDown();
                         }
                     }

+ 17 - 2
server/src/test/java/org/elasticsearch/reservedstate/service/ReservedClusterStateServiceTests.java

@@ -23,6 +23,7 @@ import org.elasticsearch.cluster.routing.RerouteService;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.settings.ClusterSettings;
 import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.core.Releasable;
 import org.elasticsearch.reservedstate.ReservedClusterStateHandler;
 import org.elasticsearch.reservedstate.TransformState;
 import org.elasticsearch.reservedstate.action.ReservedClusterSettingsAction;
@@ -177,9 +178,16 @@ public class ReservedClusterStateServiceTests extends ESTestCase {
 
             @Override
             public void onFailure(Exception failure) {}
+
+            @Override
+            public Releasable captureResponseHeaders() {
+                return null;
+            }
         };
 
-        ClusterState newState = taskExecutor.execute(new ClusterStateTaskExecutor.BatchExecutionContext<>(state, List.of(taskContext)));
+        ClusterState newState = taskExecutor.execute(
+            new ClusterStateTaskExecutor.BatchExecutionContext<>(state, List.of(taskContext), () -> null)
+        );
         assertEquals(state, newState);
         assertTrue(successCalled.get());
         verify(task, times(1)).execute(any());
@@ -231,11 +239,18 @@ public class ReservedClusterStateServiceTests extends ESTestCase {
 
                 @Override
                 public void onFailure(Exception failure) {}
+
+                @Override
+                public Releasable captureResponseHeaders() {
+                    return null;
+                }
             };
 
         ReservedStateErrorTaskExecutor executor = new ReservedStateErrorTaskExecutor();
 
-        ClusterState newState = executor.execute(new ClusterStateTaskExecutor.BatchExecutionContext<>(state, List.of(taskContext)));
+        ClusterState newState = executor.execute(
+            new ClusterStateTaskExecutor.BatchExecutionContext<>(state, List.of(taskContext), () -> null)
+        );
 
         verify(task, times(1)).execute(any());
 

+ 9 - 1
test/framework/src/main/java/org/elasticsearch/cluster/service/ClusterStateTaskExecutorUtils.java

@@ -14,6 +14,7 @@ import org.elasticsearch.cluster.ClusterStateTaskExecutor;
 import org.elasticsearch.cluster.ClusterStateTaskListener;
 import org.elasticsearch.common.CheckedBiConsumer;
 import org.elasticsearch.core.CheckedConsumer;
+import org.elasticsearch.core.Releasable;
 
 import java.util.function.Consumer;
 import java.util.stream.StreamSupport;
@@ -64,7 +65,9 @@ public class ClusterStateTaskExecutorUtils {
         final var taskContexts = StreamSupport.stream(tasks.spliterator(), false).<ClusterStateTaskExecutor.TaskContext<T>>map(
             TestTaskContext::new
         ).toList();
-        final var resultingState = executor.execute(new ClusterStateTaskExecutor.BatchExecutionContext<>(originalState, taskContexts));
+        final var resultingState = executor.execute(
+            new ClusterStateTaskExecutor.BatchExecutionContext<>(originalState, taskContexts, () -> null)
+        );
         assertNotNull(resultingState);
         for (final var taskContext : taskContexts) {
             final var testTaskContext = (TestTaskContext<T>) taskContext;
@@ -146,6 +149,11 @@ public class ClusterStateTaskExecutorUtils {
             this.succeeded = true;
         }
 
+        @Override
+        public Releasable captureResponseHeaders() {
+            return () -> {};
+        }
+
         @Override
         public String toString() {
             return "TestTaskContext[" + task + "]";

+ 3 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/license/StartBasicClusterTask.java

@@ -133,7 +133,9 @@ public class StartBasicClusterTask implements ClusterStateTaskListener {
             final LicensesMetadata originalLicensesMetadata = initialState.metadata().custom(LicensesMetadata.TYPE);
             var currentLicensesMetadata = originalLicensesMetadata;
             for (final var taskContext : batchExecutionContext.taskContexts()) {
-                currentLicensesMetadata = taskContext.getTask().execute(currentLicensesMetadata, initialState.nodes(), taskContext);
+                try (var ignored = taskContext.captureResponseHeaders()) {
+                    currentLicensesMetadata = taskContext.getTask().execute(currentLicensesMetadata, initialState.nodes(), taskContext);
+                }
             }
             if (currentLicensesMetadata == originalLicensesMetadata) {
                 return initialState;

+ 3 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/license/StartTrialClusterTask.java

@@ -120,7 +120,9 @@ public class StartTrialClusterTask implements ClusterStateTaskListener {
             final LicensesMetadata originalLicensesMetadata = initialState.metadata().custom(LicensesMetadata.TYPE);
             var currentLicensesMetadata = originalLicensesMetadata;
             for (final var taskContext : batchExecutionContext.taskContexts()) {
-                currentLicensesMetadata = taskContext.getTask().execute(currentLicensesMetadata, initialState.nodes(), taskContext);
+                try (var ignored = taskContext.captureResponseHeaders()) {
+                    currentLicensesMetadata = taskContext.getTask().execute(currentLicensesMetadata, initialState.nodes(), taskContext);
+                }
             }
             if (currentLicensesMetadata == originalLicensesMetadata) {
                 return initialState;

+ 1 - 1
x-pack/plugin/core/src/test/java/org/elasticsearch/license/LicenseServiceTests.java

@@ -212,7 +212,7 @@ public class LicenseServiceTests extends ESTestCase {
             );
 
             ClusterState updatedState = taskExecutorCaptor.getValue()
-                .execute(new ClusterStateTaskExecutor.BatchExecutionContext<>(oldState, List.of(taskContext)));
+                .execute(new ClusterStateTaskExecutor.BatchExecutionContext<>(oldState, List.of(taskContext), () -> null));
             // Pass updated state to listener to trigger onResponse call to wrapped `future`
             listenerCaptor.getValue().run();
             assertion.accept(future);

+ 3 - 1
x-pack/plugin/ilm/src/main/java/org/elasticsearch/xpack/ilm/IndexLifecycleRunner.java

@@ -65,7 +65,9 @@ class IndexLifecycleRunner {
                 for (final var taskContext : batchExecutionContext.taskContexts()) {
                     try {
                         final var task = taskContext.getTask();
-                        state = task.execute(state);
+                        try (var ignored = taskContext.captureResponseHeaders()) {
+                            state = task.execute(state);
+                        }
                         taskContext.success(
                             new ClusterStateTaskExecutor.LegacyClusterTaskResultActionListener(task, batchExecutionContext.initialState())
                         );

+ 7 - 1
x-pack/plugin/ilm/src/test/java/org/elasticsearch/xpack/ilm/action/ReservedLifecycleStateServiceTests.java

@@ -17,6 +17,7 @@ import org.elasticsearch.cluster.ClusterStateTaskExecutor;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.settings.ClusterSettings;
 import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.core.Releasable;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.license.XPackLicenseState;
 import org.elasticsearch.reservedstate.TransformState;
@@ -266,9 +267,14 @@ public class ReservedLifecycleStateServiceTests extends ESTestCase {
                 public void onFailure(Exception failure) {
                     fail("Shouldn't fail here");
                 }
+
+                @Override
+                public Releasable captureResponseHeaders() {
+                    return null;
+                }
             };
 
-            task.execute(new ClusterStateTaskExecutor.BatchExecutionContext<>(state, List.of(context)));
+            task.execute(new ClusterStateTaskExecutor.BatchExecutionContext<>(state, List.of(context), () -> null));
 
             return null;
         }).when(clusterService).submitStateUpdateTask(anyString(), any(), any(), any());

+ 3 - 1
x-pack/plugin/rollup/src/main/java/org/elasticsearch/xpack/rollup/v2/TransportRollupAction.java

@@ -96,7 +96,9 @@ public class TransportRollupAction extends AcknowledgedTransportMasterNodeAction
         for (final var taskContext : batchExecutionContext.taskContexts()) {
             try {
                 final var task = taskContext.getTask();
-                state = task.execute(state);
+                try (var ignored = taskContext.captureResponseHeaders()) {
+                    state = task.execute(state);
+                }
                 taskContext.success(() -> task.listener.onResponse(AcknowledgedResponse.TRUE));
             } catch (Exception e) {
                 taskContext.onFailure(e);

+ 1 - 1
x-pack/plugin/shutdown/src/main/java/org/elasticsearch/xpack/shutdown/TransportDeleteShutdownNodeAction.java

@@ -84,7 +84,7 @@ public class TransportDeleteShutdownNodeAction extends AcknowledgedTransportMast
             boolean changed = false;
             for (final var taskContext : batchExecutionContext.taskContexts()) {
                 var request = taskContext.getTask().request();
-                try {
+                try (var ignored = taskContext.captureResponseHeaders()) {
                     changed |= deleteShutdownNodeState(shutdownMetadata, request);
                 } catch (Exception e) {
                     taskContext.onFailure(e);

+ 1 - 1
x-pack/plugin/shutdown/src/main/java/org/elasticsearch/xpack/shutdown/TransportPutShutdownNodeAction.java

@@ -123,7 +123,7 @@ public class TransportPutShutdownNodeAction extends AcknowledgedTransportMasterN
             boolean changed = false;
             for (final var taskContext : batchExecutionContext.taskContexts()) {
                 var request = taskContext.getTask().request();
-                try {
+                try (var ignored = taskContext.captureResponseHeaders()) {
                     changed |= putShutdownNodeState(shutdownMetadata, nodeExistsPredicate, request);
                 } catch (Exception e) {
                     taskContext.onFailure(e);

+ 1 - 1
x-pack/plugin/shutdown/src/test/java/org/elasticsearch/xpack/shutdown/TransportDeleteShutdownNodeActionTests.java

@@ -78,7 +78,7 @@ public class TransportDeleteShutdownNodeActionTests extends ESTestCase {
         verify(clusterService).submitStateUpdateTask(any(), updateTask.capture(), taskConfig.capture(), taskExecutor.capture());
         when(taskContext.getTask()).thenReturn(updateTask.getValue());
         ClusterState gotState = taskExecutor.getValue()
-            .execute(new ClusterStateTaskExecutor.BatchExecutionContext<>(ClusterState.EMPTY_STATE, List.of(taskContext)));
+            .execute(new ClusterStateTaskExecutor.BatchExecutionContext<>(ClusterState.EMPTY_STATE, List.of(taskContext), () -> null));
         assertThat(gotState, sameInstance(ClusterState.EMPTY_STATE));
     }
 }

+ 2 - 2
x-pack/plugin/shutdown/src/test/java/org/elasticsearch/xpack/shutdown/TransportPutShutdownNodeActionTests.java

@@ -76,7 +76,7 @@ public class TransportPutShutdownNodeActionTests extends ESTestCase {
         verify(clusterService).submitStateUpdateTask(any(), updateTask.capture(), taskConfig.capture(), taskExecutor.capture());
         when(taskContext.getTask()).thenReturn(updateTask.getValue());
         ClusterState stableState = taskExecutor.getValue()
-            .execute(new ClusterStateTaskExecutor.BatchExecutionContext<>(ClusterState.EMPTY_STATE, List.of(taskContext)));
+            .execute(new ClusterStateTaskExecutor.BatchExecutionContext<>(ClusterState.EMPTY_STATE, List.of(taskContext), () -> null));
 
         // run the request again, there should be no call to submit an update task
         clearInvocations(clusterService);
@@ -88,7 +88,7 @@ public class TransportPutShutdownNodeActionTests extends ESTestCase {
         verify(clusterService).submitStateUpdateTask(any(), updateTask.capture(), taskConfig.capture(), taskExecutor.capture());
         when(taskContext.getTask()).thenReturn(updateTask.getValue());
         ClusterState gotState = taskExecutor.getValue()
-            .execute(new ClusterStateTaskExecutor.BatchExecutionContext<>(stableState, List.of(taskContext)));
+            .execute(new ClusterStateTaskExecutor.BatchExecutionContext<>(stableState, List.of(taskContext), () -> null));
         assertThat(gotState, sameInstance(stableState));
     }
 }