Browse Source

Introduce SubscribableListener#addTimeout (#95762)

We already have a couple of places where we use a `SubscribableListener`
to race between some task and a timeout. This pattern is more generally
useful, so this commit moves it into `SubscribableListener` itself.
David Turner 2 years ago
parent
commit
5091acebdb

+ 12 - 18
server/src/main/java/org/elasticsearch/action/admin/cluster/node/tasks/get/TransportGetTaskAction.java

@@ -9,7 +9,6 @@
 package org.elasticsearch.action.admin.cluster.node.tasks.get;
 
 import org.elasticsearch.ElasticsearchException;
-import org.elasticsearch.ElasticsearchTimeoutException;
 import org.elasticsearch.ExceptionsHelper;
 import org.elasticsearch.ResourceNotFoundException;
 import org.elasticsearch.action.ActionListener;
@@ -162,23 +161,18 @@ public class TransportGetTaskAction extends HandledTransportAction<GetTaskReques
                     ),
                     () -> taskManager.unregisterRemovedTaskListener(removedTaskListener)
                 );
-                if (future.isDone()) {
-                    // The task has already finished, we can run the completion listener in the same thread
-                    waitedForCompletionListener.onResponse(null);
-                } else {
-                    future.addListener(
-                        new ContextPreservingActionListener<>(
-                            threadPool.getThreadContext().newRestorableContext(false),
-                            waitedForCompletionListener
-                        )
-                    );
-                    var failByTimeout = threadPool.schedule(
-                        () -> future.onFailure(new ElasticsearchTimeoutException("Timed out waiting for completion of task")),
-                        requireNonNullElse(request.getTimeout(), DEFAULT_WAIT_FOR_COMPLETION_TIMEOUT),
-                        ThreadPool.Names.SAME
-                    );
-                    future.addListener(ActionListener.running(failByTimeout::cancel));
-                }
+
+                future.addListener(
+                    new ContextPreservingActionListener<>(
+                        threadPool.getThreadContext().newRestorableContext(false),
+                        waitedForCompletionListener
+                    )
+                );
+                future.addTimeout(
+                    requireNonNullElse(request.getTimeout(), DEFAULT_WAIT_FOR_COMPLETION_TIMEOUT),
+                    threadPool,
+                    ThreadPool.Names.SAME
+                );
             } else {
                 TaskInfo info = runningTask.taskInfo(clusterService.localNode().getId(), true);
                 listener.onResponse(new GetTaskResponse(new TaskResult(false, info)));

+ 14 - 23
server/src/main/java/org/elasticsearch/action/admin/cluster/node/tasks/list/TransportListTasksAction.java

@@ -8,14 +8,12 @@
 
 package org.elasticsearch.action.admin.cluster.node.tasks.list;
 
-import org.elasticsearch.ElasticsearchTimeoutException;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.FailedNodeException;
 import org.elasticsearch.action.TaskOperationFailure;
 import org.elasticsearch.action.support.ActionFilters;
 import org.elasticsearch.action.support.ContextPreservingActionListener;
 import org.elasticsearch.action.support.ListenableActionFuture;
-import org.elasticsearch.action.support.ThreadedActionListener;
 import org.elasticsearch.action.support.tasks.TransportTasksAction;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.inject.Inject;
@@ -123,27 +121,20 @@ public class TransportListTasksAction extends TransportTasksAction<Task, ListTas
             removalRefs.decRef();
             collectionComplete.set(true);
 
-            if (future.isDone()) {
-                // No tasks to wait, we can run nodeOperation in the management pool
-                allMatchedTasksRemovedListener.onResponse(null);
-            } else {
-                final var threadPool = clusterService.threadPool();
-                future.addListener(
-                    new ThreadedActionListener<>(
-                        threadPool.executor(ThreadPool.Names.MANAGEMENT),
-                        new ContextPreservingActionListener<>(
-                            threadPool.getThreadContext().newRestorableContext(false),
-                            allMatchedTasksRemovedListener
-                        )
-                    )
-                );
-                var cancellable = threadPool.schedule(
-                    () -> future.onFailure(new ElasticsearchTimeoutException("Timed out waiting for completion of tasks")),
-                    requireNonNullElse(request.getTimeout(), DEFAULT_WAIT_FOR_COMPLETION_TIMEOUT),
-                    ThreadPool.Names.SAME
-                );
-                future.addListener(ActionListener.running(cancellable::cancel));
-            }
+            final var threadPool = clusterService.threadPool();
+            future.addListener(
+                new ContextPreservingActionListener<>(
+                    threadPool.getThreadContext().newRestorableContext(false),
+                    allMatchedTasksRemovedListener
+                ),
+                threadPool.executor(ThreadPool.Names.MANAGEMENT),
+                null
+            );
+            future.addTimeout(
+                requireNonNullElse(request.getTimeout(), DEFAULT_WAIT_FOR_COMPLETION_TIMEOUT),
+                threadPool,
+                ThreadPool.Names.SAME
+            );
         } else {
             super.processTasks(request, operation, nodeOperation);
         }

+ 32 - 0
server/src/main/java/org/elasticsearch/action/support/SubscribableListener.java

@@ -10,6 +10,7 @@ package org.elasticsearch.action.support;
 
 import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
+import org.elasticsearch.ElasticsearchTimeoutException;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.util.concurrent.EsExecutors;
@@ -17,6 +18,8 @@ import org.elasticsearch.common.util.concurrent.ListenableFuture;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
 import org.elasticsearch.common.util.concurrent.UncategorizedExecutionException;
 import org.elasticsearch.core.Nullable;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.threadpool.ThreadPool;
 
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.Executor;
@@ -290,4 +293,33 @@ public class SubscribableListener<T> implements ActionListener<T> {
             }
         }
     }
+
+    /**
+     * Adds a timeout to this listener, such that if the timeout elapses before the listener is completed then it will be completed with an
+     * {@link ElasticsearchTimeoutException}.
+     * <p>
+     * The process which is racing against this timeout should stop and clean up promptly when the timeout occurs to avoid unnecessary
+     * work. For instance, it could check that the race is not lost by calling {@link #isDone} whenever appropriate, or it could subscribe
+     * another listener which performs any necessary cleanup steps.
+     */
+    public void addTimeout(TimeValue timeout, ThreadPool threadPool, String timeoutExecutor) {
+        if (isDone()) {
+            return;
+        }
+        addListener(ActionListener.running(scheduleTimeout(timeout, threadPool, timeoutExecutor)));
+    }
+
+    private Runnable scheduleTimeout(TimeValue timeout, ThreadPool threadPool, String timeoutExecutor) {
+        try {
+            final var cancellable = threadPool.schedule(
+                () -> onFailure(new ElasticsearchTimeoutException(Strings.format("timed out after [%s/%dms]", timeout, timeout.millis()))),
+                timeout,
+                timeoutExecutor
+            );
+            return cancellable::cancel;
+        } catch (Exception e) {
+            onFailure(e);
+            return () -> {};
+        }
+    }
 }

+ 81 - 0
server/src/test/java/org/elasticsearch/action/support/SubscribableListenerTests.java

@@ -9,10 +9,13 @@
 package org.elasticsearch.action.support;
 
 import org.elasticsearch.ElasticsearchException;
+import org.elasticsearch.ElasticsearchTimeoutException;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.util.concurrent.DeterministicTaskQueue;
 import org.elasticsearch.common.util.concurrent.EsExecutors;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
+import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.threadpool.ThreadPool;
 
@@ -26,6 +29,7 @@ import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.IntFunction;
 
 import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.instanceOf;
 
 public class SubscribableListenerTests extends ESTestCase {
 
@@ -292,4 +296,81 @@ public class SubscribableListenerTests extends ESTestCase {
 
         assertTrue(completion.get());
     }
+
+    public void testTimeoutBeforeCompletion() {
+        final var deterministicTaskQueue = new DeterministicTaskQueue();
+        final var threadPool = deterministicTaskQueue.getThreadPool();
+
+        final var headerName = "test-header-name";
+        final var headerValue = randomAlphaOfLength(10);
+
+        final var timedOut = new AtomicBoolean();
+        final var listener = new SubscribableListener<Void>();
+        listener.addListener(new ActionListener<>() {
+            @Override
+            public void onResponse(Void unused) {
+                fail("should not execute");
+            }
+
+            @Override
+            public void onFailure(Exception e) {
+                assertThat(e, instanceOf(ElasticsearchTimeoutException.class));
+                assertEquals("timed out after [30s/30000ms]", e.getMessage());
+                assertEquals(headerValue, threadPool.getThreadContext().getHeader(headerName));
+                assertTrue(timedOut.compareAndSet(false, true));
+            }
+        });
+        try (var ignored = threadPool.getThreadContext().stashContext()) {
+            threadPool.getThreadContext().putHeader(headerName, headerValue);
+            listener.addTimeout(TimeValue.timeValueSeconds(30), threadPool, randomFrom(ThreadPool.Names.SAME, ThreadPool.Names.GENERIC));
+        }
+
+        if (randomBoolean()) {
+            deterministicTaskQueue.scheduleAt(
+                deterministicTaskQueue.getCurrentTimeMillis() + randomLongBetween(
+                    TimeValue.timeValueSeconds(30).millis() + 1,
+                    TimeValue.timeValueSeconds(60).millis()
+                ),
+                () -> listener.onResponse(null)
+            );
+        }
+
+        assertFalse(timedOut.get());
+        assertFalse(listener.isDone());
+        deterministicTaskQueue.runAllTasksInTimeOrder();
+        assertTrue(timedOut.get());
+        assertTrue(listener.isDone());
+    }
+
+    public void testCompletionBeforeTimeout() {
+        final var deterministicTaskQueue = new DeterministicTaskQueue();
+        final var threadPool = deterministicTaskQueue.getThreadPool();
+
+        final var complete = new AtomicBoolean();
+        final var listener = new SubscribableListener<Void>();
+        listener.addListener(new ActionListener<>() {
+            @Override
+            public void onResponse(Void unused) {
+                assertTrue(complete.compareAndSet(false, true));
+            }
+
+            @Override
+            public void onFailure(Exception e) {
+                fail("should not fail");
+            }
+        });
+        listener.addTimeout(TimeValue.timeValueSeconds(30), threadPool, randomFrom(ThreadPool.Names.SAME, ThreadPool.Names.GENERIC));
+
+        deterministicTaskQueue.scheduleAt(
+            deterministicTaskQueue.getCurrentTimeMillis() + randomLongBetween(0, TimeValue.timeValueSeconds(30).millis() - 1),
+            () -> listener.onResponse(null)
+        );
+
+        assertFalse(complete.get());
+        assertFalse(listener.isDone());
+        deterministicTaskQueue.runAllTasksInTimeOrder();
+        assertTrue(complete.get());
+        assertTrue(listener.isDone());
+    }
+
 }