浏览代码

Security authn via netty channel validator (#95112)

Hooks "REST" authN, as a "validator", into the
new netty channel interceptor for http headers.
Albert Zaharovits 2 年之前
父节点
当前提交
bedaf3c9db
共有 33 个文件被更改,包括 1207 次插入524 次删除
  1. 5 0
      docs/changelog/95112.yaml
  2. 1 0
      modules/transport-netty4/src/main/java/module-info.java
  3. 4 10
      modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpHeaderValidator.java
  4. 54 78
      modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequest.java
  5. 43 28
      modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpServerTransport.java
  6. 119 0
      modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/internal/HttpHeadersAuthenticatorUtils.java
  7. 62 0
      modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/internal/HttpHeadersWithAuthenticationContext.java
  8. 25 0
      modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/internal/HttpValidator.java
  9. 85 0
      modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/HttpHeadersAuthenticatorUtilsTests.java
  10. 1 1
      modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4BadRequestTests.java
  11. 3 9
      modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpHeaderThreadContextTests.java
  12. 2 4
      modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpHeaderValidatorTests.java
  13. 1 1
      modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerPipeliningTests.java
  14. 314 10
      modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerTransportTests.java
  15. 7 1
      server/src/main/java/org/elasticsearch/ElasticsearchException.java
  16. 27 0
      server/src/main/java/org/elasticsearch/http/HttpHeadersValidationException.java
  17. 8 1
      server/src/main/java/org/elasticsearch/rest/RestController.java
  18. 2 0
      server/src/test/java/org/elasticsearch/ExceptionSerializationTests.java
  19. 42 0
      server/src/test/java/org/elasticsearch/http/DefaultRestChannelTests.java
  20. 22 0
      server/src/test/java/org/elasticsearch/rest/RestControllerTests.java
  21. 45 0
      server/src/test/java/org/elasticsearch/rest/RestResponseTests.java
  22. 1 0
      x-pack/plugin/security/src/main/java/module-info.java
  23. 69 21
      x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java
  24. 4 0
      x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/AuthenticationService.java
  25. 1 0
      x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/AuthenticatorChain.java
  26. 8 11
      x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/RemoteHostHeader.java
  27. 5 16
      x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/SecurityRestFilter.java
  28. 7 14
      x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/SSLEngineUtils.java
  29. 5 1
      x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/audit/logfile/LoggingAuditTrailFilterTests.java
  30. 33 31
      x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/audit/logfile/LoggingAuditTrailTests.java
  31. 2 125
      x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/SecurityRestFilterTests.java
  32. 0 154
      x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/SecurityRestFilterWarningHeadersTests.java
  33. 200 8
      x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4HttpServerTransportTests.java

+ 5 - 0
docs/changelog/95112.yaml

@@ -0,0 +1,5 @@
+pr: 95112
+summary: Header validator with Security
+area: Authentication
+type: enhancement
+issues: []

+ 1 - 0
modules/transport-netty4/src/main/java/module-info.java

@@ -23,4 +23,5 @@ module org.elasticsearch.transport.netty4 {
 
     exports org.elasticsearch.http.netty4;
     exports org.elasticsearch.transport.netty4;
+    exports org.elasticsearch.http.netty4.internal to org.elasticsearch.security;
 }

+ 4 - 10
modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpHeaderValidator.java

@@ -9,7 +9,6 @@
 package org.elasticsearch.http.netty4;
 
 import io.netty.buffer.Unpooled;
-import io.netty.channel.Channel;
 import io.netty.channel.ChannelHandlerContext;
 import io.netty.channel.ChannelInboundHandlerAdapter;
 import io.netty.handler.codec.DecoderResult;
@@ -21,8 +20,8 @@ import io.netty.util.ReferenceCountUtil;
 
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.support.ContextPreservingActionListener;
-import org.elasticsearch.common.TriConsumer;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
+import org.elasticsearch.http.netty4.internal.HttpValidator;
 import org.elasticsearch.transport.Transports;
 
 import java.util.ArrayDeque;
@@ -35,17 +34,12 @@ import static org.elasticsearch.http.netty4.Netty4HttpHeaderValidator.State.WAIT
 
 public class Netty4HttpHeaderValidator extends ChannelInboundHandlerAdapter {
 
-    public static final TriConsumer<HttpRequest, Channel, ActionListener<Void>> NOOP_VALIDATOR = ((
-        httpRequest,
-        channel,
-        listener) -> listener.onResponse(null));
-
-    private final TriConsumer<HttpRequest, Channel, ActionListener<Void>> validator;
+    private final HttpValidator validator;
     private final ThreadContext threadContext;
     private ArrayDeque<HttpObject> pending = new ArrayDeque<>(4);
     private State state = WAITING_TO_START;
 
-    public Netty4HttpHeaderValidator(TriConsumer<HttpRequest, Channel, ActionListener<Void>> validator, ThreadContext threadContext) {
+    public Netty4HttpHeaderValidator(HttpValidator validator, ThreadContext threadContext) {
         this.validator = validator;
         this.threadContext = threadContext;
     }
@@ -129,7 +123,7 @@ public class Netty4HttpHeaderValidator extends ChannelInboundHandlerAdapter {
             );
             // this prevents thread-context changes to propagate beyond the validation, as netty worker threads are reused
             try (ThreadContext.StoredContext ignore = threadContext.newStoredContext()) {
-                validator.apply(httpRequest, ctx.channel(), contextPreservingActionListener);
+                validator.validate(httpRequest, ctx.channel(), contextPreservingActionListener);
             }
         }
     }

+ 54 - 78
modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequest.java

@@ -11,7 +11,6 @@ package org.elasticsearch.http.netty4;
 import io.netty.buffer.ByteBuf;
 import io.netty.buffer.Unpooled;
 import io.netty.handler.codec.http.DefaultFullHttpRequest;
-import io.netty.handler.codec.http.DefaultHttpHeaders;
 import io.netty.handler.codec.http.FullHttpRequest;
 import io.netty.handler.codec.http.HttpHeaderNames;
 import io.netty.handler.codec.http.HttpHeaders;
@@ -41,51 +40,27 @@ public class Netty4HttpRequest implements HttpRequest {
 
     private final FullHttpRequest request;
     private final BytesReference content;
-    private final HttpHeadersMap headers;
+    private final Map<String, List<String>> headers;
     private final AtomicBoolean released;
     private final Exception inboundException;
     private final boolean pooled;
-
     private final int sequence;
 
     Netty4HttpRequest(int sequence, FullHttpRequest request) {
-        this(
-            sequence,
-            request,
-            new HttpHeadersMap(request.headers()),
-            new AtomicBoolean(false),
-            true,
-            Netty4Utils.toBytesReference(request.content())
-        );
+        this(sequence, request, new AtomicBoolean(false), true, Netty4Utils.toBytesReference(request.content()));
     }
 
     Netty4HttpRequest(int sequence, FullHttpRequest request, Exception inboundException) {
-        this(
-            sequence,
-            request,
-            new HttpHeadersMap(request.headers()),
-            new AtomicBoolean(false),
-            true,
-            Netty4Utils.toBytesReference(request.content()),
-            inboundException
-        );
+        this(sequence, request, new AtomicBoolean(false), true, Netty4Utils.toBytesReference(request.content()), inboundException);
     }
 
-    private Netty4HttpRequest(
-        int sequence,
-        FullHttpRequest request,
-        HttpHeadersMap headers,
-        AtomicBoolean released,
-        boolean pooled,
-        BytesReference content
-    ) {
-        this(sequence, request, headers, released, pooled, content, null);
+    private Netty4HttpRequest(int sequence, FullHttpRequest request, AtomicBoolean released, boolean pooled, BytesReference content) {
+        this(sequence, request, released, pooled, content, null);
     }
 
     private Netty4HttpRequest(
         int sequence,
         FullHttpRequest request,
-        HttpHeadersMap headers,
         AtomicBoolean released,
         boolean pooled,
         BytesReference content,
@@ -93,7 +68,7 @@ public class Netty4HttpRequest implements HttpRequest {
     ) {
         this.sequence = sequence;
         this.request = request;
-        this.headers = headers;
+        this.headers = getHttpHeadersAsMap(request.headers());
         this.content = content;
         this.pooled = pooled;
         this.released = released;
@@ -102,36 +77,7 @@ public class Netty4HttpRequest implements HttpRequest {
 
     @Override
     public RestRequest.Method method() {
-        HttpMethod httpMethod = request.method();
-        if (httpMethod == HttpMethod.GET) return RestRequest.Method.GET;
-
-        if (httpMethod == HttpMethod.POST) return RestRequest.Method.POST;
-
-        if (httpMethod == HttpMethod.PUT) return RestRequest.Method.PUT;
-
-        if (httpMethod == HttpMethod.DELETE) return RestRequest.Method.DELETE;
-
-        if (httpMethod == HttpMethod.HEAD) {
-            return RestRequest.Method.HEAD;
-        }
-
-        if (httpMethod == HttpMethod.OPTIONS) {
-            return RestRequest.Method.OPTIONS;
-        }
-
-        if (httpMethod == HttpMethod.PATCH) {
-            return RestRequest.Method.PATCH;
-        }
-
-        if (httpMethod == HttpMethod.TRACE) {
-            return RestRequest.Method.TRACE;
-        }
-
-        if (httpMethod == HttpMethod.CONNECT) {
-            return RestRequest.Method.CONNECT;
-        }
-
-        throw new IllegalArgumentException("Unexpected http method: " + httpMethod);
+        return translateRequestMethod(request.method());
     }
 
     @Override
@@ -170,7 +116,6 @@ public class Netty4HttpRequest implements HttpRequest {
                     request.headers(),
                     request.trailingHeaders()
                 ),
-                headers,
                 new AtomicBoolean(false),
                 false,
                 Netty4Utils.toBytesReference(copiedContent)
@@ -210,28 +155,19 @@ public class Netty4HttpRequest implements HttpRequest {
 
     @Override
     public HttpRequest removeHeader(String header) {
-        HttpHeaders headersWithoutContentTypeHeader = new DefaultHttpHeaders();
-        headersWithoutContentTypeHeader.add(request.headers());
-        headersWithoutContentTypeHeader.remove(header);
-        HttpHeaders trailingHeaders = new DefaultHttpHeaders();
-        trailingHeaders.add(request.trailingHeaders());
-        trailingHeaders.remove(header);
+        HttpHeaders copiedHeadersWithout = request.headers().copy();
+        copiedHeadersWithout.remove(header);
+        HttpHeaders copiedTrailingHeadersWithout = request.trailingHeaders().copy();
+        copiedTrailingHeadersWithout.remove(header);
         FullHttpRequest requestWithoutHeader = new DefaultFullHttpRequest(
             request.protocolVersion(),
             request.method(),
             request.uri(),
             request.content(),
-            headersWithoutContentTypeHeader,
-            trailingHeaders
-        );
-        return new Netty4HttpRequest(
-            sequence,
-            requestWithoutHeader,
-            new HttpHeadersMap(requestWithoutHeader.headers()),
-            released,
-            pooled,
-            content
+            copiedHeadersWithout,
+            copiedTrailingHeadersWithout
         );
+        return new Netty4HttpRequest(sequence, requestWithoutHeader, released, pooled, content);
     }
 
     @Override
@@ -249,6 +185,46 @@ public class Netty4HttpRequest implements HttpRequest {
         return inboundException;
     }
 
+    public io.netty.handler.codec.http.HttpRequest getNettyRequest() {
+        return request;
+    }
+
+    public static RestRequest.Method translateRequestMethod(HttpMethod httpMethod) {
+        if (httpMethod == HttpMethod.GET) return RestRequest.Method.GET;
+
+        if (httpMethod == HttpMethod.POST) return RestRequest.Method.POST;
+
+        if (httpMethod == HttpMethod.PUT) return RestRequest.Method.PUT;
+
+        if (httpMethod == HttpMethod.DELETE) return RestRequest.Method.DELETE;
+
+        if (httpMethod == HttpMethod.HEAD) {
+            return RestRequest.Method.HEAD;
+        }
+
+        if (httpMethod == HttpMethod.OPTIONS) {
+            return RestRequest.Method.OPTIONS;
+        }
+
+        if (httpMethod == HttpMethod.PATCH) {
+            return RestRequest.Method.PATCH;
+        }
+
+        if (httpMethod == HttpMethod.TRACE) {
+            return RestRequest.Method.TRACE;
+        }
+
+        if (httpMethod == HttpMethod.CONNECT) {
+            return RestRequest.Method.CONNECT;
+        }
+
+        throw new IllegalArgumentException("Unexpected http method: " + httpMethod);
+    }
+
+    public static Map<String, List<String>> getHttpHeadersAsMap(HttpHeaders httpHeaders) {
+        return new HttpHeadersMap(httpHeaders);
+    }
+
     /**
      * A wrapper of {@link HttpHeaders} that implements a map to prevent copying unnecessarily. This class does not support modifications
      * and due to the underlying implementation, it performs case insensitive lookups of key to values.

+ 43 - 28
modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpServerTransport.java

@@ -22,8 +22,8 @@ import io.netty.channel.socket.nio.NioChannelOption;
 import io.netty.handler.codec.ByteToMessageDecoder;
 import io.netty.handler.codec.http.HttpContentCompressor;
 import io.netty.handler.codec.http.HttpContentDecompressor;
+import io.netty.handler.codec.http.HttpMessage;
 import io.netty.handler.codec.http.HttpObjectAggregator;
-import io.netty.handler.codec.http.HttpRequest;
 import io.netty.handler.codec.http.HttpRequestDecoder;
 import io.netty.handler.codec.http.HttpResponse;
 import io.netty.handler.codec.http.HttpResponseEncoder;
@@ -36,8 +36,6 @@ import io.netty.util.AttributeKey;
 import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 import org.elasticsearch.ExceptionsHelper;
-import org.elasticsearch.action.ActionListener;
-import org.elasticsearch.common.TriConsumer;
 import org.elasticsearch.common.network.CloseableChannel;
 import org.elasticsearch.common.network.NetworkService;
 import org.elasticsearch.common.settings.ClusterSettings;
@@ -47,7 +45,6 @@ import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.unit.ByteSizeUnit;
 import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.common.util.concurrent.EsExecutors;
-import org.elasticsearch.common.util.concurrent.ThreadContext;
 import org.elasticsearch.core.IOUtils;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.http.AbstractHttpServerTransport;
@@ -56,6 +53,8 @@ import org.elasticsearch.http.HttpHandlingSettings;
 import org.elasticsearch.http.HttpReadTimeoutException;
 import org.elasticsearch.http.HttpServerChannel;
 import org.elasticsearch.http.HttpServerTransport;
+import org.elasticsearch.http.netty4.internal.HttpHeadersAuthenticatorUtils;
+import org.elasticsearch.http.netty4.internal.HttpValidator;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.tracing.Tracer;
 import org.elasticsearch.transport.netty4.AcceptChannelHandler;
@@ -147,7 +146,7 @@ public class Netty4HttpServerTransport extends AbstractHttpServerTransport {
     private final RecvByteBufAllocator recvByteBufAllocator;
     private final TLSConfig tlsConfig;
     private final AcceptChannelHandler.AcceptPredicate acceptChannelPredicate;
-    private final TriConsumer<HttpRequest, Channel, ActionListener<Void>> headerValidator;
+    private final HttpValidator httpValidator;
     private final int readTimeoutMillis;
 
     private final int maxCompositeBufferComponents;
@@ -166,7 +165,7 @@ public class Netty4HttpServerTransport extends AbstractHttpServerTransport {
         Tracer tracer,
         TLSConfig tlsConfig,
         @Nullable AcceptChannelHandler.AcceptPredicate acceptChannelPredicate,
-        @Nullable TriConsumer<HttpRequest, Channel, ActionListener<Void>> headerValidator
+        @Nullable HttpValidator httpValidator
     ) {
         super(
             settings,
@@ -183,7 +182,7 @@ public class Netty4HttpServerTransport extends AbstractHttpServerTransport {
         this.sharedGroupFactory = sharedGroupFactory;
         this.tlsConfig = tlsConfig;
         this.acceptChannelPredicate = acceptChannelPredicate;
-        this.headerValidator = headerValidator;
+        this.httpValidator = httpValidator;
 
         this.pipeliningMaxEvents = SETTING_PIPELINING_MAX_EVENTS.get(settings);
 
@@ -329,14 +328,7 @@ public class Netty4HttpServerTransport extends AbstractHttpServerTransport {
     }
 
     public ChannelHandler configureServerChannelHandler() {
-        return new HttpChannelHandler(
-            this,
-            handlingSettings,
-            tlsConfig,
-            threadPool.getThreadContext(),
-            acceptChannelPredicate,
-            headerValidator
-        );
+        return new HttpChannelHandler(this, handlingSettings, tlsConfig, acceptChannelPredicate, httpValidator);
     }
 
     static final AttributeKey<Netty4HttpChannel> HTTP_CHANNEL_KEY = AttributeKey.newInstance("es-http-channel");
@@ -347,24 +339,21 @@ public class Netty4HttpServerTransport extends AbstractHttpServerTransport {
         private final Netty4HttpServerTransport transport;
         private final HttpHandlingSettings handlingSettings;
         private final TLSConfig tlsConfig;
-        private final ThreadContext threadContext;
         private final BiPredicate<String, InetSocketAddress> acceptChannelPredicate;
-        private final TriConsumer<HttpRequest, Channel, ActionListener<Void>> headerValidator;
+        private final HttpValidator httpValidator;
 
         protected HttpChannelHandler(
             final Netty4HttpServerTransport transport,
             final HttpHandlingSettings handlingSettings,
             final TLSConfig tlsConfig,
-            final ThreadContext threadContext,
             @Nullable final BiPredicate<String, InetSocketAddress> acceptChannelPredicate,
-            @Nullable final TriConsumer<HttpRequest, Channel, ActionListener<Void>> headerValidator
+            @Nullable final HttpValidator httpValidator
         ) {
             this.transport = transport;
             this.handlingSettings = handlingSettings;
             this.tlsConfig = tlsConfig;
-            this.threadContext = threadContext;
             this.acceptChannelPredicate = acceptChannelPredicate;
-            this.headerValidator = headerValidator;
+            this.httpValidator = httpValidator;
         }
 
         @Override
@@ -387,17 +376,43 @@ public class Netty4HttpServerTransport extends AbstractHttpServerTransport {
             if (transport.readTimeoutMillis > 0) {
                 ch.pipeline().addLast("read_timeout", new ReadTimeoutHandler(transport.readTimeoutMillis, TimeUnit.MILLISECONDS));
             }
-            final HttpRequestDecoder decoder = new HttpRequestDecoder(
-                handlingSettings.maxInitialLineLength(),
-                handlingSettings.maxHeaderSize(),
-                handlingSettings.maxChunkSize()
-            );
+            final HttpRequestDecoder decoder;
+            if (httpValidator != null) {
+                decoder = new HttpRequestDecoder(
+                    handlingSettings.maxInitialLineLength(),
+                    handlingSettings.maxHeaderSize(),
+                    handlingSettings.maxChunkSize()
+                ) {
+                    @Override
+                    protected HttpMessage createMessage(String[] initialLine) throws Exception {
+                        return HttpHeadersAuthenticatorUtils.wrapAsMessageWithAuthenticationContext(super.createMessage(initialLine));
+                    }
+
+                    @Override
+                    protected HttpMessage createInvalidMessage() {
+                        return HttpHeadersAuthenticatorUtils.wrapAsMessageWithAuthenticationContext(super.createInvalidMessage());
+                    }
+                };
+            } else {
+                decoder = new HttpRequestDecoder(
+                    handlingSettings.maxInitialLineLength(),
+                    handlingSettings.maxHeaderSize(),
+                    handlingSettings.maxChunkSize()
+                );
+            }
             decoder.setCumulator(ByteToMessageDecoder.COMPOSITE_CUMULATOR);
             ch.pipeline().addLast("decoder", decoder); // parses the HTTP bytes request into HTTP message pieces
-            if (headerValidator != null) {
+            if (httpValidator != null) {
                 // runs a validation function on the first HTTP message piece which contains all the headers
                 // if validation passes, the pieces of that particular request are forwarded, otherwise they are discarded
-                ch.pipeline().addLast("header_validator", new Netty4HttpHeaderValidator(headerValidator, threadContext));
+                ch.pipeline()
+                    .addLast(
+                        "header_validator",
+                        HttpHeadersAuthenticatorUtils.getValidatorInboundHandler(
+                            httpValidator,
+                            transport.getThreadPool().getThreadContext()
+                        )
+                    );
             }
             // combines the HTTP message pieces into a single full HTTP request (with headers and body)
             final HttpObjectAggregator aggregator = new HttpObjectAggregator(handlingSettings.maxContentLength());

+ 119 - 0
modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/internal/HttpHeadersAuthenticatorUtils.java

@@ -0,0 +1,119 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.http.netty4.internal;
+
+import io.netty.handler.codec.http.DefaultHttpRequest;
+import io.netty.handler.codec.http.HttpMessage;
+import io.netty.handler.codec.http.HttpRequest;
+
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.common.util.concurrent.ThreadContext;
+import org.elasticsearch.http.HttpHeadersValidationException;
+import org.elasticsearch.http.HttpPreRequest;
+import org.elasticsearch.http.netty4.Netty4HttpHeaderValidator;
+import org.elasticsearch.http.netty4.Netty4HttpRequest;
+import org.elasticsearch.rest.RestRequest;
+
+import java.util.List;
+import java.util.Map;
+
+import static org.elasticsearch.http.netty4.Netty4HttpRequest.getHttpHeadersAsMap;
+import static org.elasticsearch.http.netty4.Netty4HttpRequest.translateRequestMethod;
+
+/**
+ * Provides utilities for hooking into the netty pipeline and authenticate each HTTP request's headers.
+ * See also {@link Netty4HttpHeaderValidator}.
+ */
+public final class HttpHeadersAuthenticatorUtils {
+
+    // utility class
+    private HttpHeadersAuthenticatorUtils() {}
+
+    /**
+     * Supplies a netty {@code ChannelInboundHandler} that runs the provided {@param validator} on the HTTP request headers.
+     * The HTTP headers of the to-be-authenticated {@link HttpRequest} must be wrapped by the special
+     * {@link HttpHeadersWithAuthenticationContext}, see {@link #wrapAsMessageWithAuthenticationContext(HttpMessage)}.
+     */
+    public static Netty4HttpHeaderValidator getValidatorInboundHandler(HttpValidator validator, ThreadContext threadContext) {
+        return new Netty4HttpHeaderValidator((httpRequest, channel, listener) -> {
+            // make sure authentication only runs on properly wrapped "authenticable" headers implementation
+            if (httpRequest.headers() instanceof HttpHeadersWithAuthenticationContext httpHeadersWithAuthenticationContext) {
+                validator.validate(httpRequest, channel, ActionListener.wrap(aVoid -> {
+                    httpHeadersWithAuthenticationContext.setAuthenticationContext(threadContext.newStoredContext());
+                    // a successful authentication needs to signal to the {@link Netty4HttpHeaderValidator} to resume
+                    // forwarding the request beyond the headers part
+                    listener.onResponse(null);
+                }, e -> listener.onFailure(new HttpHeadersValidationException(e))));
+            } else {
+                // cannot authenticate the request because it's not wrapped correctly, see {@link #wrapAsMessageWithAuthenticationContext}
+                listener.onFailure(new IllegalStateException("Cannot authenticate unwrapped requests"));
+            }
+        }, threadContext);
+    }
+
+    /**
+     * Given a {@link DefaultHttpRequest} argument, this returns a new {@link DefaultHttpRequest} instance that's identical to the
+     * passed-in one, but the headers of the latter can be authenticated, in the sense that the channel handlers returned by
+     * {@link #getValidatorInboundHandler(HttpValidator, ThreadContext)} can use this to convey the authentication result context.
+     */
+    public static HttpMessage wrapAsMessageWithAuthenticationContext(HttpMessage newlyDecodedMessage) {
+        assert newlyDecodedMessage instanceof HttpRequest;
+        DefaultHttpRequest httpRequest = (DefaultHttpRequest) newlyDecodedMessage;
+        HttpHeadersWithAuthenticationContext httpHeadersWithAuthenticationContext = new HttpHeadersWithAuthenticationContext(
+            newlyDecodedMessage.headers()
+        );
+        return new DefaultHttpRequest(
+            httpRequest.protocolVersion(),
+            httpRequest.method(),
+            httpRequest.uri(),
+            httpHeadersWithAuthenticationContext
+        );
+    }
+
+    /**
+     * Returns the authentication thread context for the {@param request}.
+     */
+    public static ThreadContext.StoredContext extractAuthenticationContext(org.elasticsearch.http.HttpRequest request) {
+        HttpHeadersWithAuthenticationContext authenticatedHeaders = unwrapAuthenticatedHeaders(request);
+        return authenticatedHeaders != null ? authenticatedHeaders.authenticationContextSetOnce.get() : null;
+    }
+
+    /**
+     * Translates the netty request internal type to a {@link HttpPreRequest} instance that code outside the network plugin has access to.
+     */
+    public static HttpPreRequest asHttpPreRequest(HttpRequest request) {
+        return new HttpPreRequest() {
+
+            @Override
+            public RestRequest.Method method() {
+                return translateRequestMethod(request.method());
+            }
+
+            @Override
+            public String uri() {
+                return request.uri();
+            }
+
+            @Override
+            public Map<String, List<String>> getHeaders() {
+                return getHttpHeadersAsMap(request.headers());
+            }
+        };
+    }
+
+    private static HttpHeadersWithAuthenticationContext unwrapAuthenticatedHeaders(org.elasticsearch.http.HttpRequest request) {
+        if (request instanceof Netty4HttpRequest == false) {
+            return null;
+        }
+        if (((Netty4HttpRequest) request).getNettyRequest().headers() instanceof HttpHeadersWithAuthenticationContext == false) {
+            return null;
+        }
+        return (HttpHeadersWithAuthenticationContext) (((Netty4HttpRequest) request).getNettyRequest().headers());
+    }
+}

+ 62 - 0
modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/internal/HttpHeadersWithAuthenticationContext.java

@@ -0,0 +1,62 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.http.netty4.internal;
+
+import io.netty.handler.codec.http.DefaultHttpHeaders;
+import io.netty.handler.codec.http.HttpHeaders;
+
+import org.apache.lucene.util.SetOnce;
+import org.elasticsearch.common.util.concurrent.ThreadContext;
+
+import java.util.Objects;
+
+/**
+ * {@link HttpHeaders} implementation that carries along the {@link ThreadContext.StoredContext} iff
+ * the HTTP headers have been authenticated successfully.
+ */
+public final class HttpHeadersWithAuthenticationContext extends DefaultHttpHeaders {
+
+    public final SetOnce<ThreadContext.StoredContext> authenticationContextSetOnce;
+
+    public HttpHeadersWithAuthenticationContext(HttpHeaders httpHeaders) {
+        this(httpHeaders, new SetOnce<>());
+    }
+
+    private HttpHeadersWithAuthenticationContext(
+        HttpHeaders httpHeaders,
+        SetOnce<ThreadContext.StoredContext> authenticationContextSetOnce
+    ) {
+        // the constructor implements the same logic as HttpHeaders#copy
+        super();
+        set(httpHeaders);
+        this.authenticationContextSetOnce = authenticationContextSetOnce;
+    }
+
+    private HttpHeadersWithAuthenticationContext(HttpHeaders httpHeaders, ThreadContext.StoredContext authenticationContext) {
+        this(httpHeaders);
+        if (authenticationContext != null) {
+            setAuthenticationContext(authenticationContext);
+        }
+    }
+
+    /**
+     * Must be called at most once in order to mark the http headers as successfully authenticated.
+     * The intent of the {@link ThreadContext.StoredContext} parameter is to associate the resulting
+     * thread context post authentication, that will later be restored when dispatching the request.
+     */
+    public void setAuthenticationContext(ThreadContext.StoredContext authenticationContext) {
+        this.authenticationContextSetOnce.set(Objects.requireNonNull(authenticationContext));
+    }
+
+    @Override
+    public HttpHeaders copy() {
+        // copy the headers but also STILL CARRY the same validation result
+        return new HttpHeadersWithAuthenticationContext(super.copy(), authenticationContextSetOnce.get());
+    }
+}

+ 25 - 0
modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/internal/HttpValidator.java

@@ -0,0 +1,25 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.http.netty4.internal;
+
+import io.netty.channel.Channel;
+import io.netty.handler.codec.http.HttpRequest;
+
+import org.elasticsearch.action.ActionListener;
+
+public interface HttpValidator {
+    /**
+     * An async HTTP request validating function that receives as arguments the initial part of a decoded HTTP request
+     * (which contains all the HTTP headers, but not the body contents), as well as the netty channel that the
+     * request is being received over, and must then call the {@code ActionListener#onResponse} method on the
+     * listener parameter in case the authentication is to be considered successful, or otherwise call
+     * {@code ActionListener#onFailure} and pass the failure exception.
+     */
+    void validate(HttpRequest httpRequest, Channel channel, ActionListener<Void> listener);
+}

+ 85 - 0
modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/HttpHeadersAuthenticatorUtilsTests.java

@@ -0,0 +1,85 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.http.netty4;
+
+import io.netty.handler.codec.http.DefaultHttpRequest;
+import io.netty.handler.codec.http.HttpHeaders;
+import io.netty.handler.codec.http.HttpMethod;
+import io.netty.handler.codec.http.HttpVersion;
+
+import org.elasticsearch.common.util.concurrent.ThreadContext;
+import org.elasticsearch.http.netty4.internal.HttpHeadersAuthenticatorUtils;
+import org.elasticsearch.http.netty4.internal.HttpHeadersWithAuthenticationContext;
+import org.elasticsearch.test.ESTestCase;
+
+import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.nullValue;
+
+public final class HttpHeadersAuthenticatorUtilsTests extends ESTestCase {
+
+    public void testRemoveHeaderPreservesValidationResult() {
+        final ThreadContext.StoredContext dummyValidationContext = () -> {};
+        final DefaultHttpRequest httpRequest = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/uri");
+        String header1 = "header1";
+        String headerValue1 = "headerValue1";
+        String header2 = "header2";
+        String headerValue2 = "headerValue2";
+        httpRequest.headers().add(header1, headerValue1);
+        httpRequest.headers().add(header2, headerValue2);
+        final DefaultHttpRequest validatableHttpRequest = (DefaultHttpRequest) HttpHeadersAuthenticatorUtils
+            .wrapAsMessageWithAuthenticationContext(httpRequest);
+        boolean validated = randomBoolean();
+        if (validated) {
+            ((HttpHeadersWithAuthenticationContext) validatableHttpRequest.headers()).setAuthenticationContext(dummyValidationContext);
+        }
+        if (randomBoolean()) {
+            validatableHttpRequest.headers().remove("header1");
+            assertThat(validatableHttpRequest.headers().contains("header1"), is(false));
+            assertThat(validatableHttpRequest.headers().contains("header2"), is(true));
+        } else {
+            validatableHttpRequest.headers().remove("header2");
+            assertThat(validatableHttpRequest.headers().contains("header1"), is(true));
+            assertThat(validatableHttpRequest.headers().contains("header2"), is(false));
+        }
+        if (validated) {
+            assertThat(
+                ((HttpHeadersWithAuthenticationContext) validatableHttpRequest.headers()).authenticationContextSetOnce.get(),
+                is(dummyValidationContext)
+            );
+        } else {
+            assertThat(
+                ((HttpHeadersWithAuthenticationContext) validatableHttpRequest.headers()).authenticationContextSetOnce.get(),
+                nullValue()
+            );
+        }
+    }
+
+    public void testCopyHeaderPreservesValidationResult() {
+        final ThreadContext.StoredContext dummyValidationContext = () -> {};
+        final DefaultHttpRequest httpRequest = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/uri");
+        String header = "header";
+        String headerValue = "headerValue";
+        httpRequest.headers().add(header, headerValue);
+        final DefaultHttpRequest validatableHttpRequest = (DefaultHttpRequest) HttpHeadersAuthenticatorUtils
+            .wrapAsMessageWithAuthenticationContext(httpRequest);
+        boolean validated = randomBoolean();
+        if (validated) {
+            ((HttpHeadersWithAuthenticationContext) validatableHttpRequest.headers()).setAuthenticationContext(dummyValidationContext);
+        }
+        HttpHeaders httpHeadersCopy = ((HttpHeadersWithAuthenticationContext) validatableHttpRequest.headers()).copy();
+        if (validated) {
+            assertThat(
+                ((HttpHeadersWithAuthenticationContext) httpHeadersCopy).authenticationContextSetOnce.get(),
+                is(dummyValidationContext)
+            );
+        } else {
+            assertThat(((HttpHeadersWithAuthenticationContext) httpHeadersCopy).authenticationContextSetOnce.get(), nullValue());
+        }
+    }
+}

+ 1 - 1
modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4BadRequestTests.java

@@ -88,7 +88,7 @@ public class Netty4BadRequestTests extends ESTestCase {
                 Tracer.NOOP,
                 TLSConfig.noTLS(),
                 null,
-                randomFrom(Netty4HttpHeaderValidator.NOOP_VALIDATOR, null)
+                randomFrom((httpPreRequest, channel, listener) -> listener.onResponse(null), null)
             )
         ) {
             httpServerTransport.start();

+ 3 - 9
modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpHeaderThreadContextTests.java

@@ -9,7 +9,6 @@
 package org.elasticsearch.http.netty4;
 
 import io.netty.buffer.Unpooled;
-import io.netty.channel.Channel;
 import io.netty.channel.ChannelDuplexHandler;
 import io.netty.channel.ChannelHandlerContext;
 import io.netty.channel.ChannelPromise;
@@ -18,12 +17,10 @@ import io.netty.handler.codec.http.DefaultHttpContent;
 import io.netty.handler.codec.http.DefaultHttpRequest;
 import io.netty.handler.codec.http.DefaultLastHttpContent;
 import io.netty.handler.codec.http.HttpMethod;
-import io.netty.handler.codec.http.HttpRequest;
 import io.netty.handler.codec.http.HttpVersion;
 
-import org.elasticsearch.action.ActionListener;
-import org.elasticsearch.common.TriConsumer;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
+import org.elasticsearch.http.netty4.internal.HttpValidator;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.threadpool.TestThreadPool;
 import org.elasticsearch.threadpool.ThreadPool;
@@ -46,6 +43,7 @@ import static org.hamcrest.Matchers.sameInstance;
  * This also tests that a threading validator cannot fork the following netty pipeline handlers on a different thread.
  */
 public class Netty4HttpHeaderThreadContextTests extends ESTestCase {
+
     private EmbeddedChannel channel;
     private ThreadPool threadPool;
 
@@ -146,11 +144,7 @@ public class Netty4HttpHeaderThreadContextTests extends ESTestCase {
         sendRequestThrough(isValidationSuccessful.get(), validationDone);
     }
 
-    private TriConsumer<HttpRequest, Channel, ActionListener<Void>> getValidator(
-        ExecutorService executorService,
-        AtomicBoolean success,
-        Semaphore validationDone
-    ) {
+    private HttpValidator getValidator(ExecutorService executorService, AtomicBoolean success, Semaphore validationDone) {
         return (httpRequest, channel, listener) -> {
             executorService.submit(() -> {
                 if (randomBoolean()) {

+ 2 - 4
modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpHeaderValidatorTests.java

@@ -9,23 +9,21 @@
 package org.elasticsearch.http.netty4;
 
 import io.netty.buffer.Unpooled;
-import io.netty.channel.Channel;
 import io.netty.channel.embedded.EmbeddedChannel;
 import io.netty.handler.codec.http.DefaultHttpContent;
 import io.netty.handler.codec.http.DefaultHttpRequest;
 import io.netty.handler.codec.http.DefaultLastHttpContent;
 import io.netty.handler.codec.http.HttpHeaderNames;
 import io.netty.handler.codec.http.HttpMethod;
-import io.netty.handler.codec.http.HttpRequest;
 import io.netty.handler.codec.http.HttpVersion;
 import io.netty.handler.codec.http.LastHttpContent;
 
 import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.ElasticsearchSecurityException;
 import org.elasticsearch.action.ActionListener;
-import org.elasticsearch.common.TriConsumer;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
+import org.elasticsearch.http.netty4.internal.HttpValidator;
 import org.elasticsearch.test.ESTestCase;
 
 import java.util.List;
@@ -57,7 +55,7 @@ public class Netty4HttpHeaderValidatorTests extends ESTestCase {
         channel = new EmbeddedChannel();
         header.set(null);
         listener.set(null);
-        TriConsumer<HttpRequest, Channel, ActionListener<Void>> validator = (httpRequest, channel, validationCompleteListener) -> {
+        HttpValidator validator = (httpRequest, channel, validationCompleteListener) -> {
             header.set(httpRequest);
             listener.set(validationCompleteListener);
         };

+ 1 - 1
modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerPipeliningTests.java

@@ -108,7 +108,7 @@ public class Netty4HttpServerPipeliningTests extends ESTestCase {
                 Tracer.NOOP,
                 TLSConfig.noTLS(),
                 null,
-                randomFrom(Netty4HttpHeaderValidator.NOOP_VALIDATOR, null)
+                randomFrom((httpPreRequest, channel, listener) -> listener.onResponse(null), null)
             );
         }
 

+ 314 - 10
modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerTransportTests.java

@@ -38,7 +38,10 @@ import io.netty.handler.codec.http.HttpResponseStatus;
 import io.netty.handler.codec.http.HttpUtil;
 import io.netty.handler.codec.http.HttpVersion;
 
+import org.apache.lucene.util.SetOnce;
 import org.elasticsearch.ElasticsearchException;
+import org.elasticsearch.ElasticsearchSecurityException;
+import org.elasticsearch.ElasticsearchWrapperException;
 import org.elasticsearch.common.bytes.BytesArray;
 import org.elasticsearch.common.collect.Iterators;
 import org.elasticsearch.common.network.NetworkAddress;
@@ -53,9 +56,12 @@ import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.http.AbstractHttpServerTransportTestCase;
 import org.elasticsearch.http.BindHttpException;
 import org.elasticsearch.http.CorsHandler;
+import org.elasticsearch.http.HttpHeadersValidationException;
 import org.elasticsearch.http.HttpServerTransport;
 import org.elasticsearch.http.HttpTransportSettings;
 import org.elasticsearch.http.NullDispatcher;
+import org.elasticsearch.http.netty4.internal.HttpHeadersAuthenticatorUtils;
+import org.elasticsearch.http.netty4.internal.HttpValidator;
 import org.elasticsearch.rest.ChunkedRestResponseBody;
 import org.elasticsearch.rest.RestChannel;
 import org.elasticsearch.rest.RestRequest;
@@ -74,20 +80,29 @@ import org.junit.Before;
 import java.io.IOException;
 import java.nio.charset.Charset;
 import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
 import java.util.Collections;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.BiConsumer;
 
+import static com.carrotsearch.randomizedtesting.RandomizedTest.getRandom;
 import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_ORIGIN;
 import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ENABLED;
 import static org.elasticsearch.rest.RestStatus.BAD_REQUEST;
 import static org.elasticsearch.rest.RestStatus.OK;
+import static org.elasticsearch.rest.RestStatus.UNAUTHORIZED;
 import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.instanceOf;
 import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.notNullValue;
+import static org.hamcrest.Matchers.nullValue;
 
 /**
  * Tests for the {@link Netty4HttpServerTransport} class.
@@ -178,7 +193,7 @@ public class Netty4HttpServerTransportTests extends AbstractHttpServerTransportT
                 Tracer.NOOP,
                 TLSConfig.noTLS(),
                 null,
-                randomFrom(Netty4HttpHeaderValidator.NOOP_VALIDATOR, null)
+                randomFrom((httpPreRequest, channel, listener) -> listener.onResponse(null), null)
             )
         ) {
             transport.start();
@@ -230,7 +245,7 @@ public class Netty4HttpServerTransportTests extends AbstractHttpServerTransportT
                 Tracer.NOOP,
                 TLSConfig.noTLS(),
                 null,
-                randomFrom(Netty4HttpHeaderValidator.NOOP_VALIDATOR, null)
+                randomFrom((httpPreRequest, channel, listener) -> listener.onResponse(null), null)
             )
         ) {
             transport.start();
@@ -251,7 +266,7 @@ public class Netty4HttpServerTransportTests extends AbstractHttpServerTransportT
                     Tracer.NOOP,
                     TLSConfig.noTLS(),
                     null,
-                    randomFrom(Netty4HttpHeaderValidator.NOOP_VALIDATOR, null)
+                    randomFrom((httpPreRequest, channel, listener) -> listener.onResponse(null), null)
                 )
             ) {
                 BindHttpException bindHttpException = expectThrows(BindHttpException.class, otherTransport::start);
@@ -306,7 +321,7 @@ public class Netty4HttpServerTransportTests extends AbstractHttpServerTransportT
                 Tracer.NOOP,
                 TLSConfig.noTLS(),
                 null,
-                randomFrom(Netty4HttpHeaderValidator.NOOP_VALIDATOR, null)
+                randomFrom((httpPreRequest, channel, listener) -> listener.onResponse(null), null)
             )
         ) {
             transport.start();
@@ -377,7 +392,7 @@ public class Netty4HttpServerTransportTests extends AbstractHttpServerTransportT
                 Tracer.NOOP,
                 TLSConfig.noTLS(),
                 null,
-                randomFrom(Netty4HttpHeaderValidator.NOOP_VALIDATOR, null)
+                randomFrom((httpPreRequest, channel, listener) -> listener.onResponse(null), null)
             ) {
                 @Override
                 public ChannelHandler configureServerChannelHandler() {
@@ -385,9 +400,8 @@ public class Netty4HttpServerTransportTests extends AbstractHttpServerTransportT
                         this,
                         handlingSettings,
                         TLSConfig.noTLS(),
-                        threadPool.getThreadContext(),
                         null,
-                        randomFrom(Netty4HttpHeaderValidator.NOOP_VALIDATOR, null)
+                        randomFrom((httpPreRequest, channel, listener) -> listener.onResponse(null), null)
                     ) {
                         @Override
                         protected void initChannel(Channel ch) throws Exception {
@@ -484,7 +498,7 @@ public class Netty4HttpServerTransportTests extends AbstractHttpServerTransportT
                 Tracer.NOOP,
                 TLSConfig.noTLS(),
                 null,
-                randomFrom(Netty4HttpHeaderValidator.NOOP_VALIDATOR, null)
+                randomFrom((httpPreRequest, channel, listener) -> listener.onResponse(null), null)
             )
         ) {
             transport.start();
@@ -557,7 +571,7 @@ public class Netty4HttpServerTransportTests extends AbstractHttpServerTransportT
                 Tracer.NOOP,
                 TLSConfig.noTLS(),
                 null,
-                randomFrom(Netty4HttpHeaderValidator.NOOP_VALIDATOR, null)
+                randomFrom((httpPreRequest, channel, listener) -> listener.onResponse(null), null)
             )
         ) {
             transport.start();
@@ -623,7 +637,7 @@ public class Netty4HttpServerTransportTests extends AbstractHttpServerTransportT
                 Tracer.NOOP,
                 TLSConfig.noTLS(),
                 null,
-                randomFrom(Netty4HttpHeaderValidator.NOOP_VALIDATOR, null)
+                randomFrom((httpPreRequest, channel, listener) -> listener.onResponse(null), null)
             )
         ) {
             transport.start();
@@ -644,6 +658,280 @@ public class Netty4HttpServerTransportTests extends AbstractHttpServerTransportT
         }
     }
 
+    public void testHttpHeadersSuccessfulValidation() throws InterruptedException {
+        final AtomicReference<HttpMethod> httpMethodReference = new AtomicReference<>();
+        final AtomicReference<String> urlReference = new AtomicReference<>();
+        final AtomicReference<String> requestHeaderReference = new AtomicReference<>();
+        final AtomicReference<String> requestHeaderValueReference = new AtomicReference<>();
+        final AtomicReference<String> contextHeaderReference = new AtomicReference<>();
+        final AtomicReference<String> contextHeaderValueReference = new AtomicReference<>();
+        final HttpServerTransport.Dispatcher dispatcher = new HttpServerTransport.Dispatcher() {
+            @Override
+            public void dispatchRequest(final RestRequest request, final RestChannel channel, final ThreadContext threadContext) {
+                assertThat(request.getHttpRequest().uri(), is(urlReference.get()));
+                assertThat(request.getHttpRequest().header(requestHeaderReference.get()), is(requestHeaderValueReference.get()));
+                assertThat(request.getHttpRequest().method(), is(translateRequestMethod(httpMethodReference.get())));
+                // validation context is restored
+                assertThat(threadPool.getThreadContext().getHeader(contextHeaderReference.get()), is(contextHeaderValueReference.get()));
+                assertThat(threadPool.getThreadContext().getTransient(contextHeaderReference.get()), is(contextHeaderValueReference.get()));
+                // return some response
+                channel.sendResponse(new RestResponse(OK, RestResponse.TEXT_CONTENT_TYPE, new BytesArray("done")));
+            }
+
+            @Override
+            public void dispatchBadRequest(final RestChannel channel, final ThreadContext threadContext, final Throwable cause) {
+                throw new AssertionError("A validated request should not dispatch as bad");
+            }
+        };
+        final HttpValidator httpValidator = (httpRequest, channel, validationListener) -> {
+            // assert that the validator sees the request unaltered
+            assertThat(httpRequest.uri(), is(urlReference.get()));
+            assertThat(httpRequest.headers().get(requestHeaderReference.get()), is(requestHeaderValueReference.get()));
+            assertThat(httpRequest.method(), is(httpMethodReference.get()));
+            // make validation alter the thread context
+            contextHeaderReference.set(randomAlphaOfLengthBetween(4, 8));
+            contextHeaderValueReference.set(randomAlphaOfLengthBetween(4, 8));
+            threadPool.getThreadContext().putHeader(contextHeaderReference.get(), contextHeaderValueReference.get());
+            threadPool.getThreadContext().putTransient(contextHeaderReference.get(), contextHeaderValueReference.get());
+            // validate successfully
+            validationListener.onResponse(null);
+        };
+        try (
+            Netty4HttpServerTransport transport = getTestNetty4HttpServerTransport(
+                dispatcher,
+                httpValidator,
+                (restRequest, threadContext) -> {
+                    // assert the thread context does not yet contain anything that validation set in
+                    assertThat(threadPool.getThreadContext().getHeader(contextHeaderReference.get()), nullValue());
+                    assertThat(threadPool.getThreadContext().getTransient(contextHeaderReference.get()), nullValue());
+                    ThreadContext.StoredContext storedAuthenticatedContext = HttpHeadersAuthenticatorUtils.extractAuthenticationContext(
+                        restRequest.getHttpRequest()
+                    );
+                    assertThat(storedAuthenticatedContext, notNullValue());
+                    // restore validation context
+                    storedAuthenticatedContext.restore();
+                    // assert that now, after restoring the validation context, it does contain what validation put in
+                    assertThat(
+                        threadPool.getThreadContext().getHeader(contextHeaderReference.get()),
+                        is(contextHeaderValueReference.get())
+                    );
+                    assertThat(
+                        threadPool.getThreadContext().getTransient(contextHeaderReference.get()),
+                        is(contextHeaderValueReference.get())
+                    );
+                }
+            )
+        ) {
+            transport.start();
+            final TransportAddress remoteAddress = randomFrom(transport.boundAddress().boundAddresses());
+            for (HttpMethod httpMethod : List.of(HttpMethod.GET, HttpMethod.POST, HttpMethod.PUT, HttpMethod.DELETE, HttpMethod.PATCH)) {
+                httpMethodReference.set(httpMethod);
+                urlReference.set(
+                    "/"
+                        + randomAlphaOfLengthBetween(4, 8)
+                        + "?X-"
+                        + randomAlphaOfLengthBetween(4, 8)
+                        + "="
+                        + randomAlphaOfLengthBetween(4, 8)
+                );
+                requestHeaderReference.set("X-" + randomAlphaOfLengthBetween(4, 8));
+                requestHeaderValueReference.set(randomAlphaOfLengthBetween(4, 8));
+                try (Netty4HttpClient client = new Netty4HttpClient()) {
+                    FullHttpRequest request = new DefaultFullHttpRequest(
+                        HttpVersion.HTTP_1_1,
+                        httpMethodReference.get(),
+                        urlReference.get()
+                    );
+                    request.headers().set(requestHeaderReference.get(), requestHeaderValueReference.get());
+                    FullHttpResponse response = client.send(remoteAddress.address(), request);
+                    assertThat(response.status(), is(HttpResponseStatus.OK));
+                }
+            }
+        }
+    }
+
+    public void testHttpHeadersFailedValidation() throws InterruptedException {
+        final AtomicReference<HttpMethod> httpMethodReference = new AtomicReference<>();
+        final AtomicReference<String> urlReference = new AtomicReference<>();
+        final AtomicReference<String> headerReference = new AtomicReference<>();
+        final AtomicReference<String> headerValueReference = new AtomicReference<>();
+        final AtomicReference<Exception> validationResultExceptionReference = new AtomicReference<>();
+        final HttpServerTransport.Dispatcher dispatcher = new HttpServerTransport.Dispatcher() {
+            @Override
+            public void dispatchRequest(final RestRequest request, final RestChannel channel, final ThreadContext threadContext) {
+                throw new AssertionError("Request that failed validation should not be dispatched");
+            }
+
+            @Override
+            public void dispatchBadRequest(final RestChannel channel, final ThreadContext threadContext, final Throwable cause) {
+                assertThat(cause, instanceOf(HttpHeadersValidationException.class));
+                assertThat(((ElasticsearchWrapperException) cause).getCause(), is(validationResultExceptionReference.get()));
+                assertThat(channel.request().getHttpRequest().uri(), is(urlReference.get()));
+                assertThat(channel.request().getHttpRequest().header(headerReference.get()), is(headerValueReference.get()));
+                assertThat(channel.request().getHttpRequest().method(), is(translateRequestMethod(httpMethodReference.get())));
+                try {
+                    channel.sendResponse(new RestResponse(channel, (Exception) ((ElasticsearchWrapperException) cause).getCause()));
+                } catch (IOException e) {
+                    throw new AssertionError(e);
+                }
+            }
+        };
+        final HttpValidator failureHeadersValidator = (httpRequest, channel, validationResultListener) -> {
+            // assert that the validator sees the request unaltered
+            assertThat(httpRequest.uri(), is(urlReference.get()));
+            assertThat(httpRequest.headers().get(headerReference.get()), is(headerValueReference.get()));
+            assertThat(httpRequest.method(), is(httpMethodReference.get()));
+            // failed validation
+            validationResultListener.onFailure(validationResultExceptionReference.get());
+        };
+        try (
+            Netty4HttpServerTransport transport = getTestNetty4HttpServerTransport(
+                dispatcher,
+                failureHeadersValidator,
+                (restRequest, threadContext) -> {
+                    throw new AssertionError("Request that failed validation should not be dispatched");
+                }
+            )
+        ) {
+            transport.start();
+            final TransportAddress remoteAddress = randomFrom(transport.boundAddress().boundAddresses());
+            for (HttpMethod httpMethod : List.of(HttpMethod.GET, HttpMethod.POST, HttpMethod.PUT, HttpMethod.DELETE, HttpMethod.PATCH)) {
+                httpMethodReference.set(httpMethod);
+                urlReference.set(
+                    "/"
+                        + randomAlphaOfLengthBetween(4, 8)
+                        + "?X-"
+                        + randomAlphaOfLengthBetween(4, 8)
+                        + "="
+                        + randomAlphaOfLengthBetween(4, 8)
+                );
+                validationResultExceptionReference.set(new ElasticsearchSecurityException("Boom", UNAUTHORIZED));
+                try (Netty4HttpClient client = new Netty4HttpClient()) {
+                    FullHttpRequest request = new DefaultFullHttpRequest(
+                        HttpVersion.HTTP_1_1,
+                        httpMethodReference.get(),
+                        urlReference.get()
+                    );
+                    // submit the request with some header custom header
+                    headerReference.set("X-" + randomAlphaOfLengthBetween(4, 8));
+                    headerValueReference.set(randomAlphaOfLengthBetween(4, 8));
+                    request.headers().set(headerReference.get(), headerValueReference.get());
+                    FullHttpResponse response = client.send(remoteAddress.address(), request);
+                    assertThat(response.status(), is(HttpResponseStatus.UNAUTHORIZED));
+                }
+            }
+        }
+    }
+
+    public void testMultipleValidationsOnTheSameChannel() throws InterruptedException {
+        // ensure that there is a single channel active
+        final Settings settings = createBuilderWithPort().put(Netty4HttpServerTransport.SETTING_HTTP_WORKER_COUNT.getKey(), 1).build();
+        final Set<String> okURIs = ConcurrentHashMap.newKeySet();
+        final Set<String> nokURIs = ConcurrentHashMap.newKeySet();
+        final SetOnce<Channel> channelSetOnce = new SetOnce<>();
+        final HttpServerTransport.Dispatcher dispatcher = new HttpServerTransport.Dispatcher() {
+            @Override
+            public void dispatchRequest(final RestRequest request, final RestChannel channel, final ThreadContext threadContext) {
+                assertThat(okURIs.contains(request.uri()), is(true));
+                // assert validated request is dispatched
+                okURIs.remove(request.uri());
+                channel.sendResponse(new RestResponse(OK, RestResponse.TEXT_CONTENT_TYPE, new BytesArray("dispatch OK")));
+            }
+
+            @Override
+            public void dispatchBadRequest(final RestChannel channel, final ThreadContext threadContext, final Throwable cause) {
+                // assert unvalidated request is NOT dispatched
+                assertThat(nokURIs.contains(channel.request().uri()), is(true));
+                nokURIs.remove(channel.request().uri());
+                try {
+                    channel.sendResponse(new RestResponse(channel, (Exception) ((ElasticsearchWrapperException) cause).getCause()));
+                } catch (IOException e) {
+                    throw new AssertionError(e);
+                }
+            }
+        };
+        final HttpValidator headersValidator = (httpPreRequest, channel, validationListener) -> {
+            // assert all validations run on the same channel
+            channelSetOnce.trySet(channel);
+            assertThat(channelSetOnce.get(), is(channel));
+            // some requests are validated while others are not
+            if (httpPreRequest.uri().contains("X-Auth=OK")) {
+                validationListener.onResponse(null);
+            } else if (httpPreRequest.uri().contains("X-Auth=NOK")) {
+                validationListener.onFailure(new ElasticsearchSecurityException("Boom", UNAUTHORIZED));
+            } else {
+                throw new AssertionError("Unrecognized URI");
+            }
+        };
+        try (
+            Netty4HttpServerTransport transport = getTestNetty4HttpServerTransport(
+                settings,
+                dispatcher,
+                headersValidator,
+                (restRequest, threadContext) -> {}
+            )
+        ) {
+            transport.start();
+            final TransportAddress remoteAddress = randomFrom(transport.boundAddress().boundAddresses());
+            final int totalRequestCount = randomIntBetween(64, 128);
+            for (int requestId = 0; requestId < totalRequestCount; requestId++) {
+                String uri = "/" + randomAlphaOfLengthBetween(4, 8) + "?Request-Id=" + requestId;
+                if (randomBoolean()) {
+                    uri = uri + "&X-Auth=OK";
+                    okURIs.add(uri);
+                } else {
+                    uri = uri + "&X-Auth=NOK";
+                    nokURIs.add(uri);
+                }
+            }
+            List<String> allURIs = new ArrayList<>();
+            allURIs.addAll(okURIs);
+            allURIs.addAll(nokURIs);
+            Collections.shuffle(allURIs, getRandom());
+            assertThat(allURIs.size(), is(totalRequestCount));
+            try (Netty4HttpClient client = new Netty4HttpClient()) {
+                client.get(remoteAddress.address(), allURIs.toArray(new String[0]));
+                // assert all validations have been dispatched (or not) correctly
+                assertThat(okURIs.size(), is(0));
+                assertThat(nokURIs.size(), is(0));
+            }
+        }
+    }
+
+    private Netty4HttpServerTransport getTestNetty4HttpServerTransport(
+        HttpServerTransport.Dispatcher dispatcher,
+        HttpValidator httpValidator,
+        BiConsumer<RestRequest, ThreadContext> populatePerRequestContext
+    ) {
+        return getTestNetty4HttpServerTransport(createSettings(), dispatcher, httpValidator, populatePerRequestContext);
+    }
+
+    private Netty4HttpServerTransport getTestNetty4HttpServerTransport(
+        Settings settings,
+        HttpServerTransport.Dispatcher dispatcher,
+        HttpValidator httpValidator,
+        BiConsumer<RestRequest, ThreadContext> populatePerRequestContext
+    ) {
+        return new Netty4HttpServerTransport(
+            settings,
+            networkService,
+            threadPool,
+            xContentRegistry(),
+            dispatcher,
+            clusterSettings,
+            new SharedGroupFactory(settings),
+            Tracer.NOOP,
+            TLSConfig.noTLS(),
+            null,
+            httpValidator
+        ) {
+            @Override
+            protected void populatePerRequestThreadContext(RestRequest restRequest, ThreadContext threadContext) {
+                populatePerRequestContext.accept(restRequest, threadContext);
+            }
+        };
+    }
+
     private Settings createSettings() {
         return createBuilderWithPort().build();
     }
@@ -651,4 +939,20 @@ public class Netty4HttpServerTransportTests extends AbstractHttpServerTransportT
     private Settings.Builder createBuilderWithPort() {
         return Settings.builder().put(HttpTransportSettings.SETTING_HTTP_PORT.getKey(), getPortRange());
     }
+
+    private static RestRequest.Method translateRequestMethod(HttpMethod httpMethod) {
+        if (httpMethod == HttpMethod.GET) return RestRequest.Method.GET;
+
+        if (httpMethod == HttpMethod.POST) return RestRequest.Method.POST;
+
+        if (httpMethod == HttpMethod.PUT) return RestRequest.Method.PUT;
+
+        if (httpMethod == HttpMethod.DELETE) return RestRequest.Method.DELETE;
+
+        if (httpMethod == HttpMethod.PATCH) {
+            return RestRequest.Method.PATCH;
+        }
+
+        throw new IllegalArgumentException("Unexpected http method: " + httpMethod);
+    }
 }

+ 7 - 1
server/src/main/java/org/elasticsearch/ElasticsearchException.java

@@ -1834,7 +1834,13 @@ public class ElasticsearchException extends RuntimeException implements ToXConte
             167,
             TransportVersion.V_8_5_0
         ),
-        DOCUMENT_PARSING_EXCEPTION(DocumentParsingException.class, DocumentParsingException::new, 168, TransportVersion.V_8_8_0);
+        DOCUMENT_PARSING_EXCEPTION(DocumentParsingException.class, DocumentParsingException::new, 168, TransportVersion.V_8_8_0),
+        HTTP_HEADERS_VALIDATION_EXCEPTION(
+            org.elasticsearch.http.HttpHeadersValidationException.class,
+            org.elasticsearch.http.HttpHeadersValidationException::new,
+            169,
+            TransportVersion.V_8_9_0
+        );
 
         final Class<? extends ElasticsearchException> exceptionClass;
         final CheckedFunction<StreamInput, ? extends ElasticsearchException, IOException> constructor;

+ 27 - 0
server/src/main/java/org/elasticsearch/http/HttpHeadersValidationException.java

@@ -0,0 +1,27 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.http;
+
+import org.elasticsearch.ElasticsearchException;
+import org.elasticsearch.ElasticsearchWrapperException;
+import org.elasticsearch.common.io.stream.StreamInput;
+
+import java.io.IOException;
+
+public final class HttpHeadersValidationException extends ElasticsearchException implements ElasticsearchWrapperException {
+
+    public HttpHeadersValidationException(Exception cause) {
+        super(cause);
+    }
+
+    public HttpHeadersValidationException(StreamInput in) throws IOException {
+        super(in);
+    }
+
+}

+ 8 - 1
server/src/main/java/org/elasticsearch/rest/RestController.java

@@ -25,6 +25,7 @@ import org.elasticsearch.common.util.concurrent.ThreadContext;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.RestApiVersion;
 import org.elasticsearch.core.Streams;
+import org.elasticsearch.http.HttpHeadersValidationException;
 import org.elasticsearch.http.HttpServerTransport;
 import org.elasticsearch.indices.breaker.CircuitBreakerService;
 import org.elasticsearch.rest.RestHandler.Route;
@@ -331,7 +332,13 @@ public class RestController implements HttpServerTransport.Dispatcher {
             } else {
                 e = new ElasticsearchException(cause);
             }
-            channel.sendResponse(new RestResponse(channel, BAD_REQUEST, e));
+            // unless it's a http headers validation error, we consider any exceptions encountered so far during request processing
+            // to be a problem of invalid/malformed request (hence the RestStatus#BAD_REQEST (400) HTTP response code)
+            if (e instanceof HttpHeadersValidationException) {
+                channel.sendResponse(new RestResponse(channel, (Exception) e.getCause()));
+            } else {
+                channel.sendResponse(new RestResponse(channel, BAD_REQUEST, e));
+            }
         } catch (final IOException e) {
             if (cause != null) {
                 e.addSuppressed(cause);

+ 2 - 0
server/src/test/java/org/elasticsearch/ExceptionSerializationTests.java

@@ -48,6 +48,7 @@ import org.elasticsearch.core.PathUtils;
 import org.elasticsearch.core.Tuple;
 import org.elasticsearch.env.ShardLockObtainFailedException;
 import org.elasticsearch.health.node.action.HealthNodeNotDiscoveredException;
+import org.elasticsearch.http.HttpHeadersValidationException;
 import org.elasticsearch.index.Index;
 import org.elasticsearch.index.engine.RecoveryEngineException;
 import org.elasticsearch.index.mapper.DocumentParsingException;
@@ -827,6 +828,7 @@ public class ExceptionSerializationTests extends ESTestCase {
         ids.put(166, HealthNodeNotDiscoveredException.class);
         ids.put(167, UnsupportedAggregationOnDownsampledIndex.class);
         ids.put(168, DocumentParsingException.class);
+        ids.put(169, HttpHeadersValidationException.class);
 
         Map<Class<? extends ElasticsearchException>, Integer> reverse = new HashMap<>();
         for (Map.Entry<Integer, Class<? extends ElasticsearchException>> entry : ids.entrySet()) {

+ 42 - 0
server/src/test/java/org/elasticsearch/http/DefaultRestChannelTests.java

@@ -27,6 +27,7 @@ import org.elasticsearch.common.util.BigArrays;
 import org.elasticsearch.common.util.ByteArray;
 import org.elasticsearch.common.util.MockBigArrays;
 import org.elasticsearch.common.util.MockPageCacheRecycler;
+import org.elasticsearch.common.util.concurrent.ThreadContext;
 import org.elasticsearch.core.Releasable;
 import org.elasticsearch.indices.breaker.NoneCircuitBreakerService;
 import org.elasticsearch.rest.ChunkedRestResponseBody;
@@ -56,12 +57,14 @@ import java.nio.charset.StandardCharsets;
 import java.util.Collections;
 import java.util.List;
 import java.util.Map;
+import java.util.concurrent.atomic.AtomicReference;
 
 import static org.elasticsearch.test.ActionListenerUtils.anyActionListener;
 import static org.hamcrest.Matchers.containsInAnyOrder;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.hasItem;
 import static org.hamcrest.Matchers.instanceOf;
+import static org.hamcrest.Matchers.is;
 import static org.hamcrest.Matchers.not;
 import static org.hamcrest.Matchers.nullValue;
 import static org.mockito.ArgumentMatchers.any;
@@ -322,6 +325,45 @@ public class DefaultRestChannelTests extends ESTestCase {
         }
     }
 
+    public void testResponseHeadersFiltering() {
+        final HttpRequest httpRequest = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/");
+        final RestRequest request = RestRequest.request(parserConfig(), httpRequest, httpChannel);
+        final AtomicReference<HttpResponse> responseReference = new AtomicReference<>();
+        final DefaultRestChannel channel = new DefaultRestChannel(
+            httpChannel,
+            httpRequest,
+            request,
+            bigArrays,
+            HttpHandlingSettings.fromSettings(Settings.EMPTY),
+            threadPool.getThreadContext(),
+            CorsHandler.fromSettings(Settings.EMPTY),
+            httpTracer,
+            tracer
+        );
+        doAnswer(invocationOnMock -> {
+            ActionListener<?> listener = invocationOnMock.getArgument(1);
+            listener.onResponse(null);
+            HttpResponse response = invocationOnMock.getArgument(0);
+            responseReference.set(response);
+            return null;
+        }).when(httpChannel).sendResponse(any(HttpResponse.class), anyActionListener());
+        for (RestResponse response : List.of(
+            new RestResponse(RestStatus.UNAUTHORIZED, "whatever"),
+            new RestResponse(RestStatus.FORBIDDEN, "whatever")
+        )) {
+            try (ThreadContext.StoredContext ignore = threadPool.getThreadContext().newStoredContext()) {
+                threadPool.getThreadContext().addResponseHeader("X-elastic-product", "some product response header");
+                threadPool.getThreadContext().addResponseHeader("Warning", "some product response header");
+                String someRandomResponseHeader = "some-random-response-header-" + randomAlphaOfLength(8);
+                threadPool.getThreadContext().addResponseHeader(someRandomResponseHeader, "should transpire to http response");
+                channel.sendResponse(response);
+                assertThat(responseReference.get().containsHeader(someRandomResponseHeader), is(true));
+                assertThat(responseReference.get().containsHeader("X-elastic-product"), is(false));
+                assertThat(responseReference.get().containsHeader("Warning"), is(false));
+            }
+        }
+    }
+
     public void testUnsupportedHttpMethod() {
         final boolean close = randomBoolean();
         final HttpRequest.HttpVersion httpVersion = close ? HttpRequest.HttpVersion.HTTP_1_0 : HttpRequest.HttpVersion.HTTP_1_1;

+ 22 - 0
server/src/test/java/org/elasticsearch/rest/RestControllerTests.java

@@ -8,6 +8,7 @@
 
 package org.elasticsearch.rest;
 
+import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.client.internal.node.NodeClient;
 import org.elasticsearch.common.breaker.CircuitBreaker;
 import org.elasticsearch.common.bytes.BytesArray;
@@ -24,6 +25,7 @@ import org.elasticsearch.common.util.MockPageCacheRecycler;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
 import org.elasticsearch.core.IOUtils;
 import org.elasticsearch.core.RestApiVersion;
+import org.elasticsearch.http.HttpHeadersValidationException;
 import org.elasticsearch.http.HttpInfo;
 import org.elasticsearch.http.HttpRequest;
 import org.elasticsearch.http.HttpResponse;
@@ -620,6 +622,26 @@ public class RestControllerTests extends ESTestCase {
         assertThat(channel.getRestResponse().content().utf8ToString(), containsString("unknown cause"));
     }
 
+    public void testDispatchBadRequestWithValidationException() {
+        final RestStatus status = randomFrom(RestStatus.values());
+        final Exception exception = new ElasticsearchStatusException("bad bad exception", status);
+        final FakeRestRequest fakeRestRequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).build();
+
+        // it's always a 400 bad request when dispatching "regular" {@code ElasticsearchException}
+        AssertingChannel channel = new AssertingChannel(fakeRestRequest, true, RestStatus.BAD_REQUEST);
+        assertFalse(channel.getSendResponseCalled());
+        restController.dispatchBadRequest(channel, client.threadPool().getThreadContext(), exception);
+        assertTrue(channel.getSendResponseCalled());
+        assertThat(channel.getRestResponse().content().utf8ToString(), containsString("bad bad exception"));
+
+        // but {@code HttpHeadersValidationException} do carry over the rest response code
+        channel = new AssertingChannel(fakeRestRequest, true, status);
+        assertFalse(channel.getSendResponseCalled());
+        restController.dispatchBadRequest(channel, client.threadPool().getThreadContext(), new HttpHeadersValidationException(exception));
+        assertTrue(channel.getSendResponseCalled());
+        assertThat(channel.getRestResponse().content().utf8ToString(), containsString("bad bad exception"));
+    }
+
     public void testFavicon() {
         final FakeRestRequest fakeRestRequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withMethod(GET)
             .withPath("/favicon.ico")

+ 45 - 0
server/src/test/java/org/elasticsearch/rest/RestResponseTests.java

@@ -9,6 +9,7 @@
 package org.elasticsearch.rest;
 
 import org.elasticsearch.ElasticsearchException;
+import org.elasticsearch.ElasticsearchSecurityException;
 import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.ExceptionsHelper;
 import org.elasticsearch.ResourceAlreadyExistsException;
@@ -26,6 +27,7 @@ import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.rest.FakeRestRequest;
 import org.elasticsearch.transport.RemoteTransportException;
 import org.elasticsearch.xcontent.MediaType;
+import org.elasticsearch.xcontent.NamedXContentRegistry;
 import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xcontent.XContentParser;
 import org.elasticsearch.xcontent.XContentType;
@@ -40,6 +42,7 @@ import static org.elasticsearch.ElasticsearchExceptionTests.assertDeepEquals;
 import static org.hamcrest.Matchers.contains;
 import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.is;
 import static org.hamcrest.Matchers.not;
 import static org.hamcrest.Matchers.notNullValue;
 
@@ -118,6 +121,48 @@ public class RestResponseTests extends ESTestCase {
             "stack_trace":"org.elasticsearch.ElasticsearchException$1: an error occurred reading data"""));
     }
 
+    public void testAuthenticationFailedNoStackTrace() throws IOException {
+        for (Exception authnException : List.of(
+            new ElasticsearchSecurityException("failed authn", RestStatus.UNAUTHORIZED),
+            new ElasticsearchSecurityException("failed authn", RestStatus.UNAUTHORIZED, new ElasticsearchException("cause"))
+        )) {
+            for (RestRequest request : List.of(
+                new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).build(),
+                new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(Map.of("error_trace", Boolean.toString(true))).build()
+            )) {
+                for (RestChannel channel : List.of(new SimpleExceptionRestChannel(request), new DetailedExceptionRestChannel(request))) {
+                    RestResponse response = new RestResponse(channel, authnException);
+                    assertThat(response.status(), is(RestStatus.UNAUTHORIZED));
+                    assertThat(response.content().utf8ToString(), not(containsString(ElasticsearchException.STACK_TRACE)));
+                }
+            }
+        }
+    }
+
+    public void testStackTrace() throws IOException {
+        for (Exception exception : List.of(new ElasticsearchException("dummy"), new IllegalArgumentException("dummy"))) {
+            for (RestRequest request : List.of(
+                new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).build(),
+                new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(Map.of("error_trace", Boolean.toString(true))).build()
+            )) {
+                for (RestChannel channel : List.of(new SimpleExceptionRestChannel(request), new DetailedExceptionRestChannel(request))) {
+                    RestResponse response = new RestResponse(channel, exception);
+                    if (exception instanceof ElasticsearchException) {
+                        assertThat(response.status(), is(RestStatus.INTERNAL_SERVER_ERROR));
+                    } else {
+                        assertThat(response.status(), is(RestStatus.BAD_REQUEST));
+                    }
+                    boolean traceExists = request.paramAsBoolean("error_trace", false) && channel.detailedErrorsEnabled();
+                    if (traceExists) {
+                        assertThat(response.content().utf8ToString(), containsString(ElasticsearchException.STACK_TRACE));
+                    } else {
+                        assertThat(response.content().utf8ToString(), not(containsString(ElasticsearchException.STACK_TRACE)));
+                    }
+                }
+            }
+        }
+    }
+
     public void testGuessRootCause() throws IOException {
         RestRequest request = new FakeRestRequest();
         {

+ 1 - 0
x-pack/plugin/security/src/main/java/module-info.java

@@ -40,6 +40,7 @@ module org.elasticsearch.security {
 
     requires com.nimbusds.jose.jwt;
     requires io.netty.common;
+    requires io.netty.codec.http;
     requires io.netty.handler;
     requires io.netty.transport;
     requires jopt.simple;

+ 69 - 21
x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java

@@ -6,9 +6,12 @@
  */
 package org.elasticsearch.xpack.security;
 
+import io.netty.channel.Channel;
+
 import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 import org.apache.lucene.util.SetOnce;
+import org.elasticsearch.ElasticsearchSecurityException;
 import org.elasticsearch.TransportVersion;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.ActionRequest;
@@ -49,11 +52,11 @@ import org.elasticsearch.core.Nullable;
 import org.elasticsearch.env.Environment;
 import org.elasticsearch.env.NodeEnvironment;
 import org.elasticsearch.env.NodeMetadata;
-import org.elasticsearch.http.HttpChannel;
 import org.elasticsearch.http.HttpPreRequest;
 import org.elasticsearch.http.HttpServerTransport;
-import org.elasticsearch.http.netty4.Netty4HttpHeaderValidator;
 import org.elasticsearch.http.netty4.Netty4HttpServerTransport;
+import org.elasticsearch.http.netty4.internal.HttpHeadersAuthenticatorUtils;
+import org.elasticsearch.http.netty4.internal.HttpValidator;
 import org.elasticsearch.index.IndexModule;
 import org.elasticsearch.indices.SystemIndexDescriptor;
 import org.elasticsearch.indices.breaker.CircuitBreakerService;
@@ -366,6 +369,7 @@ import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Locale;
 import java.util.Map;
+import java.util.Objects;
 import java.util.Set;
 import java.util.function.BiConsumer;
 import java.util.function.Function;
@@ -1623,7 +1627,7 @@ public class Security extends Plugin
             final boolean ssl = HTTP_SSL_ENABLED.get(settings);
             final SSLService sslService = getSslService();
             final SslConfiguration sslConfiguration;
-            final BiConsumer<HttpChannel, ThreadContext> populateClientCertificate;
+            final BiConsumer<Channel, ThreadContext> populateClientCertificate;
             if (ssl) {
                 sslConfiguration = sslService.getHttpTransportSSLConfiguration();
                 if (SSLService.isConfigurationValidForServerUsage(sslConfiguration) == false) {
@@ -1641,7 +1645,23 @@ public class Security extends Plugin
                 sslConfiguration = null;
                 populateClientCertificate = (channel, threadContext) -> {};
             }
-            return new Netty4HttpServerTransport(
+            final AuthenticationService authenticationService = this.authcService.get();
+            final ThreadContext threadContext = this.threadContext.get();
+            final HttpValidator httpValidator = (httpRequest, channel, listener) -> {
+                HttpPreRequest httpPreRequest = HttpHeadersAuthenticatorUtils.asHttpPreRequest(httpRequest);
+                // step 1: Populate the thread context with credentials and any other HTTP request header values (eg run-as) that the
+                // authentication process looks for while doing its duty.
+                perRequestThreadContext.accept(httpPreRequest, threadContext);
+                populateClientCertificate.accept(channel, threadContext);
+                RemoteHostHeader.process(channel, threadContext);
+                // step 2: Run authentication on the now properly prepared thread-context.
+                // This inspects and modifies the thread context.
+                authenticationService.authenticate(
+                    httpPreRequest,
+                    ActionListener.wrap(ignored -> listener.onResponse(null), listener::onFailure)
+                );
+            };
+            return getHttpServerTransportWithHeadersValidator(
                 settings,
                 networkService,
                 threadPool,
@@ -1652,29 +1672,57 @@ public class Security extends Plugin
                 tracer,
                 new TLSConfig(sslConfiguration, sslService::createSSLEngine),
                 acceptPredicate,
-                Netty4HttpHeaderValidator.NOOP_VALIDATOR
-            ) {
-                @Override
-                protected void populatePerRequestThreadContext(RestRequest restRequest, ThreadContext threadContext) {
-                    perRequestThreadContext.accept(restRequest.getHttpRequest(), threadContext);
-                    populateClientCertificate.accept(restRequest.getHttpChannel(), threadContext);
-                    RemoteHostHeader.process(restRequest, threadContext);
-                }
-            };
+                httpValidator
+            );
         });
         return httpTransports;
     }
 
+    // "public" so it can be used in tests
+    public static Netty4HttpServerTransport getHttpServerTransportWithHeadersValidator(
+        Settings settings,
+        NetworkService networkService,
+        ThreadPool threadPool,
+        NamedXContentRegistry xContentRegistry,
+        HttpServerTransport.Dispatcher dispatcher,
+        ClusterSettings clusterSettings,
+        SharedGroupFactory sharedGroupFactory,
+        Tracer tracer,
+        TLSConfig tlsConfig,
+        @Nullable AcceptChannelHandler.AcceptPredicate acceptPredicate,
+        HttpValidator httpValidator
+    ) {
+        return new Netty4HttpServerTransport(
+            settings,
+            networkService,
+            threadPool,
+            xContentRegistry,
+            dispatcher,
+            clusterSettings,
+            sharedGroupFactory,
+            tracer,
+            tlsConfig,
+            acceptPredicate,
+            Objects.requireNonNull(httpValidator)
+        ) {
+            @Override
+            protected void populatePerRequestThreadContext(RestRequest restRequest, ThreadContext threadContext) {
+                ThreadContext.StoredContext authenticationThreadContext = HttpHeadersAuthenticatorUtils.extractAuthenticationContext(
+                    restRequest.getHttpRequest()
+                );
+                if (authenticationThreadContext != null) {
+                    authenticationThreadContext.restore();
+                } else {
+                    // this is an unexpected internal error condition where {@code Netty4HttpHeaderValidator} does not work correctly
+                    throw new ElasticsearchSecurityException("Request is not authenticated");
+                }
+            }
+        };
+    }
+
     @Override
     public UnaryOperator<RestHandler> getRestHandlerInterceptor(ThreadContext threadContext) {
-        return handler -> new SecurityRestFilter(
-            enabled,
-            threadContext,
-            authcService.get(),
-            secondayAuthc.get(),
-            auditTrailService.get(),
-            handler
-        );
+        return handler -> new SecurityRestFilter(enabled, threadContext, secondayAuthc.get(), auditTrailService.get(), handler);
     }
 
     @Override

+ 4 - 0
x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/AuthenticationService.java

@@ -6,6 +6,8 @@
  */
 package org.elasticsearch.xpack.security.authc;
 
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
 import org.elasticsearch.ElasticsearchSecurityException;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.common.cache.Cache;
@@ -66,6 +68,8 @@ public class AuthenticationService {
         Property.NodeScope
     );
 
+    private static final Logger logger = LogManager.getLogger(AuthenticationService.class);
+
     private final Realms realms;
     private final AuditTrailService auditTrailService;
     private final AuthenticationFailureHandler failureHandler;

+ 1 - 0
x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/AuthenticatorChain.java

@@ -73,6 +73,7 @@ class AuthenticatorChain {
         // Check whether authentication is an operator user and mark the threadContext if necessary
         // before returning the authentication object
         final ActionListener<Authentication> listener = originalListener.map(authentication -> {
+            assert authentication != null;
             operatorPrivilegesService.maybeMarkOperatorUser(authentication, context.getThreadContext());
             return authentication;
         });

+ 8 - 11
x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/RemoteHostHeader.java

@@ -6,8 +6,9 @@
  */
 package org.elasticsearch.xpack.security.rest;
 
+import io.netty.channel.Channel;
+
 import org.elasticsearch.common.util.concurrent.ThreadContext;
-import org.elasticsearch.rest.RestRequest;
 
 import java.net.InetSocketAddress;
 import java.net.SocketAddress;
@@ -17,16 +18,16 @@ public class RemoteHostHeader {
     static final String KEY = "_rest_remote_address";
 
     /**
-     * Extracts the remote address from the given rest request and puts in the request context. This will
-     * then be copied to the subsequent action requests.
+     * Extracts the remote address from the given netty channel and puts it in the request context. This will
+     * then be copied to the subsequent action handler contexts.
      */
-    public static void process(RestRequest request, ThreadContext threadContext) {
-        threadContext.putTransient(KEY, request.getHttpChannel().getRemoteAddress());
+    public static void process(Channel channel, ThreadContext threadContext) {
+        threadContext.putTransient(KEY, channel.remoteAddress());
     }
 
     /**
-     * Extracts the rest remote address from the message context. If not found, returns {@code null}. transport
-     * messages that were created by rest handlers, should have this in their context.
+     * Extracts the rest remote address from the message context. If not found, returns {@code null}.
+     * Transport messages that were created by rest handlers should have this in their context.
      */
     public static InetSocketAddress restRemoteAddress(ThreadContext threadContext) {
         SocketAddress address = threadContext.getTransient(KEY);
@@ -35,8 +36,4 @@ public class RemoteHostHeader {
         }
         return null;
     }
-
-    public static void putRestRemoteAddress(ThreadContext threadContext, SocketAddress address) {
-        threadContext.putTransient(KEY, address);
-    }
 }

+ 5 - 16
x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/SecurityRestFilter.java

@@ -19,7 +19,6 @@ import org.elasticsearch.rest.RestRequest.Method;
 import org.elasticsearch.rest.RestRequestFilter;
 import org.elasticsearch.rest.RestResponse;
 import org.elasticsearch.xpack.security.audit.AuditTrailService;
-import org.elasticsearch.xpack.security.authc.AuthenticationService;
 import org.elasticsearch.xpack.security.authc.support.SecondaryAuthenticator;
 
 import java.util.List;
@@ -31,7 +30,6 @@ public class SecurityRestFilter implements RestHandler {
     private static final Logger logger = LogManager.getLogger(SecurityRestFilter.class);
 
     private final RestHandler restHandler;
-    private final AuthenticationService authenticationService;
     private final SecondaryAuthenticator secondaryAuthenticator;
     private final AuditTrailService auditTrailService;
     private final boolean enabled;
@@ -40,14 +38,12 @@ public class SecurityRestFilter implements RestHandler {
     public SecurityRestFilter(
         boolean enabled,
         ThreadContext threadContext,
-        AuthenticationService authenticationService,
         SecondaryAuthenticator secondaryAuthenticator,
         AuditTrailService auditTrailService,
         RestHandler restHandler
     ) {
         this.enabled = enabled;
         this.threadContext = threadContext;
-        this.authenticationService = authenticationService;
         this.secondaryAuthenticator = secondaryAuthenticator;
         this.auditTrailService = auditTrailService;
         this.restHandler = restHandler;
@@ -76,19 +72,12 @@ public class SecurityRestFilter implements RestHandler {
         }
 
         final RestRequest wrappedRequest = maybeWrapRestRequest(request);
-        authenticationService.authenticate(wrappedRequest.getHttpRequest(), ActionListener.wrap(authentication -> {
-            if (authentication == null) {
-                logger.trace("No authentication available for REST request [{}]", request.uri());
-            } else {
-                logger.trace("Authenticated REST request [{}] as {}", request.uri(), authentication);
+        auditTrailService.get().authenticationSuccess(wrappedRequest);
+        secondaryAuthenticator.authenticateAndAttachToContext(wrappedRequest, ActionListener.wrap(secondaryAuthentication -> {
+            if (secondaryAuthentication != null) {
+                logger.trace("Found secondary authentication {} in REST request [{}]", secondaryAuthentication, request.uri());
             }
-            auditTrailService.get().authenticationSuccess(wrappedRequest);
-            secondaryAuthenticator.authenticateAndAttachToContext(wrappedRequest, ActionListener.wrap(secondaryAuthentication -> {
-                if (secondaryAuthentication != null) {
-                    logger.trace("Found secondary authentication {} in REST request [{}]", secondaryAuthentication, request.uri());
-                }
-                doHandleRequest(request, channel, client);
-            }, e -> handleException(request, channel, e)));
+            doHandleRequest(request, channel, client);
         }, e -> handleException(request, channel, e)));
     }
 

+ 7 - 14
x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/SSLEngineUtils.java

@@ -13,8 +13,6 @@ import io.netty.handler.ssl.SslHandler;
 import org.apache.logging.log4j.Logger;
 import org.apache.logging.log4j.util.Supplier;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
-import org.elasticsearch.http.HttpChannel;
-import org.elasticsearch.http.netty4.Netty4HttpChannel;
 import org.elasticsearch.transport.TcpChannel;
 import org.elasticsearch.transport.netty4.Netty4TcpChannel;
 import org.elasticsearch.xpack.security.authc.pki.PkiRealm;
@@ -29,9 +27,9 @@ public class SSLEngineUtils {
 
     private SSLEngineUtils() {}
 
-    public static void extractClientCertificates(Logger logger, ThreadContext threadContext, HttpChannel httpChannel) {
-        SSLEngine sslEngine = getSSLEngine(httpChannel);
-        extract(logger, threadContext, sslEngine, httpChannel);
+    public static void extractClientCertificates(Logger logger, ThreadContext threadContext, Channel channel) {
+        SSLEngine sslEngine = getSSLEngine(channel);
+        extract(logger, threadContext, sslEngine, channel);
     }
 
     public static void extractClientCertificates(Logger logger, ThreadContext threadContext, TcpChannel tcpChannel) {
@@ -39,15 +37,10 @@ public class SSLEngineUtils {
         extract(logger, threadContext, sslEngine, tcpChannel);
     }
 
-    public static SSLEngine getSSLEngine(HttpChannel httpChannel) {
-        if (httpChannel instanceof Netty4HttpChannel) {
-            Channel nettyChannel = ((Netty4HttpChannel) httpChannel).getNettyChannel();
-            SslHandler handler = nettyChannel.pipeline().get(SslHandler.class);
-            assert handler != null : "Must have SslHandler";
-            return handler.engine();
-        } else {
-            throw new AssertionError("Unknown channel class type: " + httpChannel.getClass());
-        }
+    public static SSLEngine getSSLEngine(Channel channel) {
+        SslHandler handler = channel.pipeline().get(SslHandler.class);
+        assert handler != null : "Must have SslHandler";
+        return handler.engine();
     }
 
     public static SSLEngine getSSLEngine(TcpChannel tcpChannel) {

+ 5 - 1
x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/audit/logfile/LoggingAuditTrailFilterTests.java

@@ -6,6 +6,8 @@
  */
 package org.elasticsearch.xpack.security.audit.logfile;
 
+import io.netty.channel.Channel;
+
 import org.apache.logging.log4j.Level;
 import org.apache.logging.log4j.Logger;
 import org.elasticsearch.TransportVersion;
@@ -3009,7 +3011,9 @@ public class LoggingAuditTrailFilterTests extends ESTestCase {
                 remoteAddress(buildNewFakeTransportAddress().address());
             }
             if (randomBoolean()) {
-                RemoteHostHeader.putRestRemoteAddress(threadContext, new InetSocketAddress(forge("localhost", "127.0.0.1"), 1234));
+                Channel mockChannel = mock(Channel.class);
+                when(mockChannel.remoteAddress()).thenReturn(new InetSocketAddress(forge("localhost", "127.0.0.1"), 1234));
+                RemoteHostHeader.process(mockChannel, threadContext);
             }
         }
 

+ 33 - 31
x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/audit/logfile/LoggingAuditTrailTests.java

@@ -6,6 +6,8 @@
  */
 package org.elasticsearch.xpack.security.audit.logfile;
 
+import io.netty.channel.Channel;
+
 import org.apache.logging.log4j.Level;
 import org.apache.logging.log4j.Logger;
 import org.apache.logging.log4j.core.layout.PatternLayout;
@@ -1551,10 +1553,9 @@ public class LoggingAuditTrailTests extends ESTestCase {
             forge("_hostname", randomBoolean() ? "127.0.0.1" : "::1"),
             randomIntBetween(9200, 9300)
         );
-        final Tuple<RestContent, RestRequest> tuple = prepareRestContent("_uri", address);
-        final String expectedMessage = tuple.v1().expectedMessage();
+        final Tuple<Channel, RestRequest> tuple = prepareRestContent("_uri", address);
         final RestRequest request = tuple.v2();
-        RemoteHostHeader.process(request, threadContext);
+        RemoteHostHeader.process(tuple.v1(), threadContext);
 
         final String requestId = randomRequestId();
         auditTrail.anonymousAccessDenied(requestId, request.getHttpRequest());
@@ -1650,11 +1651,10 @@ public class LoggingAuditTrailTests extends ESTestCase {
             forge("_hostname", randomBoolean() ? "127.0.0.1" : "::1"),
             randomIntBetween(9200, 9300)
         );
-        final Tuple<RestContent, RestRequest> tuple = prepareRestContent("_uri", address, params);
-        final String expectedMessage = tuple.v1().expectedMessage();
+        final Tuple<Channel, RestRequest> tuple = prepareRestContent("_uri", address, params);
         final RestRequest request = tuple.v2();
         final AuthenticationToken authToken = createAuthenticationToken();
-        RemoteHostHeader.process(request, threadContext);
+        RemoteHostHeader.process(tuple.v1(), threadContext);
 
         final String requestId = randomRequestId();
         auditTrail.authenticationFailed(requestId, authToken, request.getHttpRequest());
@@ -1696,10 +1696,9 @@ public class LoggingAuditTrailTests extends ESTestCase {
             forge("_hostname", randomBoolean() ? "127.0.0.1" : "::1"),
             randomIntBetween(9200, 9300)
         );
-        final Tuple<RestContent, RestRequest> tuple = prepareRestContent("_uri", address, params);
-        final String expectedMessage = tuple.v1().expectedMessage();
+        final Tuple<Channel, RestRequest> tuple = prepareRestContent("_uri", address, params);
         final RestRequest request = tuple.v2();
-        RemoteHostHeader.process(request, threadContext);
+        RemoteHostHeader.process(tuple.v1(), threadContext);
 
         final String requestId = randomRequestId();
         auditTrail.authenticationFailed(requestId, request.getHttpRequest());
@@ -1767,12 +1766,11 @@ public class LoggingAuditTrailTests extends ESTestCase {
             forge("_hostname", randomBoolean() ? "127.0.0.1" : "::1"),
             randomIntBetween(9200, 9300)
         );
-        final Tuple<RestContent, RestRequest> tuple = prepareRestContent("_uri", address, params);
-        final String expectedMessage = tuple.v1().expectedMessage();
+        final Tuple<Channel, RestRequest> tuple = prepareRestContent("_uri", address, params);
         final RestRequest request = tuple.v2();
         final AuthenticationToken authToken = mockToken();
         final String realm = randomAlphaOfLengthBetween(1, 6);
-        RemoteHostHeader.process(request, threadContext);
+        RemoteHostHeader.process(tuple.v1(), threadContext);
         final String requestId = randomRequestId();
         auditTrail.authenticationFailed(requestId, realm, authToken, request.getHttpRequest());
         assertEmptyLog(logger);
@@ -2153,10 +2151,9 @@ public class LoggingAuditTrailTests extends ESTestCase {
             forge("_hostname", randomBoolean() ? "127.0.0.1" : "::1"),
             randomIntBetween(9200, 9300)
         );
-        final Tuple<RestContent, RestRequest> tuple = prepareRestContent("_uri", address, params);
-        final String expectedMessage = tuple.v1().expectedMessage();
+        final Tuple<Channel, RestRequest> tuple = prepareRestContent("_uri", address, params);
         final RestRequest request = tuple.v2();
-        RemoteHostHeader.process(request, threadContext);
+        RemoteHostHeader.process(tuple.v1(), threadContext);
         final String requestId = randomRequestId();
         auditTrail.tamperedRequest(requestId, request.getHttpRequest());
         final MapBuilder<String, String> checkedFields = new MapBuilder<>(commonFields);
@@ -2431,14 +2428,13 @@ public class LoggingAuditTrailTests extends ESTestCase {
             forge("_hostname", randomBoolean() ? "127.0.0.1" : "::1"),
             randomIntBetween(9200, 9300)
         );
-        final Tuple<RestContent, RestRequest> tuple = prepareRestContent("_uri", address, params);
-        final String expectedMessage = tuple.v1().expectedMessage();
+        final Tuple<Channel, RestRequest> tuple = prepareRestContent("_uri", address, params);
         final RestRequest request = tuple.v2();
         String requestId = AuditUtil.generateRequestId(threadContext);
         MapBuilder<String, String> checkedFields = new MapBuilder<>(commonFields);
         Authentication authentication = createAuthentication();
         authentication.writeToContext(threadContext);
-        RemoteHostHeader.process(request, threadContext);
+        RemoteHostHeader.process(tuple.v1(), threadContext);
 
         // event by default disabled
         auditTrail.authenticationSuccess(request);
@@ -2456,8 +2452,8 @@ public class LoggingAuditTrailTests extends ESTestCase {
             .put(LoggingAuditTrail.REQUEST_METHOD_FIELD_NAME, request.method().toString())
             .put(LoggingAuditTrail.REQUEST_ID_FIELD_NAME, requestId)
             .put(LoggingAuditTrail.URL_PATH_FIELD_NAME, "_uri");
-        if (includeRequestBody && Strings.hasLength(expectedMessage)) {
-            checkedFields.put(LoggingAuditTrail.REQUEST_BODY_FIELD_NAME, expectedMessage);
+        if (includeRequestBody && Strings.hasLength(request.content())) {
+            checkedFields.put(LoggingAuditTrail.REQUEST_BODY_FIELD_NAME, request.content().utf8ToString());
         }
         if (params.isEmpty() == false) {
             checkedFields.put(LoggingAuditTrail.URL_QUERY_FIELD_NAME, "foo=bar&evac=true");
@@ -2475,7 +2471,7 @@ public class LoggingAuditTrailTests extends ESTestCase {
         authentication = createApiKeyAuthenticationAndMaybeWithRunAs(authentication);
         authentication.writeToContext(threadContext);
         checkedFields = new MapBuilder<>(commonFields);
-        RemoteHostHeader.process(request, threadContext);
+        RemoteHostHeader.process(tuple.v1(), threadContext);
         auditTrail.authenticationSuccess(request);
         checkedFields.put(LoggingAuditTrail.EVENT_TYPE_FIELD_NAME, LoggingAuditTrail.REST_ORIGIN_FIELD_VALUE)
             .put(LoggingAuditTrail.EVENT_ACTION_FIELD_NAME, "authentication_success")
@@ -2485,8 +2481,8 @@ public class LoggingAuditTrailTests extends ESTestCase {
             .put(LoggingAuditTrail.REQUEST_METHOD_FIELD_NAME, request.method().toString())
             .put(LoggingAuditTrail.REQUEST_ID_FIELD_NAME, requestId)
             .put(LoggingAuditTrail.URL_PATH_FIELD_NAME, "_uri");
-        if (includeRequestBody && Strings.hasLength(expectedMessage)) {
-            checkedFields.put(LoggingAuditTrail.REQUEST_BODY_FIELD_NAME, expectedMessage);
+        if (includeRequestBody && Strings.hasLength(request.content())) {
+            checkedFields.put(LoggingAuditTrail.REQUEST_BODY_FIELD_NAME, request.getHttpRequest().content().utf8ToString());
         }
         if (params.isEmpty() == false) {
             checkedFields.put(LoggingAuditTrail.URL_QUERY_FIELD_NAME, "foo=bar&evac=true");
@@ -2504,7 +2500,7 @@ public class LoggingAuditTrailTests extends ESTestCase {
         authentication = AuthenticationTestHelper.builder().realm().build(false).runAs(new User(randomAlphaOfLengthBetween(3, 8)), null);
         authentication.writeToContext(threadContext);
         checkedFields = new MapBuilder<>(commonFields);
-        RemoteHostHeader.process(request, threadContext);
+        RemoteHostHeader.process(tuple.v1(), threadContext);
         auditTrail.authenticationSuccess(request);
         checkedFields.put(LoggingAuditTrail.EVENT_TYPE_FIELD_NAME, LoggingAuditTrail.REST_ORIGIN_FIELD_VALUE)
             .put(LoggingAuditTrail.EVENT_ACTION_FIELD_NAME, "authentication_success")
@@ -2514,8 +2510,8 @@ public class LoggingAuditTrailTests extends ESTestCase {
             .put(LoggingAuditTrail.REQUEST_METHOD_FIELD_NAME, request.method().toString())
             .put(LoggingAuditTrail.REQUEST_ID_FIELD_NAME, requestId)
             .put(LoggingAuditTrail.URL_PATH_FIELD_NAME, "_uri");
-        if (includeRequestBody && Strings.hasLength(expectedMessage)) {
-            checkedFields.put(LoggingAuditTrail.REQUEST_BODY_FIELD_NAME, expectedMessage);
+        if (includeRequestBody && Strings.hasLength(request.content().utf8ToString())) {
+            checkedFields.put(LoggingAuditTrail.REQUEST_BODY_FIELD_NAME, request.content().utf8ToString());
         }
         if (params.isEmpty() == false) {
             checkedFields.put(LoggingAuditTrail.URL_QUERY_FIELD_NAME, "foo=bar&evac=true");
@@ -2817,11 +2813,11 @@ public class LoggingAuditTrailTests extends ESTestCase {
         assertThat("Logger is not empty", CapturingLogger.isEmpty(logger.getName()), is(true));
     }
 
-    protected Tuple<RestContent, RestRequest> prepareRestContent(String uri, InetSocketAddress remoteAddress) {
+    protected Tuple<Channel, RestRequest> prepareRestContent(String uri, InetSocketAddress remoteAddress) {
         return prepareRestContent(uri, remoteAddress, Collections.emptyMap());
     }
 
-    private Tuple<RestContent, RestRequest> prepareRestContent(String uri, InetSocketAddress remoteAddress, Map<String, String> params) {
+    private Tuple<Channel, RestRequest> prepareRestContent(String uri, InetSocketAddress remoteAddress, Map<String, String> params) {
         final RestContent content = randomFrom(RestContent.values());
         final FakeRestRequest.Builder builder = new Builder(NamedXContentRegistry.EMPTY);
         if (content.hasContent()) {
@@ -2842,7 +2838,9 @@ public class LoggingAuditTrailTests extends ESTestCase {
         builder.withRemoteAddress(remoteAddress);
         builder.withParams(params);
         builder.withMethod(randomFrom(RestRequest.Method.values()));
-        return new Tuple<>(content, builder.build());
+        Channel channel = mock(Channel.class);
+        when(channel.remoteAddress()).thenReturn(remoteAddress);
+        return new Tuple<>(channel, builder.build());
     }
 
     /** creates address without any lookups. hostname can be null, for missing */
@@ -2912,7 +2910,9 @@ public class LoggingAuditTrailTests extends ESTestCase {
                 }
             }
             if (randomBoolean()) {
-                RemoteHostHeader.putRestRemoteAddress(threadContext, new InetSocketAddress(forge("localhost", "127.0.0.1"), 1234));
+                Channel mockChannel = mock(Channel.class);
+                when(mockChannel.remoteAddress()).thenReturn(new InetSocketAddress(forge("localhost", "127.0.0.1"), 1234));
+                RemoteHostHeader.process(mockChannel, threadContext);
             }
         }
 
@@ -2931,7 +2931,9 @@ public class LoggingAuditTrailTests extends ESTestCase {
                 remoteAddress(buildNewFakeTransportAddress().address());
             }
             if (randomBoolean()) {
-                RemoteHostHeader.putRestRemoteAddress(threadContext, new InetSocketAddress(forge("localhost", "127.0.0.1"), 1234));
+                Channel mockChannel = mock(Channel.class);
+                when(mockChannel.remoteAddress()).thenReturn(new InetSocketAddress(forge("localhost", "127.0.0.1"), 1234));
+                RemoteHostHeader.process(mockChannel, threadContext);
             }
         }
 

+ 2 - 125
x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/SecurityRestFilterTests.java

@@ -9,14 +9,12 @@ package org.elasticsearch.xpack.security.rest;
 import com.nimbusds.jose.util.StandardCharset;
 
 import org.apache.lucene.util.SetOnce;
-import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.client.internal.node.NodeClient;
 import org.elasticsearch.common.bytes.BytesArray;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
 import org.elasticsearch.http.HttpChannel;
-import org.elasticsearch.http.HttpPreRequest;
 import org.elasticsearch.http.HttpRequest;
 import org.elasticsearch.license.TestUtils;
 import org.elasticsearch.license.XPackLicenseState;
@@ -24,18 +22,14 @@ import org.elasticsearch.rest.RestChannel;
 import org.elasticsearch.rest.RestHandler;
 import org.elasticsearch.rest.RestRequest;
 import org.elasticsearch.rest.RestRequestFilter;
-import org.elasticsearch.rest.RestResponse;
-import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.SecuritySettingsSourceField;
 import org.elasticsearch.test.rest.FakeRestRequest;
 import org.elasticsearch.xcontent.DeprecationHandler;
 import org.elasticsearch.xcontent.NamedXContentRegistry;
 import org.elasticsearch.xcontent.XContentType;
-import org.elasticsearch.xcontent.json.JsonXContent;
 import org.elasticsearch.xpack.core.security.SecurityContext;
 import org.elasticsearch.xpack.core.security.authc.Authentication;
-import org.elasticsearch.xpack.core.security.authc.Authentication.RealmRef;
 import org.elasticsearch.xpack.core.security.authc.AuthenticationTestHelper;
 import org.elasticsearch.xpack.core.security.authc.support.SecondaryAuthentication;
 import org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken;
@@ -44,7 +38,6 @@ import org.elasticsearch.xpack.security.audit.AuditTrailService;
 import org.elasticsearch.xpack.security.authc.AuthenticationService;
 import org.elasticsearch.xpack.security.authc.support.SecondaryAuthenticator;
 import org.junit.Before;
-import org.mockito.ArgumentCaptor;
 
 import java.util.Base64;
 import java.util.Collections;
@@ -54,10 +47,7 @@ import java.util.Set;
 import java.util.concurrent.atomic.AtomicReference;
 
 import static org.elasticsearch.test.ActionListenerUtils.anyActionListener;
-import static org.elasticsearch.xpack.core.security.support.Exceptions.authenticationError;
-import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.hasItem;
-import static org.hamcrest.Matchers.is;
 import static org.hamcrest.Matchers.not;
 import static org.hamcrest.Matchers.notNullValue;
 import static org.hamcrest.Matchers.sameInstance;
@@ -85,14 +75,7 @@ public class SecurityRestFilterTests extends ESTestCase {
         restHandler = mock(RestHandler.class);
         threadContext = new ThreadContext(Settings.EMPTY);
         secondaryAuthenticator = new SecondaryAuthenticator(Settings.EMPTY, threadContext, authcService, new AuditTrailService(null, null));
-        filter = new SecurityRestFilter(
-            true,
-            threadContext,
-            authcService,
-            secondaryAuthenticator,
-            new AuditTrailService(null, null),
-            restHandler
-        );
+        filter = new SecurityRestFilter(true, threadContext, secondaryAuthenticator, new AuditTrailService(null, null), restHandler);
     }
 
     public void testProcess() throws Exception {
@@ -158,109 +141,13 @@ public class SecurityRestFilterTests extends ESTestCase {
     }
 
     public void testProcessWithSecurityDisabled() throws Exception {
-        filter = new SecurityRestFilter(
-            false,
-            threadContext,
-            authcService,
-            secondaryAuthenticator,
-            mock(AuditTrailService.class),
-            restHandler
-        );
+        filter = new SecurityRestFilter(false, threadContext, secondaryAuthenticator, mock(AuditTrailService.class), restHandler);
         RestRequest request = mock(RestRequest.class);
         filter.handleRequest(request, channel, null);
         verify(restHandler).handleRequest(request, channel, null);
         verifyNoMoreInteractions(channel, authcService);
     }
 
-    public void testProcessAuthenticationFailedNoTrace() throws Exception {
-        filter = new SecurityRestFilter(
-            true,
-            threadContext,
-            authcService,
-            secondaryAuthenticator,
-            mock(AuditTrailService.class),
-            restHandler
-        );
-        testProcessAuthenticationFailed(
-            randomBoolean()
-                ? authenticationError("failed authn")
-                : authenticationError("failed authn with " + "cause", new ElasticsearchException("cause")),
-            RestStatus.UNAUTHORIZED,
-            true,
-            true,
-            false
-        );
-        testProcessAuthenticationFailed(
-            randomBoolean()
-                ? authenticationError("failed authn")
-                : authenticationError("failed authn with " + "cause", new ElasticsearchException("cause")),
-            RestStatus.UNAUTHORIZED,
-            true,
-            false,
-            false
-        );
-        testProcessAuthenticationFailed(
-            randomBoolean()
-                ? authenticationError("failed authn")
-                : authenticationError("failed authn with " + "cause", new ElasticsearchException("cause")),
-            RestStatus.UNAUTHORIZED,
-            false,
-            true,
-            false
-        );
-        testProcessAuthenticationFailed(
-            randomBoolean()
-                ? authenticationError("failed authn")
-                : authenticationError("failed authn with " + "cause", new ElasticsearchException("cause")),
-            RestStatus.UNAUTHORIZED,
-            false,
-            false,
-            false
-        );
-        testProcessAuthenticationFailed(new ElasticsearchException("dummy"), RestStatus.INTERNAL_SERVER_ERROR, false, false, false);
-        testProcessAuthenticationFailed(new IllegalArgumentException("dummy"), RestStatus.BAD_REQUEST, true, false, false);
-        testProcessAuthenticationFailed(new ElasticsearchException("dummy"), RestStatus.INTERNAL_SERVER_ERROR, false, true, false);
-        testProcessAuthenticationFailed(new IllegalArgumentException("dummy"), RestStatus.BAD_REQUEST, true, true, true);
-    }
-
-    private void testProcessAuthenticationFailed(
-        Exception authnException,
-        RestStatus expectedRestStatus,
-        boolean errorTrace,
-        boolean detailedErrorsEnabled,
-        boolean traceExists
-    ) throws Exception {
-        RestRequest request;
-        if (errorTrace != ElasticsearchException.REST_EXCEPTION_SKIP_STACK_TRACE_DEFAULT == false || randomBoolean()) {
-            request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(
-                Map.of("error_trace", Boolean.toString(errorTrace))
-            ).build();
-        } else {
-            // sometimes do not fill in the default value
-            request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).build();
-        }
-        doAnswer((i) -> {
-            ActionListener<?> callback = (ActionListener<?>) i.getArguments()[1];
-            callback.onFailure(authnException);
-            return Void.TYPE;
-        }).when(authcService).authenticate(eq(request.getHttpRequest()), anyActionListener());
-        RestChannel channel = mock(RestChannel.class);
-        when(channel.detailedErrorsEnabled()).thenReturn(detailedErrorsEnabled);
-        when(channel.request()).thenReturn(request);
-        when(channel.newErrorBuilder()).thenReturn(JsonXContent.contentBuilder());
-        filter.handleRequest(request, channel, null);
-        ArgumentCaptor<RestResponse> response = ArgumentCaptor.forClass(RestResponse.class);
-        verify(channel).sendResponse(response.capture());
-        RestResponse restResponse = response.getValue();
-        assertThat(restResponse.status(), is(expectedRestStatus));
-        if (traceExists) {
-            assertThat(restResponse.content().utf8ToString(), containsString(ElasticsearchException.STACK_TRACE));
-        } else {
-            assertThat(restResponse.content().utf8ToString(), not(containsString(ElasticsearchException.STACK_TRACE)));
-        }
-        verifyNoMoreInteractions(restHandler);
-    }
-
     public void testProcessOptionsMethod() throws Exception {
         RestRequest request = mock(RestRequest.class);
         when(request.method()).thenReturn(RestRequest.Method.OPTIONS);
@@ -288,14 +175,6 @@ public class SecurityRestFilterTests extends ESTestCase {
                 return Collections.singleton("password");
             }
         };
-        SetOnce<HttpPreRequest> authcServiceRequest = new SetOnce<>();
-        doAnswer((i) -> {
-            @SuppressWarnings("unchecked")
-            ActionListener<Authentication> callback = (ActionListener<Authentication>) i.getArguments()[1];
-            authcServiceRequest.set((HttpPreRequest) i.getArguments()[0]);
-            callback.onResponse(AuthenticationTestHelper.builder().realmRef(new RealmRef("test", "test", "t")).build(false));
-            return Void.TYPE;
-        }).when(authcService).authenticate(any(HttpRequest.class), anyActionListener());
         AuditTrail auditTrail = mock(AuditTrail.class);
         XPackLicenseState licenseState = TestUtils.newTestLicenseState();
         SetOnce<RestRequest> auditTrailRequest = new SetOnce<>();
@@ -306,7 +185,6 @@ public class SecurityRestFilterTests extends ESTestCase {
         filter = new SecurityRestFilter(
             true,
             threadContext,
-            authcService,
             secondaryAuthenticator,
             new AuditTrailService(auditTrail, licenseState),
             restHandler
@@ -327,7 +205,6 @@ public class SecurityRestFilterTests extends ESTestCase {
         assertEquals(SecuritySettingsSourceField.TEST_PASSWORD, original.get("password"));
         assertEquals("bar", original.get("foo"));
 
-        assertEquals(restRequest.getHttpRequest(), authcServiceRequest.get());
         assertNotEquals(restRequest, auditTrailRequest.get());
         assertNotEquals(restRequest.content(), auditTrailRequest.get().content());
 

+ 0 - 154
x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/SecurityRestFilterWarningHeadersTests.java

@@ -1,154 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License
- * 2.0; you may not use this file except in compliance with the Elastic License
- * 2.0.
- */
-
-package org.elasticsearch.xpack.security.rest;
-
-import org.elasticsearch.ElasticsearchStatusException;
-import org.elasticsearch.action.ActionListener;
-import org.elasticsearch.common.collect.MapBuilder;
-import org.elasticsearch.common.settings.Settings;
-import org.elasticsearch.common.util.concurrent.ThreadContext;
-import org.elasticsearch.rest.RestChannel;
-import org.elasticsearch.rest.RestHandler;
-import org.elasticsearch.rest.RestRequest;
-import org.elasticsearch.rest.RestResponse;
-import org.elasticsearch.rest.RestStatus;
-import org.elasticsearch.test.ESTestCase;
-import org.elasticsearch.test.rest.FakeRestRequest;
-import org.elasticsearch.xcontent.NamedXContentRegistry;
-import org.elasticsearch.xcontent.json.JsonXContent;
-import org.elasticsearch.xpack.core.security.authc.Authentication;
-import org.elasticsearch.xpack.core.security.authc.AuthenticationTestHelper;
-import org.elasticsearch.xpack.security.audit.AuditTrailService;
-import org.elasticsearch.xpack.security.authc.AuthenticationService;
-import org.elasticsearch.xpack.security.authc.support.SecondaryAuthenticator;
-import org.junit.Before;
-import org.mockito.ArgumentCaptor;
-
-import java.util.Collections;
-import java.util.List;
-import java.util.Map;
-
-import static org.elasticsearch.test.ActionListenerUtils.anyActionListener;
-import static org.mockito.ArgumentMatchers.eq;
-import static org.mockito.Mockito.doAnswer;
-import static org.mockito.Mockito.doThrow;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.when;
-
-public class SecurityRestFilterWarningHeadersTests extends ESTestCase {
-    private ThreadContext threadContext;
-    private AuthenticationService authcService;
-    private SecondaryAuthenticator secondaryAuthenticator;
-    private RestHandler restHandler;
-
-    @Override
-    protected boolean enableWarningsCheck() {
-        return false;
-    }
-
-    @Before
-    public void init() throws Exception {
-        authcService = mock(AuthenticationService.class);
-        restHandler = mock(RestHandler.class);
-        threadContext = new ThreadContext(Settings.EMPTY);
-        secondaryAuthenticator = new SecondaryAuthenticator(Settings.EMPTY, threadContext, authcService, new AuditTrailService(null, null));
-    }
-
-    public void testResponseHeadersOnFailure() throws Exception {
-        MapBuilder<String, List<String>> headers = new MapBuilder<>();
-        headers.put("Warning", Collections.singletonList("Some warning header"));
-        headers.put("X-elastic-product", Collections.singletonList("Some product header"));
-        Map<String, List<String>> afterHeaders;
-
-        // only remove the response headers for 401 and 403
-        afterHeaders = testProcessAuthenticationFailed(RestStatus.BAD_REQUEST, headers);
-        assertEquals(afterHeaders.size(), 2);
-        afterHeaders = testProcessAuthenticationFailed(RestStatus.INTERNAL_SERVER_ERROR, headers);
-        assertEquals(afterHeaders.size(), 2);
-        afterHeaders = testProcessAuthenticationFailed(RestStatus.UNAUTHORIZED, headers);
-        assertEquals(afterHeaders.size(), 0);
-        afterHeaders = testProcessAuthenticationFailed(RestStatus.FORBIDDEN, headers);
-        assertEquals(afterHeaders.size(), 0);
-
-        // only remove the response headers for 401 and 403
-        afterHeaders = testProcessRestHandlingFailed(RestStatus.BAD_REQUEST, headers);
-        assertEquals(afterHeaders.size(), 2);
-        afterHeaders = testProcessRestHandlingFailed(RestStatus.INTERNAL_SERVER_ERROR, headers);
-        assertEquals(afterHeaders.size(), 2);
-        afterHeaders = testProcessRestHandlingFailed(RestStatus.UNAUTHORIZED, headers);
-        assertEquals(afterHeaders.size(), 0);
-        afterHeaders = testProcessRestHandlingFailed(RestStatus.FORBIDDEN, headers);
-        assertEquals(afterHeaders.size(), 0);
-    }
-
-    private Map<String, List<String>> testProcessRestHandlingFailed(RestStatus restStatus, MapBuilder<String, List<String>> headers)
-        throws Exception {
-        RestChannel channel = mock(RestChannel.class);
-        SecurityRestFilter filter = new SecurityRestFilter(
-            true,
-            threadContext,
-            authcService,
-            secondaryAuthenticator,
-            new AuditTrailService(null, null),
-            restHandler
-        );
-        RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).build();
-        Authentication primaryAuthentication = AuthenticationTestHelper.builder().build();
-        doAnswer(i -> {
-            final Object[] arguments = i.getArguments();
-            @SuppressWarnings("unchecked")
-            ActionListener<Authentication> callback = (ActionListener<Authentication>) arguments[arguments.length - 1];
-            callback.onResponse(primaryAuthentication);
-            return null;
-        }).when(authcService).authenticate(eq(request.getHttpRequest()), anyActionListener());
-        Authentication secondaryAuthentication = AuthenticationTestHelper.builder().build();
-        doAnswer(i -> {
-            final Object[] arguments = i.getArguments();
-            @SuppressWarnings("unchecked")
-            ActionListener<Authentication> callback = (ActionListener<Authentication>) arguments[arguments.length - 1];
-            callback.onResponse(secondaryAuthentication);
-            return null;
-        }).when(authcService).authenticate(eq(request.getHttpRequest()), eq(false), anyActionListener());
-        doThrow(new ElasticsearchStatusException("Rest handling failed", restStatus, "")).when(restHandler)
-            .handleRequest(request, channel, null);
-        when(channel.request()).thenReturn(request);
-        when(channel.newErrorBuilder()).thenReturn(JsonXContent.contentBuilder());
-        filter.handleRequest(request, channel, null);
-        ArgumentCaptor<RestResponse> response = ArgumentCaptor.forClass(RestResponse.class);
-        verify(channel).sendResponse(response.capture());
-        RestResponse restResponse = response.getValue();
-        return restResponse.filterHeaders(headers.immutableMap());
-    }
-
-    private Map<String, List<String>> testProcessAuthenticationFailed(RestStatus restStatus, MapBuilder<String, List<String>> headers)
-        throws Exception {
-        RestChannel channel = mock(RestChannel.class);
-        SecurityRestFilter filter = new SecurityRestFilter(
-            true,
-            threadContext,
-            authcService,
-            secondaryAuthenticator,
-            mock(AuditTrailService.class),
-            restHandler
-        );
-        RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).build();
-        doAnswer((i) -> {
-            ActionListener<?> callback = (ActionListener<?>) i.getArguments()[1];
-            callback.onFailure(new ElasticsearchStatusException("Authentication failed", restStatus, ""));
-            return Void.TYPE;
-        }).when(authcService).authenticate(eq(request.getHttpRequest()), anyActionListener());
-        when(channel.request()).thenReturn(request);
-        when(channel.newErrorBuilder()).thenReturn(JsonXContent.contentBuilder());
-        filter.handleRequest(request, channel, null);
-        ArgumentCaptor<RestResponse> response = ArgumentCaptor.forClass(RestResponse.class);
-        verify(channel).sendResponse(response.capture());
-        RestResponse restResponse = response.getValue();
-        return restResponse.filterHeaders(headers.immutableMap());
-    }
-}

+ 200 - 8
x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4HttpServerTransportTests.java

@@ -6,38 +6,64 @@
  */
 package org.elasticsearch.xpack.security.transport.netty4;
 
+import io.netty.buffer.ByteBufUtil;
 import io.netty.channel.ChannelHandler;
 import io.netty.channel.embedded.EmbeddedChannel;
+import io.netty.handler.codec.http.DefaultFullHttpRequest;
+import io.netty.handler.codec.http.DefaultHttpRequest;
+import io.netty.handler.codec.http.DefaultLastHttpContent;
+import io.netty.handler.codec.http.HttpMessage;
+import io.netty.handler.codec.http.HttpMethod;
+import io.netty.handler.codec.http.HttpResponseStatus;
 import io.netty.handler.ssl.SslHandler;
 
+import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.common.network.NetworkService;
 import org.elasticsearch.common.settings.MockSecureSettings;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.ssl.SslClientAuthenticationMode;
+import org.elasticsearch.common.util.concurrent.ThreadContext;
 import org.elasticsearch.env.Environment;
 import org.elasticsearch.env.TestEnvironment;
 import org.elasticsearch.http.AbstractHttpServerTransportTestCase;
+import org.elasticsearch.http.HttpRequest;
+import org.elasticsearch.http.HttpServerTransport;
 import org.elasticsearch.http.NullDispatcher;
-import org.elasticsearch.http.netty4.Netty4HttpHeaderValidator;
+import org.elasticsearch.http.netty4.Netty4HttpResponse;
 import org.elasticsearch.http.netty4.Netty4HttpServerTransport;
+import org.elasticsearch.http.netty4.internal.HttpHeadersAuthenticatorUtils;
+import org.elasticsearch.http.netty4.internal.HttpHeadersWithAuthenticationContext;
+import org.elasticsearch.rest.RestChannel;
+import org.elasticsearch.rest.RestRequest;
+import org.elasticsearch.test.rest.FakeRestRequest;
+import org.elasticsearch.threadpool.TestThreadPool;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.tracing.Tracer;
 import org.elasticsearch.transport.netty4.SharedGroupFactory;
 import org.elasticsearch.transport.netty4.TLSConfig;
 import org.elasticsearch.xpack.core.XPackSettings;
 import org.elasticsearch.xpack.core.ssl.SSLService;
+import org.elasticsearch.xpack.security.Security;
 import org.junit.Before;
 
+import java.nio.charset.StandardCharsets;
 import java.nio.file.Path;
 import java.util.Collections;
+import java.util.List;
+import java.util.concurrent.atomic.AtomicReference;
 
 import javax.net.ssl.SSLEngine;
 
+import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1;
+import static org.elasticsearch.transport.Transports.TEST_MOCK_TRANSPORT_THREAD_PREFIX;
 import static org.elasticsearch.xpack.security.transport.netty4.SimpleSecurityNetty4ServerTransportTests.randomCapitalization;
 import static org.hamcrest.Matchers.arrayContaining;
+import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.is;
 import static org.hamcrest.Matchers.not;
+import static org.hamcrest.Matchers.notNullValue;
+import static org.hamcrest.Matchers.nullValue;
 import static org.mockito.Mockito.mock;
 
 public class SecurityNetty4HttpServerTransportTests extends AbstractHttpServerTransportTestCase {
@@ -79,7 +105,7 @@ public class SecurityNetty4HttpServerTransportTests extends AbstractHttpServerTr
             Tracer.NOOP,
             new TLSConfig(sslService.getHttpTransportSSLConfiguration(), sslService::createSSLEngine),
             null,
-            randomFrom(Netty4HttpHeaderValidator.NOOP_VALIDATOR, null)
+            randomFrom((httpPreRequest, channel, listener) -> listener.onResponse(null), null)
         );
         ChannelHandler handler = transport.configureServerChannelHandler();
         final EmbeddedChannel ch = new EmbeddedChannel(handler);
@@ -106,7 +132,7 @@ public class SecurityNetty4HttpServerTransportTests extends AbstractHttpServerTr
             Tracer.NOOP,
             new TLSConfig(sslService.getHttpTransportSSLConfiguration(), sslService::createSSLEngine),
             null,
-            randomFrom(Netty4HttpHeaderValidator.NOOP_VALIDATOR, null)
+            randomFrom((httpPreRequest, channel, listener) -> listener.onResponse(null), null)
         );
         ChannelHandler handler = transport.configureServerChannelHandler();
         final EmbeddedChannel ch = new EmbeddedChannel(handler);
@@ -133,7 +159,7 @@ public class SecurityNetty4HttpServerTransportTests extends AbstractHttpServerTr
             Tracer.NOOP,
             new TLSConfig(sslService.getHttpTransportSSLConfiguration(), sslService::createSSLEngine),
             null,
-            randomFrom(Netty4HttpHeaderValidator.NOOP_VALIDATOR, null)
+            randomFrom((httpPreRequest, channel, listener) -> listener.onResponse(null), null)
         );
         ChannelHandler handler = transport.configureServerChannelHandler();
         final EmbeddedChannel ch = new EmbeddedChannel(handler);
@@ -160,7 +186,7 @@ public class SecurityNetty4HttpServerTransportTests extends AbstractHttpServerTr
             Tracer.NOOP,
             new TLSConfig(sslService.getHttpTransportSSLConfiguration(), sslService::createSSLEngine),
             null,
-            randomFrom(Netty4HttpHeaderValidator.NOOP_VALIDATOR, null)
+            randomFrom((httpPreRequest, channel, listener) -> listener.onResponse(null), null)
         );
         ChannelHandler handler = transport.configureServerChannelHandler();
         final EmbeddedChannel ch = new EmbeddedChannel(handler);
@@ -182,7 +208,7 @@ public class SecurityNetty4HttpServerTransportTests extends AbstractHttpServerTr
             Tracer.NOOP,
             new TLSConfig(sslService.getHttpTransportSSLConfiguration(), sslService::createSSLEngine),
             null,
-            randomFrom(Netty4HttpHeaderValidator.NOOP_VALIDATOR, null)
+            randomFrom((httpPreRequest, channel, listener) -> listener.onResponse(null), null)
         );
         ChannelHandler handler = transport.configureServerChannelHandler();
         EmbeddedChannel ch = new EmbeddedChannel(handler);
@@ -205,7 +231,7 @@ public class SecurityNetty4HttpServerTransportTests extends AbstractHttpServerTr
             Tracer.NOOP,
             new TLSConfig(sslService.getHttpTransportSSLConfiguration(), sslService::createSSLEngine),
             null,
-            randomFrom(Netty4HttpHeaderValidator.NOOP_VALIDATOR, null)
+            randomFrom((httpPreRequest, channel, listener) -> listener.onResponse(null), null)
         );
         handler = transport.configureServerChannelHandler();
         ch = new EmbeddedChannel(handler);
@@ -237,8 +263,174 @@ public class SecurityNetty4HttpServerTransportTests extends AbstractHttpServerTr
             Tracer.NOOP,
             new TLSConfig(sslService.getHttpTransportSSLConfiguration(), sslService::createSSLEngine),
             null,
-            randomFrom(Netty4HttpHeaderValidator.NOOP_VALIDATOR, null)
+            randomFrom((httpPreRequest, channel, listener) -> listener.onResponse(null), null)
         );
         assertNotNull(transport.configureServerChannelHandler());
     }
+
+    public void testAuthnContextWrapping() throws Exception {
+        final Settings settings = Settings.builder().put(env.settings()).build();
+        final AtomicReference<HttpRequest> dispatchedHttpRequestReference = new AtomicReference<>();
+        final String header = "TEST-" + randomAlphaOfLength(8);
+        final String headerValue = "TEST-" + randomAlphaOfLength(8);
+        final String transientHeader = "TEST-" + randomAlphaOfLength(8);
+        final String transientHeaderValue = "TEST-" + randomAlphaOfLength(8);
+        final HttpServerTransport.Dispatcher dispatcher = new HttpServerTransport.Dispatcher() {
+            @Override
+            public void dispatchRequest(final RestRequest request, final RestChannel channel, final ThreadContext threadContext) {
+                // STEP 2: store the dispatched request, which should be wrapping the context
+                dispatchedHttpRequestReference.set(request.getHttpRequest());
+            }
+
+            @Override
+            public void dispatchBadRequest(final RestChannel channel, final ThreadContext threadContext, final Throwable cause) {
+                logger.error(() -> "--> Unexpected bad request [" + FakeRestRequest.requestToString(channel.request()) + "]", cause);
+                throw new AssertionError("Unexpected bad request");
+            }
+
+        };
+        final ThreadPool testThreadPool = new TestThreadPool(TEST_MOCK_TRANSPORT_THREAD_PREFIX);
+        try (
+            Netty4HttpServerTransport transport = Security.getHttpServerTransportWithHeadersValidator(
+                settings,
+                new NetworkService(List.of()),
+                testThreadPool,
+                xContentRegistry(),
+                dispatcher,
+                randomClusterSettings(),
+                new SharedGroupFactory(settings),
+                Tracer.NOOP,
+                TLSConfig.noTLS(),
+                null,
+                (httpPreRequest, channel, listener) -> {
+                    // STEP 1: amend the thread context during authentication
+                    testThreadPool.getThreadContext().putHeader(header, headerValue);
+                    testThreadPool.getThreadContext().putTransient(transientHeader, transientHeaderValue);
+                    listener.onResponse(null);
+                }
+            )
+        ) {
+            final ChannelHandler handler = transport.configureServerChannelHandler();
+            final EmbeddedChannel ch = new EmbeddedChannel(handler);
+            // remove these pipeline handlers as they interfere in the test scenario
+            for (String pipelineHandlerName : ch.pipeline().names()) {
+                if (pipelineHandlerName.equals("decoder")
+                    || pipelineHandlerName.equals("encoder")
+                    || pipelineHandlerName.equals("encoder_compress")
+                    || pipelineHandlerName.equals("chunked_writer")) {
+                    ch.pipeline().remove(pipelineHandlerName);
+                }
+            }
+            var writeFuture = testThreadPool.generic().submit(() -> {
+                ch.writeInbound(
+                    HttpHeadersAuthenticatorUtils.wrapAsMessageWithAuthenticationContext(
+                        new DefaultHttpRequest(HTTP_1_1, HttpMethod.GET, "/wrapped_request")
+                    )
+                );
+                ch.writeInbound(new DefaultLastHttpContent());
+                ch.flushInbound();
+            });
+            writeFuture.get();
+            // STEP 3: assert the wrapped context
+            var storedAuthnContext = HttpHeadersAuthenticatorUtils.extractAuthenticationContext(dispatchedHttpRequestReference.get());
+            assertThat(storedAuthnContext, notNullValue());
+            try (var ignored = testThreadPool.getThreadContext().stashContext()) {
+                assertThat(testThreadPool.getThreadContext().getHeader(header), nullValue());
+                assertThat(testThreadPool.getThreadContext().getTransientHeaders().get(transientHeader), nullValue());
+                storedAuthnContext.restore();
+                assertThat(testThreadPool.getThreadContext().getHeader(header), is(headerValue));
+                assertThat(testThreadPool.getThreadContext().getTransientHeaders().get(transientHeader), is(transientHeaderValue));
+            }
+        } finally {
+            testThreadPool.shutdownNow();
+        }
+    }
+
+    public void testHttpHeaderAuthnFaultyHeaderValidator() throws Exception {
+        final Settings settings = Settings.builder().put(env.settings()).build();
+        final ThreadPool testThreadPool = new TestThreadPool(TEST_MOCK_TRANSPORT_THREAD_PREFIX);
+        try (
+            Netty4HttpServerTransport transport = Security.getHttpServerTransportWithHeadersValidator(
+                settings,
+                new NetworkService(List.of()),
+                testThreadPool,
+                xContentRegistry(),
+                new NullDispatcher(),
+                randomClusterSettings(),
+                new SharedGroupFactory(settings),
+                Tracer.NOOP,
+                TLSConfig.noTLS(),
+                null,
+                (httpPreRequest, channel, listener) -> listener.onResponse(null)
+            )
+        ) {
+            final ChannelHandler handler = transport.configureServerChannelHandler();
+            final EmbeddedChannel ch = new EmbeddedChannel(handler);
+            // remove these pipeline handlers as they interfere in the test scenario
+            for (String pipelineHandlerName : ch.pipeline().names()) {
+                if (pipelineHandlerName.equals("decoder")
+                    || pipelineHandlerName.equals("header_validator") // ALSO REMOVE VALIDATOR so requests are not "validatable"
+                    || pipelineHandlerName.equals("encoder")
+                    || pipelineHandlerName.equals("encoder_compress")
+                    || pipelineHandlerName.equals("chunked_writer")) {
+                    ch.pipeline().remove(pipelineHandlerName);
+                }
+            }
+            // this tests a request that cannot be authenticated, but somehow passed authentication
+            // this is the case of an erroneous internal state
+            var writeFuture = testThreadPool.generic().submit(() -> {
+                ch.writeInbound(new DefaultFullHttpRequest(HTTP_1_1, HttpMethod.GET, "/unauthenticable_request"));
+                ch.flushInbound();
+            });
+            writeFuture.get();
+            ch.flushOutbound();
+            Netty4HttpResponse response = ch.readOutbound();
+            assertThat(response.status(), is(HttpResponseStatus.INTERNAL_SERVER_ERROR));
+            String responseContentString = new String(ByteBufUtil.getBytes(response.content()), StandardCharsets.UTF_8);
+            assertThat(
+                responseContentString,
+                containsString("\"type\":\"security_exception\",\"reason\":\"Request is not authenticated\"")
+            );
+            // this tests a request that CAN be authenticated, but that, somehow, has not been
+            writeFuture = testThreadPool.generic().submit(() -> {
+                ch.writeInbound(
+                    HttpHeadersAuthenticatorUtils.wrapAsMessageWithAuthenticationContext(
+                        new DefaultHttpRequest(HTTP_1_1, HttpMethod.GET, "/_request")
+                    )
+                );
+                ch.writeInbound(new DefaultLastHttpContent());
+                ch.flushInbound();
+            });
+            writeFuture.get();
+            ch.flushOutbound();
+            response = ch.readOutbound();
+            assertThat(response.status(), is(HttpResponseStatus.INTERNAL_SERVER_ERROR));
+            responseContentString = new String(ByteBufUtil.getBytes(response.content()), StandardCharsets.UTF_8);
+            assertThat(
+                responseContentString,
+                containsString("\"type\":\"security_exception\",\"reason\":\"Request is not authenticated\"")
+            );
+            // this tests the case where authentication passed and the request is to be dispatched, BUT that the authentication context
+            // cannot be instated before dispatching the request
+            writeFuture = testThreadPool.generic().submit(() -> {
+                HttpMessage authenticableMessage = HttpHeadersAuthenticatorUtils.wrapAsMessageWithAuthenticationContext(
+                    new DefaultHttpRequest(HTTP_1_1, HttpMethod.GET, "/unauthenticated_request")
+                );
+                ((HttpHeadersWithAuthenticationContext) authenticableMessage.headers()).setAuthenticationContext(() -> {
+                    throw new ElasticsearchException("Boom");
+                });
+                ch.writeInbound(authenticableMessage);
+                ch.writeInbound(new DefaultLastHttpContent());
+                ch.flushInbound();
+            });
+            writeFuture.get();
+            ch.flushOutbound();
+            response = ch.readOutbound();
+            assertThat(response.status(), is(HttpResponseStatus.INTERNAL_SERVER_ERROR));
+            responseContentString = new String(ByteBufUtil.getBytes(response.content()), StandardCharsets.UTF_8);
+            assertThat(responseContentString, containsString("\"type\":\"exception\",\"reason\":\"Boom\""));
+        } finally {
+            testThreadPool.shutdownNow();
+        }
+    }
 }