Browse Source

Retain ref to requests when running ActionFilterChain (#104000)

`ActionFilter` implementations may be async, so we have to keep the
request alive while the chain is running.

Closes #103952
David Turner 1 year ago
parent
commit
51521e2bea

+ 15 - 4
server/src/main/java/org/elasticsearch/action/support/TransportAction.java

@@ -15,6 +15,8 @@ import org.elasticsearch.action.ActionRequest;
 import org.elasticsearch.action.ActionRequestValidationException;
 import org.elasticsearch.action.ActionResponse;
 import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.core.Releasable;
+import org.elasticsearch.core.Releasables;
 import org.elasticsearch.tasks.Task;
 import org.elasticsearch.tasks.TaskManager;
 
@@ -58,8 +60,13 @@ public abstract class TransportAction<Request extends ActionRequest, Response ex
             listener = new TaskResultStoringActionListener<>(taskManager, task, listener);
         }
 
-        RequestFilterChain<Request, Response> requestFilterChain = new RequestFilterChain<>(this, logger);
-        requestFilterChain.proceed(task, actionName, request, listener);
+        // Note on request refcounting: we can be sure that either we get to the end of the chain (and execute the actual action) or
+        // we complete the response listener and short-circuit the outer chain, so we release our request ref on both paths, using
+        // Releasables#releaseOnce to avoid a double-release.
+        request.mustIncRef();
+        final var releaseRef = Releasables.releaseOnce(request::decRef);
+        RequestFilterChain<Request, Response> requestFilterChain = new RequestFilterChain<>(this, logger, releaseRef);
+        requestFilterChain.proceed(task, actionName, request, ActionListener.runBefore(listener, releaseRef::close));
     }
 
     protected abstract void doExecute(Task task, Request request, ActionListener<Response> listener);
@@ -71,10 +78,12 @@ public abstract class TransportAction<Request extends ActionRequest, Response ex
         private final TransportAction<Request, Response> action;
         private final AtomicInteger index = new AtomicInteger();
         private final Logger logger;
+        private final Releasable releaseRef;
 
-        private RequestFilterChain(TransportAction<Request, Response> action, Logger logger) {
+        private RequestFilterChain(TransportAction<Request, Response> action, Logger logger, Releasable releaseRef) {
             this.action = action;
             this.logger = logger;
+            this.releaseRef = releaseRef;
         }
 
         @Override
@@ -84,7 +93,9 @@ public abstract class TransportAction<Request extends ActionRequest, Response ex
                 if (i < this.action.filters.length) {
                     this.action.filters[i].apply(task, actionName, request, listener, this);
                 } else if (i == this.action.filters.length) {
-                    this.action.doExecute(task, request, listener);
+                    try (releaseRef) {
+                        this.action.doExecute(task, request, listener);
+                    }
                 } else {
                     listener.onFailure(new IllegalStateException("proceed was called too many times"));
                 }

+ 207 - 0
server/src/test/java/org/elasticsearch/action/support/TransportActionFilterChainRefCountingTests.java

@@ -0,0 +1,207 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.action.support;
+
+import org.elasticsearch.ElasticsearchException;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.ActionRequest;
+import org.elasticsearch.action.ActionRequestValidationException;
+import org.elasticsearch.action.ActionResponse;
+import org.elasticsearch.action.ActionRunnable;
+import org.elasticsearch.action.ActionType;
+import org.elasticsearch.common.inject.Inject;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.util.concurrent.AbstractRunnable;
+import org.elasticsearch.common.util.concurrent.EsExecutors;
+import org.elasticsearch.core.AbstractRefCounted;
+import org.elasticsearch.core.RefCounted;
+import org.elasticsearch.plugins.ActionPlugin;
+import org.elasticsearch.plugins.Plugin;
+import org.elasticsearch.tasks.Task;
+import org.elasticsearch.test.ESSingleNodeTestCase;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.transport.LeakTracker;
+import org.elasticsearch.transport.TransportService;
+
+import java.util.Collection;
+import java.util.List;
+import java.util.Objects;
+import java.util.concurrent.CountDownLatch;
+
+public class TransportActionFilterChainRefCountingTests extends ESSingleNodeTestCase {
+
+    @Override
+    protected Collection<Class<? extends Plugin>> getPlugins() {
+        return List.of(TestPlugin.class);
+    }
+
+    static final ActionType<Response> TYPE = ActionType.localOnly("test:action");
+
+    public void testAsyncActionFilterRefCounting() {
+        final var countDownLatch = new CountDownLatch(2);
+        final var request = new Request();
+        try {
+            client().execute(TYPE, request, ActionListener.<Response>running(countDownLatch::countDown).delegateResponse((delegate, e) -> {
+                // _If_ we got an exception then it must be an ElasticsearchException with message "short-circuit failure", i.e. we're
+                // checking that nothing else can go wrong here. But it's also ok for everything to succeed too, in which case we countDown
+                // the latch without running this block.
+                assertEquals("short-circuit failure", asInstanceOf(ElasticsearchException.class, e).getMessage());
+                delegate.onResponse(null);
+            }));
+        } finally {
+            request.decRef();
+        }
+        request.addCloseListener(ActionListener.running(countDownLatch::countDown));
+        safeAwait(countDownLatch);
+    }
+
+    public static class TestPlugin extends Plugin implements ActionPlugin {
+
+        private ThreadPool threadPool;
+
+        @Override
+        public Collection<?> createComponents(PluginServices services) {
+            threadPool = services.threadPool();
+            return List.of();
+        }
+
+        @Override
+        public List<ActionHandler<? extends ActionRequest, ? extends ActionResponse>> getActions() {
+            return List.of(new ActionHandler<>(TYPE, TestAction.class));
+        }
+
+        @Override
+        public List<ActionFilter> getActionFilters() {
+            return randomSubsetOf(
+                List.of(
+                    new TestAsyncActionFilter(threadPool),
+                    new TestAsyncActionFilter(threadPool),
+                    new TestAsyncMappedActionFilter(threadPool),
+                    new TestAsyncMappedActionFilter(threadPool)
+                )
+            );
+        }
+    }
+
+    private static class TestAsyncActionFilter implements ActionFilter {
+
+        private final ThreadPool threadPool;
+        private final int order = randomInt();
+
+        private TestAsyncActionFilter(ThreadPool threadPool) {
+            this.threadPool = Objects.requireNonNull(threadPool);
+        }
+
+        @Override
+        public int order() {
+            return order;
+        }
+
+        @Override
+        public <Req extends ActionRequest, Rsp extends ActionResponse> void apply(
+            Task task,
+            String action,
+            Req request,
+            ActionListener<Rsp> listener,
+            ActionFilterChain<Req, Rsp> chain
+        ) {
+            if (action.equals(TYPE.name())) {
+                randomFrom(EsExecutors.DIRECT_EXECUTOR_SERVICE, threadPool.generic()).execute(new AbstractRunnable() {
+                    @Override
+                    public void onFailure(Exception e) {
+                        fail(e);
+                    }
+
+                    @Override
+                    protected void doRun() {
+                        assertTrue(request.hasReferences());
+                        if (randomBoolean()) {
+                            chain.proceed(task, action, request, listener);
+                        } else {
+                            listener.onFailure(new ElasticsearchException("short-circuit failure"));
+                        }
+                    }
+                });
+            } else {
+                chain.proceed(task, action, request, listener);
+            }
+        }
+    }
+
+    private static class TestAsyncMappedActionFilter extends TestAsyncActionFilter implements MappedActionFilter {
+
+        private TestAsyncMappedActionFilter(ThreadPool threadPool) {
+            super(threadPool);
+        }
+
+        @Override
+        public String actionName() {
+            return TYPE.name();
+        }
+    }
+
+    public static class TestAction extends TransportAction<Request, Response> {
+
+        private final ThreadPool threadPool;
+
+        @Inject
+        public TestAction(TransportService transportService, ActionFilters actionFilters) {
+            super(TYPE.name(), actionFilters, transportService.getTaskManager());
+            threadPool = transportService.getThreadPool();
+        }
+
+        @Override
+        protected void doExecute(Task task, Request request, ActionListener<Response> listener) {
+            request.mustIncRef();
+            threadPool.generic().execute(ActionRunnable.supply(ActionListener.runBefore(listener, request::decRef), () -> {
+                assert request.hasReferences();
+                return new Response();
+            }));
+        }
+    }
+
+    private static class Request extends ActionRequest {
+        private final SubscribableListener<Void> closeListeners = new SubscribableListener<>();
+        private final RefCounted refs = LeakTracker.wrap(AbstractRefCounted.of(() -> closeListeners.onResponse(null)));
+
+        @Override
+        public ActionRequestValidationException validate() {
+            return null;
+        }
+
+        @Override
+        public void incRef() {
+            refs.incRef();
+        }
+
+        @Override
+        public boolean tryIncRef() {
+            return refs.tryIncRef();
+        }
+
+        @Override
+        public boolean decRef() {
+            return refs.decRef();
+        }
+
+        @Override
+        public boolean hasReferences() {
+            return refs.hasReferences();
+        }
+
+        void addCloseListener(ActionListener<Void> listener) {
+            closeListeners.addListener(listener);
+        }
+    }
+
+    private static class Response extends ActionResponse {
+        @Override
+        public void writeTo(StreamOutput out) {}
+    }
+}

+ 9 - 11
server/src/test/java/org/elasticsearch/action/support/TransportActionFilterChainTests.java

@@ -135,8 +135,6 @@ public class TransportActionFilterChainTests extends ESTestCase {
     }
 
     public void testTooManyContinueProcessingRequest() throws InterruptedException {
-        final int additionalContinueCount = randomInt(10);
-
         RequestTestFilter testFilter = new RequestTestFilter(randomInt(), new RequestCallback() {
             @Override
             public <Request extends ActionRequest, Response extends ActionResponse> void execute(
@@ -146,15 +144,18 @@ public class TransportActionFilterChainTests extends ESTestCase {
                 ActionListener<Response> listener,
                 ActionFilterChain<Request, Response> actionFilterChain
             ) {
-                for (int i = 0; i <= additionalContinueCount; i++) {
-                    actionFilterChain.proceed(task, action, request, listener);
-                }
+                // expected proceed() call:
+                actionFilterChain.proceed(task, action, request, listener);
+
+                // extra, invalid, proceed() call:
+                actionFilterChain.proceed(task, action, request, listener);
             }
         });
 
         Set<ActionFilter> filters = new HashSet<>();
         filters.add(testFilter);
 
+        final CountDownLatch latch = new CountDownLatch(2);
         String actionName = randomAlphaOfLength(randomInt(30));
         ActionFilters actionFilters = new ActionFilters(filters);
         TransportAction<TestRequest, TestResponse> transportAction = new TransportAction<TestRequest, TestResponse>(
@@ -164,18 +165,16 @@ public class TransportActionFilterChainTests extends ESTestCase {
         ) {
             @Override
             protected void doExecute(Task task, TestRequest request, ActionListener<TestResponse> listener) {
-                listener.onResponse(new TestResponse());
+                latch.countDown();
             }
         };
 
-        final CountDownLatch latch = new CountDownLatch(additionalContinueCount + 1);
-        final AtomicInteger responses = new AtomicInteger();
         final List<Throwable> failures = new CopyOnWriteArrayList<>();
 
         ActionTestUtils.execute(transportAction, null, new TestRequest(), new LatchedActionListener<>(new ActionListener<>() {
             @Override
             public void onResponse(TestResponse testResponse) {
-                responses.incrementAndGet();
+                fail("should not complete listener");
             }
 
             @Override
@@ -191,8 +190,7 @@ public class TransportActionFilterChainTests extends ESTestCase {
         assertThat(testFilter.runs.get(), equalTo(1));
         assertThat(testFilter.lastActionName, equalTo(actionName));
 
-        assertThat(responses.get(), equalTo(1));
-        assertThat(failures.size(), equalTo(additionalContinueCount));
+        assertThat(failures.size(), equalTo(1));
         for (Throwable failure : failures) {
             assertThat(failure, instanceOf(IllegalStateException.class));
         }