Browse Source

Use underlying ByteBuf refCount for ReleasableBytesReference (#116211) (#116278)

Mikhail Berezovskiy 11 months ago
parent
commit
87117c6d7f

+ 5 - 0
docs/changelog/116211.yaml

@@ -0,0 +1,5 @@
+pr: 116211
+summary: Use underlying `ByteBuf` `refCount` for `ReleasableBytesReference`
+area: Network
+type: bug
+issues: []

+ 1 - 35
modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4MessageInboundHandler.java

@@ -14,10 +14,8 @@ import io.netty.channel.ChannelInboundHandlerAdapter;
 
 import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.ExceptionsHelper;
-import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.common.bytes.ReleasableBytesReference;
 import org.elasticsearch.common.network.ThreadWatchdog;
-import org.elasticsearch.core.RefCounted;
 import org.elasticsearch.core.Releasables;
 import org.elasticsearch.transport.InboundPipeline;
 import org.elasticsearch.transport.Transports;
@@ -52,9 +50,8 @@ public class Netty4MessageInboundHandler extends ChannelInboundHandlerAdapter {
 
         final ByteBuf buffer = (ByteBuf) msg;
         Netty4TcpChannel channel = ctx.channel().attr(Netty4Transport.CHANNEL_KEY).get();
-        final BytesReference wrapped = Netty4Utils.toBytesReference(buffer);
         activityTracker.startActivity();
-        try (ReleasableBytesReference reference = new ReleasableBytesReference(wrapped, new ByteBufRefCounted(buffer))) {
+        try (ReleasableBytesReference reference = Netty4Utils.toReleasableBytesReference(buffer)) {
             pipeline.handleBytes(channel, reference);
         } finally {
             activityTracker.stopActivity();
@@ -81,35 +78,4 @@ public class Netty4MessageInboundHandler extends ChannelInboundHandlerAdapter {
         super.channelInactive(ctx);
     }
 
-    private record ByteBufRefCounted(ByteBuf buffer) implements RefCounted {
-
-        @Override
-        public void incRef() {
-            buffer.retain();
-        }
-
-        @Override
-        public boolean tryIncRef() {
-            if (hasReferences() == false) {
-                return false;
-            }
-            try {
-                buffer.retain();
-            } catch (RuntimeException e) {
-                assert hasReferences() == false;
-                return false;
-            }
-            return true;
-        }
-
-        @Override
-        public boolean decRef() {
-            return buffer.release();
-        }
-
-        @Override
-        public boolean hasReferences() {
-            return buffer.refCnt() > 0;
-        }
-    }
 }

+ 45 - 1
modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Utils.java

@@ -32,6 +32,7 @@ import org.elasticsearch.common.recycler.Recycler;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.util.concurrent.EsExecutors;
 import org.elasticsearch.core.Booleans;
+import org.elasticsearch.core.RefCounted;
 import org.elasticsearch.core.SuppressForbidden;
 import org.elasticsearch.http.HttpBody;
 import org.elasticsearch.transport.TransportException;
@@ -130,8 +131,51 @@ public class Netty4Utils {
         }
     }
 
+    /**
+     * Wrap Netty's {@link ByteBuf} into {@link ReleasableBytesReference} and delegating reference count to ByteBuf.
+     */
     public static ReleasableBytesReference toReleasableBytesReference(final ByteBuf buffer) {
-        return new ReleasableBytesReference(toBytesReference(buffer), buffer::release);
+        return new ReleasableBytesReference(toBytesReference(buffer), toRefCounted(buffer));
+    }
+
+    static ByteBufRefCounted toRefCounted(final ByteBuf buf) {
+        return new ByteBufRefCounted(buf);
+    }
+
+    record ByteBufRefCounted(ByteBuf buffer) implements RefCounted {
+
+        public int refCnt() {
+            return buffer.refCnt();
+        }
+
+        @Override
+        public void incRef() {
+            buffer.retain();
+        }
+
+        @Override
+        public boolean tryIncRef() {
+            if (hasReferences() == false) {
+                return false;
+            }
+            try {
+                buffer.retain();
+            } catch (RuntimeException e) {
+                assert hasReferences() == false;
+                return false;
+            }
+            return true;
+        }
+
+        @Override
+        public boolean decRef() {
+            return buffer.release();
+        }
+
+        @Override
+        public boolean hasReferences() {
+            return buffer.refCnt() > 0;
+        }
     }
 
     public static HttpBody.Full fullHttpBodyFrom(final ByteBuf buf) {

+ 38 - 0
modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4UtilsTests.java

@@ -11,6 +11,7 @@ package org.elasticsearch.transport.netty4;
 
 import io.netty.buffer.ByteBuf;
 import io.netty.buffer.CompositeByteBuf;
+import io.netty.buffer.PooledByteBufAllocator;
 import io.netty.buffer.Unpooled;
 
 import org.apache.lucene.util.BytesRef;
@@ -68,6 +69,43 @@ public class Netty4UtilsTests extends ESTestCase {
         assertArrayEquals(BytesReference.toBytes(ref), BytesReference.toBytes(bytesReference));
     }
 
+    /**
+     * Test that wrapped reference counted object from netty reflects correct counts in ES RefCounted
+     */
+    public void testToRefCounted() {
+        var buf = PooledByteBufAllocator.DEFAULT.buffer(1);
+        assertEquals(1, buf.refCnt());
+
+        var refCounted = Netty4Utils.toRefCounted(buf);
+        assertEquals(1, refCounted.refCnt());
+
+        buf.retain();
+        assertEquals(2, refCounted.refCnt());
+
+        refCounted.incRef();
+        assertEquals(3, refCounted.refCnt());
+        assertEquals(buf.refCnt(), refCounted.refCnt());
+
+        refCounted.decRef();
+        assertEquals(2, refCounted.refCnt());
+        assertEquals(buf.refCnt(), refCounted.refCnt());
+        assertTrue(refCounted.hasReferences());
+
+        refCounted.decRef();
+        refCounted.decRef();
+        assertFalse(refCounted.hasReferences());
+    }
+
+    /**
+     * Ensures that released ByteBuf cannot be accessed from ReleasableBytesReference
+     */
+    public void testToReleasableBytesReferenceThrowOnByteBufRelease() {
+        var buf = PooledByteBufAllocator.DEFAULT.buffer(1);
+        var relBytes = Netty4Utils.toReleasableBytesReference(buf);
+        buf.release();
+        assertThrows(AssertionError.class, () -> relBytes.get(0));
+    }
+
     private BytesReference getRandomizedBytesReference(int length) throws IOException {
         // we know bytes stream output always creates a paged bytes reference, we use it to create randomized content
         ReleasableBytesStreamOutput out = new ReleasableBytesStreamOutput(length, bigarrays);