Browse Source

Remove first `FlowControlHandler` from HTTP pipeline (#128099)

Today we have a `FlowControlHandler` near the top of the Netty HTTP
pipeline in order to hold back a request body while validating the
request headers. This is inefficient since once we've validated the
headers we can handle the body chunks as fast as they arrive, needing no
more flow control. Moreover today we always fork the validation
completion back onto the event loop, forcing any available chunks to be
buffered in the `FlowControlHandler`.

This commit moves the flow-control mechanism into
`Netty4HttpHeaderValidator` itself so that we can bypass it on validated
message bodies. Morever in the (common) case that validation completes
immediately, e.g. because the credentials are available in cache, then
with this commit we skip the flow-control-related buffering entirely.
David Turner 4 months ago
parent
commit
c3a1d58e25

+ 5 - 0
docs/changelog/128099.yaml

@@ -0,0 +1,5 @@
+pr: 128099
+summary: Remove first `FlowControlHandler` from HTTP pipeline
+area: Network
+type: enhancement
+issues: []

+ 109 - 60
modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpHeaderValidator.java

@@ -15,19 +15,23 @@ import io.netty.handler.codec.DecoderResult;
 import io.netty.handler.codec.http.HttpContent;
 import io.netty.handler.codec.http.HttpObject;
 import io.netty.handler.codec.http.HttpRequest;
+import io.netty.util.ReferenceCounted;
 
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.support.ContextPreservingActionListener;
+import org.elasticsearch.action.support.SubscribableListener;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
-import org.elasticsearch.core.Nullable;
 import org.elasticsearch.http.netty4.internal.HttpValidator;
 import org.elasticsearch.transport.Transports;
 
+import java.util.ArrayDeque;
+
 public class Netty4HttpHeaderValidator extends ChannelDuplexHandler {
 
     private final HttpValidator validator;
     private final ThreadContext threadContext;
-    private State state;
+    private State state = State.PASSING;
+    private final ArrayDeque<Object> buffer = new ArrayDeque<>();
 
     public Netty4HttpHeaderValidator(HttpValidator validator, ThreadContext threadContext) {
         this.validator = validator;
@@ -36,80 +40,125 @@ public class Netty4HttpHeaderValidator extends ChannelDuplexHandler {
 
     @Override
     public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
+        if (state == State.VALIDATING || buffer.size() > 0) {
+            // there's already some buffered messages that need to be processed before this one, so queue this one up behind them
+            buffer.offerLast(msg);
+            return;
+        }
+
         assert msg instanceof HttpObject;
-        var httpObject = (HttpObject) msg;
+        final var httpObject = (HttpObject) msg;
         if (httpObject.decoderResult().isFailure()) {
             ctx.fireChannelRead(httpObject); // pass-through for decoding failures
+        } else if (msg instanceof HttpRequest httpRequest) {
+            validate(ctx, httpRequest);
+        } else if (state == State.PASSING) {
+            assert msg instanceof HttpContent;
+            ctx.fireChannelRead(msg);
         } else {
-            if (msg instanceof HttpRequest request) {
-                validate(ctx, request);
-            } else {
-                assert msg instanceof HttpContent;
-                var content = (HttpContent) msg;
-                if (state == State.DROPPING) {
-                    content.release();
-                    ctx.read();
-                } else {
-                    assert state == State.PASSING : "unexpected content before validation completed";
-                    ctx.fireChannelRead(content);
-                }
-            }
+            assert state == State.DROPPING : state;
+            assert msg instanceof HttpContent;
+            final var httpContent = (HttpContent) msg;
+            httpContent.release();
+            ctx.read();
         }
     }
 
     @Override
-    public void read(ChannelHandlerContext ctx) throws Exception {
-        // until validation is completed we can ignore read calls,
-        // once validation is finished HttpRequest will be fired and downstream can read from there
-        if (state != State.VALIDATING) {
-            ctx.read();
-        }
+    public void channelReadComplete(ChannelHandlerContext ctx) {
+        if (buffer.size() == 0) {
+            ctx.fireChannelReadComplete();
+        } // else we're buffering messages so will manage the read-complete messages ourselves
     }
 
-    void validate(ChannelHandlerContext ctx, HttpRequest request) {
-        assert Transports.assertDefaultThreadContext(threadContext);
-        state = State.VALIDATING;
-        ActionListener.run(
-            // this prevents thread-context changes to propagate to the validation listener
-            // atm, the validation listener submits to the event loop executor, which doesn't know about the ES thread-context,
-            // so this is just a defensive play, in case the code inside the listener changes to not use the event loop executor
-            ActionListener.assertOnce(
-                new ContextPreservingActionListener<Void>(
-                    threadContext.wrapRestorable(threadContext.newStoredContext()),
-                    new ActionListener<>() {
-                        @Override
-                        public void onResponse(Void unused) {
-                            handleValidationResult(ctx, request, null);
-                        }
-
-                        @Override
-                        public void onFailure(Exception e) {
-                            handleValidationResult(ctx, request, e);
-                        }
+    @Override
+    public void read(ChannelHandlerContext ctx) throws Exception {
+        assert ctx.channel().eventLoop().inEventLoop();
+        if (state != State.VALIDATING) {
+            if (buffer.size() > 0) {
+                final var message = buffer.pollFirst();
+                if (message instanceof HttpRequest httpRequest) {
+                    if (httpRequest.decoderResult().isFailure()) {
+                        ctx.fireChannelRead(message); // pass-through for decoding failures
+                        ctx.fireChannelReadComplete(); // downstream will have to call read() again when it's ready
+                    } else {
+                        validate(ctx, httpRequest);
                     }
-                )
-            ),
-            listener -> {
-                // this prevents thread-context changes to propagate beyond the validation, as netty worker threads are reused
-                try (ThreadContext.StoredContext ignore = threadContext.newStoredContext()) {
-                    validator.validate(request, ctx.channel(), listener);
+                } else {
+                    assert message instanceof HttpContent;
+                    assert state == State.PASSING : state; // DROPPING releases any buffered chunks up-front
+                    ctx.fireChannelRead(message);
+                    ctx.fireChannelReadComplete(); // downstream will have to call read() again when it's ready
                 }
+            } else {
+                ctx.read();
             }
-        );
+        }
     }
 
-    void handleValidationResult(ChannelHandlerContext ctx, HttpRequest request, @Nullable Exception validationError) {
-        assert Transports.assertDefaultThreadContext(threadContext);
-        // Always explicitly dispatch back to the event loop to prevent reentrancy concerns if we are still on event loop
-        ctx.channel().eventLoop().execute(() -> {
-            if (validationError != null) {
-                request.setDecoderResult(DecoderResult.failure(validationError));
-                state = State.DROPPING;
-            } else {
-                state = State.PASSING;
+    void validate(ChannelHandlerContext ctx, HttpRequest httpRequest) {
+        final var validationResultListener = new ValidationResultListener(ctx, httpRequest);
+        SubscribableListener.newForked(validationResultListener::doValidate)
+            .addListener(
+                validationResultListener,
+                // dispatch back to event loop unless validation completed already in which case we can just continue on this thread
+                // straight away, avoiding the need to buffer any subsequent messages
+                ctx.channel().eventLoop(),
+                null
+            );
+    }
+
+    private class ValidationResultListener implements ActionListener<Void> {
+
+        private final ChannelHandlerContext ctx;
+        private final HttpRequest httpRequest;
+
+        ValidationResultListener(ChannelHandlerContext ctx, HttpRequest httpRequest) {
+            this.ctx = ctx;
+            this.httpRequest = httpRequest;
+        }
+
+        void doValidate(ActionListener<Void> listener) {
+            assert Transports.assertDefaultThreadContext(threadContext);
+            assert ctx.channel().eventLoop().inEventLoop();
+            assert state == State.PASSING || state == State.DROPPING : state;
+            state = State.VALIDATING;
+            try (var ignore = threadContext.newEmptyContext()) {
+                validator.validate(
+                    httpRequest,
+                    ctx.channel(),
+                    new ContextPreservingActionListener<>(threadContext::newEmptyContext, listener)
+                );
             }
-            ctx.fireChannelRead(request);
-        });
+        }
+
+        @Override
+        public void onResponse(Void unused) {
+            assert Transports.assertDefaultThreadContext(threadContext);
+            assert ctx.channel().eventLoop().inEventLoop();
+            assert state == State.VALIDATING : state;
+            state = State.PASSING;
+            fireChannelRead();
+        }
+
+        @Override
+        public void onFailure(Exception e) {
+            assert Transports.assertDefaultThreadContext(threadContext);
+            assert ctx.channel().eventLoop().inEventLoop();
+            assert state == State.VALIDATING : state;
+            httpRequest.setDecoderResult(DecoderResult.failure(e));
+            state = State.DROPPING;
+            while (buffer.isEmpty() == false && buffer.peekFirst() instanceof HttpRequest == false) {
+                assert buffer.peekFirst() instanceof HttpContent;
+                ((ReferenceCounted) buffer.pollFirst()).release();
+            }
+            fireChannelRead();
+        }
+
+        private void fireChannelRead() {
+            ctx.fireChannelRead(httpRequest);
+            ctx.fireChannelReadComplete(); // downstream needs to read() again
+        }
     }
 
     private enum State {

+ 0 - 1
modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpServerTransport.java

@@ -371,7 +371,6 @@ public class Netty4HttpServerTransport extends AbstractHttpServerTransport {
             ch.pipeline().addLast("decoder", decoder); // parses the HTTP bytes request into HTTP message pieces
 
             // from this point in pipeline every handler must call ctx or channel #read() when ready to process next HTTP part
-            ch.pipeline().addLast(new FlowControlHandler());
             if (Assertions.ENABLED) {
                 // missing reads are hard to catch, but we can detect absence of reads within interval
                 long missingReadIntervalMs = 10_000;

+ 175 - 12
modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpHeaderValidatorTests.java

@@ -23,20 +23,27 @@ import io.netty.handler.codec.http.HttpRequest;
 import io.netty.handler.codec.http.HttpRequestDecoder;
 import io.netty.handler.codec.http.HttpVersion;
 import io.netty.handler.codec.http.LastHttpContent;
-import io.netty.handler.flow.FlowControlHandler;
 
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.common.ValidationException;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
+import org.elasticsearch.http.netty4.internal.HttpValidator;
 import org.elasticsearch.test.ESTestCase;
 
+import java.util.ArrayDeque;
+import java.util.Objects;
 import java.util.concurrent.BlockingQueue;
 import java.util.concurrent.LinkedBlockingQueue;
 
+import static org.hamcrest.Matchers.instanceOf;
+
 public class Netty4HttpHeaderValidatorTests extends ESTestCase {
     private EmbeddedChannel channel;
     private BlockingQueue<ValidationRequest> validatorRequestQueue;
+    private HttpValidator httpValidator = (httpRequest, channel, listener) -> validatorRequestQueue.add(
+        new ValidationRequest(httpRequest, channel, listener)
+    );
 
     @Override
     public void setUp() throws Exception {
@@ -44,7 +51,7 @@ public class Netty4HttpHeaderValidatorTests extends ESTestCase {
         validatorRequestQueue = new LinkedBlockingQueue<>();
         channel = new EmbeddedChannel(
             new Netty4HttpHeaderValidator(
-                (httpRequest, channel, listener) -> validatorRequestQueue.add(new ValidationRequest(httpRequest, channel, listener)),
+                (httpRequest, channel, listener) -> httpValidator.validate(httpRequest, channel, listener),
                 new ThreadContext(Settings.EMPTY)
             )
         );
@@ -70,12 +77,42 @@ public class Netty4HttpHeaderValidatorTests extends ESTestCase {
     }
 
     public void testDecoderFailurePassThrough() {
-        for (var i = 0; i < 1000; i++) {
-            var httpRequest = newHttpRequest();
-            httpRequest.setDecoderResult(DecoderResult.failure(new Exception("bad")));
-            channel.writeInbound(httpRequest);
-            assertEquals(httpRequest, channel.readInbound());
+        // send a valid request so that the buffer is nonempty
+        final var validRequest = newHttpRequest();
+        channel.writeInbound(validRequest);
+        channel.writeInbound(newLastHttpContent());
+
+        // follow it with an invalid request which should be buffered
+        final var invalidHttpRequest1 = newHttpRequest();
+        invalidHttpRequest1.setDecoderResult(DecoderResult.failure(new Exception("simulated decoder failure 1")));
+        channel.writeInbound(invalidHttpRequest1);
+
+        // handle the first request
+        if (randomBoolean()) {
+            Objects.requireNonNull(validatorRequestQueue.poll()).listener().onResponse(null);
+            channel.runPendingTasks();
+            assertSame(validRequest, channel.readInbound());
+            channel.read();
+            asInstanceOf(LastHttpContent.class, channel.readInbound()).release();
+        } else {
+            Objects.requireNonNull(validatorRequestQueue.poll()).listener().onFailure(new Exception("simulated validation failure"));
+            channel.runPendingTasks();
+            assertSame(validRequest, channel.readInbound());
         }
+
+        // handle the second request, which is read from the buffer and passed on without validation
+        assertNull(channel.readInbound());
+        channel.read();
+        assertSame(invalidHttpRequest1, channel.readInbound());
+
+        // send another invalid request which is passed straight through
+        final var invalidHttpRequest2 = newHttpRequest();
+        invalidHttpRequest2.setDecoderResult(DecoderResult.failure(new Exception("simulated decoder failure 2")));
+        channel.writeInbound(invalidHttpRequest2);
+        if (randomBoolean()) {
+            channel.read(); // optional read
+        }
+        assertSame(invalidHttpRequest2, channel.readInbound());
     }
 
     /**
@@ -121,10 +158,8 @@ public class Netty4HttpHeaderValidatorTests extends ESTestCase {
     }
 
     public void testIgnoreReadWhenValidating() {
-        channel.pipeline().addFirst(new FlowControlHandler()); // catch all inbound messages
-
         channel.writeInbound(newHttpRequest());
-        channel.writeInbound(newLastHttpContent()); // should hold by flow-control-handler
+        channel.writeInbound(newLastHttpContent());
         assertNull("nothing should pass yet", channel.readInbound());
 
         channel.read();
@@ -143,8 +178,7 @@ public class Netty4HttpHeaderValidatorTests extends ESTestCase {
         asInstanceOf(LastHttpContent.class, channel.readInbound()).release();
     }
 
-    public void testWithFlowControlAndAggregator() {
-        channel.pipeline().addFirst(new FlowControlHandler());
+    public void testWithAggregator() {
         channel.pipeline().addLast(new Netty4HttpAggregator(8192, (req) -> true, new HttpRequestDecoder()));
 
         channel.writeInbound(newHttpRequest());
@@ -162,5 +196,134 @@ public class Netty4HttpHeaderValidatorTests extends ESTestCase {
         asInstanceOf(FullHttpRequest.class, channel.readInbound()).release();
     }
 
+    public void testBufferPipelinedRequestsWhenValidating() {
+        final var expectedChunks = new ArrayDeque<HttpContent>();
+        expectedChunks.addLast(newHttpContent());
+
+        // write one full request and one incomplete request received all at once
+        channel.writeInbound(newHttpRequest());
+        channel.writeInbound(newLastHttpContent());
+        channel.writeInbound(newHttpRequest());
+        channel.writeInbound(expectedChunks.peekLast());
+        assertNull("nothing should pass yet", channel.readInbound());
+
+        if (randomBoolean()) {
+            channel.read();
+        }
+        var validationRequest = validatorRequestQueue.poll();
+        assertNotNull(validationRequest);
+
+        channel.read();
+        assertNull("should ignore read while validating", channel.readInbound());
+
+        validationRequest.listener().onResponse(null);
+        channel.runPendingTasks();
+        assertTrue("http request should pass", channel.readInbound() instanceof HttpRequest);
+        assertNull("content should not pass yet, need explicit read", channel.readInbound());
+
+        channel.read();
+        asInstanceOf(LastHttpContent.class, channel.readInbound()).release();
+
+        // should have started to validate the next request
+        channel.read();
+        assertNull("should ignore read while validating", channel.readInbound());
+        Objects.requireNonNull(validatorRequestQueue.poll()).listener().onResponse(null);
+
+        channel.runPendingTasks();
+        assertThat("next http request should pass", channel.readInbound(), instanceOf(HttpRequest.class));
+
+        // another chunk received and is buffered, nothing is sent downstream
+        expectedChunks.addLast(newHttpContent());
+        channel.writeInbound(expectedChunks.peekLast());
+        assertNull(channel.readInbound());
+        assertFalse(channel.hasPendingTasks());
+
+        // the first chunk is now emitted on request
+        channel.read();
+        var nextChunk = asInstanceOf(HttpContent.class, channel.readInbound());
+        assertSame(nextChunk, expectedChunks.pollFirst());
+        nextChunk.release();
+        assertNull(channel.readInbound());
+        assertFalse(channel.hasPendingTasks());
+
+        // and the second chunk
+        channel.read();
+        nextChunk = asInstanceOf(HttpContent.class, channel.readInbound());
+        assertSame(nextChunk, expectedChunks.pollFirst());
+        nextChunk.release();
+        assertNull(channel.readInbound());
+        assertFalse(channel.hasPendingTasks());
+
+        // buffer is now drained, no more chunks available
+        if (randomBoolean()) {
+            channel.read(); // optional read
+        }
+        assertNull(channel.readInbound());
+        assertTrue(expectedChunks.isEmpty());
+        assertFalse(channel.hasPendingTasks());
+
+        // subsequent chunks are passed straight through without another read()
+        expectedChunks.addLast(newHttpContent());
+        channel.writeInbound(expectedChunks.peekLast());
+        nextChunk = asInstanceOf(HttpContent.class, channel.readInbound());
+        assertSame(nextChunk, expectedChunks.pollFirst());
+        nextChunk.release();
+        assertNull(channel.readInbound());
+        assertFalse(channel.hasPendingTasks());
+    }
+
+    public void testDropChunksOnValidationFailure() {
+        // write an incomplete request which will be marked as invalid
+        channel.writeInbound(newHttpRequest());
+        channel.writeInbound(newHttpContent());
+        assertNull("nothing should pass yet", channel.readInbound());
+
+        var validationRequest = validatorRequestQueue.poll();
+        assertNotNull(validationRequest);
+        validationRequest.listener().onFailure(new Exception("simulated validation failure"));
+
+        // failed request is passed downstream
+        channel.runPendingTasks();
+        var inboundRequest = asInstanceOf(HttpRequest.class, channel.readInbound());
+        assertTrue(inboundRequest.decoderResult().isFailure());
+        assertEquals("simulated validation failure", inboundRequest.decoderResult().cause().getMessage());
+
+        // chunk is not emitted (the buffer is now drained)
+        assertNull(channel.readInbound());
+        if (randomBoolean()) {
+            channel.read();
+            assertNull(channel.readInbound());
+        }
+
+        // next chunk is also not emitted (it is released on receipt, not buffered)
+        channel.writeInbound(newLastHttpContent());
+        assertNull(channel.readInbound());
+        if (randomBoolean()) {
+            channel.read();
+            assertNull(channel.readInbound());
+        }
+        assertFalse(channel.hasPendingTasks());
+
+        // next request triggers validation again
+        final var nextRequest = newHttpRequest();
+        channel.writeInbound(nextRequest);
+        Objects.requireNonNull(validatorRequestQueue.poll()).listener().onResponse(null);
+        channel.runPendingTasks();
+
+        if (randomBoolean()) {
+            channel.read(); // optional read
+        }
+        assertSame(nextRequest, channel.readInbound());
+        assertFalse(channel.hasPendingTasks());
+    }
+
+    public void testInlineValidationDoesNotFork() {
+        httpValidator = (httpRequest, channel, listener) -> listener.onResponse(null);
+        final var httpRequest = newHttpRequest();
+        channel.writeInbound(httpRequest);
+        assertFalse(channel.hasPendingTasks());
+        assertSame(httpRequest, channel.readInbound());
+    }
+
     record ValidationRequest(HttpRequest request, Channel channel, ActionListener<Void> listener) {}
 }

+ 8 - 4
modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerTransportTests.java

@@ -63,6 +63,7 @@ import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.transport.BoundTransportAddress;
 import org.elasticsearch.common.transport.TransportAddress;
 import org.elasticsearch.common.unit.ByteSizeValue;
+import org.elasticsearch.common.util.concurrent.EsExecutors;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.core.Tuple;
@@ -120,6 +121,7 @@ import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.emptyIterable;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.hasSize;
+import static org.hamcrest.Matchers.in;
 import static org.hamcrest.Matchers.instanceOf;
 import static org.hamcrest.Matchers.is;
 import static org.hamcrest.Matchers.iterableWithSize;
@@ -976,7 +978,7 @@ public class Netty4HttpServerTransportTests extends AbstractHttpServerTransportT
         final HttpServerTransport.Dispatcher dispatcher = new HttpServerTransport.Dispatcher() {
             @Override
             public void dispatchRequest(final RestRequest request, final RestChannel channel, final ThreadContext threadContext) {
-                assertThat(okURIs.contains(request.uri()), is(true));
+                assertThat(request.uri(), in(okURIs));
                 // assert validated request is dispatched
                 okURIs.remove(request.uri());
                 channel.sendResponse(new RestResponse(OK, RestResponse.TEXT_CONTENT_TYPE, new BytesArray("dispatch OK")));
@@ -985,7 +987,7 @@ public class Netty4HttpServerTransportTests extends AbstractHttpServerTransportT
             @Override
             public void dispatchBadRequest(final RestChannel channel, final ThreadContext threadContext, final Throwable cause) {
                 // assert unvalidated request is NOT dispatched
-                assertThat(nokURIs.contains(channel.request().uri()), is(true));
+                assertThat(channel.request().uri(), in(nokURIs));
                 nokURIs.remove(channel.request().uri());
                 try {
                     channel.sendResponse(new RestResponse(channel, (Exception) ((ElasticsearchWrapperException) cause).getCause()));
@@ -1000,9 +1002,11 @@ public class Netty4HttpServerTransportTests extends AbstractHttpServerTransportT
             assertThat(channelSetOnce.get(), is(channel));
             // some requests are validated while others are not
             if (httpPreRequest.uri().contains("X-Auth=OK")) {
-                validationListener.onResponse(null);
+                randomFrom(EsExecutors.DIRECT_EXECUTOR_SERVICE, channel.eventLoop()).execute(() -> validationListener.onResponse(null));
             } else if (httpPreRequest.uri().contains("X-Auth=NOK")) {
-                validationListener.onFailure(new ElasticsearchSecurityException("Boom", UNAUTHORIZED));
+                randomFrom(EsExecutors.DIRECT_EXECUTOR_SERVICE, channel.eventLoop()).execute(
+                    () -> validationListener.onFailure(new ElasticsearchSecurityException("Boom", UNAUTHORIZED))
+                );
             } else {
                 throw new AssertionError("Unrecognized URI");
             }