1
0
Эх сурвалжийг харах

Accept only single tasks at master service (#83829)

Today `MasterService` (and `TaskBatcher`) allow callers to submit a
collection of tasks that will be executed all at once. Support for
batches of tasks makes things more complicated than they need to be,
noting that (since #83803) in production code we only ever submit single
tasks. This commit specializes things to accept only single tasks.
David Turner 3 жил өмнө
parent
commit
dd4d442b05

+ 1 - 3
server/src/main/java/org/elasticsearch/cluster/service/ClusterService.java

@@ -28,8 +28,6 @@ import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.node.Node;
 import org.elasticsearch.threadpool.ThreadPool;
 
-import java.util.List;
-
 public class ClusterService extends AbstractLifecycleComponent {
     private final MasterService masterService;
 
@@ -259,7 +257,7 @@ public class ClusterService extends AbstractLifecycleComponent {
         ClusterStateTaskConfig config,
         ClusterStateTaskExecutor<T> executor
     ) {
-        masterService.submitStateUpdateTasks(source, List.of(task), config, executor);
+        masterService.submitStateUpdateTask(source, task, config, executor);
     }
 
 }

+ 17 - 49
server/src/main/java/org/elasticsearch/cluster/service/MasterService.java

@@ -46,7 +46,6 @@ import org.elasticsearch.threadpool.Scheduler;
 import org.elasticsearch.threadpool.ThreadPool;
 
 import java.util.Arrays;
-import java.util.Collection;
 import java.util.List;
 import java.util.Locale;
 import java.util.Map;
@@ -150,13 +149,9 @@ public class MasterService extends AbstractLifecycleComponent {
         }
 
         @Override
-        protected void onTimeout(List<? extends BatchedTask> tasks, TimeValue timeout) {
+        protected void onTimeout(BatchedTask task, TimeValue timeout) {
             threadPool.generic()
-                .execute(
-                    () -> tasks.forEach(
-                        task -> ((UpdateTask) task).onFailure(new ProcessClusterEventTimeoutException(timeout, task.source))
-                    )
-                );
+                .execute(() -> ((UpdateTask) task).onFailure(new ProcessClusterEventTimeoutException(timeout, task.source)));
         }
 
         @Override
@@ -506,7 +501,21 @@ public class MasterService extends AbstractLifecycleComponent {
         ClusterStateTaskConfig config,
         ClusterStateTaskExecutor<T> executor
     ) {
-        submitStateUpdateTasks(source, List.of(task), config, executor);
+        if (lifecycle.started() == false) {
+            return;
+        }
+        final ThreadContext threadContext = threadPool.getThreadContext();
+        final Supplier<ThreadContext.StoredContext> supplier = threadContext.newRestorableContext(true);
+        try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
+            threadContext.markAsSystemContext();
+            taskBatcher.submitTask(taskBatcher.new UpdateTask(config.priority(), source, task, supplier, executor), config.timeout());
+        } catch (EsRejectedExecutionException e) {
+            // ignore cases where we are shutting down..., there is really nothing interesting
+            // to be done here...
+            if (lifecycle.stoppedOrClosed() == false) {
+                throw e;
+            }
+        }
     }
 
     /**
@@ -903,47 +912,6 @@ public class MasterService extends AbstractLifecycleComponent {
         }
     }
 
-    /**
-     * Submits a batch of cluster state update tasks; submitted updates are guaranteed to be processed together,
-     * potentially with more tasks of the same executor.
-     *
-     * @param source   the source of the cluster state update task
-     * @param tasks    a collection of update tasks, which implement {@link ClusterStateTaskListener} so that they are notified when they
-     *                 are executed; tasks that also implement {@link ClusterStateAckListener} are notified on acks too.
-     * @param config   the cluster state update task configuration
-     * @param executor the cluster state update task executor; tasks
-     *                 that share the same executor will be executed
-     *                 batches on this executor
-     * @param <T>      the type of the cluster state update task state
-     *
-     */
-    public <T extends ClusterStateTaskListener> void submitStateUpdateTasks(
-        final String source,
-        final Collection<T> tasks,
-        final ClusterStateTaskConfig config,
-        final ClusterStateTaskExecutor<T> executor
-    ) {
-        if (lifecycle.started() == false) {
-            return;
-        }
-        final ThreadContext threadContext = threadPool.getThreadContext();
-        final Supplier<ThreadContext.StoredContext> supplier = threadContext.newRestorableContext(true);
-        try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
-            threadContext.markAsSystemContext();
-
-            List<Batcher.UpdateTask> safeTasks = tasks.stream()
-                .map(task -> taskBatcher.new UpdateTask(config.priority(), source, task, supplier, executor))
-                .toList();
-            taskBatcher.submitTasks(safeTasks, config.timeout());
-        } catch (EsRejectedExecutionException e) {
-            // ignore cases where we are shutting down..., there is really nothing interesting
-            // to be done here...
-            if (lifecycle.stoppedOrClosed() == false) {
-                throw e;
-            }
-        }
-    }
-
     private static class MasterServiceStarvationWatcher implements PrioritizedEsThreadPoolExecutor.StarvationWatcher {
 
         private final long warnThreshold;

+ 24 - 62
server/src/main/java/org/elasticsearch/cluster/service/TaskBatcher.java

@@ -19,15 +19,12 @@ import org.elasticsearch.core.TimeValue;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.HashMap;
-import java.util.IdentityHashMap;
 import java.util.LinkedHashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.atomic.AtomicBoolean;
-import java.util.function.Function;
-import java.util.stream.Collectors;
 
 /**
  * Batching support for {@link PrioritizedEsThreadPoolExecutor}
@@ -45,86 +42,50 @@ public abstract class TaskBatcher {
         this.threadExecutor = threadExecutor;
     }
 
-    public void submitTasks(List<? extends BatchedTask> tasks, @Nullable TimeValue timeout) throws EsRejectedExecutionException {
-        if (tasks.isEmpty()) {
-            return;
-        }
-        final BatchedTask firstTask = tasks.get(0);
-        assert tasks.stream().allMatch(t -> t.batchingKey == firstTask.batchingKey)
-            : "tasks submitted in a batch should share the same batching key: " + tasks;
-        // convert to an identity map to check for dups based on task identity
-
-        tasksPerBatchingKey.compute(firstTask.batchingKey, (k, existingTasks) -> {
-            assert assertNoDuplicateTasks(tasks, existingTasks);
+    public void submitTask(BatchedTask task, @Nullable TimeValue timeout) throws EsRejectedExecutionException {
+        tasksPerBatchingKey.compute(task.batchingKey, (k, existingTasks) -> {
             if (existingTasks == null) {
-                return Collections.synchronizedSet(new LinkedHashSet<>(tasks));
+                existingTasks = Collections.synchronizedSet(new LinkedHashSet<>());
+            } else {
+                assert assertNoDuplicateTasks(task, existingTasks);
             }
-            existingTasks.addAll(tasks);
+            existingTasks.add(task);
             return existingTasks;
         });
 
         if (timeout != null) {
-            threadExecutor.execute(firstTask, timeout, () -> onTimeoutInternal(tasks, timeout));
+            threadExecutor.execute(task, timeout, () -> onTimeoutInternal(task, timeout));
         } else {
-            threadExecutor.execute(firstTask);
+            threadExecutor.execute(task);
         }
     }
 
-    private static boolean assertNoDuplicateTasks(List<? extends BatchedTask> tasks, Set<BatchedTask> existingTasks) {
-        final Map<Object, BatchedTask> tasksIdentity = tasks.stream()
-            .collect(
-                Collectors.toMap(
-                    BatchedTask::getTask,
-                    Function.identity(),
-                    (a, b) -> { throw new AssertionError("cannot add duplicate task: " + a); },
-                    IdentityHashMap::new
-                )
-            );
-        if (existingTasks == null) {
-            return true;
-        }
-        for (BatchedTask existing : existingTasks) {
-            // check that there won't be two tasks with the same identity for the same batching key
-            BatchedTask duplicateTask = tasksIdentity.get(existing.getTask());
-            assert duplicateTask == null
-                : "task ["
-                    + duplicateTask.describeTasks(Collections.singletonList(existing))
-                    + "] with source ["
-                    + duplicateTask.source
-                    + "] is already queued";
+    private static boolean assertNoDuplicateTasks(BatchedTask task, Set<BatchedTask> existingTasks) {
+        for (final var existingTask : existingTasks) {
+            assert existingTask.getTask() != task.getTask()
+                : "task [" + task.describeTasks(List.of(task)) + "] with source [" + task.source + "] is already queued";
         }
         return true;
     }
 
-    private void onTimeoutInternal(List<? extends BatchedTask> tasks, TimeValue timeout) {
-        final ArrayList<BatchedTask> toRemove = new ArrayList<>();
-        for (BatchedTask task : tasks) {
-            if (task.processed.getAndSet(true) == false) {
-                logger.debug("task [{}] timed out after [{}]", task.source, timeout);
-                toRemove.add(task);
-            }
-        }
-        if (toRemove.isEmpty() == false) {
-            BatchedTask firstTask = toRemove.get(0);
-            Object batchingKey = firstTask.batchingKey;
-            assert tasks.stream().allMatch(t -> t.batchingKey == batchingKey)
-                : "tasks submitted in a batch should share the same batching key: " + tasks;
-            tasksPerBatchingKey.computeIfPresent(batchingKey, (key, existingTasks) -> {
-                toRemove.forEach(existingTasks::remove);
-                if (existingTasks.isEmpty()) {
-                    return null;
-                }
-                return existingTasks;
-            });
-            onTimeout(toRemove, timeout);
+    private void onTimeoutInternal(BatchedTask task, TimeValue timeout) {
+        if (task.processed.getAndSet(true)) {
+            return;
         }
+
+        logger.debug("task [{}] timed out after [{}]", task.source, timeout);
+        tasksPerBatchingKey.computeIfPresent(task.batchingKey, (key, existingTasks) -> {
+            existingTasks.remove(task);
+            return existingTasks.isEmpty() ? null : existingTasks;
+        });
+        onTimeout(task, timeout);
     }
 
     /**
      * Action to be implemented by the specific batching implementation.
      * All tasks have the same batching key.
      */
-    protected abstract void onTimeout(List<? extends BatchedTask> tasks, TimeValue timeout);
+    protected abstract void onTimeout(BatchedTask task, TimeValue timeout);
 
     void runIfNotProcessed(BatchedTask updateTask) {
         // if this task is already processed, it shouldn't execute other tasks with same batching key that arrived later,
@@ -135,6 +96,7 @@ public abstract class TaskBatcher {
             final Set<BatchedTask> pending = tasksPerBatchingKey.remove(updateTask.batchingKey);
             if (pending != null) {
                 // pending is a java.util.Collections.SynchronizedSet so we can safely iterate holding its mutex
+                // noinspection SynchronizationOnLocalVariableOrMethodParameter
                 synchronized (pending) {
                     for (BatchedTask task : pending) {
                         if (task.processed.getAndSet(true) == false) {

+ 34 - 82
server/src/test/java/org/elasticsearch/cluster/service/MasterServiceTests.java

@@ -56,7 +56,6 @@ import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
 import java.util.Map;
-import java.util.Set;
 import java.util.concurrent.BrokenBarrierException;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.CyclicBarrier;
@@ -67,13 +66,12 @@ import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicReference;
 import java.util.stream.Collectors;
-import java.util.stream.IntStream;
 
 import static java.util.Collections.emptyMap;
 import static java.util.Collections.emptySet;
-import static org.hamcrest.Matchers.anyOf;
 import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.hasItem;
 import static org.hamcrest.Matchers.hasSize;
 import static org.hamcrest.Matchers.instanceOf;
 
@@ -534,14 +532,14 @@ public class MasterServiceTests extends ESTestCase {
             final var submitThreads = new Thread[between(1, 10)];
             for (int i = 0; i < submitThreads.length; i++) {
                 final var executor = randomFrom(executors);
-                final var tasks = randomList(1, 10, Task::new);
-                executor.addExpectedTaskCount(tasks.size());
+                final var task = new Task();
+                executor.addExpectedTaskCount(1);
                 submitThreads[i] = new Thread(() -> {
                     try {
                         assertTrue(submissionLatch.await(10, TimeUnit.SECONDS));
-                        masterService.submitStateUpdateTasks(
+                        masterService.submitStateUpdateTask(
                             Thread.currentThread().getName(),
-                            tasks,
+                            task,
                             ClusterStateTaskConfig.build(randomFrom(Priority.values())),
                             executor
                         );
@@ -656,21 +654,13 @@ public class MasterServiceTests extends ESTestCase {
             private final AtomicInteger assigned = new AtomicInteger();
             private final AtomicInteger batches = new AtomicInteger();
             private final AtomicInteger published = new AtomicInteger();
-            private final List<Set<Task>> assignments = new ArrayList<>();
+            private final List<Task> assignments = new ArrayList<>();
 
             @Override
             public ClusterTasksResult<Task> execute(ClusterState currentState, List<Task> tasks) throws Exception {
-                int totalCount = 0;
-                for (Set<Task> group : assignments) {
-                    long count = tasks.stream().filter(group::contains).count();
-                    assertThat(
-                        "batched set should be executed together or not at all. Expected " + group + "s. Executing " + tasks,
-                        count,
-                        anyOf(equalTo(0L), equalTo((long) group.size()))
-                    );
-                    totalCount += count;
+                for (Task task : tasks) {
+                    assertThat("All tasks should belong to this executor", assignments, hasItem(task));
                 }
-                assertThat("All tasks should belong to this executor", totalCount, equalTo(tasks.size()));
                 tasks.forEach(Task::execute);
                 executed.addAndGet(tasks.size());
                 ClusterState maybeUpdatedClusterState = currentState;
@@ -699,16 +689,16 @@ public class MasterServiceTests extends ESTestCase {
         }
 
         // randomly assign tasks to executors
-        List<Tuple<TaskExecutor, Set<Task>>> assignments = new ArrayList<>();
+        List<Tuple<TaskExecutor, Task>> assignments = new ArrayList<>();
         AtomicInteger totalTasks = new AtomicInteger();
         for (int i = 0; i < numberOfThreads; i++) {
             for (int j = 0; j < taskSubmissionsPerThread; j++) {
                 var executor = randomFrom(executors);
-                var tasks = Set.copyOf(randomList(1, 3, () -> new Task(totalTasks.getAndIncrement())));
+                var task = new Task(totalTasks.getAndIncrement());
 
-                assignments.add(Tuple.tuple(executor, tasks));
-                executor.assigned.addAndGet(tasks.size());
-                executor.assignments.add(tasks);
+                assignments.add(Tuple.tuple(executor, task));
+                executor.assigned.incrementAndGet();
+                executor.assignments.add(task);
             }
         }
         processedStatesLatch.set(new CountDownLatch(totalTasks.get()));
@@ -723,24 +713,15 @@ public class MasterServiceTests extends ESTestCase {
                         barrier.await();
                         for (int j = 0; j < taskSubmissionsPerThread; j++) {
                             var assignment = assignments.get(index * taskSubmissionsPerThread + j);
-                            var tasks = assignment.v2();
+                            var task = assignment.v2();
                             var executor = assignment.v1();
-                            submittedTasks.addAndGet(tasks.size());
-                            if (tasks.size() == 1) {
-                                masterService.submitStateUpdateTask(
-                                    threadName,
-                                    tasks.iterator().next(),
-                                    ClusterStateTaskConfig.build(randomFrom(Priority.values())),
-                                    executor
-                                );
-                            } else {
-                                masterService.submitStateUpdateTasks(
-                                    threadName,
-                                    tasks,
-                                    ClusterStateTaskConfig.build(randomFrom(Priority.values())),
-                                    executor
-                                );
-                            }
+                            submittedTasks.incrementAndGet();
+                            masterService.submitStateUpdateTask(
+                                threadName,
+                                task,
+                                ClusterStateTaskConfig.build(randomFrom(Priority.values())),
+                                executor
+                            );
                         }
                         barrier.await();
                     } catch (BrokenBarrierException | InterruptedException e) {
@@ -836,26 +817,13 @@ public class MasterServiceTests extends ESTestCase {
                 }
             );
 
-            int toSubmit = taskCount;
-
-            while (toSubmit > 0) {
-                final int batchSize = between(1, toSubmit);
-                toSubmit -= batchSize;
+            for (int i = 0; i < taskCount; i++) {
                 try (ThreadContext.StoredContext ignored = threadContext.newStoredContext(false)) {
                     final String testContextHeaderValue = randomAlphaOfLength(10);
                     threadContext.putHeader(testContextHeaderName, testContextHeaderValue);
-
-                    final List<Task> tasks = IntStream.range(0, batchSize)
-                        .mapToObj(i -> new Task(testContextHeaderValue))
-                        .collect(Collectors.toList());
-
-                    final ClusterStateTaskConfig clusterStateTaskConfig = ClusterStateTaskConfig.build(Priority.NORMAL);
-
-                    if (batchSize == 1 && randomBoolean()) {
-                        masterService.submitStateUpdateTask("test", tasks.get(0), clusterStateTaskConfig, executor);
-                    } else {
-                        masterService.submitStateUpdateTasks("test", tasks, clusterStateTaskConfig, executor);
-                    }
+                    final var task = new Task(testContextHeaderValue);
+                    final var clusterStateTaskConfig = ClusterStateTaskConfig.build(Priority.NORMAL);
+                    masterService.submitStateUpdateTask("test", task, clusterStateTaskConfig, executor);
                 }
             }
 
@@ -928,14 +896,11 @@ public class MasterServiceTests extends ESTestCase {
             int toSubmit = between(1, 10);
             final CountDownLatch publishSuccessCountdown = new CountDownLatch(toSubmit);
 
-            while (toSubmit > 0) {
-                final int batchSize = between(1, toSubmit);
-                toSubmit -= batchSize;
+            for (int i = 0; i < toSubmit; i++) {
                 try (ThreadContext.StoredContext ignored = threadContext.newStoredContext(false)) {
-                    final String testContextHeaderValue = randomAlphaOfLength(10);
+                    final var testContextHeaderValue = randomAlphaOfLength(10);
                     threadContext.putHeader(testContextHeaderName, testContextHeaderValue);
-
-                    final List<Task> tasks = IntStream.range(0, batchSize).mapToObj(i -> new Task(new ActionListener<>() {
+                    final var task = new Task(new ActionListener<>() {
                         @Override
                         public void onResponse(ClusterState clusterState) {
                             assertEquals(testContextHeaderValue, threadContext.getHeader(testContextHeaderName));
@@ -947,15 +912,10 @@ public class MasterServiceTests extends ESTestCase {
                         public void onFailure(Exception e) {
                             throw new AssertionError(e);
                         }
-                    })).collect(Collectors.toList());
+                    });
 
                     final ClusterStateTaskConfig clusterStateTaskConfig = ClusterStateTaskConfig.build(Priority.NORMAL);
-
-                    if (batchSize == 1 && randomBoolean()) {
-                        masterService.submitStateUpdateTask("test", tasks.get(0), clusterStateTaskConfig, executor);
-                    } else {
-                        masterService.submitStateUpdateTasks("test", tasks, clusterStateTaskConfig, executor);
-                    }
+                    masterService.submitStateUpdateTask("test", task, clusterStateTaskConfig, executor);
                 }
             }
 
@@ -976,14 +936,11 @@ public class MasterServiceTests extends ESTestCase {
             toSubmit = between(1, 10);
             final CountDownLatch publishFailureCountdown = new CountDownLatch(toSubmit);
 
-            while (toSubmit > 0) {
-                final int batchSize = between(1, toSubmit);
-                toSubmit -= batchSize;
+            for (int i = 0; i < toSubmit; i++) {
                 try (ThreadContext.StoredContext ignored = threadContext.newStoredContext(false)) {
                     final String testContextHeaderValue = randomAlphaOfLength(10);
                     threadContext.putHeader(testContextHeaderName, testContextHeaderValue);
-
-                    final List<Task> tasks = IntStream.range(0, batchSize).mapToObj(i -> new Task(new ActionListener<>() {
+                    final var task = new Task(new ActionListener<>() {
                         @Override
                         public void onResponse(ClusterState clusterState) {
                             throw new AssertionError("should not succeed");
@@ -996,15 +953,10 @@ public class MasterServiceTests extends ESTestCase {
                             assertThat(e.getMessage(), equalTo(exceptionMessage));
                             publishFailureCountdown.countDown();
                         }
-                    })).collect(Collectors.toList());
+                    });
 
                     final ClusterStateTaskConfig clusterStateTaskConfig = ClusterStateTaskConfig.build(Priority.NORMAL);
-
-                    if (batchSize == 1 && randomBoolean()) {
-                        masterService.submitStateUpdateTask("test", tasks.get(0), clusterStateTaskConfig, executor);
-                    } else {
-                        masterService.submitStateUpdateTasks("test", tasks, clusterStateTaskConfig, executor);
-                    }
+                    masterService.submitStateUpdateTask("test", task, clusterStateTaskConfig, executor);
                 }
             }
 

+ 28 - 70
server/src/test/java/org/elasticsearch/cluster/service/TaskBatcherTests.java

@@ -9,32 +9,25 @@
 package org.elasticsearch.cluster.service;
 
 import org.apache.logging.log4j.Logger;
-import org.apache.logging.log4j.message.ParameterizedMessage;
 import org.elasticsearch.cluster.ClusterStateTaskConfig;
 import org.elasticsearch.cluster.metadata.ProcessClusterEventTimeoutException;
 import org.elasticsearch.common.Priority;
 import org.elasticsearch.common.util.concurrent.PrioritizedEsThreadPoolExecutor;
 import org.elasticsearch.core.TimeValue;
-import org.elasticsearch.core.Tuple;
 import org.junit.Before;
 
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
-import java.util.HashMap;
-import java.util.HashSet;
 import java.util.List;
-import java.util.Map;
-import java.util.Set;
 import java.util.concurrent.BrokenBarrierException;
-import java.util.concurrent.CopyOnWriteArrayList;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.CyclicBarrier;
 import java.util.concurrent.Semaphore;
+import java.util.concurrent.TimeUnit;
 import java.util.stream.Collectors;
 
 import static org.hamcrest.Matchers.containsString;
-import static org.hamcrest.Matchers.empty;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.hasToString;
 
@@ -47,7 +40,7 @@ public class TaskBatcherTests extends TaskExecutorTests {
         taskBatcher = new TestTaskBatcher(logger, threadExecutor);
     }
 
-    class TestTaskBatcher extends TaskBatcher {
+    static class TestTaskBatcher extends TaskBatcher {
 
         TestTaskBatcher(Logger logger, PrioritizedEsThreadPoolExecutor threadExecutor) {
             super(logger, threadExecutor);
@@ -58,20 +51,13 @@ public class TaskBatcherTests extends TaskExecutorTests {
         protected void run(Object batchingKey, List<? extends BatchedTask> tasks, String tasksSummary) {
             List<UpdateTask> updateTasks = (List<UpdateTask>) tasks;
             ((TestExecutor<Object>) batchingKey).execute(updateTasks.stream().map(t -> t.task).collect(Collectors.toList()));
-            updateTasks.forEach(updateTask -> updateTask.listener.processed(updateTask.source));
+            updateTasks.forEach(updateTask -> updateTask.listener.processed());
         }
 
         @Override
-        protected void onTimeout(List<? extends BatchedTask> tasks, TimeValue timeout) {
+        protected void onTimeout(BatchedTask task, TimeValue timeout) {
             threadPool.generic()
-                .execute(
-                    () -> tasks.forEach(
-                        task -> ((UpdateTask) task).listener.onFailure(
-                            task.source,
-                            new ProcessClusterEventTimeoutException(timeout, task.source)
-                        )
-                    )
-                );
+                .execute(() -> ((UpdateTask) task).listener.onFailure(new ProcessClusterEventTimeoutException(timeout, task.source)));
         }
 
         class UpdateTask extends BatchedTask {
@@ -99,20 +85,7 @@ public class TaskBatcherTests extends TaskExecutorTests {
     }
 
     private <T> void submitTask(String source, T task, ClusterStateTaskConfig config, TestExecutor<T> executor, TestListener listener) {
-        submitTasks(source, Collections.singletonMap(task, listener), config, executor);
-    }
-
-    private <T> void submitTasks(
-        final String source,
-        final Map<T, TestListener> tasks,
-        final ClusterStateTaskConfig config,
-        final TestExecutor<T> executor
-    ) {
-        List<TestTaskBatcher.UpdateTask> safeTasks = tasks.entrySet()
-            .stream()
-            .map(e -> taskBatcher.new UpdateTask(config.priority(), source, e.getKey(), e.getValue(), executor))
-            .collect(Collectors.toList());
-        taskBatcher.submitTasks(safeTasks, config.timeout());
+        taskBatcher.submitTask(taskBatcher.new UpdateTask(config.priority(), source, task, listener, executor), config.timeout());
     }
 
     @Override
@@ -144,7 +117,7 @@ public class TaskBatcherTests extends TaskExecutorTests {
         TaskExecutor executorB = new TaskExecutor();
 
         final ClusterStateTaskConfig config = ClusterStateTaskConfig.build(Priority.NORMAL);
-        final TestListener noopListener = (source, e) -> { throw new AssertionError(e); };
+        final TestListener noopListener = e -> { throw new AssertionError(e); };
         // this blocks the cluster state queue, so we can set it up right
         submitTask("0", "A0", config, executorA, noopListener);
         // wait to be processed
@@ -196,19 +169,16 @@ public class TaskBatcherTests extends TaskExecutorTests {
 
         int tasksSubmittedPerThread = randomIntBetween(2, 1024);
 
-        CopyOnWriteArrayList<Tuple<String, Throwable>> failures = new CopyOnWriteArrayList<>();
         CountDownLatch updateLatch = new CountDownLatch(numberOfThreads * tasksSubmittedPerThread);
 
         final TestListener listener = new TestListener() {
             @Override
-            public void onFailure(String source, Exception e) {
-                logger.error(() -> new ParameterizedMessage("unexpected failure: [{}]", source), e);
-                failures.add(new Tuple<>(source, e));
-                updateLatch.countDown();
+            public void onFailure(Exception e) {
+                throw new AssertionError(e);
             }
 
             @Override
-            public void processed(String source) {
+            public void processed() {
                 updateLatch.countDown();
             }
         };
@@ -242,9 +212,7 @@ public class TaskBatcherTests extends TaskExecutorTests {
         // wait for all threads to finish
         barrier.await();
 
-        updateLatch.await();
-
-        assertThat(failures, empty());
+        assertTrue(updateLatch.await(10, TimeUnit.SECONDS));
 
         for (int i = 0; i < numberOfThreads; i++) {
             assertEquals(tasksSubmittedPerThread, executors[i].tasks.size());
@@ -255,34 +223,24 @@ public class TaskBatcherTests extends TaskExecutorTests {
         }
     }
 
-    public void testSingleBatchSubmission() throws InterruptedException {
-        Map<Integer, TestListener> tasks = new HashMap<>();
-        final int numOfTasks = randomInt(10);
-        final CountDownLatch latch = new CountDownLatch(numOfTasks);
-        Set<Integer> usedKeys = new HashSet<>(numOfTasks);
-        for (int i = 0; i < numOfTasks; i++) {
-            int key = randomValueOtherThanMany(k -> usedKeys.contains(k), () -> randomInt(1024));
-            tasks.put(key, new TestListener() {
-                @Override
-                public void processed(String source) {
-                    latch.countDown();
-                }
-
-                @Override
-                public void onFailure(String source, Exception e) {
-                    throw new AssertionError(e);
-                }
-            });
-            usedKeys.add(key);
-        }
-        assert usedKeys.size() == numOfTasks;
-
+    public void testSingleTaskSubmission() throws InterruptedException {
+        final CountDownLatch latch = new CountDownLatch(1);
+        final Integer task = randomInt(1024);
         TestExecutor<Integer> executor = taskList -> {
-            assertThat(taskList.size(), equalTo(tasks.size()));
-            assertThat(taskList.stream().collect(Collectors.toSet()), equalTo(tasks.keySet()));
+            assertThat(taskList.size(), equalTo(1));
+            assertThat(taskList.get(0), equalTo(task));
         };
-        submitTasks("test", tasks, ClusterStateTaskConfig.build(Priority.LANGUID), executor);
+        submitTask("test", task, ClusterStateTaskConfig.build(randomFrom(Priority.values())), executor, new TestListener() {
+            @Override
+            public void processed() {
+                latch.countDown();
+            }
 
+            @Override
+            public void onFailure(Exception e) {
+                throw new AssertionError(e);
+            }
+        });
         latch.await();
     }
 
@@ -295,12 +253,12 @@ public class TaskBatcherTests extends TaskExecutorTests {
             SimpleTask task = new SimpleTask(1);
             TestListener listener = new TestListener() {
                 @Override
-                public void processed(String source) {
+                public void processed() {
                     latch.countDown();
                 }
 
                 @Override
-                public void onFailure(String source, Exception e) {
+                public void onFailure(Exception e) {
                     throw new AssertionError(e);
                 }
             };

+ 11 - 11
server/src/test/java/org/elasticsearch/cluster/service/TaskExecutorTests.java

@@ -70,9 +70,9 @@ public class TaskExecutorTests extends ESTestCase {
     }
 
     protected interface TestListener {
-        void onFailure(String source, Exception e);
+        void onFailure(Exception e);
 
-        default void processed(String source) {
+        default void processed() {
             // do nothing by default
         }
     }
@@ -129,7 +129,7 @@ public class TaskExecutorTests extends ESTestCase {
         public void run() {
             logger.trace("will process {}", source);
             testTask.execute(Collections.singletonList(testTask));
-            testTask.processed(source);
+            testTask.processed();
         }
     }
 
@@ -140,7 +140,7 @@ public class TaskExecutorTests extends ESTestCase {
         if (timeout != null) {
             threadExecutor.execute(task, timeout, () -> threadPool.generic().execute(() -> {
                 logger.debug("task [{}] timed out after [{}]", task, timeout);
-                testTask.onFailure(source, new ProcessClusterEventTimeoutException(timeout, source));
+                testTask.onFailure(new ProcessClusterEventTimeoutException(timeout, source));
             }));
         } else {
             threadExecutor.execute(task);
@@ -163,7 +163,7 @@ public class TaskExecutorTests extends ESTestCase {
             }
 
             @Override
-            public void onFailure(String source, Exception e) {
+            public void onFailure(Exception e) {
                 throw new RuntimeException(e);
             }
         };
@@ -178,7 +178,7 @@ public class TaskExecutorTests extends ESTestCase {
             }
 
             @Override
-            public void onFailure(String source, Exception e) {
+            public void onFailure(Exception e) {
                 block2.countDown();
             }
 
@@ -207,7 +207,7 @@ public class TaskExecutorTests extends ESTestCase {
             }
 
             @Override
-            public void onFailure(String source, Exception e) {
+            public void onFailure(Exception e) {
                 throw new RuntimeException(e);
             }
         };
@@ -228,7 +228,7 @@ public class TaskExecutorTests extends ESTestCase {
             }
 
             @Override
-            public void onFailure(String source, Exception e) {
+            public void onFailure(Exception e) {
                 timedOut.countDown();
             }
         };
@@ -245,7 +245,7 @@ public class TaskExecutorTests extends ESTestCase {
             }
 
             @Override
-            public void onFailure(String source, Exception e) {
+            public void onFailure(Exception e) {
                 throw new RuntimeException(e);
             }
         };
@@ -312,7 +312,7 @@ public class TaskExecutorTests extends ESTestCase {
         }
 
         @Override
-        public void onFailure(String source, Exception e) {}
+        public void onFailure(Exception e) {}
 
         @Override
         public Priority priority() {
@@ -349,7 +349,7 @@ public class TaskExecutorTests extends ESTestCase {
         }
 
         @Override
-        public void onFailure(String source, Exception e) {
+        public void onFailure(Exception e) {
             latch.countDown();
         }
     }