Explorar el Código

Release master service task on timeout (#97711)

Today when a task in a master service queue times out it remains in the
queue until its batch is eventually processed. This might mean we retain
a reference to the task (and therefore any dependent listeners and other
objects) for an extended period of time after it has failed. This commit
adjusts the timeout handling to drop the reference to the task when it
completes.
David Turner hace 2 años
padre
commit
23df96cdbe

+ 5 - 0
docs/changelog/97711.yaml

@@ -0,0 +1,5 @@
+pr: 97711
+summary: Release master service task on timeout
+area: Cluster Coordination
+type: bug
+issues: []

+ 37 - 35
server/src/main/java/org/elasticsearch/cluster/service/MasterService.java

@@ -64,9 +64,9 @@ import java.util.Objects;
 import java.util.concurrent.ConcurrentLinkedQueue;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.TimeUnit;
-import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicLong;
+import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.Consumer;
 import java.util.function.LongSupplier;
 import java.util.function.Supplier;
@@ -1425,18 +1425,16 @@ public class MasterService extends AbstractLifecycleComponent {
         );
     }
 
-    private static class TaskTimeoutHandler extends AbstractRunnable {
+    private static class TaskTimeoutHandler<T extends ClusterStateTaskListener> extends AbstractRunnable {
 
         private final TimeValue timeout;
         private final String source;
-        private final AtomicBoolean executed;
-        private final ClusterStateTaskListener listener;
+        private final AtomicReference<T> taskHolder; // atomically read and set to null by at most one of {execute, timeout}
 
-        private TaskTimeoutHandler(TimeValue timeout, String source, AtomicBoolean executed, ClusterStateTaskListener listener) {
+        private TaskTimeoutHandler(TimeValue timeout, String source, AtomicReference<T> taskHolder) {
             this.timeout = timeout;
             this.source = source;
-            this.executed = executed;
-            this.listener = listener;
+            this.taskHolder = taskHolder;
         }
 
         @Override
@@ -1463,10 +1461,17 @@ public class MasterService extends AbstractLifecycleComponent {
         }
 
         private void completeTask(Exception e) {
-            if (executed.compareAndSet(false, true)) {
-                listener.onFailure(e);
+            final var task = taskHolder.getAndSet(null);
+            if (task != null) {
+                logger.trace("timing out [{}][{}] after [{}]", source, task, timeout);
+                task.onFailure(e);
             }
         }
+
+        @Override
+        public String toString() {
+            return Strings.format("master service timeout handler for [%s][%s] after [%s]", source, taskHolder.get(), timeout);
+        }
     }
 
     /**
@@ -1514,11 +1519,11 @@ public class MasterService extends AbstractLifecycleComponent {
 
         @Override
         public void submitTask(String source, T task, @Nullable TimeValue timeout) {
-            final var executed = new AtomicBoolean(false);
+            final var taskHolder = new AtomicReference<>(task);
             final Scheduler.Cancellable timeoutCancellable;
             if (timeout != null && timeout.millis() > 0) {
                 timeoutCancellable = threadPool.schedule(
-                    new TaskTimeoutHandler(timeout, source, executed, task),
+                    new TaskTimeoutHandler<>(timeout, source, taskHolder),
                     timeout,
                     ThreadPool.Names.GENERIC
                 );
@@ -1529,10 +1534,9 @@ public class MasterService extends AbstractLifecycleComponent {
             queue.add(
                 new Entry<>(
                     source,
-                    task,
+                    taskHolder,
                     insertionIndexSupplier.getAsLong(),
                     threadPool.relativeTimeInMillis(),
-                    executed,
                     threadPool.getThreadContext().newRestorableContext(true),
                     timeoutCancellable
                 )
@@ -1550,26 +1554,23 @@ public class MasterService extends AbstractLifecycleComponent {
 
         private record Entry<T extends ClusterStateTaskListener>(
             String source,
-            T task,
+            AtomicReference<T> taskHolder,
             long insertionIndex,
             long insertionTimeMillis,
-            AtomicBoolean executed,
             Supplier<ThreadContext.StoredContext> storedContextSupplier,
             @Nullable Scheduler.Cancellable timeoutCancellable
         ) {
-            boolean acquireForExecution() {
-                if (executed.compareAndSet(false, true) == false) {
-                    return false;
-                }
-
-                if (timeoutCancellable != null) {
+            T acquireForExecution() {
+                final var task = taskHolder.getAndSet(null);
+                if (task != null && timeoutCancellable != null) {
                     timeoutCancellable.cancel();
                 }
-                return true;
+                return task;
             }
 
             void onRejection(FailedToCommitClusterStateException e) {
-                if (acquireForExecution()) {
+                final var task = acquireForExecution();
+                if (task != null) {
                     try (var ignored = storedContextSupplier.get()) {
                         task.onFailure(e);
                     } catch (Exception e2) {
@@ -1579,6 +1580,10 @@ public class MasterService extends AbstractLifecycleComponent {
                     }
                 }
             }
+
+            boolean isPending() {
+                return taskHolder().get() != null;
+            }
         }
 
         private class Processor implements Batch {
@@ -1597,12 +1602,17 @@ public class MasterService extends AbstractLifecycleComponent {
                 assert executing.isEmpty() : executing;
                 final var entryCount = queueSize.getAndSet(0);
                 var taskCount = 0;
+                final var tasks = new ArrayList<ExecutionResult<T>>(entryCount);
                 for (int i = 0; i < entryCount; i++) {
                     final var entry = queue.poll();
                     assert entry != null;
-                    if (entry.acquireForExecution()) {
+                    final var task = entry.acquireForExecution();
+                    if (task != null) {
                         taskCount += 1;
                         executing.add(entry);
+                        tasks.add(
+                            new ExecutionResult<>(entry.source(), task, threadPool.getThreadContext(), entry.storedContextSupplier())
+                        );
                     }
                 }
                 if (taskCount == 0) {
@@ -1610,12 +1620,6 @@ public class MasterService extends AbstractLifecycleComponent {
                     return;
                 }
                 final var finalTaskCount = taskCount;
-                final var tasks = new ArrayList<ExecutionResult<T>>(finalTaskCount);
-                for (final var entry : executing) {
-                    tasks.add(
-                        new ExecutionResult<>(entry.source(), entry.task(), threadPool.getThreadContext(), entry.storedContextSupplier())
-                    );
-                }
                 ActionListener.run(ActionListener.runBefore(listener, () -> {
                     assert executing.size() == finalTaskCount;
                     executing.clear();
@@ -1643,9 +1647,7 @@ public class MasterService extends AbstractLifecycleComponent {
             public Stream<PendingClusterTask> getPending(long currentTimeMillis) {
                 return Stream.concat(
                     executing.stream().map(entry -> makePendingTask(entry, currentTimeMillis, true)),
-                    queue.stream()
-                        .filter(entry -> entry.executed().get() == false)
-                        .map(entry -> makePendingTask(entry, currentTimeMillis, false))
+                    queue.stream().filter(Entry::isPending).map(entry -> makePendingTask(entry, currentTimeMillis, false))
                 );
             }
 
@@ -1664,7 +1666,7 @@ public class MasterService extends AbstractLifecycleComponent {
             public int getPendingCount() {
                 int count = executing.size();
                 for (final var entry : queue) {
-                    if (entry.executed().get() == false) {
+                    if (entry.isPending()) {
                         count += 1;
                     }
                 }
@@ -1673,7 +1675,7 @@ public class MasterService extends AbstractLifecycleComponent {
 
             @Override
             public long getCreationTimeMillis() {
-                return Stream.concat(executing.stream(), queue.stream().filter(entry -> entry.executed().get() == false))
+                return Stream.concat(executing.stream(), queue.stream().filter(Entry::isPending))
                     .mapToLong(Entry::insertionTimeMillis)
                     .min()
                     .orElse(Long.MAX_VALUE);

+ 73 - 0
server/src/test/java/org/elasticsearch/cluster/service/MasterServiceTests.java

@@ -53,6 +53,7 @@ import org.elasticsearch.tasks.TaskManager;
 import org.elasticsearch.test.ClusterServiceUtils;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.MockLogAppender;
+import org.elasticsearch.test.ReachabilityChecker;
 import org.elasticsearch.test.junit.annotations.TestLogging;
 import org.elasticsearch.test.tasks.MockTaskManager;
 import org.elasticsearch.test.tasks.MockTaskManagerListener;
@@ -2379,6 +2380,78 @@ public class MasterServiceTests extends ESTestCase {
         }
     }
 
+    public void testReleaseOnTimeout() {
+
+        final var deterministicTaskQueue = new DeterministicTaskQueue();
+
+        final var threadPool = deterministicTaskQueue.getThreadPool();
+        try (var masterService = createMasterService(true, null, threadPool, new StoppableExecutorServiceWrapper(threadPool.generic()))) {
+
+            final var actionCount = new AtomicInteger();
+
+            class BlockingTask extends ClusterStateUpdateTask {
+                BlockingTask() {
+                    super(Priority.IMMEDIATE);
+                }
+
+                @Override
+                public ClusterState execute(ClusterState currentState) {
+                    var targetTime = deterministicTaskQueue.getCurrentTimeMillis() + between(1, 1000);
+                    deterministicTaskQueue.scheduleAt(targetTime, () -> {});
+
+                    while (deterministicTaskQueue.getCurrentTimeMillis() < targetTime) {
+                        deterministicTaskQueue.advanceTime();
+                    }
+
+                    return currentState;
+                }
+
+                @Override
+                public void clusterStateProcessed(ClusterState initialState, ClusterState newState) {
+                    if (actionCount.get() < 1) {
+                        masterService.submitUnbatchedStateUpdateTask("blocker", BlockingTask.this);
+                    }
+                }
+
+                @Override
+                public void onFailure(Exception e) {
+                    throw new AssertionError("unexpected", e);
+                }
+            }
+
+            masterService.submitUnbatchedStateUpdateTask("blocker", new BlockingTask());
+
+            final var queue = masterService.createTaskQueue("queue", Priority.NORMAL, batchExecutionContext -> {
+                assertEquals(1, batchExecutionContext.taskContexts().size());
+                for (final var taskContext : batchExecutionContext.taskContexts()) {
+                    taskContext.success(actionCount::incrementAndGet);
+                }
+                return batchExecutionContext.initialState();
+            });
+
+            final var reachabilityChecker = new ReachabilityChecker();
+
+            class TestTask implements ClusterStateTaskListener {
+                @Override
+                public void onFailure(Exception e) {
+                    assertThat(e, instanceOf(ProcessClusterEventTimeoutException.class));
+                    deterministicTaskQueue.scheduleNow(() -> {
+                        reachabilityChecker.ensureUnreachable();
+                        actionCount.incrementAndGet();
+                    });
+                }
+            }
+
+            final var timeout = TimeValue.timeValueMillis(between(1, 30000));
+            queue.submitTask("will timeout", reachabilityChecker.register(new TestTask()), timeout);
+            queue.submitTask("no timeout", new TestTask(), null);
+
+            threadPool.getThreadContext().markAsSystemContext();
+            deterministicTaskQueue.runAllTasks();
+            assertEquals(2, actionCount.get());
+        }
+    }
+
     public void testPrioritization() {
         final var deterministicTaskQueue = new DeterministicTaskQueue();
         final var threadPool = deterministicTaskQueue.getThreadPool();