Browse Source

Remove lots of redundant ref-counting from transport pipeline (#123390) (#123554)

We can do with a whole lot less in ref-counting, avoiding lots of contention and speeding
up the logic in general by only incrementing ref-counts where ownership is unclear while
avoiding count changes on obvious "moves".

Co-authored-by: Dimitris Rempapis <dimitris.rempapis@elastic.co>
Armin Braun 2 months ago
parent
commit
b25542fd14

+ 2 - 3
modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4MessageInboundHandler.java

@@ -14,7 +14,6 @@ import io.netty.channel.ChannelInboundHandlerAdapter;
 
 import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.ExceptionsHelper;
-import org.elasticsearch.common.bytes.ReleasableBytesReference;
 import org.elasticsearch.common.network.ThreadWatchdog;
 import org.elasticsearch.core.Releasables;
 import org.elasticsearch.transport.InboundPipeline;
@@ -51,8 +50,8 @@ public class Netty4MessageInboundHandler extends ChannelInboundHandlerAdapter {
         final ByteBuf buffer = (ByteBuf) msg;
         Netty4TcpChannel channel = ctx.channel().attr(Netty4Transport.CHANNEL_KEY).get();
         activityTracker.startActivity();
-        try (ReleasableBytesReference reference = Netty4Utils.toReleasableBytesReference(buffer)) {
-            pipeline.handleBytes(channel, reference);
+        try {
+            pipeline.handleBytes(channel, Netty4Utils.toReleasableBytesReference(buffer));
         } finally {
             activityTracker.stopActivity();
         }

+ 8 - 12
modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/NettyByteBufSizer.java

@@ -12,12 +12,10 @@ package org.elasticsearch.transport.netty4;
 import io.netty.buffer.ByteBuf;
 import io.netty.channel.ChannelHandler;
 import io.netty.channel.ChannelHandlerContext;
-import io.netty.handler.codec.MessageToMessageDecoder;
-
-import java.util.List;
+import io.netty.channel.ChannelInboundHandlerAdapter;
 
 @ChannelHandler.Sharable
-public class NettyByteBufSizer extends MessageToMessageDecoder<ByteBuf> {
+public class NettyByteBufSizer extends ChannelInboundHandlerAdapter {
 
     public static final NettyByteBufSizer INSTANCE = new NettyByteBufSizer();
 
@@ -26,14 +24,12 @@ public class NettyByteBufSizer extends MessageToMessageDecoder<ByteBuf> {
     }
 
     @Override
-    protected void decode(ChannelHandlerContext ctx, ByteBuf buf, List<Object> out) {
-        int readableBytes = buf.readableBytes();
-        if (buf.capacity() >= 1024) {
-            ByteBuf resized = buf.discardReadBytes().capacity(readableBytes);
-            assert resized.readableBytes() == readableBytes;
-            out.add(resized.retain());
-        } else {
-            out.add(buf.retain());
+    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
+        if (msg instanceof ByteBuf buf && buf.capacity() >= 1024) {
+            int readableBytes = buf.readableBytes();
+            buf = buf.discardReadBytes().capacity(readableBytes);
+            assert buf.readableBytes() == readableBytes;
         }
+        ctx.fireChannelRead(msg);
     }
 }

+ 18 - 20
server/src/main/java/org/elasticsearch/transport/InboundDecoder.java

@@ -18,12 +18,12 @@ import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.recycler.Recycler;
 import org.elasticsearch.common.unit.ByteSizeUnit;
 import org.elasticsearch.common.unit.ByteSizeValue;
+import org.elasticsearch.core.CheckedConsumer;
 import org.elasticsearch.core.Releasable;
 import org.elasticsearch.core.Releasables;
 
 import java.io.IOException;
 import java.io.StreamCorruptedException;
-import java.util.function.Consumer;
 
 public class InboundDecoder implements Releasable {
 
@@ -53,7 +53,7 @@ public class InboundDecoder implements Releasable {
         this.channelType = channelType;
     }
 
-    public int decode(ReleasableBytesReference reference, Consumer<Object> fragmentConsumer) throws IOException {
+    public int decode(ReleasableBytesReference reference, CheckedConsumer<Object, IOException> fragmentConsumer) throws IOException {
         ensureOpen();
         try {
             return internalDecode(reference, fragmentConsumer);
@@ -63,7 +63,8 @@ public class InboundDecoder implements Releasable {
         }
     }
 
-    public int internalDecode(ReleasableBytesReference reference, Consumer<Object> fragmentConsumer) throws IOException {
+    public int internalDecode(ReleasableBytesReference reference, CheckedConsumer<Object, IOException> fragmentConsumer)
+        throws IOException {
         if (isOnHeader()) {
             int messageLength = TcpTransport.readMessageLength(reference);
             if (messageLength == -1) {
@@ -104,25 +105,28 @@ public class InboundDecoder implements Releasable {
             }
             int remainingToConsume = totalNetworkSize - bytesConsumed;
             int maxBytesToConsume = Math.min(reference.length(), remainingToConsume);
-            ReleasableBytesReference retainedContent;
-            if (maxBytesToConsume == remainingToConsume) {
-                retainedContent = reference.retainedSlice(0, maxBytesToConsume);
-            } else {
-                retainedContent = reference.retain();
-            }
-
             int bytesConsumedThisDecode = 0;
             if (decompressor != null) {
-                bytesConsumedThisDecode += decompress(retainedContent);
+                bytesConsumedThisDecode += decompressor.decompress(
+                    maxBytesToConsume == remainingToConsume ? reference.slice(0, maxBytesToConsume) : reference
+                );
                 bytesConsumed += bytesConsumedThisDecode;
                 ReleasableBytesReference decompressed;
                 while ((decompressed = decompressor.pollDecompressedPage(isDone())) != null) {
-                    fragmentConsumer.accept(decompressed);
+                    try (var buf = decompressed) {
+                        fragmentConsumer.accept(buf);
+                    }
                 }
             } else {
                 bytesConsumedThisDecode += maxBytesToConsume;
                 bytesConsumed += maxBytesToConsume;
-                fragmentConsumer.accept(retainedContent);
+                if (maxBytesToConsume == remainingToConsume) {
+                    try (ReleasableBytesReference retained = reference.retainedSlice(0, maxBytesToConsume)) {
+                        fragmentConsumer.accept(retained);
+                    }
+                } else {
+                    fragmentConsumer.accept(reference);
+                }
             }
             if (isDone()) {
                 finishMessage(fragmentConsumer);
@@ -138,7 +142,7 @@ public class InboundDecoder implements Releasable {
         cleanDecodeState();
     }
 
-    private void finishMessage(Consumer<Object> fragmentConsumer) {
+    private void finishMessage(CheckedConsumer<Object, IOException> fragmentConsumer) throws IOException {
         cleanDecodeState();
         fragmentConsumer.accept(END_CONTENT);
     }
@@ -154,12 +158,6 @@ public class InboundDecoder implements Releasable {
         }
     }
 
-    private int decompress(ReleasableBytesReference content) throws IOException {
-        try (content) {
-            return decompressor.decompress(content);
-        }
-    }
-
     private boolean isDone() {
         return bytesConsumed == totalNetworkSize;
     }

+ 53 - 80
server/src/main/java/org/elasticsearch/transport/InboundPipeline.java

@@ -11,18 +11,17 @@ package org.elasticsearch.transport;
 
 import org.elasticsearch.common.bytes.CompositeBytesReference;
 import org.elasticsearch.common.bytes.ReleasableBytesReference;
+import org.elasticsearch.core.CheckedConsumer;
 import org.elasticsearch.core.Releasable;
 import org.elasticsearch.core.Releasables;
 
 import java.io.IOException;
 import java.util.ArrayDeque;
-import java.util.ArrayList;
 import java.util.function.BiConsumer;
 import java.util.function.LongSupplier;
 
 public class InboundPipeline implements Releasable {
 
-    private static final ThreadLocal<ArrayList<Object>> fragmentList = ThreadLocal.withInitial(ArrayList::new);
     private static final InboundMessage PING_MESSAGE = new InboundMessage(null, true);
 
     private final LongSupplier relativeTimeInMillis;
@@ -56,81 +55,74 @@ public class InboundPipeline implements Releasable {
 
     public void handleBytes(TcpChannel channel, ReleasableBytesReference reference) throws IOException {
         if (uncaughtException != null) {
+            reference.close();
             throw new IllegalStateException("Pipeline state corrupted by uncaught exception", uncaughtException);
         }
         try {
-            doHandleBytes(channel, reference);
+            channel.getChannelStats().markAccessed(relativeTimeInMillis.getAsLong());
+            statsTracker.markBytesRead(reference.length());
+            if (isClosed) {
+                reference.close();
+                return;
+            }
+            pending.add(reference);
+            doHandleBytes(channel);
         } catch (Exception e) {
             uncaughtException = e;
             throw e;
         }
     }
 
-    public void doHandleBytes(TcpChannel channel, ReleasableBytesReference reference) throws IOException {
-        channel.getChannelStats().markAccessed(relativeTimeInMillis.getAsLong());
-        statsTracker.markBytesRead(reference.length());
-        pending.add(reference.retain());
-
-        final ArrayList<Object> fragments = fragmentList.get();
-        boolean continueHandling = true;
-
-        while (continueHandling && isClosed == false) {
-            boolean continueDecoding = true;
-            while (continueDecoding && pending.isEmpty() == false) {
-                try (ReleasableBytesReference toDecode = getPendingBytes()) {
-                    final int bytesDecoded = decoder.decode(toDecode, fragments::add);
-                    if (bytesDecoded != 0) {
-                        releasePendingBytes(bytesDecoded);
-                        if (fragments.isEmpty() == false && endOfMessage(fragments.get(fragments.size() - 1))) {
-                            continueDecoding = false;
-                        }
-                    } else {
-                        continueDecoding = false;
-                    }
+    private void doHandleBytes(TcpChannel channel) throws IOException {
+        do {
+            CheckedConsumer<Object, IOException> decodeConsumer = f -> forwardFragment(channel, f);
+            int bytesDecoded = decoder.decode(pending.peekFirst(), decodeConsumer);
+            if (bytesDecoded == 0 && pending.size() > 1) {
+                final ReleasableBytesReference[] bytesReferences = new ReleasableBytesReference[pending.size()];
+                int index = 0;
+                for (ReleasableBytesReference pendingReference : pending) {
+                    bytesReferences[index] = pendingReference.retain();
+                    ++index;
+                }
+                try (
+                    ReleasableBytesReference toDecode = new ReleasableBytesReference(
+                        CompositeBytesReference.of(bytesReferences),
+                        () -> Releasables.closeExpectNoException(bytesReferences)
+                    )
+                ) {
+                    bytesDecoded = decoder.decode(toDecode, decodeConsumer);
                 }
             }
-
-            if (fragments.isEmpty()) {
-                continueHandling = false;
+            if (bytesDecoded != 0) {
+                releasePendingBytes(bytesDecoded);
             } else {
-                try {
-                    forwardFragments(channel, fragments);
-                } finally {
-                    for (Object fragment : fragments) {
-                        if (fragment instanceof ReleasableBytesReference) {
-                            ((ReleasableBytesReference) fragment).close();
-                        }
-                    }
-                    fragments.clear();
-                }
+                break;
             }
-        }
+        } while (pending.isEmpty() == false);
     }
 
-    private void forwardFragments(TcpChannel channel, ArrayList<Object> fragments) throws IOException {
-        for (Object fragment : fragments) {
-            if (fragment instanceof Header) {
-                headerReceived((Header) fragment);
-            } else if (fragment instanceof Compression.Scheme) {
-                assert aggregator.isAggregating();
-                aggregator.updateCompressionScheme((Compression.Scheme) fragment);
-            } else if (fragment == InboundDecoder.PING) {
-                assert aggregator.isAggregating() == false;
-                messageHandler.accept(channel, PING_MESSAGE);
-            } else if (fragment == InboundDecoder.END_CONTENT) {
-                assert aggregator.isAggregating();
-                InboundMessage aggregated = aggregator.finishAggregation();
-                try {
-                    statsTracker.markMessageReceived();
-                    messageHandler.accept(channel, aggregated);
-                } finally {
-                    aggregated.decRef();
-                }
-            } else {
-                assert aggregator.isAggregating();
-                assert fragment instanceof ReleasableBytesReference;
-                aggregator.aggregate((ReleasableBytesReference) fragment);
+    private void forwardFragment(TcpChannel channel, Object fragment) throws IOException {
+        if (fragment instanceof Header) {
+            headerReceived((Header) fragment);
+        } else if (fragment instanceof Compression.Scheme) {
+            assert aggregator.isAggregating();
+            aggregator.updateCompressionScheme((Compression.Scheme) fragment);
+        } else if (fragment == InboundDecoder.PING) {
+            assert aggregator.isAggregating() == false;
+            messageHandler.accept(channel, PING_MESSAGE);
+        } else if (fragment == InboundDecoder.END_CONTENT) {
+            assert aggregator.isAggregating();
+            InboundMessage aggregated = aggregator.finishAggregation();
+            try {
+                statsTracker.markMessageReceived();
+                messageHandler.accept(channel, aggregated);
+            } finally {
+                aggregated.decRef();
             }
+        } else {
+            assert aggregator.isAggregating();
+            assert fragment instanceof ReleasableBytesReference;
+            aggregator.aggregate((ReleasableBytesReference) fragment);
         }
     }
 
@@ -139,25 +131,6 @@ public class InboundPipeline implements Releasable {
         aggregator.headerReceived(header);
     }
 
-    private static boolean endOfMessage(Object fragment) {
-        return fragment == InboundDecoder.PING || fragment == InboundDecoder.END_CONTENT || fragment instanceof Exception;
-    }
-
-    private ReleasableBytesReference getPendingBytes() {
-        if (pending.size() == 1) {
-            return pending.peekFirst().retain();
-        } else {
-            final ReleasableBytesReference[] bytesReferences = new ReleasableBytesReference[pending.size()];
-            int index = 0;
-            for (ReleasableBytesReference pendingReference : pending) {
-                bytesReferences[index] = pendingReference.retain();
-                ++index;
-            }
-            final Releasable releasable = () -> Releasables.closeExpectNoException(bytesReferences);
-            return new ReleasableBytesReference(CompositeBytesReference.of(bytesReferences), releasable);
-        }
-    }
-
     private void releasePendingBytes(int bytesConsumed) {
         int bytesToRelease = bytesConsumed;
         while (bytesToRelease != 0) {

+ 6 - 3
server/src/test/java/org/elasticsearch/transport/InboundDecoderTests.java

@@ -107,8 +107,6 @@ public class InboundDecoderTests extends ESTestCase {
             assertEquals(messageBytes, content);
             // Ref count is incremented since the bytes are forwarded as a fragment
             assertTrue(releasable2.hasReferences());
-            releasable2.decRef();
-            assertTrue(releasable2.hasReferences());
             assertTrue(releasable2.decRef());
             assertEquals(InboundDecoder.END_CONTENT, endMarker);
         }
@@ -423,7 +421,12 @@ public class InboundDecoderTests extends ESTestCase {
 
             final BytesReference bytes2 = totalBytes.slice(bytesConsumed, totalBytes.length() - bytesConsumed);
             final ReleasableBytesReference releasable2 = wrapAsReleasable(bytes2);
-            int bytesConsumed2 = decoder.decode(releasable2, fragments::add);
+            int bytesConsumed2 = decoder.decode(releasable2, e -> {
+                fragments.add(e);
+                if (e instanceof ReleasableBytesReference reference) {
+                    reference.retain();
+                }
+            });
             assertEquals(totalBytes.length() - totalHeaderSize, bytesConsumed2);
 
             final Object compressionScheme = fragments.get(0);

+ 10 - 12
server/src/test/java/org/elasticsearch/transport/InboundPipelineTests.java

@@ -165,12 +165,11 @@ public class InboundPipelineTests extends ESTestCase {
                     final int remainingBytes = networkBytes.length() - currentOffset;
                     final int bytesToRead = Math.min(randomIntBetween(1, 32 * 1024), remainingBytes);
                     final BytesReference slice = networkBytes.slice(currentOffset, bytesToRead);
-                    try (ReleasableBytesReference reference = new ReleasableBytesReference(slice, () -> {})) {
-                        toRelease.add(reference);
-                        bytesReceived += reference.length();
-                        pipeline.handleBytes(channel, reference);
-                        currentOffset += bytesToRead;
-                    }
+                    ReleasableBytesReference reference = new ReleasableBytesReference(slice, () -> {});
+                    toRelease.add(reference);
+                    bytesReceived += reference.length();
+                    pipeline.handleBytes(channel, reference);
+                    currentOffset += bytesToRead;
                 }
 
                 final int messages = expected.size();
@@ -294,13 +293,12 @@ public class InboundPipelineTests extends ESTestCase {
             final Releasable releasable = () -> bodyReleased.set(true);
             final int from = totalHeaderSize - 1;
             final BytesReference partHeaderPartBody = reference.slice(from, reference.length() - from - 1);
-            try (ReleasableBytesReference slice = new ReleasableBytesReference(partHeaderPartBody, releasable)) {
-                pipeline.handleBytes(new FakeTcpChannel(), slice);
-            }
+            pipeline.handleBytes(new FakeTcpChannel(), new ReleasableBytesReference(partHeaderPartBody, releasable));
             assertFalse(bodyReleased.get());
-            try (ReleasableBytesReference slice = new ReleasableBytesReference(reference.slice(reference.length() - 1, 1), releasable)) {
-                pipeline.handleBytes(new FakeTcpChannel(), slice);
-            }
+            pipeline.handleBytes(
+                new FakeTcpChannel(),
+                new ReleasableBytesReference(reference.slice(reference.length() - 1, 1), releasable)
+            );
             assertTrue(bodyReleased.get());
         }
     }