Browse Source

Parameterize the `Security` plugin with the request-wise thread context supplier (#95040)

When `Security` authenticates a request (AuthenticationService#authenticate) it mostly
looks in the thread context for all the request details that it needs.
Consequently, there's code that populates the thread context before invoking authentication.
This PR brings all the code that's responsible for populating the said thread context,
as a parameter to the `Security` plugin.
This is required in the follow-up that is going to involve setting up the thread context
(given the aforementioned parameter), and running the authentication under that
context, before the request is dispatched (rather than "while" it's being dispatched, currently).
Albert Zaharovits 2 years ago
parent
commit
f662a198cf
19 changed files with 349 additions and 237 deletions
  1. 11 1
      modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Plugin.java
  2. 37 10
      server/src/main/java/org/elasticsearch/action/ActionModule.java
  3. 5 0
      server/src/main/java/org/elasticsearch/common/network/NetworkModule.java
  4. 17 0
      server/src/main/java/org/elasticsearch/http/AbstractHttpServerTransport.java
  5. 1 0
      server/src/main/java/org/elasticsearch/node/Node.java
  6. 3 0
      server/src/main/java/org/elasticsearch/plugins/NetworkPlugin.java
  7. 0 30
      server/src/main/java/org/elasticsearch/rest/RestController.java
  8. 6 0
      server/src/test/java/org/elasticsearch/common/network/NetworkModuleTests.java
  9. 210 0
      server/src/test/java/org/elasticsearch/http/AbstractHttpServerTransportTests.java
  10. 23 131
      server/src/test/java/org/elasticsearch/rest/RestControllerTests.java
  11. 1 9
      server/src/test/java/org/elasticsearch/rest/RestHttpResponseHeadersTests.java
  12. 1 10
      server/src/test/java/org/elasticsearch/rest/action/admin/indices/RestValidateQueryActionTests.java
  13. 1 9
      test/framework/src/main/java/org/elasticsearch/test/rest/RestActionTestCase.java
  14. 3 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/LocalStateCompositeXPackPlugin.java
  15. 0 2
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/termsenum/action/RestTermsEnumActionTests.java
  16. 23 11
      x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java
  17. 1 12
      x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/SecurityRestFilter.java
  18. 4 8
      x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/SecurityRestFilterTests.java
  19. 2 4
      x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/SecurityRestFilterWarningHeadersTests.java

+ 11 - 1
modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Plugin.java

@@ -18,11 +18,14 @@ import org.elasticsearch.common.settings.Setting;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.util.BigArrays;
 import org.elasticsearch.common.util.PageCacheRecycler;
+import org.elasticsearch.common.util.concurrent.ThreadContext;
+import org.elasticsearch.http.HttpPreRequest;
 import org.elasticsearch.http.HttpServerTransport;
 import org.elasticsearch.http.netty4.Netty4HttpServerTransport;
 import org.elasticsearch.indices.breaker.CircuitBreakerService;
 import org.elasticsearch.plugins.NetworkPlugin;
 import org.elasticsearch.plugins.Plugin;
+import org.elasticsearch.rest.RestRequest;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.tracing.Tracer;
 import org.elasticsearch.transport.Transport;
@@ -32,6 +35,7 @@ import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
 import java.util.Map;
+import java.util.function.BiConsumer;
 import java.util.function.Supplier;
 
 public class Netty4Plugin extends Plugin implements NetworkPlugin {
@@ -99,6 +103,7 @@ public class Netty4Plugin extends Plugin implements NetworkPlugin {
         NamedXContentRegistry xContentRegistry,
         NetworkService networkService,
         HttpServerTransport.Dispatcher dispatcher,
+        BiConsumer<HttpPreRequest, ThreadContext> perRequestThreadContext,
         ClusterSettings clusterSettings,
         Tracer tracer
     ) {
@@ -116,7 +121,12 @@ public class Netty4Plugin extends Plugin implements NetworkPlugin {
                 TLSConfig.noTLS(),
                 null,
                 null
-            )
+            ) {
+                @Override
+                protected void populatePerRequestThreadContext(RestRequest restRequest, ThreadContext threadContext) {
+                    perRequestThreadContext.accept(restRequest.getHttpRequest(), threadContext);
+                }
+            }
         );
     }
 

+ 37 - 10
server/src/main/java/org/elasticsearch/action/ActionModule.java

@@ -267,6 +267,7 @@ import org.elasticsearch.common.settings.ClusterSettings;
 import org.elasticsearch.common.settings.IndexScopedSettings;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.settings.SettingsFilter;
+import org.elasticsearch.common.util.concurrent.ThreadContext;
 import org.elasticsearch.gateway.TransportNodesListGatewayStartedShards;
 import org.elasticsearch.health.GetHealthAction;
 import org.elasticsearch.health.RestGetHealthAction;
@@ -274,6 +275,7 @@ import org.elasticsearch.health.node.FetchHealthInfoCacheAction;
 import org.elasticsearch.health.node.UpdateHealthInfoCacheAction;
 import org.elasticsearch.health.stats.HealthApiStatsAction;
 import org.elasticsearch.health.stats.HealthApiStatsTransportAction;
+import org.elasticsearch.http.HttpPreRequest;
 import org.elasticsearch.index.seqno.GlobalCheckpointSyncAction;
 import org.elasticsearch.index.seqno.RetentionLeaseActions;
 import org.elasticsearch.indices.SystemIndices;
@@ -291,6 +293,7 @@ import org.elasticsearch.reservedstate.service.ReservedClusterStateService;
 import org.elasticsearch.rest.RestController;
 import org.elasticsearch.rest.RestHandler;
 import org.elasticsearch.rest.RestHeaderDefinition;
+import org.elasticsearch.rest.RestUtils;
 import org.elasticsearch.rest.action.RestFieldCapabilitiesAction;
 import org.elasticsearch.rest.action.RestMainAction;
 import org.elasticsearch.rest.action.admin.cluster.RestAddVotingConfigExclusionAction;
@@ -435,6 +438,7 @@ import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
 import java.util.Map;
+import java.util.Optional;
 import java.util.Set;
 import java.util.function.Consumer;
 import java.util.function.Supplier;
@@ -468,6 +472,8 @@ public class ActionModule extends AbstractModule {
     private final AutoCreateIndex autoCreateIndex;
     private final DestructiveOperations destructiveOperations;
     private final RestController restController;
+    /** Rest headers that are copied to internal requests made during a rest request. */
+    private final Set<RestHeaderDefinition> headersToCopy;
     private final RequestValidators<PutMappingRequest> mappingRequestValidators;
     private final RequestValidators<IndicesAliasesRequest> indicesAliasesRequestRequestValidators;
     private final ThreadPool threadPool;
@@ -539,19 +545,40 @@ public class ActionModule extends AbstractModule {
         indicesAliasesRequestRequestValidators = new RequestValidators<>(
             actionPlugins.stream().flatMap(p -> p.indicesAliasesRequestValidators().stream()).toList()
         );
-
-        restController = new RestController(
-            headers,
-            restInterceptor,
-            nodeClient,
-            circuitBreakerService,
-            usageService,
-            tracer,
-            serverlessEnabled
-        );
+        headersToCopy = headers;
+        restController = new RestController(restInterceptor, nodeClient, circuitBreakerService, usageService, tracer, serverlessEnabled);
         reservedClusterStateService = new ReservedClusterStateService(clusterService, reservedStateHandlers);
     }
 
+    /**
+     * Certain request header values need to be copied in the thread context under which request handlers are to be dispatched.
+     * Careful that this method modifies the thread context. The thread context must be reinstated after the request handler
+     * finishes and returns.
+     */
+    public void copyRequestHeadersToThreadContext(HttpPreRequest request, ThreadContext threadContext) {
+        for (final RestHeaderDefinition restHeader : headersToCopy) {
+            final String name = restHeader.getName();
+            final List<String> headerValues = request.getHeaders().get(name);
+            if (headerValues != null && headerValues.isEmpty() == false) {
+                final List<String> distinctHeaderValues = headerValues.stream().distinct().toList();
+                if (restHeader.isMultiValueAllowed() == false && distinctHeaderValues.size() > 1) {
+                    throw new IllegalArgumentException("multiple values for single-valued header [" + name + "].");
+                } else if (name.equals(Task.TRACE_PARENT_HTTP_HEADER)) {
+                    String traceparent = distinctHeaderValues.get(0);
+                    Optional<String> traceId = RestUtils.extractTraceId(traceparent);
+                    if (traceId.isPresent()) {
+                        threadContext.putHeader(Task.TRACE_ID, traceId.get());
+                        threadContext.putTransient("parent_" + Task.TRACE_PARENT_HTTP_HEADER, traceparent);
+                    }
+                } else if (name.equals(Task.TRACE_STATE)) {
+                    threadContext.putTransient("parent_" + Task.TRACE_STATE, distinctHeaderValues.get(0));
+                } else {
+                    threadContext.putHeader(name, String.join(",", distinctHeaderValues));
+                }
+            }
+        }
+    }
+
     public Map<String, ActionHandler<?, ?>> getActions() {
         return actions;
     }

+ 5 - 0
server/src/main/java/org/elasticsearch/common/network/NetworkModule.java

@@ -23,7 +23,9 @@ import org.elasticsearch.common.settings.Setting.Property;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.util.BigArrays;
 import org.elasticsearch.common.util.PageCacheRecycler;
+import org.elasticsearch.common.util.concurrent.ThreadContext;
 import org.elasticsearch.core.CheckedFunction;
+import org.elasticsearch.http.HttpPreRequest;
 import org.elasticsearch.http.HttpServerTransport;
 import org.elasticsearch.index.shard.PrimaryReplicaSyncer.ResyncTask;
 import org.elasticsearch.indices.breaker.CircuitBreakerService;
@@ -47,6 +49,7 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
+import java.util.function.BiConsumer;
 import java.util.function.Supplier;
 
 /**
@@ -122,6 +125,7 @@ public final class NetworkModule {
         NamedXContentRegistry xContentRegistry,
         NetworkService networkService,
         HttpServerTransport.Dispatcher dispatcher,
+        BiConsumer<HttpPreRequest, ThreadContext> perRequestThreadContext,
         ClusterSettings clusterSettings,
         Tracer tracer
     ) {
@@ -136,6 +140,7 @@ public final class NetworkModule {
                 xContentRegistry,
                 networkService,
                 dispatcher,
+                perRequestThreadContext,
                 clusterSettings,
                 tracer
             );

+ 17 - 0
server/src/main/java/org/elasticsearch/http/AbstractHttpServerTransport.java

@@ -34,6 +34,7 @@ import org.elasticsearch.core.AbstractRefCounted;
 import org.elasticsearch.core.RefCounted;
 import org.elasticsearch.rest.RestChannel;
 import org.elasticsearch.rest.RestRequest;
+import org.elasticsearch.rest.RestResponse;
 import org.elasticsearch.tasks.Task;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.tracing.Tracer;
@@ -376,11 +377,27 @@ public abstract class AbstractHttpServerTransport extends AbstractLifecycleCompo
             if (badRequestCause != null) {
                 dispatcher.dispatchBadRequest(channel, threadContext, badRequestCause);
             } else {
+                populatePerRequestThreadContext0(restRequest, channel, threadContext);
                 dispatcher.dispatchRequest(restRequest, channel, threadContext);
             }
         }
     }
 
+    private void populatePerRequestThreadContext0(RestRequest restRequest, RestChannel channel, ThreadContext threadContext) {
+        try {
+            populatePerRequestThreadContext(restRequest, threadContext);
+        } catch (Exception e) {
+            try {
+                channel.sendResponse(new RestResponse(channel, e));
+            } catch (Exception inner) {
+                inner.addSuppressed(e);
+                logger.error(() -> "failed to send failure response for uri [" + restRequest.uri() + "]", inner);
+            }
+        }
+    }
+
+    protected void populatePerRequestThreadContext(RestRequest restRequest, ThreadContext threadContext) {}
+
     private void handleIncomingRequest(final HttpRequest httpRequest, final HttpChannel httpChannel, final Exception exception) {
         if (exception == null) {
             HttpResponse earlyResponse = corsHandler.handleInbound(httpRequest);

+ 1 - 0
server/src/main/java/org/elasticsearch/node/Node.java

@@ -790,6 +790,7 @@ public class Node implements Closeable {
                 xContentRegistry,
                 networkService,
                 restController,
+                actionModule::copyRequestHeadersToThreadContext,
                 clusterService.getClusterSettings(),
                 tracer
             );

+ 3 - 0
server/src/main/java/org/elasticsearch/plugins/NetworkPlugin.java

@@ -14,6 +14,7 @@ import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.util.BigArrays;
 import org.elasticsearch.common.util.PageCacheRecycler;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
+import org.elasticsearch.http.HttpPreRequest;
 import org.elasticsearch.http.HttpServerTransport;
 import org.elasticsearch.indices.breaker.CircuitBreakerService;
 import org.elasticsearch.threadpool.ThreadPool;
@@ -25,6 +26,7 @@ import org.elasticsearch.xcontent.NamedXContentRegistry;
 import java.util.Collections;
 import java.util.List;
 import java.util.Map;
+import java.util.function.BiConsumer;
 import java.util.function.Supplier;
 
 /**
@@ -76,6 +78,7 @@ public interface NetworkPlugin {
         NamedXContentRegistry xContentRegistry,
         NetworkService networkService,
         HttpServerTransport.Dispatcher dispatcher,
+        BiConsumer<HttpPreRequest, ThreadContext> perRequestThreadContext,
         ClusterSettings clusterSettings,
         Tracer tracer
     ) {

+ 0 - 30
server/src/main/java/org/elasticsearch/rest/RestController.java

@@ -44,7 +44,6 @@ import java.util.Iterator;
 import java.util.List;
 import java.util.Locale;
 import java.util.Map;
-import java.util.Optional;
 import java.util.Set;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.function.Supplier;
@@ -93,15 +92,12 @@ public class RestController implements HttpServerTransport.Dispatcher {
 
     private final CircuitBreakerService circuitBreakerService;
 
-    /** Rest headers that are copied to internal requests made during a rest request. */
-    private final Set<RestHeaderDefinition> headersToCopy;
     private final UsageService usageService;
     private final Tracer tracer;
     // If true, the ServerlessScope annotations will be enforced
     private final boolean serverlessEnabled;
 
     public RestController(
-        Set<RestHeaderDefinition> headersToCopy,
         UnaryOperator<RestHandler> handlerWrapper,
         NodeClient client,
         CircuitBreakerService circuitBreakerService,
@@ -109,7 +105,6 @@ public class RestController implements HttpServerTransport.Dispatcher {
         Tracer tracer,
         boolean serverlessEnabled
     ) {
-        this.headersToCopy = headersToCopy;
         this.usageService = usageService;
         this.tracer = tracer;
         if (handlerWrapper == null) {
@@ -510,7 +505,6 @@ public class RestController implements HttpServerTransport.Dispatcher {
 
     private void tryAllHandlers(final RestRequest request, final RestChannel channel, final ThreadContext threadContext) throws Exception {
         try {
-            copyRestHeaders(request, threadContext);
             validateErrorTrace(request, channel);
         } catch (IllegalArgumentException e) {
             startTrace(threadContext, channel);
@@ -565,30 +559,6 @@ public class RestController implements HttpServerTransport.Dispatcher {
         }
     }
 
-    private void copyRestHeaders(RestRequest request, ThreadContext threadContext) {
-        for (final RestHeaderDefinition restHeader : headersToCopy) {
-            final String name = restHeader.getName();
-            final List<String> headerValues = request.getAllHeaderValues(name);
-            if (headerValues != null && headerValues.isEmpty() == false) {
-                final List<String> distinctHeaderValues = headerValues.stream().distinct().toList();
-                if (restHeader.isMultiValueAllowed() == false && distinctHeaderValues.size() > 1) {
-                    throw new IllegalArgumentException("multiple values for single-valued header [" + name + "].");
-                } else if (name.equals(Task.TRACE_PARENT_HTTP_HEADER)) {
-                    String traceparent = distinctHeaderValues.get(0);
-                    Optional<String> traceId = RestUtils.extractTraceId(traceparent);
-                    if (traceId.isPresent()) {
-                        threadContext.putHeader(Task.TRACE_ID, traceId.get());
-                        threadContext.putTransient("parent_" + Task.TRACE_PARENT_HTTP_HEADER, traceparent);
-                    }
-                } else if (name.equals(Task.TRACE_STATE)) {
-                    threadContext.putTransient("parent_" + Task.TRACE_STATE, distinctHeaderValues.get(0));
-                } else {
-                    threadContext.putHeader(name, String.join(",", distinctHeaderValues));
-                }
-            }
-        }
-    }
-
     Iterator<MethodHandlers> getAllHandlers(@Nullable Map<String, String> requestParamsRef, String rawPath) {
         final Supplier<Map<String, String>> paramsSupplier;
         if (requestParamsRef == null) {

+ 6 - 0
server/src/test/java/org/elasticsearch/common/network/NetworkModuleTests.java

@@ -17,6 +17,7 @@ import org.elasticsearch.common.util.BigArrays;
 import org.elasticsearch.common.util.PageCacheRecycler;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
 import org.elasticsearch.http.HttpInfo;
+import org.elasticsearch.http.HttpPreRequest;
 import org.elasticsearch.http.HttpServerTransport;
 import org.elasticsearch.http.HttpStats;
 import org.elasticsearch.http.NullDispatcher;
@@ -39,6 +40,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.BiConsumer;
 import java.util.function.Supplier;
 
 public class NetworkModuleTests extends ESTestCase {
@@ -120,6 +122,7 @@ public class NetworkModuleTests extends ESTestCase {
                 NamedXContentRegistry xContentRegistry,
                 NetworkService networkService,
                 HttpServerTransport.Dispatcher requestDispatcher,
+                BiConsumer<HttpPreRequest, ThreadContext> perRequestThreadContext,
                 ClusterSettings clusterSettings,
                 Tracer tracer
             ) {
@@ -166,6 +169,7 @@ public class NetworkModuleTests extends ESTestCase {
                 NamedXContentRegistry xContentRegistry,
                 NetworkService networkService,
                 HttpServerTransport.Dispatcher requestDispatcher,
+                BiConsumer<HttpPreRequest, ThreadContext> perRequestThreadContext,
                 ClusterSettings clusterSettings,
                 Tracer tracer
             ) {
@@ -210,6 +214,7 @@ public class NetworkModuleTests extends ESTestCase {
                 NamedXContentRegistry xContentRegistry,
                 NetworkService networkService,
                 HttpServerTransport.Dispatcher requestDispatcher,
+                BiConsumer<HttpPreRequest, ThreadContext> perRequestThreadContext,
                 ClusterSettings clusterSettings,
                 Tracer tracer
             ) {
@@ -293,6 +298,7 @@ public class NetworkModuleTests extends ESTestCase {
             xContentRegistry(),
             null,
             new NullDispatcher(),
+            (preRequest, threadContext) -> {},
             new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS),
             Tracer.NOOP
         );

+ 210 - 0
server/src/test/java/org/elasticsearch/http/AbstractHttpServerTransportTests.java

@@ -12,6 +12,8 @@ import org.apache.logging.log4j.Level;
 import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 import org.apache.lucene.util.BytesRef;
+import org.elasticsearch.action.ActionModule;
+import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.UUIDs;
 import org.elasticsearch.common.bytes.BytesArray;
 import org.elasticsearch.common.logging.Loggers;
@@ -21,11 +23,16 @@ import org.elasticsearch.common.network.NetworkUtils;
 import org.elasticsearch.common.recycler.Recycler;
 import org.elasticsearch.common.settings.ClusterSettings;
 import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.settings.SettingsModule;
 import org.elasticsearch.common.transport.TransportAddress;
 import org.elasticsearch.common.util.MockPageCacheRecycler;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
 import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.indices.TestIndexNameExpressionResolver;
+import org.elasticsearch.plugins.ActionPlugin;
 import org.elasticsearch.rest.RestChannel;
+import org.elasticsearch.rest.RestControllerTests;
+import org.elasticsearch.rest.RestHeaderDefinition;
 import org.elasticsearch.rest.RestRequest;
 import org.elasticsearch.rest.RestResponse;
 import org.elasticsearch.rest.RestStatus;
@@ -39,6 +46,7 @@ import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.tracing.Tracer;
 import org.elasticsearch.transport.BytesRefRecycler;
 import org.elasticsearch.transport.TransportSettings;
+import org.elasticsearch.usage.UsageService;
 import org.elasticsearch.xcontent.NamedXContentRegistry;
 import org.junit.After;
 import org.junit.Assert;
@@ -47,8 +55,11 @@ import org.junit.Before;
 import java.net.InetSocketAddress;
 import java.net.UnknownHostException;
 import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Locale;
 import java.util.Map;
@@ -61,6 +72,8 @@ import static org.elasticsearch.http.AbstractHttpServerTransport.resolvePublishP
 import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.instanceOf;
+import static org.hamcrest.Matchers.nullValue;
+import static org.mockito.Mockito.mock;
 
 public class AbstractHttpServerTransportTests extends ESTestCase {
 
@@ -198,6 +211,177 @@ public class AbstractHttpServerTransportTests extends ESTestCase {
         }
     }
 
+    public void testRequestHeadersPopulateThreadContext() {
+        final HttpServerTransport.Dispatcher dispatcher = new HttpServerTransport.Dispatcher() {
+
+            @Override
+            public void dispatchRequest(final RestRequest request, final RestChannel channel, final ThreadContext threadContext) {
+                // specified request headers value are copied into the thread context
+                assertEquals("true", threadContext.getHeader("header.1"));
+                assertEquals("true", threadContext.getHeader("header.2"));
+                // but unknown headers are not copied at all
+                assertNull(threadContext.getHeader("header.3"));
+            }
+
+            @Override
+            public void dispatchBadRequest(final RestChannel channel, final ThreadContext threadContext, final Throwable cause) {
+                // no request headers are copied in to the context of malformed requests
+                assertNull(threadContext.getHeader("header.1"));
+                assertNull(threadContext.getHeader("header.2"));
+                assertNull(threadContext.getHeader("header.3"));
+            }
+
+        };
+        // the set of headers to copy
+        final Set<RestHeaderDefinition> headers = new HashSet<>(
+            Arrays.asList(new RestHeaderDefinition("header.1", true), new RestHeaderDefinition("header.2", true))
+        );
+        // sample request headers to test with
+        final Map<String, List<String>> restHeaders = new HashMap<>();
+        restHeaders.put("header.1", Collections.singletonList("true"));
+        restHeaders.put("header.2", Collections.singletonList("true"));
+        restHeaders.put("header.3", Collections.singletonList("true"));
+        final RestRequest fakeRequest = new FakeRestRequest.Builder(xContentRegistry()).withHeaders(restHeaders).build();
+        final RestControllerTests.AssertingChannel channel = new RestControllerTests.AssertingChannel(
+            fakeRequest,
+            false,
+            RestStatus.BAD_REQUEST
+        );
+
+        try (
+            AbstractHttpServerTransport transport = new AbstractHttpServerTransport(
+                Settings.EMPTY,
+                networkService,
+                recycler,
+                threadPool,
+                xContentRegistry(),
+                dispatcher,
+                new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS),
+                Tracer.NOOP
+            ) {
+
+                @Override
+                protected HttpServerChannel bind(InetSocketAddress hostAddress) {
+                    return null;
+                }
+
+                @Override
+                protected void doStart() {
+
+                }
+
+                @Override
+                protected void stopInternal() {
+
+                }
+
+                @Override
+                public HttpStats stats() {
+                    return null;
+                }
+
+                @Override
+                protected void populatePerRequestThreadContext(RestRequest restRequest, ThreadContext threadContext) {
+                    getFakeActionModule(headers).copyRequestHeadersToThreadContext(restRequest.getHttpRequest(), threadContext);
+                }
+            }
+        ) {
+            transport.dispatchRequest(fakeRequest, channel, null);
+            // headers are "null" here, aka not present, because the thread context changes containing them is to be confined to the request
+            assertNull(threadPool.getThreadContext().getHeader("header.1"));
+            assertNull(threadPool.getThreadContext().getHeader("header.2"));
+            assertNull(threadPool.getThreadContext().getHeader("header.3"));
+            transport.dispatchRequest(null, null, new Exception());
+            // headers are "null" here, aka not present, because the thread context changes containing them is to be confined to the request
+            assertNull(threadPool.getThreadContext().getHeader("header.1"));
+            assertNull(threadPool.getThreadContext().getHeader("header.2"));
+            assertNull(threadPool.getThreadContext().getHeader("header.3"));
+        }
+    }
+
+    /**
+     * Check that the REST controller picks up and propagates W3C trace context headers via the {@link ThreadContext}.
+     * @see <a href="https://www.w3.org/TR/trace-context/">Trace Context - W3C Recommendation</a>
+     */
+    public void testTraceParentAndTraceId() {
+        final String traceParentValue = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01";
+        final HttpServerTransport.Dispatcher dispatcher = new HttpServerTransport.Dispatcher() {
+
+            @Override
+            public void dispatchRequest(final RestRequest request, final RestChannel channel, final ThreadContext threadContext) {
+                assertThat(threadContext.getHeader(Task.TRACE_ID), equalTo("0af7651916cd43dd8448eb211c80319c"));
+                assertThat(threadContext.getHeader(Task.TRACE_PARENT_HTTP_HEADER), nullValue());
+                assertThat(threadContext.getTransient("parent_" + Task.TRACE_PARENT_HTTP_HEADER), equalTo(traceParentValue));
+            }
+
+            @Override
+            public void dispatchBadRequest(final RestChannel channel, final ThreadContext threadContext, final Throwable cause) {
+                // but they're not copied in for bad requests
+                assertThat(threadContext.getHeader(Task.TRACE_ID), nullValue());
+                assertThat(threadContext.getHeader(Task.TRACE_PARENT_HTTP_HEADER), nullValue());
+                assertThat(threadContext.getTransient("parent_" + Task.TRACE_PARENT_HTTP_HEADER), nullValue());
+            }
+
+        };
+        // the set of headers to copy
+        Set<RestHeaderDefinition> headers = Set.of(new RestHeaderDefinition(Task.TRACE_PARENT_HTTP_HEADER, false));
+        // sample request headers to test with
+        Map<String, List<String>> restHeaders = new HashMap<>();
+        restHeaders.put(Task.TRACE_PARENT_HTTP_HEADER, Collections.singletonList(traceParentValue));
+        RestRequest fakeRequest = new FakeRestRequest.Builder(xContentRegistry()).withHeaders(restHeaders).build();
+        RestControllerTests.AssertingChannel channel = new RestControllerTests.AssertingChannel(fakeRequest, false, RestStatus.BAD_REQUEST);
+
+        try (
+            AbstractHttpServerTransport transport = new AbstractHttpServerTransport(
+                Settings.EMPTY,
+                networkService,
+                recycler,
+                threadPool,
+                xContentRegistry(),
+                dispatcher,
+                new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS),
+                Tracer.NOOP
+            ) {
+
+                @Override
+                protected HttpServerChannel bind(InetSocketAddress hostAddress) {
+                    return null;
+                }
+
+                @Override
+                protected void doStart() {
+
+                }
+
+                @Override
+                protected void stopInternal() {
+
+                }
+
+                @Override
+                public HttpStats stats() {
+                    return null;
+                }
+
+                @Override
+                protected void populatePerRequestThreadContext(RestRequest restRequest, ThreadContext threadContext) {
+                    getFakeActionModule(headers).copyRequestHeadersToThreadContext(restRequest.getHttpRequest(), threadContext);
+                }
+            }
+        ) {
+            transport.dispatchRequest(fakeRequest, channel, null);
+            // headers are "null" here, aka not present, because the thread context changes containing them is to be confined to the request
+            assertThat(threadPool.getThreadContext().getHeader(Task.TRACE_ID), nullValue());
+            assertThat(threadPool.getThreadContext().getHeader(Task.TRACE_PARENT_HTTP_HEADER), nullValue());
+            assertThat(threadPool.getThreadContext().getTransient("parent_" + Task.TRACE_PARENT_HTTP_HEADER), nullValue());
+            transport.dispatchRequest(null, null, new Exception());
+            // headers are "null" here, aka not present, because the thread context changes containing them is to be confined to the request
+            assertThat(threadPool.getThreadContext().getHeader(Task.TRACE_ID), nullValue());
+            assertThat(threadPool.getThreadContext().getHeader(Task.TRACE_PARENT_HTTP_HEADER), nullValue());
+            assertThat(threadPool.getThreadContext().getTransient("parent_" + Task.TRACE_PARENT_HTTP_HEADER), nullValue());
+        }
+    }
+
     public void testHandlingCompatibleVersionParsingErrors() {
         // a compatible version exception (v7 on accept and v8 on content-type) should be handled gracefully
         final ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
@@ -711,4 +895,30 @@ public class AbstractHttpServerTransportTests extends ESTestCase {
         }
         return addresses;
     }
+
+    private ActionModule getFakeActionModule(Set<RestHeaderDefinition> headersToCopy) {
+        SettingsModule settings = new SettingsModule(Settings.EMPTY);
+        ActionPlugin copyHeadersPlugin = new ActionPlugin() {
+            @Override
+            public Collection<RestHeaderDefinition> getRestHeaders() {
+                return headersToCopy;
+            }
+        };
+        return new ActionModule(
+            settings.getSettings(),
+            TestIndexNameExpressionResolver.newInstance(threadPool.getThreadContext()),
+            settings.getIndexScopedSettings(),
+            settings.getClusterSettings(),
+            settings.getSettingsFilter(),
+            threadPool,
+            List.of(copyHeadersPlugin),
+            null,
+            null,
+            new UsageService(),
+            null,
+            null,
+            mock(ClusterService.class),
+            List.of()
+        );
+    }
 }

+ 23 - 131
server/src/test/java/org/elasticsearch/rest/RestControllerTests.java

@@ -32,7 +32,6 @@ import org.elasticsearch.http.HttpStats;
 import org.elasticsearch.indices.breaker.HierarchyCircuitBreakerService;
 import org.elasticsearch.rest.RestHandler.Route;
 import org.elasticsearch.rest.action.RestToXContentListener;
-import org.elasticsearch.tasks.Task;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.client.NoOpNodeClient;
 import org.elasticsearch.test.rest.FakeRestRequest;
@@ -60,13 +59,14 @@ import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicReference;
 
 import static org.elasticsearch.rest.RestController.ELASTIC_INTERNAL_ORIGIN_HTTP_HEADER;
+import static org.elasticsearch.rest.RestController.ELASTIC_PRODUCT_HTTP_HEADER;
+import static org.elasticsearch.rest.RestController.ELASTIC_PRODUCT_HTTP_HEADER_VALUE;
 import static org.elasticsearch.rest.RestRequest.Method.GET;
 import static org.elasticsearch.rest.RestRequest.Method.OPTIONS;
 import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.hasItem;
 import static org.hamcrest.Matchers.is;
-import static org.hamcrest.Matchers.nullValue;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyMap;
 import static org.mockito.ArgumentMatchers.anyString;
@@ -105,7 +105,7 @@ public class RestControllerTests extends ESTestCase {
         HttpServerTransport httpServerTransport = new TestHttpServerTransport();
         client = new NoOpNodeClient(this.getTestName());
         tracer = mock(Tracer.class);
-        restController = new RestController(Collections.emptySet(), null, client, circuitBreakerService, usageService, tracer, false);
+        restController = new RestController(null, client, circuitBreakerService, usageService, tracer, false);
         restController.registerHandler(
             new Route(GET, "/"),
             (request, channel, client) -> channel.sendResponse(
@@ -124,30 +124,17 @@ public class RestControllerTests extends ESTestCase {
         IOUtils.close(client);
     }
 
-    public void testApplyRelevantHeaders() throws Exception {
+    public void testApplyProductSpecificResponseHeaders() {
         final ThreadContext threadContext = client.threadPool().getThreadContext();
-        Set<RestHeaderDefinition> headers = new HashSet<>(
-            Arrays.asList(new RestHeaderDefinition("header.1", true), new RestHeaderDefinition("header.2", true))
-        );
-        final RestController restController = new RestController(headers, null, null, circuitBreakerService, usageService, tracer, false);
-        Map<String, List<String>> restHeaders = new HashMap<>();
-        restHeaders.put("header.1", Collections.singletonList("true"));
-        restHeaders.put("header.2", Collections.singletonList("true"));
-        restHeaders.put("header.3", Collections.singletonList("false"));
-        RestRequest fakeRequest = new FakeRestRequest.Builder(xContentRegistry()).withHeaders(restHeaders).build();
+        final RestController restController = new RestController(null, null, circuitBreakerService, usageService, tracer, false);
+        RestRequest fakeRequest = new FakeRestRequest.Builder(xContentRegistry()).build();
         AssertingChannel channel = new AssertingChannel(fakeRequest, false, RestStatus.BAD_REQUEST);
         restController.dispatchRequest(fakeRequest, channel, threadContext);
         // the rest controller relies on the caller to stash the context, so we should expect these values here as we didn't stash the
         // context in this test
-        assertEquals("true", threadContext.getHeader("header.1"));
-        assertEquals("true", threadContext.getHeader("header.2"));
-        assertNull(threadContext.getHeader("header.3"));
         List<String> expectedProductResponseHeader = new ArrayList<>();
-        expectedProductResponseHeader.add(RestController.ELASTIC_PRODUCT_HTTP_HEADER_VALUE);
-        assertEquals(
-            expectedProductResponseHeader,
-            threadContext.getResponseHeaders().getOrDefault(RestController.ELASTIC_PRODUCT_HTTP_HEADER, null)
-        );
+        expectedProductResponseHeader.add(ELASTIC_PRODUCT_HTTP_HEADER_VALUE);
+        assertEquals(expectedProductResponseHeader, threadContext.getResponseHeaders().getOrDefault(ELASTIC_PRODUCT_HTTP_HEADER, null));
     }
 
     public void testRequestWithDisallowedMultiValuedHeader() {
@@ -155,7 +142,7 @@ public class RestControllerTests extends ESTestCase {
         Set<RestHeaderDefinition> headers = new HashSet<>(
             Arrays.asList(new RestHeaderDefinition("header.1", true), new RestHeaderDefinition("header.2", false))
         );
-        final RestController restController = new RestController(headers, null, null, circuitBreakerService, usageService, tracer, false);
+        final RestController restController = new RestController(null, null, circuitBreakerService, usageService, tracer, false);
         Map<String, List<String>> restHeaders = new HashMap<>();
         restHeaders.put("header.1", Collections.singletonList("boo"));
         restHeaders.put("header.2", List.of("foo", "bar"));
@@ -170,7 +157,7 @@ public class RestControllerTests extends ESTestCase {
      */
     public void testDispatchStartsTrace() {
         final ThreadContext threadContext = client.threadPool().getThreadContext();
-        final RestController restController = new RestController(Set.of(), null, null, circuitBreakerService, usageService, tracer, false);
+        final RestController restController = new RestController(null, null, circuitBreakerService, usageService, tracer, false);
         RestRequest fakeRequest = new FakeRestRequest.Builder(xContentRegistry()).build();
         final RestController spyRestController = spy(restController);
         when(spyRestController.getAllHandlers(null, fakeRequest.rawPath())).thenReturn(new Iterator<>() {
@@ -194,35 +181,12 @@ public class RestControllerTests extends ESTestCase {
         );
     }
 
-    /**
-     * Check that the REST controller picks up and propagates W3C trace context headers via the {@link ThreadContext}.
-     * @see <a href="https://www.w3.org/TR/trace-context/">Trace Context - W3C Recommendation</a>
-     */
-    public void testTraceParentAndTraceId() {
-        final ThreadContext threadContext = client.threadPool().getThreadContext();
-        Set<RestHeaderDefinition> headers = Set.of(new RestHeaderDefinition(Task.TRACE_PARENT_HTTP_HEADER, false));
-        final RestController restController = new RestController(headers, null, null, circuitBreakerService, usageService, tracer, false);
-        Map<String, List<String>> restHeaders = new HashMap<>();
-        final String traceParentValue = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01";
-        restHeaders.put(Task.TRACE_PARENT_HTTP_HEADER, Collections.singletonList(traceParentValue));
-        RestRequest fakeRequest = new FakeRestRequest.Builder(xContentRegistry()).withHeaders(restHeaders).build();
-        AssertingChannel channel = new AssertingChannel(fakeRequest, false, RestStatus.BAD_REQUEST);
-
-        restController.dispatchRequest(fakeRequest, channel, threadContext);
-
-        // the rest controller relies on the caller to stash the context, so we should expect these values here as we didn't stash the
-        // context in this test
-        assertThat(threadContext.getHeader(Task.TRACE_ID), equalTo("0af7651916cd43dd8448eb211c80319c"));
-        assertThat(threadContext.getHeader(Task.TRACE_PARENT_HTTP_HEADER), nullValue());
-        assertThat(threadContext.getTransient("parent_" + Task.TRACE_PARENT_HTTP_HEADER), equalTo(traceParentValue));
-    }
-
     public void testRequestWithDisallowedMultiValuedHeaderButSameValues() {
         final ThreadContext threadContext = client.threadPool().getThreadContext();
         Set<RestHeaderDefinition> headers = new HashSet<>(
             Arrays.asList(new RestHeaderDefinition("header.1", true), new RestHeaderDefinition("header.2", false))
         );
-        final RestController restController = new RestController(headers, null, client, circuitBreakerService, usageService, tracer, false);
+        final RestController restController = new RestController(null, client, circuitBreakerService, usageService, tracer, false);
         Map<String, List<String>> restHeaders = new HashMap<>();
         restHeaders.put("header.1", Collections.singletonList("boo"));
         restHeaders.put("header.2", List.of("foo", "foo"));
@@ -293,7 +257,7 @@ public class RestControllerTests extends ESTestCase {
     }
 
     public void testRegisterSecondMethodWithDifferentNamedWildcard() {
-        final RestController restController = new RestController(null, null, null, circuitBreakerService, usageService, tracer, false);
+        final RestController restController = new RestController(null, null, circuitBreakerService, usageService, tracer, false);
 
         RestRequest.Method firstMethod = randomFrom(RestRequest.Method.values());
         RestRequest.Method secondMethod = randomFrom(Arrays.stream(RestRequest.Method.values()).filter(m -> m != firstMethod).toList());
@@ -317,7 +281,7 @@ public class RestControllerTests extends ESTestCase {
         AtomicBoolean wrapperCalled = new AtomicBoolean(false);
         final RestHandler handler = (RestRequest request, RestChannel channel, NodeClient client) -> handlerCalled.set(true);
         final HttpServerTransport httpServerTransport = new TestHttpServerTransport();
-        final RestController restController = new RestController(Collections.emptySet(), h -> {
+        final RestController restController = new RestController(h -> {
             assertSame(handler, h);
             return (RestRequest request, RestChannel channel, NodeClient client) -> wrapperCalled.set(true);
         }, client, circuitBreakerService, usageService, tracer, false);
@@ -407,7 +371,7 @@ public class RestControllerTests extends ESTestCase {
         String content = randomAlphaOfLength((int) Math.round(BREAKER_LIMIT.getBytes() / inFlightRequestsBreaker.getOverhead()));
         RestRequest request = testRestRequest("/", content, null);
         AssertingChannel channel = new AssertingChannel(request, true, RestStatus.NOT_ACCEPTABLE);
-        restController = new RestController(Collections.emptySet(), null, null, circuitBreakerService, usageService, tracer, false);
+        restController = new RestController(null, null, circuitBreakerService, usageService, tracer, false);
         restController.registerHandler(
             new Route(GET, "/"),
             (r, c, client) -> c.sendResponse(new RestResponse(RestStatus.OK, RestResponse.TEXT_CONTENT_TYPE, BytesArray.EMPTY))
@@ -774,15 +738,7 @@ public class RestControllerTests extends ESTestCase {
 
     public void testDispatchCompatibleHandler() {
 
-        RestController restController = new RestController(
-            Collections.emptySet(),
-            null,
-            client,
-            circuitBreakerService,
-            usageService,
-            tracer,
-            false
-        );
+        RestController restController = new RestController(null, client, circuitBreakerService, usageService, tracer, false);
 
         final RestApiVersion version = RestApiVersion.minimumSupported();
 
@@ -806,15 +762,7 @@ public class RestControllerTests extends ESTestCase {
 
     public void testDispatchCompatibleRequestToNewlyAddedHandler() {
 
-        RestController restController = new RestController(
-            Collections.emptySet(),
-            null,
-            client,
-            circuitBreakerService,
-            usageService,
-            tracer,
-            false
-        );
+        RestController restController = new RestController(null, client, circuitBreakerService, usageService, tracer, false);
 
         final RestApiVersion version = RestApiVersion.minimumSupported();
 
@@ -849,15 +797,7 @@ public class RestControllerTests extends ESTestCase {
     }
 
     public void testCurrentVersionVNDMediaTypeIsNotUsingCompatibility() {
-        RestController restController = new RestController(
-            Collections.emptySet(),
-            null,
-            client,
-            circuitBreakerService,
-            usageService,
-            tracer,
-            false
-        );
+        RestController restController = new RestController(null, client, circuitBreakerService, usageService, tracer, false);
 
         final RestApiVersion version = RestApiVersion.current();
 
@@ -882,15 +822,7 @@ public class RestControllerTests extends ESTestCase {
     }
 
     public void testCustomMediaTypeValidation() {
-        RestController restController = new RestController(
-            Collections.emptySet(),
-            null,
-            client,
-            circuitBreakerService,
-            usageService,
-            tracer,
-            false
-        );
+        RestController restController = new RestController(null, client, circuitBreakerService, usageService, tracer, false);
 
         final String mediaType = "application/x-protobuf";
         FakeRestRequest fakeRestRequest = requestWithContent(mediaType);
@@ -916,15 +848,7 @@ public class RestControllerTests extends ESTestCase {
     }
 
     public void testBrowserSafelistedContentTypesAreRejected() {
-        RestController restController = new RestController(
-            Collections.emptySet(),
-            null,
-            client,
-            circuitBreakerService,
-            usageService,
-            tracer,
-            false
-        );
+        RestController restController = new RestController(null, client, circuitBreakerService, usageService, tracer, false);
 
         final String mediaType = randomFrom(RestController.SAFELISTED_MEDIA_TYPES);
         FakeRestRequest fakeRestRequest = requestWithContent(mediaType);
@@ -945,15 +869,7 @@ public class RestControllerTests extends ESTestCase {
     }
 
     public void testRegisterWithReservedPath() {
-        final RestController restController = new RestController(
-            Set.of(),
-            null,
-            client,
-            circuitBreakerService,
-            usageService,
-            tracer,
-            false
-        );
+        final RestController restController = new RestController(null, client, circuitBreakerService, usageService, tracer, false);
         for (String path : RestController.RESERVED_PATHS) {
             IllegalArgumentException iae = expectThrows(IllegalArgumentException.class, () -> {
                 restController.registerHandler(
@@ -971,15 +887,7 @@ public class RestControllerTests extends ESTestCase {
      * Test that when serverless is disabled, all endpoints are available regardless of ServerlessScope annotations.
      */
     public void testApiProtectionWithServerlessDisabled() {
-        final RestController restController = new RestController(
-            Set.of(),
-            null,
-            client,
-            circuitBreakerService,
-            new UsageService(),
-            tracer,
-            false
-        );
+        final RestController restController = new RestController(null, client, circuitBreakerService, new UsageService(), tracer, false);
         restController.registerHandler(new PublicRestHandler());
         restController.registerHandler(new InternalRestHandler());
         restController.registerHandler(new HiddenRestHandler());
@@ -996,15 +904,7 @@ public class RestControllerTests extends ESTestCase {
      * annotated with a PUBLIC scope.
      */
     public void testApiProtectionWithServerlessEnabledAsEndUser() {
-        final RestController restController = new RestController(
-            Set.of(),
-            null,
-            client,
-            circuitBreakerService,
-            new UsageService(),
-            tracer,
-            true
-        );
+        final RestController restController = new RestController(null, client, circuitBreakerService, new UsageService(), tracer, true);
         restController.registerHandler(new PublicRestHandler());
         restController.registerHandler(new InternalRestHandler());
         restController.registerHandler(new HiddenRestHandler());
@@ -1027,15 +927,7 @@ public class RestControllerTests extends ESTestCase {
      * annotated with a PUBLIC or INTERNAL scope.
      */
     public void testApiProtectionWithServerlessEnabledAsInternalUser() {
-        final RestController restController = new RestController(
-            Set.of(),
-            null,
-            client,
-            circuitBreakerService,
-            new UsageService(),
-            tracer,
-            true
-        );
+        final RestController restController = new RestController(null, client, circuitBreakerService, new UsageService(), tracer, true);
         restController.registerHandler(new PublicRestHandler());
         restController.registerHandler(new InternalRestHandler());
         restController.registerHandler(new HiddenRestHandler());

+ 1 - 9
server/src/test/java/org/elasticsearch/rest/RestHttpResponseHeadersTests.java

@@ -78,15 +78,7 @@ public class RestHttpResponseHeadersTests extends ESTestCase {
         );
 
         UsageService usageService = new UsageService();
-        RestController restController = new RestController(
-            Collections.emptySet(),
-            null,
-            null,
-            circuitBreakerService,
-            usageService,
-            Tracer.NOOP,
-            false
-        );
+        RestController restController = new RestController(null, null, circuitBreakerService, usageService, Tracer.NOOP, false);
 
         // A basic RestHandler handles requests to the endpoint
         RestHandler restHandler = (request, channel, client) -> channel.sendResponse(new RestResponse(RestStatus.OK, ""));

+ 1 - 10
server/src/test/java/org/elasticsearch/rest/action/admin/indices/RestValidateQueryActionTests.java

@@ -44,7 +44,6 @@ import java.util.List;
 import java.util.Map;
 
 import static java.util.Collections.emptyMap;
-import static java.util.Collections.emptySet;
 import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.equalTo;
 import static org.mockito.Mockito.mock;
@@ -55,15 +54,7 @@ public class RestValidateQueryActionTests extends AbstractSearchTestCase {
     private NodeClient client = new NodeClient(Settings.EMPTY, threadPool);
 
     private UsageService usageService = new UsageService();
-    private RestController controller = new RestController(
-        emptySet(),
-        null,
-        client,
-        new NoneCircuitBreakerService(),
-        usageService,
-        Tracer.NOOP,
-        false
-    );
+    private RestController controller = new RestController(null, client, new NoneCircuitBreakerService(), usageService, Tracer.NOOP, false);
     private RestValidateQueryAction action = new RestValidateQueryAction();
 
     /**

+ 1 - 9
test/framework/src/main/java/org/elasticsearch/test/rest/RestActionTestCase.java

@@ -40,15 +40,7 @@ public abstract class RestActionTestCase extends ESTestCase {
     @Before
     public void setUpController() {
         verifyingClient = new VerifyingClient(this.getTestName());
-        controller = new RestController(
-            Collections.emptySet(),
-            null,
-            verifyingClient,
-            new NoneCircuitBreakerService(),
-            new UsageService(),
-            Tracer.NOOP,
-            false
-        );
+        controller = new RestController(null, verifyingClient, new NoneCircuitBreakerService(), new UsageService(), Tracer.NOOP, false);
     }
 
     @After

+ 3 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/LocalStateCompositeXPackPlugin.java

@@ -46,6 +46,7 @@ import org.elasticsearch.common.util.concurrent.ThreadContext;
 import org.elasticsearch.core.IOUtils;
 import org.elasticsearch.env.Environment;
 import org.elasticsearch.env.NodeEnvironment;
+import org.elasticsearch.http.HttpPreRequest;
 import org.elasticsearch.http.HttpServerTransport;
 import org.elasticsearch.index.IndexModule;
 import org.elasticsearch.index.IndexSettingProvider;
@@ -412,6 +413,7 @@ public class LocalStateCompositeXPackPlugin extends XPackPlugin
         NamedXContentRegistry xContentRegistry,
         NetworkService networkService,
         HttpServerTransport.Dispatcher dispatcher,
+        BiConsumer<HttpPreRequest, ThreadContext> perRequestThreadContext,
         ClusterSettings clusterSettings,
         Tracer tracer
     ) {
@@ -428,6 +430,7 @@ public class LocalStateCompositeXPackPlugin extends XPackPlugin
                         xContentRegistry,
                         networkService,
                         dispatcher,
+                        perRequestThreadContext,
                         clusterSettings,
                         tracer
                     )

+ 0 - 2
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/termsenum/action/RestTermsEnumActionTests.java

@@ -42,7 +42,6 @@ import java.util.List;
 import java.util.Map;
 
 import static java.util.Collections.emptyMap;
-import static java.util.Collections.emptySet;
 import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.equalTo;
 import static org.mockito.Mockito.mock;
@@ -54,7 +53,6 @@ public class RestTermsEnumActionTests extends ESTestCase {
 
     private static UsageService usageService = new UsageService();
     private static RestController controller = new RestController(
-        emptySet(),
         null,
         client,
         new NoneCircuitBreakerService(),

+ 23 - 11
x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java

@@ -49,6 +49,8 @@ import org.elasticsearch.core.Nullable;
 import org.elasticsearch.env.Environment;
 import org.elasticsearch.env.NodeEnvironment;
 import org.elasticsearch.env.NodeMetadata;
+import org.elasticsearch.http.HttpChannel;
+import org.elasticsearch.http.HttpPreRequest;
 import org.elasticsearch.http.HttpServerTransport;
 import org.elasticsearch.http.netty4.Netty4HttpHeaderValidator;
 import org.elasticsearch.http.netty4.Netty4HttpServerTransport;
@@ -76,6 +78,7 @@ import org.elasticsearch.reservedstate.ReservedClusterStateHandler;
 import org.elasticsearch.rest.RestController;
 import org.elasticsearch.rest.RestHandler;
 import org.elasticsearch.rest.RestHeaderDefinition;
+import org.elasticsearch.rest.RestRequest;
 import org.elasticsearch.script.ScriptService;
 import org.elasticsearch.search.internal.ShardSearchRequest;
 import org.elasticsearch.threadpool.ExecutorBuilder;
@@ -284,6 +287,7 @@ import org.elasticsearch.xpack.security.operator.OperatorOnlyRegistry;
 import org.elasticsearch.xpack.security.operator.OperatorPrivileges;
 import org.elasticsearch.xpack.security.operator.OperatorPrivileges.OperatorPrivilegesService;
 import org.elasticsearch.xpack.security.profile.ProfileService;
+import org.elasticsearch.xpack.security.rest.RemoteHostHeader;
 import org.elasticsearch.xpack.security.rest.SecurityRestFilter;
 import org.elasticsearch.xpack.security.rest.action.RestAuthenticateAction;
 import org.elasticsearch.xpack.security.rest.action.RestDelegatePkiAuthenticationAction;
@@ -379,6 +383,7 @@ import static org.elasticsearch.xpack.core.XPackSettings.API_KEY_SERVICE_ENABLED
 import static org.elasticsearch.xpack.core.XPackSettings.HTTP_SSL_ENABLED;
 import static org.elasticsearch.xpack.core.security.SecurityField.FIELD_LEVEL_SECURITY_FEATURE;
 import static org.elasticsearch.xpack.security.operator.OperatorPrivileges.OPERATOR_PRIVILEGES_ENABLED;
+import static org.elasticsearch.xpack.security.transport.SSLEngineUtils.extractClientCertificates;
 
 public class Security extends Plugin
     implements
@@ -1592,6 +1597,7 @@ public class Security extends Plugin
         NamedXContentRegistry xContentRegistry,
         NetworkService networkService,
         HttpServerTransport.Dispatcher dispatcher,
+        BiConsumer<HttpPreRequest, ThreadContext> perRequestThreadContext,
         ClusterSettings clusterSettings,
         Tracer tracer
     ) {
@@ -1615,8 +1621,9 @@ public class Security extends Plugin
         Map<String, Supplier<HttpServerTransport>> httpTransports = new HashMap<>();
         httpTransports.put(SecurityField.NAME4, () -> {
             final boolean ssl = HTTP_SSL_ENABLED.get(settings);
-            SSLService sslService = getSslService();
+            final SSLService sslService = getSslService();
             final SslConfiguration sslConfiguration;
+            final BiConsumer<HttpChannel, ThreadContext> populateClientCertificate;
             if (ssl) {
                 sslConfiguration = sslService.getHttpTransportSSLConfiguration();
                 if (SSLService.isConfigurationValidForServerUsage(sslConfiguration) == false) {
@@ -1625,8 +1632,14 @@ public class Security extends Plugin
                             + "[xpack.security.http.ssl.key] or [xpack.security.http.ssl.keystore.path] setting"
                     );
                 }
+                if (SSLService.isSSLClientAuthEnabled(sslConfiguration)) {
+                    populateClientCertificate = (channel, threadContext) -> extractClientCertificates(logger, threadContext, channel);
+                } else {
+                    populateClientCertificate = (channel, threadContext) -> {};
+                }
             } else {
                 sslConfiguration = null;
+                populateClientCertificate = (channel, threadContext) -> {};
             }
             return new Netty4HttpServerTransport(
                 settings,
@@ -1640,28 +1653,27 @@ public class Security extends Plugin
                 new TLSConfig(sslConfiguration, sslService::createSSLEngine),
                 acceptPredicate,
                 Netty4HttpHeaderValidator.NOOP_VALIDATOR
-            );
+            ) {
+                @Override
+                protected void populatePerRequestThreadContext(RestRequest restRequest, ThreadContext threadContext) {
+                    perRequestThreadContext.accept(restRequest.getHttpRequest(), threadContext);
+                    populateClientCertificate.accept(restRequest.getHttpChannel(), threadContext);
+                    RemoteHostHeader.process(restRequest, threadContext);
+                }
+            };
         });
         return httpTransports;
     }
 
     @Override
     public UnaryOperator<RestHandler> getRestHandlerInterceptor(ThreadContext threadContext) {
-        final boolean extractClientCertificate;
-        if (enabled && HTTP_SSL_ENABLED.get(settings)) {
-            final SslConfiguration httpSSLConfig = getSslService().getHttpTransportSSLConfiguration();
-            extractClientCertificate = SSLService.isSSLClientAuthEnabled(httpSSLConfig);
-        } else {
-            extractClientCertificate = false;
-        }
         return handler -> new SecurityRestFilter(
             enabled,
             threadContext,
             authcService.get(),
             secondayAuthc.get(),
             auditTrailService.get(),
-            handler,
-            extractClientCertificate
+            handler
         );
     }
 

+ 1 - 12
x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/SecurityRestFilter.java

@@ -14,7 +14,6 @@ import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.client.internal.node.NodeClient;
 import org.elasticsearch.common.util.Maps;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
-import org.elasticsearch.http.HttpChannel;
 import org.elasticsearch.rest.RestChannel;
 import org.elasticsearch.rest.RestHandler;
 import org.elasticsearch.rest.RestRequest;
@@ -25,7 +24,6 @@ import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.xpack.security.audit.AuditTrailService;
 import org.elasticsearch.xpack.security.authc.AuthenticationService;
 import org.elasticsearch.xpack.security.authc.support.SecondaryAuthenticator;
-import org.elasticsearch.xpack.security.transport.SSLEngineUtils;
 
 import java.util.List;
 import java.util.Map;
@@ -42,7 +40,6 @@ public class SecurityRestFilter implements RestHandler {
     private final AuditTrailService auditTrailService;
     private final boolean enabled;
     private final ThreadContext threadContext;
-    private final boolean extractClientCertificate;
 
     public enum ActionType {
         Authentication("Authentication"),
@@ -67,8 +64,7 @@ public class SecurityRestFilter implements RestHandler {
         AuthenticationService authenticationService,
         SecondaryAuthenticator secondaryAuthenticator,
         AuditTrailService auditTrailService,
-        RestHandler restHandler,
-        boolean extractClientCertificate
+        RestHandler restHandler
     ) {
         this.enabled = enabled;
         this.threadContext = threadContext;
@@ -76,7 +72,6 @@ public class SecurityRestFilter implements RestHandler {
         this.secondaryAuthenticator = secondaryAuthenticator;
         this.auditTrailService = auditTrailService;
         this.restHandler = restHandler;
-        this.extractClientCertificate = extractClientCertificate;
     }
 
     @Override
@@ -101,13 +96,7 @@ public class SecurityRestFilter implements RestHandler {
             return;
         }
 
-        if (extractClientCertificate) {
-            HttpChannel httpChannel = request.getHttpChannel();
-            SSLEngineUtils.extractClientCertificates(logger, threadContext, httpChannel);
-        }
-
         final RestRequest wrappedRequest = maybeWrapRestRequest(request);
-        RemoteHostHeader.process(request, threadContext);
         authenticationService.authenticate(wrappedRequest.getHttpRequest(), ActionListener.wrap(authentication -> {
             if (authentication == null) {
                 logger.trace("No authentication available for REST request [{}]", request.uri());

+ 4 - 8
x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/SecurityRestFilterTests.java

@@ -91,8 +91,7 @@ public class SecurityRestFilterTests extends ESTestCase {
             authcService,
             secondaryAuthenticator,
             new AuditTrailService(null, null),
-            restHandler,
-            false
+            restHandler
         );
     }
 
@@ -165,8 +164,7 @@ public class SecurityRestFilterTests extends ESTestCase {
             authcService,
             secondaryAuthenticator,
             mock(AuditTrailService.class),
-            restHandler,
-            false
+            restHandler
         );
         RestRequest request = mock(RestRequest.class);
         filter.handleRequest(request, channel, null);
@@ -181,8 +179,7 @@ public class SecurityRestFilterTests extends ESTestCase {
             authcService,
             secondaryAuthenticator,
             mock(AuditTrailService.class),
-            restHandler,
-            false
+            restHandler
         );
         testProcessAuthenticationFailed(
             randomBoolean()
@@ -312,8 +309,7 @@ public class SecurityRestFilterTests extends ESTestCase {
             authcService,
             secondaryAuthenticator,
             new AuditTrailService(auditTrail, licenseState),
-            restHandler,
-            false
+            restHandler
         );
 
         filter.handleRequest(restRequest, channel, null);

+ 2 - 4
x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/SecurityRestFilterWarningHeadersTests.java

@@ -96,8 +96,7 @@ public class SecurityRestFilterWarningHeadersTests extends ESTestCase {
             authcService,
             secondaryAuthenticator,
             new AuditTrailService(null, null),
-            restHandler,
-            false
+            restHandler
         );
         RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).build();
         Authentication primaryAuthentication = AuthenticationTestHelper.builder().build();
@@ -136,8 +135,7 @@ public class SecurityRestFilterWarningHeadersTests extends ESTestCase {
             authcService,
             secondaryAuthenticator,
             mock(AuditTrailService.class),
-            restHandler,
-            false
+            restHandler
         );
         RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).build();
         doAnswer((i) -> {