Browse Source

Release master service task on timeout (#97711)

Today when a task in a master service queue times out it remains in the
queue until its batch is eventually processed. This might mean we retain
a reference to the task (and therefore any dependent listeners and other
objects) for an extended period of time after it has failed. This commit
adjusts the timeout handling to drop the reference to the task when it
completes.
David Turner 2 years ago
parent
commit
23df96cdbe

+ 5 - 0
docs/changelog/97711.yaml

@@ -0,0 +1,5 @@
+pr: 97711
+summary: Release master service task on timeout
+area: Cluster Coordination
+type: bug
+issues: []

+ 37 - 35
server/src/main/java/org/elasticsearch/cluster/service/MasterService.java

@@ -64,9 +64,9 @@ import java.util.Objects;
 import java.util.concurrent.ConcurrentLinkedQueue;
 import java.util.concurrent.ConcurrentLinkedQueue;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeUnit;
-import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicLong;
 import java.util.concurrent.atomic.AtomicLong;
+import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.Consumer;
 import java.util.function.Consumer;
 import java.util.function.LongSupplier;
 import java.util.function.LongSupplier;
 import java.util.function.Supplier;
 import java.util.function.Supplier;
@@ -1425,18 +1425,16 @@ public class MasterService extends AbstractLifecycleComponent {
         );
         );
     }
     }
 
 
-    private static class TaskTimeoutHandler extends AbstractRunnable {
+    private static class TaskTimeoutHandler<T extends ClusterStateTaskListener> extends AbstractRunnable {
 
 
         private final TimeValue timeout;
         private final TimeValue timeout;
         private final String source;
         private final String source;
-        private final AtomicBoolean executed;
-        private final ClusterStateTaskListener listener;
+        private final AtomicReference<T> taskHolder; // atomically read and set to null by at most one of {execute, timeout}
 
 
-        private TaskTimeoutHandler(TimeValue timeout, String source, AtomicBoolean executed, ClusterStateTaskListener listener) {
+        private TaskTimeoutHandler(TimeValue timeout, String source, AtomicReference<T> taskHolder) {
             this.timeout = timeout;
             this.timeout = timeout;
             this.source = source;
             this.source = source;
-            this.executed = executed;
-            this.listener = listener;
+            this.taskHolder = taskHolder;
         }
         }
 
 
         @Override
         @Override
@@ -1463,10 +1461,17 @@ public class MasterService extends AbstractLifecycleComponent {
         }
         }
 
 
         private void completeTask(Exception e) {
         private void completeTask(Exception e) {
-            if (executed.compareAndSet(false, true)) {
-                listener.onFailure(e);
+            final var task = taskHolder.getAndSet(null);
+            if (task != null) {
+                logger.trace("timing out [{}][{}] after [{}]", source, task, timeout);
+                task.onFailure(e);
             }
             }
         }
         }
+
+        @Override
+        public String toString() {
+            return Strings.format("master service timeout handler for [%s][%s] after [%s]", source, taskHolder.get(), timeout);
+        }
     }
     }
 
 
     /**
     /**
@@ -1514,11 +1519,11 @@ public class MasterService extends AbstractLifecycleComponent {
 
 
         @Override
         @Override
         public void submitTask(String source, T task, @Nullable TimeValue timeout) {
         public void submitTask(String source, T task, @Nullable TimeValue timeout) {
-            final var executed = new AtomicBoolean(false);
+            final var taskHolder = new AtomicReference<>(task);
             final Scheduler.Cancellable timeoutCancellable;
             final Scheduler.Cancellable timeoutCancellable;
             if (timeout != null && timeout.millis() > 0) {
             if (timeout != null && timeout.millis() > 0) {
                 timeoutCancellable = threadPool.schedule(
                 timeoutCancellable = threadPool.schedule(
-                    new TaskTimeoutHandler(timeout, source, executed, task),
+                    new TaskTimeoutHandler<>(timeout, source, taskHolder),
                     timeout,
                     timeout,
                     ThreadPool.Names.GENERIC
                     ThreadPool.Names.GENERIC
                 );
                 );
@@ -1529,10 +1534,9 @@ public class MasterService extends AbstractLifecycleComponent {
             queue.add(
             queue.add(
                 new Entry<>(
                 new Entry<>(
                     source,
                     source,
-                    task,
+                    taskHolder,
                     insertionIndexSupplier.getAsLong(),
                     insertionIndexSupplier.getAsLong(),
                     threadPool.relativeTimeInMillis(),
                     threadPool.relativeTimeInMillis(),
-                    executed,
                     threadPool.getThreadContext().newRestorableContext(true),
                     threadPool.getThreadContext().newRestorableContext(true),
                     timeoutCancellable
                     timeoutCancellable
                 )
                 )
@@ -1550,26 +1554,23 @@ public class MasterService extends AbstractLifecycleComponent {
 
 
         private record Entry<T extends ClusterStateTaskListener>(
         private record Entry<T extends ClusterStateTaskListener>(
             String source,
             String source,
-            T task,
+            AtomicReference<T> taskHolder,
             long insertionIndex,
             long insertionIndex,
             long insertionTimeMillis,
             long insertionTimeMillis,
-            AtomicBoolean executed,
             Supplier<ThreadContext.StoredContext> storedContextSupplier,
             Supplier<ThreadContext.StoredContext> storedContextSupplier,
             @Nullable Scheduler.Cancellable timeoutCancellable
             @Nullable Scheduler.Cancellable timeoutCancellable
         ) {
         ) {
-            boolean acquireForExecution() {
-                if (executed.compareAndSet(false, true) == false) {
-                    return false;
-                }
-
-                if (timeoutCancellable != null) {
+            T acquireForExecution() {
+                final var task = taskHolder.getAndSet(null);
+                if (task != null && timeoutCancellable != null) {
                     timeoutCancellable.cancel();
                     timeoutCancellable.cancel();
                 }
                 }
-                return true;
+                return task;
             }
             }
 
 
             void onRejection(FailedToCommitClusterStateException e) {
             void onRejection(FailedToCommitClusterStateException e) {
-                if (acquireForExecution()) {
+                final var task = acquireForExecution();
+                if (task != null) {
                     try (var ignored = storedContextSupplier.get()) {
                     try (var ignored = storedContextSupplier.get()) {
                         task.onFailure(e);
                         task.onFailure(e);
                     } catch (Exception e2) {
                     } catch (Exception e2) {
@@ -1579,6 +1580,10 @@ public class MasterService extends AbstractLifecycleComponent {
                     }
                     }
                 }
                 }
             }
             }
+
+            boolean isPending() {
+                return taskHolder().get() != null;
+            }
         }
         }
 
 
         private class Processor implements Batch {
         private class Processor implements Batch {
@@ -1597,12 +1602,17 @@ public class MasterService extends AbstractLifecycleComponent {
                 assert executing.isEmpty() : executing;
                 assert executing.isEmpty() : executing;
                 final var entryCount = queueSize.getAndSet(0);
                 final var entryCount = queueSize.getAndSet(0);
                 var taskCount = 0;
                 var taskCount = 0;
+                final var tasks = new ArrayList<ExecutionResult<T>>(entryCount);
                 for (int i = 0; i < entryCount; i++) {
                 for (int i = 0; i < entryCount; i++) {
                     final var entry = queue.poll();
                     final var entry = queue.poll();
                     assert entry != null;
                     assert entry != null;
-                    if (entry.acquireForExecution()) {
+                    final var task = entry.acquireForExecution();
+                    if (task != null) {
                         taskCount += 1;
                         taskCount += 1;
                         executing.add(entry);
                         executing.add(entry);
+                        tasks.add(
+                            new ExecutionResult<>(entry.source(), task, threadPool.getThreadContext(), entry.storedContextSupplier())
+                        );
                     }
                     }
                 }
                 }
                 if (taskCount == 0) {
                 if (taskCount == 0) {
@@ -1610,12 +1620,6 @@ public class MasterService extends AbstractLifecycleComponent {
                     return;
                     return;
                 }
                 }
                 final var finalTaskCount = taskCount;
                 final var finalTaskCount = taskCount;
-                final var tasks = new ArrayList<ExecutionResult<T>>(finalTaskCount);
-                for (final var entry : executing) {
-                    tasks.add(
-                        new ExecutionResult<>(entry.source(), entry.task(), threadPool.getThreadContext(), entry.storedContextSupplier())
-                    );
-                }
                 ActionListener.run(ActionListener.runBefore(listener, () -> {
                 ActionListener.run(ActionListener.runBefore(listener, () -> {
                     assert executing.size() == finalTaskCount;
                     assert executing.size() == finalTaskCount;
                     executing.clear();
                     executing.clear();
@@ -1643,9 +1647,7 @@ public class MasterService extends AbstractLifecycleComponent {
             public Stream<PendingClusterTask> getPending(long currentTimeMillis) {
             public Stream<PendingClusterTask> getPending(long currentTimeMillis) {
                 return Stream.concat(
                 return Stream.concat(
                     executing.stream().map(entry -> makePendingTask(entry, currentTimeMillis, true)),
                     executing.stream().map(entry -> makePendingTask(entry, currentTimeMillis, true)),
-                    queue.stream()
-                        .filter(entry -> entry.executed().get() == false)
-                        .map(entry -> makePendingTask(entry, currentTimeMillis, false))
+                    queue.stream().filter(Entry::isPending).map(entry -> makePendingTask(entry, currentTimeMillis, false))
                 );
                 );
             }
             }
 
 
@@ -1664,7 +1666,7 @@ public class MasterService extends AbstractLifecycleComponent {
             public int getPendingCount() {
             public int getPendingCount() {
                 int count = executing.size();
                 int count = executing.size();
                 for (final var entry : queue) {
                 for (final var entry : queue) {
-                    if (entry.executed().get() == false) {
+                    if (entry.isPending()) {
                         count += 1;
                         count += 1;
                     }
                     }
                 }
                 }
@@ -1673,7 +1675,7 @@ public class MasterService extends AbstractLifecycleComponent {
 
 
             @Override
             @Override
             public long getCreationTimeMillis() {
             public long getCreationTimeMillis() {
-                return Stream.concat(executing.stream(), queue.stream().filter(entry -> entry.executed().get() == false))
+                return Stream.concat(executing.stream(), queue.stream().filter(Entry::isPending))
                     .mapToLong(Entry::insertionTimeMillis)
                     .mapToLong(Entry::insertionTimeMillis)
                     .min()
                     .min()
                     .orElse(Long.MAX_VALUE);
                     .orElse(Long.MAX_VALUE);

+ 73 - 0
server/src/test/java/org/elasticsearch/cluster/service/MasterServiceTests.java

@@ -53,6 +53,7 @@ import org.elasticsearch.tasks.TaskManager;
 import org.elasticsearch.test.ClusterServiceUtils;
 import org.elasticsearch.test.ClusterServiceUtils;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.MockLogAppender;
 import org.elasticsearch.test.MockLogAppender;
+import org.elasticsearch.test.ReachabilityChecker;
 import org.elasticsearch.test.junit.annotations.TestLogging;
 import org.elasticsearch.test.junit.annotations.TestLogging;
 import org.elasticsearch.test.tasks.MockTaskManager;
 import org.elasticsearch.test.tasks.MockTaskManager;
 import org.elasticsearch.test.tasks.MockTaskManagerListener;
 import org.elasticsearch.test.tasks.MockTaskManagerListener;
@@ -2379,6 +2380,78 @@ public class MasterServiceTests extends ESTestCase {
         }
         }
     }
     }
 
 
+    public void testReleaseOnTimeout() {
+
+        final var deterministicTaskQueue = new DeterministicTaskQueue();
+
+        final var threadPool = deterministicTaskQueue.getThreadPool();
+        try (var masterService = createMasterService(true, null, threadPool, new StoppableExecutorServiceWrapper(threadPool.generic()))) {
+
+            final var actionCount = new AtomicInteger();
+
+            class BlockingTask extends ClusterStateUpdateTask {
+                BlockingTask() {
+                    super(Priority.IMMEDIATE);
+                }
+
+                @Override
+                public ClusterState execute(ClusterState currentState) {
+                    var targetTime = deterministicTaskQueue.getCurrentTimeMillis() + between(1, 1000);
+                    deterministicTaskQueue.scheduleAt(targetTime, () -> {});
+
+                    while (deterministicTaskQueue.getCurrentTimeMillis() < targetTime) {
+                        deterministicTaskQueue.advanceTime();
+                    }
+
+                    return currentState;
+                }
+
+                @Override
+                public void clusterStateProcessed(ClusterState initialState, ClusterState newState) {
+                    if (actionCount.get() < 1) {
+                        masterService.submitUnbatchedStateUpdateTask("blocker", BlockingTask.this);
+                    }
+                }
+
+                @Override
+                public void onFailure(Exception e) {
+                    throw new AssertionError("unexpected", e);
+                }
+            }
+
+            masterService.submitUnbatchedStateUpdateTask("blocker", new BlockingTask());
+
+            final var queue = masterService.createTaskQueue("queue", Priority.NORMAL, batchExecutionContext -> {
+                assertEquals(1, batchExecutionContext.taskContexts().size());
+                for (final var taskContext : batchExecutionContext.taskContexts()) {
+                    taskContext.success(actionCount::incrementAndGet);
+                }
+                return batchExecutionContext.initialState();
+            });
+
+            final var reachabilityChecker = new ReachabilityChecker();
+
+            class TestTask implements ClusterStateTaskListener {
+                @Override
+                public void onFailure(Exception e) {
+                    assertThat(e, instanceOf(ProcessClusterEventTimeoutException.class));
+                    deterministicTaskQueue.scheduleNow(() -> {
+                        reachabilityChecker.ensureUnreachable();
+                        actionCount.incrementAndGet();
+                    });
+                }
+            }
+
+            final var timeout = TimeValue.timeValueMillis(between(1, 30000));
+            queue.submitTask("will timeout", reachabilityChecker.register(new TestTask()), timeout);
+            queue.submitTask("no timeout", new TestTask(), null);
+
+            threadPool.getThreadContext().markAsSystemContext();
+            deterministicTaskQueue.runAllTasks();
+            assertEquals(2, actionCount.get());
+        }
+    }
+
     public void testPrioritization() {
     public void testPrioritization() {
         final var deterministicTaskQueue = new DeterministicTaskQueue();
         final var deterministicTaskQueue = new DeterministicTaskQueue();
         final var threadPool = deterministicTaskQueue.getThreadPool();
         final var threadPool = deterministicTaskQueue.getThreadPool();