Browse Source

Rework listeners in MasterService (#83301)

Tasks submitted to the master service may optionally implement the
`AckedClusterStateTaskListener` interface to listen for acks. Today this
interface extends `ClusterStateTaskListener` but the two concerns are
pretty much orthogonal so this commit removes the inheritance between
the listeners.

The master service listens for various events in the publication process
using these general-purpose listener interfaces, even though in the
implementation we know exactly what class we're using throughout. The
extra generality gets in the way of some planned changes, so with this
commit we reorganise things to use the concrete classes where known and
inline some unnecessary abstractions.
David Turner 3 years ago
parent
commit
5f21df692f

+ 1 - 1
server/src/main/java/org/elasticsearch/cluster/AckedClusterStateUpdateTask.java

@@ -19,7 +19,7 @@ import org.elasticsearch.core.TimeValue;
  * An extension interface to {@link ClusterStateUpdateTask} that allows to be notified when
  * all the nodes have acknowledged a cluster state update request
  */
-public abstract class AckedClusterStateUpdateTask extends ClusterStateUpdateTask implements AckedClusterStateTaskListener {
+public abstract class AckedClusterStateUpdateTask extends ClusterStateUpdateTask implements ClusterStateAckListener {
 
     private final ActionListener<AcknowledgedResponse> listener;
     private final AckedRequest request;

+ 7 - 1
server/src/main/java/org/elasticsearch/cluster/AckedClusterStateTaskListener.java → server/src/main/java/org/elasticsearch/cluster/ClusterStateAckListener.java

@@ -11,7 +11,13 @@ import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.TimeValue;
 
-public interface AckedClusterStateTaskListener extends ClusterStateTaskListener {
+/**
+ * Interface that a cluster state update task can implement to indicate that it wishes to be notified when the update has been acked by
+ * (some subset of) the nodes in the cluster. Nodes ack a cluster state update after successfully applying the resulting state. Note that
+ * updates which do not change the cluster state are automatically reported as acked by all nodes without checking to see whether there are
+ * any nodes that have not already applied this state.
+ */
+public interface ClusterStateAckListener {
 
     /**
      * Called to determine which nodes the acknowledgement is expected from.

+ 3 - 2
server/src/main/java/org/elasticsearch/cluster/metadata/MetadataMappingService.java

@@ -13,10 +13,11 @@ import org.apache.logging.log4j.Logger;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.admin.indices.mapping.put.PutMappingClusterStateUpdateRequest;
 import org.elasticsearch.action.support.master.AcknowledgedResponse;
-import org.elasticsearch.cluster.AckedClusterStateTaskListener;
 import org.elasticsearch.cluster.ClusterState;
+import org.elasticsearch.cluster.ClusterStateAckListener;
 import org.elasticsearch.cluster.ClusterStateTaskConfig;
 import org.elasticsearch.cluster.ClusterStateTaskExecutor;
+import org.elasticsearch.cluster.ClusterStateTaskListener;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.Priority;
@@ -56,7 +57,7 @@ public class MetadataMappingService {
         this.indicesService = indicesService;
     }
 
-    static class PutMappingClusterStateUpdateTask implements AckedClusterStateTaskListener {
+    static class PutMappingClusterStateUpdateTask implements ClusterStateTaskListener, ClusterStateAckListener {
 
         private final PutMappingClusterStateUpdateRequest request;
         private final ActionListener<AcknowledgedResponse> listener;

+ 130 - 151
server/src/main/java/org/elasticsearch/cluster/service/MasterService.java

@@ -13,9 +13,9 @@ import org.apache.logging.log4j.Logger;
 import org.apache.logging.log4j.message.ParameterizedMessage;
 import org.elasticsearch.Assertions;
 import org.elasticsearch.action.support.PlainActionFuture;
-import org.elasticsearch.cluster.AckedClusterStateTaskListener;
 import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.cluster.ClusterState.Builder;
+import org.elasticsearch.cluster.ClusterStateAckListener;
 import org.elasticsearch.cluster.ClusterStatePublicationEvent;
 import org.elasticsearch.cluster.ClusterStateTaskConfig;
 import org.elasticsearch.cluster.ClusterStateTaskExecutor;
@@ -152,7 +152,7 @@ public class MasterService extends AbstractLifecycleComponent {
             threadPool.generic()
                 .execute(
                     () -> tasks.forEach(
-                        task -> ((UpdateTask) task).listener.onFailure(new ProcessClusterEventTimeoutException(timeout, task.source))
+                        task -> ((UpdateTask) task).onFailure(new ProcessClusterEventTimeoutException(timeout, task.source))
                     )
                 );
         }
@@ -165,17 +165,27 @@ public class MasterService extends AbstractLifecycleComponent {
         }
 
         class UpdateTask extends BatchedTask {
-            final ClusterStateTaskListener listener;
+            private final ClusterStateTaskListener listener;
+            private final Supplier<ThreadContext.StoredContext> threadContextSupplier;
+
+            @Nullable
+            private final ContextPreservingAckListener contextPreservingAckListener;
 
             UpdateTask(
                 Priority priority,
                 String source,
-                Object task,
-                ClusterStateTaskListener listener,
+                ClusterStateTaskListener task,
+                Supplier<ThreadContext.StoredContext> threadContextSupplier,
                 ClusterStateTaskExecutor<?> executor
             ) {
                 super(priority, source, executor, task);
-                this.listener = listener;
+                this.threadContextSupplier = threadContextSupplier;
+                this.listener = task;
+                if (task instanceof ClusterStateAckListener clusterStateAckListener) {
+                    this.contextPreservingAckListener = new ContextPreservingAckListener(clusterStateAckListener, threadContextSupplier);
+                } else {
+                    this.contextPreservingAckListener = null;
+                }
             }
 
             @Override
@@ -185,6 +195,50 @@ public class MasterService extends AbstractLifecycleComponent {
                 );
             }
 
+            public void onFailure(Exception e) {
+                try (ThreadContext.StoredContext ignore = threadContextSupplier.get()) {
+                    listener.onFailure(e);
+                } catch (Exception inner) {
+                    inner.addSuppressed(e);
+                    logger.error("exception thrown by listener notifying of failure", inner);
+                }
+            }
+
+            public void onNoLongerMaster() {
+                try (ThreadContext.StoredContext ignore = threadContextSupplier.get()) {
+                    listener.onNoLongerMaster();
+                } catch (Exception e) {
+                    logger.error("exception thrown by listener while notifying no longer master", e);
+                }
+            }
+
+            public void clusterStateProcessed(ClusterState oldState, ClusterState newState) {
+                try (ThreadContext.StoredContext ignore = threadContextSupplier.get()) {
+                    listener.clusterStateProcessed(oldState, newState);
+                } catch (Exception e) {
+                    logger.error(() -> new ParameterizedMessage("""
+                        exception thrown by listener while notifying of cluster state, old cluster state:
+                        {}
+                        new cluster state:
+                        {}""", oldState, newState), e);
+                }
+            }
+
+            @Nullable
+            public TaskAckListener createTaskAckListener(long clusterStateVersion, DiscoveryNodes nodes) {
+                return contextPreservingAckListener == null
+                    ? null
+                    : new TaskAckListener(contextPreservingAckListener, clusterStateVersion, nodes, threadPool);
+            }
+
+            public void clusterStateUnchanged(ClusterState clusterState) {
+                if (contextPreservingAckListener != null) {
+                    // no need to wait for ack if nothing changed, the update can be counted as acknowledged
+                    contextPreservingAckListener.onAllNodesAcked(null);
+                }
+                clusterStateProcessed(clusterState, clusterState);
+            }
+
             @Override
             public ClusterStateTaskListener getTask() {
                 return (ClusterStateTaskListener) task;
@@ -299,7 +353,7 @@ public class MasterService extends AbstractLifecycleComponent {
         clusterStatePublisher.publish(
             clusterStatePublicationEvent,
             fut,
-            taskOutputs.createAckListener(threadPool, clusterStatePublicationEvent.getNewState())
+            taskOutputs.createAckListener(clusterStatePublicationEvent.getNewState())
         );
 
         // indefinitely wait for publication to complete
@@ -433,8 +487,9 @@ public class MasterService extends AbstractLifecycleComponent {
     /**
      * Submits a cluster state update task
      * @param source     the source of the cluster state update task
-     * @param updateTask the full context for the cluster state update
-     * @param executor
+     * @param updateTask the full context for the cluster state update, which implements {@link ClusterStateTaskListener} so that it is
+     *                   notified when it is executed; tasks that also implement {@link ClusterStateAckListener} are notified on acks too.
+     * @param executor   the executor for the task; tasks that share the same executor instance may be batched together
      *
      */
     public <T extends ClusterStateTaskConfig & ClusterStateTaskListener> void submitStateUpdateTask(
@@ -454,7 +509,8 @@ public class MasterService extends AbstractLifecycleComponent {
      * tasks will all be executed on the executor in a single batch
      *
      * @param source   the source of the cluster state update task
-     * @param task     the state and the callback needed for the cluster state update task
+     * @param task     the state needed for the cluster state update task, which implements {@link ClusterStateTaskListener} so that it is
+     *                 notified when it is executed; tasks that also implement {@link ClusterStateAckListener} are notified on acks too.
      * @param config   the cluster state update task configuration
      * @param executor the cluster state update task executor; tasks
      *                 that share the same executor will be executed
@@ -496,29 +552,22 @@ public class MasterService extends AbstractLifecycleComponent {
         }
 
         void publishingFailed(FailedToCommitClusterStateException t) {
-            nonFailedTasks.forEach(task -> task.listener.onFailure(t));
+            nonFailedTasks.forEach(task -> task.onFailure(t));
         }
 
         void processedDifferentClusterState(ClusterState previousClusterState, ClusterState newClusterState) {
-            nonFailedTasks.forEach(task -> task.listener.clusterStateProcessed(previousClusterState, newClusterState));
+            nonFailedTasks.forEach(task -> task.clusterStateProcessed(previousClusterState, newClusterState));
         }
 
         void clusterStatePublished(ClusterStatePublicationEvent clusterStatePublicationEvent) {
             taskInputs.executor.clusterStatePublished(clusterStatePublicationEvent);
         }
 
-        ClusterStatePublisher.AckListener createAckListener(ThreadPool threadPool, ClusterState newClusterState) {
-            return new DelegatingAckListener(
+        ClusterStatePublisher.AckListener createAckListener(ClusterState newClusterState) {
+            return new CompositeTaskAckListener(
                 nonFailedTasks.stream()
-                    .filter(task -> task.listener instanceof AckedClusterStateTaskListener)
-                    .map(
-                        task -> new AckCountDownListener(
-                            (AckedClusterStateTaskListener) task.listener,
-                            newClusterState.version(),
-                            newClusterState.nodes(),
-                            threadPool
-                        )
-                    )
+                    .map(task -> task.createTaskAckListener(newClusterState.version(), newClusterState.nodes()))
+                    .filter(Objects::nonNull)
                     .collect(Collectors.toList())
             );
         }
@@ -533,19 +582,13 @@ public class MasterService extends AbstractLifecycleComponent {
                 assert executionResults.containsKey(updateTask.task) : "missing " + updateTask;
                 final ClusterStateTaskExecutor.TaskResult taskResult = executionResults.get(updateTask.task);
                 if (taskResult.isSuccess() == false) {
-                    updateTask.listener.onFailure(taskResult.getFailure());
+                    updateTask.onFailure(taskResult.getFailure());
                 }
             }
         }
 
         void notifySuccessfulTasksOnUnchangedClusterState() {
-            nonFailedTasks.forEach(task -> {
-                if (task.listener instanceof AckedClusterStateTaskListener) {
-                    // no need to wait for ack if nothing changed, the update can be counted as acknowledged
-                    ((AckedClusterStateTaskListener) task.listener).onAllNodesAcked(null);
-                }
-                task.listener.clusterStateProcessed(newClusterState, newClusterState);
-            });
+            nonFailedTasks.forEach(task -> task.clusterStateUnchanged(newClusterState));
         }
     }
 
@@ -584,78 +627,32 @@ public class MasterService extends AbstractLifecycleComponent {
         return threadPoolExecutor.getMaxTaskWaitTime();
     }
 
-    private SafeClusterStateTaskListener safe(ClusterStateTaskListener listener, Supplier<ThreadContext.StoredContext> contextSupplier) {
-        if (listener instanceof AckedClusterStateTaskListener) {
-            return new SafeAckedClusterStateTaskListener((AckedClusterStateTaskListener) listener, contextSupplier, logger);
+    private void logExecutionTime(TimeValue executionTime, String activity, String summary) {
+        if (executionTime.getMillis() > slowTaskLoggingThreshold.getMillis()) {
+            logger.warn(
+                "took [{}/{}ms] to {} for [{}], which exceeds the warn threshold of [{}]",
+                executionTime,
+                executionTime.getMillis(),
+                activity,
+                summary,
+                slowTaskLoggingThreshold
+            );
         } else {
-            return new SafeClusterStateTaskListener(listener, contextSupplier, logger);
-        }
-    }
-
-    private static class SafeClusterStateTaskListener implements ClusterStateTaskListener {
-        private final ClusterStateTaskListener listener;
-        protected final Supplier<ThreadContext.StoredContext> context;
-        private final Logger logger;
-
-        SafeClusterStateTaskListener(ClusterStateTaskListener listener, Supplier<ThreadContext.StoredContext> context, Logger logger) {
-            this.listener = listener;
-            this.context = context;
-            this.logger = logger;
-        }
-
-        @Override
-        public void onFailure(Exception e) {
-            try (ThreadContext.StoredContext ignore = context.get()) {
-                listener.onFailure(e);
-            } catch (Exception inner) {
-                inner.addSuppressed(e);
-                logger.error("exception thrown by listener notifying of failure", inner);
-            }
-        }
-
-        @Override
-        public void onNoLongerMaster() {
-            try (ThreadContext.StoredContext ignore = context.get()) {
-                listener.onNoLongerMaster();
-            } catch (Exception e) {
-                logger.error("exception thrown by listener while notifying no longer master", e);
-            }
-        }
-
-        @Override
-        public void clusterStateProcessed(ClusterState oldState, ClusterState newState) {
-            try (ThreadContext.StoredContext ignore = context.get()) {
-                listener.clusterStateProcessed(oldState, newState);
-            } catch (Exception e) {
-                logger.error(() -> new ParameterizedMessage("""
-                    exception thrown by listener while notifying of cluster state, old cluster state:
-                    {}
-                    new cluster state:
-                    {}""", oldState, newState), e);
-            }
+            logger.debug("took [{}] to {} for [{}]", executionTime, activity, summary);
         }
     }
 
-    private static class SafeAckedClusterStateTaskListener extends SafeClusterStateTaskListener implements AckedClusterStateTaskListener {
-        private final AckedClusterStateTaskListener listener;
-        private final Logger logger;
-
-        SafeAckedClusterStateTaskListener(
-            AckedClusterStateTaskListener listener,
-            Supplier<ThreadContext.StoredContext> context,
-            Logger logger
-        ) {
-            super(listener, context, logger);
-            this.listener = listener;
-            this.logger = logger;
-        }
+    /**
+     * A wrapper around a {@link ClusterStateAckListener} which restores the given thread context before delegating to the inner listener's
+     * callbacks, and also logs and swallows any exceptions thrown. One of these is created for each task in the batch that implements
+     * {@link ClusterStateAckListener}.
+     */
+    private record ContextPreservingAckListener(ClusterStateAckListener listener, Supplier<ThreadContext.StoredContext> context) {
 
-        @Override
         public boolean mustAck(DiscoveryNode discoveryNode) {
             return listener.mustAck(discoveryNode);
         }
 
-        @Override
         public void onAllNodesAcked(@Nullable Exception e) {
             try (ThreadContext.StoredContext ignore = context.get()) {
                 listener.onAllNodesAcked(e);
@@ -665,7 +662,6 @@ public class MasterService extends AbstractLifecycleComponent {
             }
         }
 
-        @Override
         public void onAckTimeout() {
             try (ThreadContext.StoredContext ignore = context.get()) {
                 listener.onAckTimeout();
@@ -674,55 +670,19 @@ public class MasterService extends AbstractLifecycleComponent {
             }
         }
 
-        @Override
         public TimeValue ackTimeout() {
             return listener.ackTimeout();
         }
     }
 
-    private void logExecutionTime(TimeValue executionTime, String activity, String summary) {
-        if (executionTime.getMillis() > slowTaskLoggingThreshold.getMillis()) {
-            logger.warn(
-                "took [{}/{}ms] to {} for [{}], which exceeds the warn threshold of [{}]",
-                executionTime,
-                executionTime.getMillis(),
-                activity,
-                summary,
-                slowTaskLoggingThreshold
-            );
-        } else {
-            logger.debug("took [{}] to {} for [{}]", executionTime, activity, summary);
-        }
-    }
-
-    private static class DelegatingAckListener implements ClusterStatePublisher.AckListener {
-
-        private final List<ClusterStatePublisher.AckListener> listeners;
-
-        private DelegatingAckListener(List<ClusterStatePublisher.AckListener> listeners) {
-            this.listeners = listeners;
-        }
-
-        @Override
-        public void onCommit(TimeValue commitTime) {
-            for (ClusterStatePublisher.AckListener listener : listeners) {
-                listener.onCommit(commitTime);
-            }
-        }
-
-        @Override
-        public void onNodeAck(DiscoveryNode node, @Nullable Exception e) {
-            for (ClusterStatePublisher.AckListener listener : listeners) {
-                listener.onNodeAck(node, e);
-            }
-        }
-    }
-
-    private static class AckCountDownListener implements ClusterStatePublisher.AckListener {
-
-        private static final Logger logger = LogManager.getLogger(AckCountDownListener.class);
+    /**
+     * A wrapper around a {@link ContextPreservingAckListener} which keeps track of acks received during publication and notifies the inner
+     * listener when sufficiently many have been received. One of these is created for each {@link ContextPreservingAckListener} once the
+     * state for publication has been computed.
+     */
+    private static class TaskAckListener {
 
-        private final AckedClusterStateTaskListener ackedTaskListener;
+        private final ContextPreservingAckListener contextPreservingAckListener;
         private final CountDown countDown;
         private final DiscoveryNode masterNode;
         private final ThreadPool threadPool;
@@ -730,20 +690,20 @@ public class MasterService extends AbstractLifecycleComponent {
         private volatile Scheduler.Cancellable ackTimeoutCallback;
         private Exception lastFailure;
 
-        AckCountDownListener(
-            AckedClusterStateTaskListener ackedTaskListener,
+        TaskAckListener(
+            ContextPreservingAckListener contextPreservingAckListener,
             long clusterStateVersion,
             DiscoveryNodes nodes,
             ThreadPool threadPool
         ) {
-            this.ackedTaskListener = ackedTaskListener;
+            this.contextPreservingAckListener = contextPreservingAckListener;
             this.clusterStateVersion = clusterStateVersion;
             this.threadPool = threadPool;
             this.masterNode = nodes.getMasterNode();
             int countDown = 0;
             for (DiscoveryNode node : nodes) {
                 // we always wait for at least the master node
-                if (node.equals(masterNode) || ackedTaskListener.mustAck(node)) {
+                if (node.equals(masterNode) || contextPreservingAckListener.mustAck(node)) {
                     countDown++;
                 }
             }
@@ -751,9 +711,8 @@ public class MasterService extends AbstractLifecycleComponent {
             this.countDown = new CountDown(countDown + 1); // we also wait for onCommit to be called
         }
 
-        @Override
         public void onCommit(TimeValue commitTime) {
-            TimeValue ackTimeout = ackedTaskListener.ackTimeout();
+            TimeValue ackTimeout = contextPreservingAckListener.ackTimeout();
             if (ackTimeout == null) {
                 ackTimeout = TimeValue.ZERO;
             }
@@ -771,9 +730,8 @@ public class MasterService extends AbstractLifecycleComponent {
             }
         }
 
-        @Override
         public void onNodeAck(DiscoveryNode node, @Nullable Exception e) {
-            if (node.equals(masterNode) == false && ackedTaskListener.mustAck(node) == false) {
+            if (node.equals(masterNode) == false && contextPreservingAckListener.mustAck(node) == false) {
                 return;
             }
             if (e == null) {
@@ -800,13 +758,33 @@ public class MasterService extends AbstractLifecycleComponent {
             if (ackTimeoutCallback != null) {
                 ackTimeoutCallback.cancel();
             }
-            ackedTaskListener.onAllNodesAcked(lastFailure);
+            contextPreservingAckListener.onAllNodesAcked(lastFailure);
         }
 
         public void onTimeout() {
             if (countDown.fastForward()) {
                 logger.trace("timeout waiting for acknowledgement for cluster_state update (version: {})", clusterStateVersion);
-                ackedTaskListener.onAckTimeout();
+                contextPreservingAckListener.onAckTimeout();
+            }
+        }
+    }
+
+    /**
+     * A wrapper around the collection of {@link TaskAckListener}s for a publication.
+     */
+    private record CompositeTaskAckListener(List<TaskAckListener> listeners) implements ClusterStatePublisher.AckListener {
+
+        @Override
+        public void onCommit(TimeValue commitTime) {
+            for (TaskAckListener listener : listeners) {
+                listener.onCommit(commitTime);
+            }
+        }
+
+        @Override
+        public void onNodeAck(DiscoveryNode node, @Nullable Exception e) {
+            for (TaskAckListener listener : listeners) {
+                listener.onNodeAck(node, e);
             }
         }
     }
@@ -890,7 +868,7 @@ public class MasterService extends AbstractLifecycleComponent {
         }
 
         void onNoLongerMaster() {
-            updateTasks.forEach(task -> task.listener.onNoLongerMaster());
+            updateTasks.forEach(task -> task.onNoLongerMaster());
         }
     }
 
@@ -899,7 +877,8 @@ public class MasterService extends AbstractLifecycleComponent {
      * potentially with more tasks of the same executor.
      *
      * @param source   the source of the cluster state update task
-     * @param tasks    a collection of update tasks and their corresponding listeners
+     * @param tasks    a collection of update tasks, which implement {@link ClusterStateTaskListener} so that they are notified when they
+     *                 are executed; tasks that also implement {@link ClusterStateAckListener} are notified on acks too.
      * @param config   the cluster state update task configuration
      * @param executor the cluster state update task executor; tasks
      *                 that share the same executor will be executed
@@ -922,7 +901,7 @@ public class MasterService extends AbstractLifecycleComponent {
             threadContext.markAsSystemContext();
 
             List<Batcher.UpdateTask> safeTasks = tasks.stream()
-                .map(e -> taskBatcher.new UpdateTask(config.priority(), source, e, safe(e, supplier), executor))
+                .map(task -> taskBatcher.new UpdateTask(config.priority(), source, task, supplier, executor))
                 .toList();
             taskBatcher.submitTasks(safeTasks, config.timeout());
         } catch (EsRejectedExecutionException e) {

+ 1 - 1
test/framework/src/main/java/org/elasticsearch/cluster/service/FakeThreadPoolMasterService.java

@@ -130,7 +130,7 @@ public class FakeThreadPoolMasterService extends MasterService {
     protected void publish(ClusterStatePublicationEvent clusterStatePublicationEvent, TaskOutputs taskOutputs) {
         assert waitForPublish == false;
         waitForPublish = true;
-        final AckListener ackListener = taskOutputs.createAckListener(threadPool, clusterStatePublicationEvent.getNewState());
+        final AckListener ackListener = taskOutputs.createAckListener(clusterStatePublicationEvent.getNewState());
         final ActionListener<Void> publishListener = new ActionListener<>() {
 
             private boolean listenerCalled = false;