Przeglądaj źródła

Deserialize responses on the handling thread-pool (#91367)

This is the start of moving message deserialization off of the transport
threads where possible. This PR introduces the basic facilities to ref
count and fork serialization of transport message instances which
already provides some tangible benefits to transport thread latencies.

We can't not fork for large messages (which are mostly responses) in
scenarios where responses can grow beyond O(1M) as this introduces
unmanageable latency on the transport pool when e.g. deserializing a
O(100M) cluster state or a similarly sized search response.
Armin Braun 2 lat temu
rodzic
commit
9139dd9e1d

+ 5 - 0
docs/changelog/91367.yaml

@@ -0,0 +1,5 @@
+pr: 91367
+summary: Deserialize responses on the handling thread-pool
+area: Network
+type: enhancement
+issues: []

+ 2 - 2
server/src/main/java/org/elasticsearch/transport/InboundAggregator.java

@@ -119,7 +119,7 @@ public class InboundAggregator implements Releasable {
                 checkBreaker(aggregated.getHeader(), aggregated.getContentLength(), breakerControl);
             }
             if (isShortCircuited()) {
-                aggregated.close();
+                aggregated.decRef();
                 success = true;
                 return new InboundMessage(aggregated.getHeader(), aggregationException);
             } else {
@@ -130,7 +130,7 @@ public class InboundAggregator implements Releasable {
         } finally {
             resetCurrentAggregation();
             if (success == false) {
-                aggregated.close();
+                aggregated.decRef();
             }
         }
     }

+ 71 - 39
server/src/main/java/org/elasticsearch/transport/InboundHandler.java

@@ -19,6 +19,8 @@ import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.network.HandlingTimeTracker;
 import org.elasticsearch.common.util.concurrent.AbstractRunnable;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
+import org.elasticsearch.core.Releasable;
+import org.elasticsearch.core.Releasables;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.threadpool.ThreadPool;
 
@@ -133,34 +135,17 @@ public class InboundHandler {
                 }
                 // ignore if its null, the service logs it
                 if (responseHandler != null) {
-                    final StreamInput streamInput;
                     if (message.getContentLength() > 0 || header.getVersion().equals(Version.CURRENT) == false) {
-                        streamInput = namedWriteableStream(message.openOrGetStreamInput());
+                        final StreamInput streamInput = namedWriteableStream(message.openOrGetStreamInput());
                         assertRemoteVersion(streamInput, header.getVersion());
                         if (header.isError()) {
-                            handlerResponseError(streamInput, responseHandler);
+                            handlerResponseError(streamInput, message, responseHandler);
                         } else {
-                            handleResponse(remoteAddress, streamInput, responseHandler);
-                        }
-                        // Check the entire message has been read
-                        final int nextByte = streamInput.read();
-                        // calling read() is useful to make sure the message is fully read, even if there is an EOS marker
-                        if (nextByte != -1) {
-                            final IllegalStateException exception = new IllegalStateException(
-                                "Message not fully read (response) for requestId ["
-                                    + requestId
-                                    + "], handler ["
-                                    + responseHandler
-                                    + "], error ["
-                                    + header.isError()
-                                    + "]; resetting"
-                            );
-                            assert ignoreDeserializationErrors : exception;
-                            throw exception;
+                            handleResponse(remoteAddress, streamInput, responseHandler, message);
                         }
                     } else {
                         assert header.isError() == false;
-                        handleResponse(remoteAddress, EMPTY_STREAM_INPUT, responseHandler);
+                        handleResponse(remoteAddress, EMPTY_STREAM_INPUT, responseHandler, message);
                     }
                 }
             }
@@ -189,6 +174,26 @@ public class InboundHandler {
         }
     }
 
+    private void verifyResponseReadFully(Header header, TransportResponseHandler<?> responseHandler, StreamInput streamInput)
+        throws IOException {
+        // Check the entire message has been read
+        final int nextByte = streamInput.read();
+        // calling read() is useful to make sure the message is fully read, even if there is an EOS marker
+        if (nextByte != -1) {
+            final IllegalStateException exception = new IllegalStateException(
+                "Message not fully read (response) for requestId ["
+                    + header.getRequestId()
+                    + "], handler ["
+                    + responseHandler
+                    + "], error ["
+                    + header.isError()
+                    + "]; resetting"
+            );
+            assert ignoreDeserializationErrors : exception;
+            throw exception;
+        }
+    }
+
     private <T extends TransportRequest> void handleRequest(TcpChannel channel, Header header, InboundMessage message) throws IOException {
         final String action = header.getActionName();
         final long requestId = header.getRequestId();
@@ -335,10 +340,49 @@ public class InboundHandler {
     private <T extends TransportResponse> void handleResponse(
         InetSocketAddress remoteAddress,
         final StreamInput stream,
-        final TransportResponseHandler<T> handler
+        final TransportResponseHandler<T> handler,
+        final InboundMessage inboundMessage
+    ) {
+        final String executor = handler.executor();
+        if (ThreadPool.Names.SAME.equals(executor)) {
+            // no need to provide a buffer release here, we never escape the buffer when handling directly
+            doHandleResponse(handler, remoteAddress, stream, inboundMessage.getHeader(), () -> {});
+        } else {
+            inboundMessage.incRef();
+            // release buffer once we deserialize the message, but have a fail-safe in #onAfter below in case that didn't work out
+            final Releasable releaseBuffer = Releasables.releaseOnce(inboundMessage::decRef);
+            threadPool.executor(executor).execute(new ForkingResponseHandlerRunnable(handler, null) {
+                @Override
+                protected void doRun() {
+                    doHandleResponse(handler, remoteAddress, stream, inboundMessage.getHeader(), releaseBuffer);
+                }
+
+                @Override
+                public void onAfter() {
+                    Releasables.closeExpectNoException(releaseBuffer);
+                }
+            });
+        }
+    }
+
+    /**
+     *
+     * @param handler response handler
+     * @param remoteAddress remote address that the message was sent from
+     * @param stream bytes stream for reading the message
+     * @param header message header
+     * @param releaseResponseBuffer releasable that will be released once the message has been read from the {@code stream}
+     * @param <T> response message type
+     */
+    private <T extends TransportResponse> void doHandleResponse(
+        TransportResponseHandler<T> handler,
+        InetSocketAddress remoteAddress,
+        final StreamInput stream,
+        final Header header,
+        Releasable releaseResponseBuffer
     ) {
         final T response;
-        try {
+        try (releaseResponseBuffer) {
             response = handler.read(stream);
             response.remoteAddress(remoteAddress);
         } catch (Exception e) {
@@ -348,24 +392,11 @@ public class InboundHandler {
             );
             logger.warn(() -> "Failed to deserialize response from [" + remoteAddress + "]", serializationException);
             assert ignoreDeserializationErrors : e;
-            handleException(handler, serializationException);
+            doHandleException(handler, serializationException);
             return;
         }
-        final String executor = handler.executor();
-        if (ThreadPool.Names.SAME.equals(executor)) {
-            doHandleResponse(handler, response);
-        } else {
-            threadPool.executor(executor).execute(new ForkingResponseHandlerRunnable(handler, null) {
-                @Override
-                protected void doRun() {
-                    doHandleResponse(handler, response);
-                }
-            });
-        }
-    }
-
-    private static <T extends TransportResponse> void doHandleResponse(TransportResponseHandler<T> handler, T response) {
         try {
+            verifyResponseReadFully(header, handler, stream);
             handler.handleResponse(response);
         } catch (Exception e) {
             doHandleException(handler, new ResponseHandlerFailureTransportException(e));
@@ -374,10 +405,11 @@ public class InboundHandler {
         }
     }
 
-    private void handlerResponseError(StreamInput stream, final TransportResponseHandler<?> handler) {
+    private void handlerResponseError(StreamInput stream, InboundMessage message, final TransportResponseHandler<?> handler) {
         Exception error;
         try {
             error = stream.readException();
+            verifyResponseReadFully(message.getHeader(), handler, stream);
         } catch (Exception e) {
             error = new TransportSerializationException(
                 "Failed to deserialize exception response from stream for handler [" + handler + "]",

+ 9 - 7
server/src/main/java/org/elasticsearch/transport/InboundMessage.java

@@ -11,13 +11,14 @@ package org.elasticsearch.transport;
 import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.common.bytes.ReleasableBytesReference;
 import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.core.AbstractRefCounted;
 import org.elasticsearch.core.IOUtils;
 import org.elasticsearch.core.Releasable;
 
 import java.io.IOException;
 import java.util.Objects;
 
-public class InboundMessage implements Releasable {
+public class InboundMessage extends AbstractRefCounted {
 
     private final Header header;
     private final ReleasableBytesReference content;
@@ -82,6 +83,7 @@ public class InboundMessage implements Releasable {
 
     public StreamInput openOrGetStreamInput() throws IOException {
         assert isPing == false && content != null;
+        assert hasReferences();
         if (streamInput == null) {
             streamInput = content.streamInput();
             streamInput.setVersion(header.getVersion());
@@ -90,7 +92,12 @@ public class InboundMessage implements Releasable {
     }
 
     @Override
-    public void close() {
+    public String toString() {
+        return "InboundMessage{" + header + "}";
+    }
+
+    @Override
+    protected void closeInternal() {
         try {
             IOUtils.close(streamInput, content, breakerRelease);
         } catch (Exception e) {
@@ -98,9 +105,4 @@ public class InboundMessage implements Releasable {
             throw new ElasticsearchException(e);
         }
     }
-
-    @Override
-    public String toString() {
-        return "InboundMessage{" + header + "}";
-    }
 }

+ 4 - 1
server/src/main/java/org/elasticsearch/transport/InboundPipeline.java

@@ -144,9 +144,12 @@ public class InboundPipeline implements Releasable {
                 messageHandler.accept(channel, PING_MESSAGE);
             } else if (fragment == InboundDecoder.END_CONTENT) {
                 assert aggregator.isAggregating();
-                try (InboundMessage aggregated = aggregator.finishAggregation()) {
+                InboundMessage aggregated = aggregator.finishAggregation();
+                try {
                     statsTracker.markMessageReceived();
                     messageHandler.accept(channel, aggregated);
+                } finally {
+                    aggregated.decRef();
                 }
             } else {
                 assert aggregator.isAggregating();

+ 1 - 1
server/src/test/java/org/elasticsearch/transport/InboundAggregatorTests.java

@@ -93,7 +93,7 @@ public class InboundAggregatorTests extends ESTestCase {
         for (ReleasableBytesReference reference : references) {
             assertTrue(reference.hasReferences());
         }
-        aggregated.close();
+        aggregated.decRef();
         for (ReleasableBytesReference reference : references) {
             assertFalse(reference.hasReferences());
         }