Browse Source

Fix memory leak when double invoking RestChannel.sendResponse (#89873)

When using the resource handling channel we must make sure that
if we (by what is IMO a bug) try to double invoke it after
having already sent a response (or tried to do so) we at least
release the memory in the channel's outbound buffer.
Otherwise we will leak any memory from it that was used to create
the now failing to send `RestResponse`.
Armin Braun 3 years ago
parent
commit
7c67116adc

+ 5 - 0
docs/changelog/89873.yaml

@@ -0,0 +1,5 @@
+pr: 89873
+summary: Fix memory leak when double invoking `RestChannel.sendResponse`
+area: Network
+type: bug
+issues: []

+ 2 - 5
server/src/main/java/org/elasticsearch/rest/AbstractRestChannel.java

@@ -186,11 +186,8 @@ public abstract class AbstractRestChannel implements RestChannel {
         return bytesOut;
     }
 
-    /**
-     * Releases the current output buffer for this channel. Must be called after the buffer derived from {@link #bytesOutput} is no longer
-     * needed.
-     */
-    protected final void releaseOutputBuffer() {
+    @Override
+    public final void releaseOutputBuffer() {
         if (bytesOut != null) {
             try {
                 bytesOut.close();

+ 6 - 0
server/src/main/java/org/elasticsearch/rest/RestChannel.java

@@ -39,6 +39,12 @@ public interface RestChannel {
 
     BytesStream bytesOutput();
 
+    /**
+     * Releases the current output buffer for this channel. Must be called after the buffer derived from {@link #bytesOutput} is no longer
+     * needed.
+     */
+    void releaseOutputBuffer();
+
     RestRequest request();
 
     /**

+ 15 - 2
server/src/main/java/org/elasticsearch/rest/RestController.java

@@ -725,6 +725,11 @@ public class RestController implements HttpServerTransport.Dispatcher {
             return delegate.bytesOutput();
         }
 
+        @Override
+        public void releaseOutputBuffer() {
+            delegate.releaseOutputBuffer();
+        }
+
         @Override
         public RestRequest request() {
             return delegate.request();
@@ -737,8 +742,16 @@ public class RestController implements HttpServerTransport.Dispatcher {
 
         @Override
         public void sendResponse(RestResponse response) {
-            close();
-            delegate.sendResponse(response);
+            boolean success = false;
+            try {
+                close();
+                delegate.sendResponse(response);
+                success = true;
+            } finally {
+                if (success == false) {
+                    releaseOutputBuffer();
+                }
+            }
         }
 
         private void close() {

+ 30 - 1
server/src/test/java/org/elasticsearch/rest/RestControllerTests.java

@@ -13,11 +13,14 @@ import org.elasticsearch.common.breaker.CircuitBreaker;
 import org.elasticsearch.common.bytes.BytesArray;
 import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.common.component.AbstractLifecycleComponent;
+import org.elasticsearch.common.io.stream.BytesStream;
+import org.elasticsearch.common.io.stream.RecyclerBytesStreamOutput;
 import org.elasticsearch.common.settings.ClusterSettings;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.transport.BoundTransportAddress;
 import org.elasticsearch.common.transport.TransportAddress;
 import org.elasticsearch.common.unit.ByteSizeValue;
+import org.elasticsearch.common.util.MockPageCacheRecycler;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
 import org.elasticsearch.core.IOUtils;
 import org.elasticsearch.core.RestApiVersion;
@@ -28,11 +31,13 @@ import org.elasticsearch.http.HttpServerTransport;
 import org.elasticsearch.http.HttpStats;
 import org.elasticsearch.indices.breaker.HierarchyCircuitBreakerService;
 import org.elasticsearch.rest.RestHandler.Route;
+import org.elasticsearch.rest.action.RestToXContentListener;
 import org.elasticsearch.tasks.Task;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.client.NoOpNodeClient;
 import org.elasticsearch.test.rest.FakeRestRequest;
 import org.elasticsearch.tracing.Tracer;
+import org.elasticsearch.transport.BytesRefRecycler;
 import org.elasticsearch.usage.UsageService;
 import org.elasticsearch.xcontent.NamedXContentRegistry;
 import org.elasticsearch.xcontent.XContentBuilder;
@@ -361,6 +366,30 @@ public class RestControllerTests extends ESTestCase {
         assertEquals(0, inFlightRequestsBreaker.getUsed());
     }
 
+    public void testDispatchRequestAddsAndFreesBytesOnlyOnceOnErrorDuringSend() {
+        int contentLength = Math.toIntExact(BREAKER_LIMIT.getBytes());
+        String content = randomAlphaOfLength((int) Math.round(contentLength / inFlightRequestsBreaker.getOverhead()));
+        // use a real recycler that tracks leaks and create some content bytes in the test handler to check for leaks
+        final BytesRefRecycler recycler = new BytesRefRecycler(new MockPageCacheRecycler(Settings.EMPTY));
+        restController.registerHandler(
+            new Route(GET, "/foo"),
+            (request, c, client) -> new RestToXContentListener<>(c).onResponse((b, p) -> b.startObject().endObject())
+        );
+        // we will produce an error in the rest handler and one more when sending the error response
+        RestRequest request = testRestRequest("/foo", content, XContentType.JSON);
+        ExceptionThrowingChannel channel = new ExceptionThrowingChannel(request, true) {
+            @Override
+            protected BytesStream newBytesOutput() {
+                return new RecyclerBytesStreamOutput(recycler);
+            }
+        };
+
+        restController.dispatchRequest(request, channel, client.threadPool().getThreadContext());
+
+        assertEquals(0, inFlightRequestsBreaker.getTrippedCount());
+        assertEquals(0, inFlightRequestsBreaker.getUsed());
+    }
+
     public void testDispatchRequestLimitsBytes() {
         int contentLength = BREAKER_LIMIT.bytesAsInt() + 1;
         String content = randomAlphaOfLength((int) Math.round(contentLength / inFlightRequestsBreaker.getOverhead()));
@@ -964,7 +993,7 @@ public class RestControllerTests extends ESTestCase {
         }
     }
 
-    private static final class ExceptionThrowingChannel extends AbstractRestChannel {
+    private static class ExceptionThrowingChannel extends AbstractRestChannel {
 
         protected ExceptionThrowingChannel(RestRequest request, boolean detailedErrorsEnabled) {
             super(request, detailedErrorsEnabled);