Browse Source

Fix TransportResponse reference counting in DirectResponseChannel (#91289)

In #76474 we fixed a circuit breaker leak in TransportActionProxy 
by incrementing a reference on the TransportResponse that is later 
decremented by the OutboundHandler.

This works well for all cases except when the request targets the 
node which is also the proxy node. In that case the reference is 
incremented but will never be decremented as the local execution 
(using TransportService#localNodeConnection and 
DirectResponseChannel) bypasses the OutboundHandler.

This change fixes the ref counting by also decrementing the 
TransportResponse in DirectResponseChannel.

This will also have the consequence to correctly decrement used 
bytes of the request circuit breaker when 
GetCcrRestoreFileChunkResponse are executed on a node that 
is also a proxy node.
Tanguy Leroux 3 years ago
parent
commit
163f218078

+ 5 - 0
docs/changelog/91289.yaml

@@ -0,0 +1,5 @@
+pr: 91289
+summary: Fix `TransportActionProxy` for local execution
+area: Network
+type: bug
+issues: []

+ 35 - 25
server/src/main/java/org/elasticsearch/transport/TransportService.java

@@ -1411,33 +1411,43 @@ public class TransportService extends AbstractLifecycleComponent
 
         @Override
         public void sendResponse(TransportResponse response) throws IOException {
-            service.onResponseSent(requestId, action, response);
-            try (var shutdownBlock = service.pendingDirectHandlers.withRef()) {
-                if (shutdownBlock == null) {
-                    // already shutting down, the handler will be completed by sendRequestInternal or doStop
-                    return;
-                }
-                final TransportResponseHandler<?> handler = service.responseHandlers.onResponseReceived(requestId, service);
-                if (handler == null) {
-                    // handler already completed, likely by a timeout which is logged elsewhere
-                    return;
-                }
-                final String executor = handler.executor();
-                if (ThreadPool.Names.SAME.equals(executor)) {
-                    processResponse(handler, response);
-                } else {
-                    threadPool.executor(executor).execute(new ForkingResponseHandlerRunnable(handler, null) {
-                        @Override
-                        protected void doRun() {
-                            processResponse(handler, response);
-                        }
+            try {
+                service.onResponseSent(requestId, action, response);
+                try (var shutdownBlock = service.pendingDirectHandlers.withRef()) {
+                    if (shutdownBlock == null) {
+                        // already shutting down, the handler will be completed by sendRequestInternal or doStop
+                        return;
+                    }
+                    final TransportResponseHandler<?> handler = service.responseHandlers.onResponseReceived(requestId, service);
+                    if (handler == null) {
+                        // handler already completed, likely by a timeout which is logged elsewhere
+                        return;
+                    }
+                    final String executor = handler.executor();
+                    if (ThreadPool.Names.SAME.equals(executor)) {
+                        processResponse(handler, response);
+                    } else {
+                        response.incRef();
+                        threadPool.executor(executor).execute(new ForkingResponseHandlerRunnable(handler, null) {
+                            @Override
+                            protected void doRun() {
+                                processResponse(handler, response);
+                            }
 
-                        @Override
-                        public String toString() {
-                            return "delivery of response to [" + requestId + "][" + action + "]: " + response;
-                        }
-                    });
+                            @Override
+                            public void onAfter() {
+                                response.decRef();
+                            }
+
+                            @Override
+                            public String toString() {
+                                return "delivery of response to [" + requestId + "][" + action + "]: " + response;
+                            }
+                        });
+                    }
                 }
+            } finally {
+                response.decRef();
             }
         }
 

+ 92 - 4
server/src/test/java/org/elasticsearch/transport/TransportActionProxyTests.java

@@ -14,7 +14,9 @@ import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.core.AbstractRefCounted;
 import org.elasticsearch.core.IOUtils;
+import org.elasticsearch.core.RefCounted;
 import org.elasticsearch.tasks.CancellableTask;
 import org.elasticsearch.tasks.Task;
 import org.elasticsearch.tasks.TaskId;
@@ -27,8 +29,10 @@ import org.junit.Before;
 import java.io.IOException;
 import java.util.Map;
 import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.atomic.AtomicReference;
 
 import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.notNullValue;
 
 public class TransportActionProxyTests extends ESTestCase {
     protected ThreadPool threadPool;
@@ -76,8 +80,9 @@ public class TransportActionProxyTests extends ESTestCase {
     public void testSendMessage() throws InterruptedException {
         serviceA.registerRequestHandler("internal:test", ThreadPool.Names.SAME, SimpleTestRequest::new, (request, channel, task) -> {
             assertEquals(request.sourceNode, "TS_A");
-            SimpleTestResponse response = new SimpleTestResponse("TS_A");
+            final SimpleTestResponse response = new SimpleTestResponse("TS_A");
             channel.sendResponse(response);
+            assertThat(response.hasReferences(), equalTo(false));
         });
         final boolean cancellable = randomBoolean();
         TransportActionProxy.registerProxyAction(serviceA, "internal:test", cancellable, SimpleTestResponse::new);
@@ -86,21 +91,24 @@ public class TransportActionProxyTests extends ESTestCase {
         serviceB.registerRequestHandler("internal:test", ThreadPool.Names.SAME, SimpleTestRequest::new, (request, channel, task) -> {
             assertThat(task instanceof CancellableTask, equalTo(cancellable));
             assertEquals(request.sourceNode, "TS_A");
-            SimpleTestResponse response = new SimpleTestResponse("TS_B");
+            final SimpleTestResponse response = new SimpleTestResponse("TS_B");
             channel.sendResponse(response);
+            assertThat(response.hasReferences(), equalTo(false));
         });
         TransportActionProxy.registerProxyAction(serviceB, "internal:test", cancellable, SimpleTestResponse::new);
         AbstractSimpleTransportTestCase.connectToNode(serviceB, nodeC);
         serviceC.registerRequestHandler("internal:test", ThreadPool.Names.SAME, SimpleTestRequest::new, (request, channel, task) -> {
             assertThat(task instanceof CancellableTask, equalTo(cancellable));
             assertEquals(request.sourceNode, "TS_A");
-            SimpleTestResponse response = new SimpleTestResponse("TS_C");
+            final SimpleTestResponse response = new SimpleTestResponse("TS_C");
             channel.sendResponse(response);
+            assertThat(response.hasReferences(), equalTo(false));
         });
 
         TransportActionProxy.registerProxyAction(serviceC, "internal:test", cancellable, SimpleTestResponse::new);
 
-        CountDownLatch latch = new CountDownLatch(1);
+        final CountDownLatch latch = new CountDownLatch(1);
+        // Node A -> Node B -> Node C
         serviceA.sendRequest(
             nodeB,
             TransportActionProxy.getProxyAction("internal:test"),
@@ -133,6 +141,61 @@ public class TransportActionProxyTests extends ESTestCase {
         latch.await();
     }
 
+    public void testSendLocalRequest() throws InterruptedException {
+        final AtomicReference<SimpleTestResponse> response = new AtomicReference<>();
+        final boolean cancellable = randomBoolean();
+        serviceB.registerRequestHandler(
+            "internal:test",
+            randomFrom(ThreadPool.Names.SAME, ThreadPool.Names.GENERIC),
+            SimpleTestRequest::new,
+            (request, channel, task) -> {
+                assertThat(task instanceof CancellableTask, equalTo(cancellable));
+                assertEquals(request.sourceNode, "TS_A");
+                final SimpleTestResponse responseB = new SimpleTestResponse("TS_B");
+                channel.sendResponse(responseB);
+                response.set(responseB);
+            }
+        );
+        TransportActionProxy.registerProxyAction(serviceB, "internal:test", cancellable, SimpleTestResponse::new);
+        AbstractSimpleTransportTestCase.connectToNode(serviceA, nodeB);
+
+        final CountDownLatch latch = new CountDownLatch(1);
+        // Node A -> Proxy Node B (Local execution)
+        serviceA.sendRequest(
+            nodeB,
+            TransportActionProxy.getProxyAction("internal:test"),
+            TransportActionProxy.wrapRequest(nodeB, new SimpleTestRequest("TS_A", cancellable)), // Request
+            new TransportResponseHandler<SimpleTestResponse>() {
+                @Override
+                public SimpleTestResponse read(StreamInput in) throws IOException {
+                    return new SimpleTestResponse(in);
+                }
+
+                @Override
+                public void handleResponse(SimpleTestResponse response) {
+                    try {
+                        assertEquals("TS_B", response.targetNode);
+                    } finally {
+                        latch.countDown();
+                    }
+                }
+
+                @Override
+                public void handleException(TransportException exp) {
+                    try {
+                        throw new AssertionError(exp);
+                    } finally {
+                        latch.countDown();
+                    }
+                }
+            }
+        );
+        latch.await();
+
+        assertThat(response.get(), notNullValue());
+        assertThat(response.get().hasReferences(), equalTo(false));
+    }
+
     public void testException() throws InterruptedException {
         boolean cancellable = randomBoolean();
         serviceA.registerRequestHandler("internal:test", ThreadPool.Names.SAME, SimpleTestRequest::new, (request, channel, task) -> {
@@ -230,7 +293,12 @@ public class TransportActionProxyTests extends ESTestCase {
     }
 
     public static class SimpleTestResponse extends TransportResponse {
+
         final String targetNode;
+        final RefCounted refCounted = new AbstractRefCounted() {
+            @Override
+            protected void closeInternal() {}
+        };
 
         SimpleTestResponse(String targetNode) {
             this.targetNode = targetNode;
@@ -245,6 +313,26 @@ public class TransportActionProxyTests extends ESTestCase {
         public void writeTo(StreamOutput out) throws IOException {
             out.writeString(targetNode);
         }
+
+        @Override
+        public void incRef() {
+            refCounted.incRef();
+        }
+
+        @Override
+        public boolean tryIncRef() {
+            return refCounted.tryIncRef();
+        }
+
+        @Override
+        public boolean decRef() {
+            return refCounted.decRef();
+        }
+
+        @Override
+        public boolean hasReferences() {
+            return refCounted.hasReferences();
+        }
     }
 
     public void testGetAction() {