Răsfoiți Sursa

ClusterStateTaskListener usage refactoring in MasterServiceTests (#82869)

Today node removal tasks executed by the master have a separate
ClusterStateTaskListener to feed back the result to the requester.
It'd be preferable to use the task itself as the listener.
Ievgen Degtiarenko 3 ani în urmă
părinte
comite
805cd39147

+ 108 - 123
server/src/test/java/org/elasticsearch/cluster/service/MasterServiceTests.java

@@ -11,6 +11,7 @@ package org.elasticsearch.cluster.service;
 import org.apache.logging.log4j.Level;
 import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
+import org.apache.lucene.util.SetOnce;
 import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.Version;
 import org.elasticsearch.action.support.master.AcknowledgedResponse;
@@ -20,6 +21,7 @@ import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.cluster.ClusterStatePublicationEvent;
 import org.elasticsearch.cluster.ClusterStateTaskConfig;
 import org.elasticsearch.cluster.ClusterStateTaskExecutor;
+import org.elasticsearch.cluster.ClusterStateTaskExecutor.ClusterTasksResult;
 import org.elasticsearch.cluster.ClusterStateTaskListener;
 import org.elasticsearch.cluster.ClusterStateUpdateTask;
 import org.elasticsearch.cluster.LocalMasterServiceTask;
@@ -51,14 +53,10 @@ import org.junit.BeforeClass;
 
 import java.util.ArrayList;
 import java.util.Collections;
-import java.util.HashMap;
-import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.BrokenBarrierException;
-import java.util.concurrent.ConcurrentHashMap;
-import java.util.concurrent.ConcurrentMap;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.CyclicBarrier;
 import java.util.concurrent.Semaphore;
@@ -67,13 +65,14 @@ import java.util.concurrent.TimeoutException;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Function;
 
 import static java.util.Collections.emptyMap;
 import static java.util.Collections.emptySet;
+import static java.util.stream.Collectors.toMap;
 import static org.hamcrest.Matchers.anyOf;
 import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.equalTo;
-import static org.hamcrest.Matchers.hasKey;
 
 public class MasterServiceTests extends ESTestCase {
 
@@ -263,9 +262,18 @@ public class MasterServiceTests extends ESTestCase {
         AtomicBoolean published = new AtomicBoolean();
 
         try (MasterService masterService = createMasterService(true)) {
+            ClusterStateTaskListener update = new ClusterStateTaskListener() {
+                @Override
+                public void clusterStateProcessed(ClusterState oldState, ClusterState newState) {
+                    throw new RuntimeException("testing exception handling");
+                }
+
+                @Override
+                public void onFailure(Exception e) {}
+            };
             masterService.submitStateUpdateTask(
                 "testClusterStateTaskListenerThrowingExceptionIsOkay",
-                new Object(),
+                update,
                 ClusterStateTaskConfig.build(Priority.NORMAL),
                 new ClusterStateTaskExecutor<Object>() {
                     @Override
@@ -280,15 +288,7 @@ public class MasterServiceTests extends ESTestCase {
                         latch.countDown();
                     }
                 },
-                new ClusterStateTaskListener() {
-                    @Override
-                    public void clusterStateProcessed(ClusterState oldState, ClusterState newState) {
-                        throw new IllegalStateException();
-                    }
-
-                    @Override
-                    public void onFailure(Exception e) {}
-                }
+                update
             );
 
             latch.await();
@@ -464,9 +464,14 @@ public class MasterServiceTests extends ESTestCase {
     }
 
     public void testClusterStateBatchedUpdates() throws BrokenBarrierException, InterruptedException {
-        AtomicInteger counter = new AtomicInteger();
-        class Task {
-            private AtomicBoolean state = new AtomicBoolean();
+
+        AtomicInteger executedTasks = new AtomicInteger();
+        AtomicInteger submittedTasks = new AtomicInteger();
+        AtomicInteger processedStates = new AtomicInteger();
+        SetOnce<CountDownLatch> processedStatesLatch = new SetOnce<>();
+
+        class Task implements ClusterStateTaskListener {
+            private final AtomicBoolean executed = new AtomicBoolean();
             private final int id;
 
             Task(int id) {
@@ -474,13 +479,24 @@ public class MasterServiceTests extends ESTestCase {
             }
 
             public void execute() {
-                if (state.compareAndSet(false, true) == false) {
-                    throw new IllegalStateException();
+                if (executed.compareAndSet(false, true) == false) {
+                    throw new AssertionError("Task [" + id + "] should only be executed once");
                 } else {
-                    counter.incrementAndGet();
+                    executedTasks.incrementAndGet();
                 }
             }
 
+            @Override
+            public void onFailure(Exception e) {
+                throw new AssertionError(e);
+            }
+
+            @Override
+            public void clusterStateProcessed(ClusterState oldState, ClusterState newState) {
+                processedStates.incrementAndGet();
+                processedStatesLatch.get().countDown();
+            }
+
             @Override
             public boolean equals(Object o) {
                 if (this == o) {
@@ -491,7 +507,6 @@ public class MasterServiceTests extends ESTestCase {
                 }
                 Task task = (Task) o;
                 return id == task.id;
-
             }
 
             @Override
@@ -505,38 +520,43 @@ public class MasterServiceTests extends ESTestCase {
             }
         }
 
-        int numberOfThreads = randomIntBetween(2, 8);
-        int taskSubmissionsPerThread = randomIntBetween(1, 64);
-        int numberOfExecutors = Math.max(1, numberOfThreads / 4);
-        final Semaphore semaphore = new Semaphore(numberOfExecutors);
+        final int numberOfThreads = randomIntBetween(2, 8);
+        final int taskSubmissionsPerThread = randomIntBetween(1, 64);
+        final int numberOfExecutors = Math.max(1, numberOfThreads / 4);
+        final Semaphore semaphore = new Semaphore(1);
 
         class TaskExecutor implements ClusterStateTaskExecutor<Task> {
-            private final List<Set<Task>> taskGroups;
-            private AtomicInteger counter = new AtomicInteger();
-            private AtomicInteger batches = new AtomicInteger();
-            private AtomicInteger published = new AtomicInteger();
 
-            TaskExecutor(List<Set<Task>> taskGroups) {
-                this.taskGroups = taskGroups;
-            }
+            private final AtomicInteger executed = new AtomicInteger();
+            private final AtomicInteger assigned = new AtomicInteger();
+            private final AtomicInteger batches = new AtomicInteger();
+            private final AtomicInteger published = new AtomicInteger();
+            private final List<Set<Task>> assignments = new ArrayList<>();
 
             @Override
             public ClusterTasksResult<Task> execute(ClusterState currentState, List<Task> tasks) throws Exception {
-                for (Set<Task> expectedSet : taskGroups) {
-                    long count = tasks.stream().filter(expectedSet::contains).count();
+                int totalCount = 0;
+                for (Set<Task> group : assignments) {
+                    long count = tasks.stream().filter(group::contains).count();
                     assertThat(
-                        "batched set should be executed together or not at all. Expected " + expectedSet + "s. Executing " + tasks,
+                        "batched set should be executed together or not at all. Expected " + group + "s. Executing " + tasks,
                         count,
-                        anyOf(equalTo(0L), equalTo((long) expectedSet.size()))
+                        anyOf(equalTo(0L), equalTo((long) group.size()))
                     );
+                    totalCount += count;
                 }
+                assertThat("All tasks should belong to this executor", totalCount, equalTo(tasks.size()));
                 tasks.forEach(Task::execute);
-                counter.addAndGet(tasks.size());
+                executed.addAndGet(tasks.size());
                 ClusterState maybeUpdatedClusterState = currentState;
                 if (randomBoolean()) {
                     maybeUpdatedClusterState = ClusterState.builder(currentState).build();
                     batches.incrementAndGet();
-                    semaphore.acquire();
+                    assertThat(
+                        "All cluster state modifications should be executed on a single thread",
+                        semaphore.tryAcquire(),
+                        equalTo(true)
+                    );
                 }
                 return ClusterTasksResult.<Task>builder().successes(tasks).build(maybeUpdatedClusterState);
             }
@@ -548,40 +568,27 @@ public class MasterServiceTests extends ESTestCase {
             }
         }
 
-        ConcurrentMap<String, AtomicInteger> processedStates = new ConcurrentHashMap<>();
-
-        List<Set<Task>> taskGroups = new ArrayList<>();
         List<TaskExecutor> executors = new ArrayList<>();
         for (int i = 0; i < numberOfExecutors; i++) {
-            executors.add(new TaskExecutor(taskGroups));
+            executors.add(new TaskExecutor());
         }
 
         // randomly assign tasks to executors
         List<Tuple<TaskExecutor, Set<Task>>> assignments = new ArrayList<>();
-        int taskId = 0;
+        AtomicInteger totalTasks = new AtomicInteger();
         for (int i = 0; i < numberOfThreads; i++) {
             for (int j = 0; j < taskSubmissionsPerThread; j++) {
-                TaskExecutor executor = randomFrom(executors);
-                Set<Task> tasks = new HashSet<>();
-                for (int t = randomInt(3); t >= 0; t--) {
-                    tasks.add(new Task(taskId++));
-                }
-                taskGroups.add(tasks);
+                var executor = randomFrom(executors);
+                var tasks = Set.copyOf(randomList(1, 3, () -> new Task(totalTasks.getAndIncrement())));
+
                 assignments.add(Tuple.tuple(executor, tasks));
+                executor.assigned.addAndGet(tasks.size());
+                executor.assignments.add(tasks);
             }
         }
-
-        Map<TaskExecutor, Integer> counts = new HashMap<>();
-        int totalTaskCount = 0;
-        for (Tuple<TaskExecutor, Set<Task>> assignment : assignments) {
-            final int taskCount = assignment.v2().size();
-            counts.merge(assignment.v1(), taskCount, (previous, count) -> previous + count);
-            totalTaskCount += taskCount;
-        }
-        final CountDownLatch updateLatch = new CountDownLatch(totalTaskCount);
+        processedStatesLatch.set(new CountDownLatch(totalTasks.get()));
 
         try (MasterService masterService = createMasterService(true)) {
-            final ConcurrentMap<String, AtomicInteger> submittedTasksPerThread = new ConcurrentHashMap<>();
             CyclicBarrier barrier = new CyclicBarrier(1 + numberOfThreads);
             for (int i = 0; i < numberOfThreads; i++) {
                 final int index = i;
@@ -590,36 +597,23 @@ public class MasterServiceTests extends ESTestCase {
                     try {
                         barrier.await();
                         for (int j = 0; j < taskSubmissionsPerThread; j++) {
-                            Tuple<TaskExecutor, Set<Task>> assignment = assignments.get(index * taskSubmissionsPerThread + j);
-                            final Set<Task> tasks = assignment.v2();
-                            submittedTasksPerThread.computeIfAbsent(threadName, key -> new AtomicInteger()).addAndGet(tasks.size());
-                            final TaskExecutor executor = assignment.v1();
-                            final ClusterStateTaskListener listener = new ClusterStateTaskListener() {
-                                @Override
-                                public void onFailure(Exception e) {
-                                    throw new AssertionError(e);
-                                }
-
-                                @Override
-                                public void clusterStateProcessed(ClusterState oldState, ClusterState newState) {
-                                    processedStates.computeIfAbsent(threadName, key -> new AtomicInteger()).incrementAndGet();
-                                    updateLatch.countDown();
-                                }
-                            };
+                            var assignment = assignments.get(index * taskSubmissionsPerThread + j);
+                            var tasks = assignment.v2();
+                            var executor = assignment.v1();
+                            submittedTasks.addAndGet(tasks.size());
                             if (tasks.size() == 1) {
+                                var update = tasks.iterator().next();
                                 masterService.submitStateUpdateTask(
                                     threadName,
-                                    tasks.stream().findFirst().get(),
+                                    update,
                                     ClusterStateTaskConfig.build(randomFrom(Priority.values())),
                                     executor,
-                                    listener
+                                    update
                                 );
                             } else {
-                                Map<Task, ClusterStateTaskListener> taskListeners = new HashMap<>();
-                                tasks.forEach(t -> taskListeners.put(t, listener));
                                 masterService.submitStateUpdateTasks(
                                     threadName,
-                                    taskListeners,
+                                    tasks.stream().collect(toMap(Function.<Task>identity(), Function.<ClusterStateTaskListener>identity())),
                                     ClusterStateTaskConfig.build(randomFrom(Priority.values())),
                                     executor
                                 );
@@ -639,29 +633,19 @@ public class MasterServiceTests extends ESTestCase {
             barrier.await();
 
             // wait until all the cluster state updates have been processed
-            updateLatch.await();
-            // and until all of the publication callbacks have completed
-            semaphore.acquire(numberOfExecutors);
+            processedStatesLatch.get().await();
+            // and until all the publication callbacks have completed
+            semaphore.acquire();
 
             // assert the number of executed tasks is correct
-            assertEquals(totalTaskCount, counter.get());
+            assertThat(submittedTasks.get(), equalTo(totalTasks.get()));
+            assertThat(executedTasks.get(), equalTo(totalTasks.get()));
+            assertThat(processedStates.get(), equalTo(totalTasks.get()));
 
             // assert each executor executed the correct number of tasks
             for (TaskExecutor executor : executors) {
-                if (counts.containsKey(executor)) {
-                    assertEquals((int) counts.get(executor), executor.counter.get());
-                    assertEquals(executor.batches.get(), executor.published.get());
-                }
-            }
-
-            // assert the correct number of clusterStateProcessed events were triggered
-            for (Map.Entry<String, AtomicInteger> entry : processedStates.entrySet()) {
-                assertThat(submittedTasksPerThread, hasKey(entry.getKey()));
-                assertEquals(
-                    "not all tasks submitted by " + entry.getKey() + " received a processed event",
-                    entry.getValue().get(),
-                    submittedTasksPerThread.get(entry.getKey()).get()
-                );
+                assertEquals(executor.assigned.get(), executor.executed.get());
+                assertEquals(executor.batches.get(), executor.published.get());
             }
         }
     }
@@ -672,36 +656,37 @@ public class MasterServiceTests extends ESTestCase {
         final AtomicReference<AssertionError> assertionRef = new AtomicReference<>();
 
         try (MasterService masterService = createMasterService(true)) {
+            ClusterStateTaskListener update = new ClusterStateTaskListener() {
+                @Override
+                public void clusterStateProcessed(ClusterState oldState, ClusterState newState) {
+                    BaseFuture<Void> future = new BaseFuture<Void>() {
+                    };
+                    try {
+                        if (randomBoolean()) {
+                            future.get(1L, TimeUnit.SECONDS);
+                        } else {
+                            future.get();
+                        }
+                    } catch (Exception e) {
+                        throw new RuntimeException(e);
+                    } catch (AssertionError e) {
+                        assertionRef.set(e);
+                        latch.countDown();
+                    }
+                }
+
+                @Override
+                public void onFailure(Exception e) {}
+            };
             masterService.submitStateUpdateTask(
                 "testBlockingCallInClusterStateTaskListenerFails",
-                new Object(),
+                update,
                 ClusterStateTaskConfig.build(Priority.NORMAL),
                 (currentState, tasks) -> {
                     ClusterState newClusterState = ClusterState.builder(currentState).build();
-                    return ClusterStateTaskExecutor.ClusterTasksResult.builder().successes(tasks).build(newClusterState);
+                    return ClusterTasksResult.<ClusterStateTaskListener>builder().successes(tasks).build(newClusterState);
                 },
-                new ClusterStateTaskListener() {
-                    @Override
-                    public void clusterStateProcessed(ClusterState oldState, ClusterState newState) {
-                        BaseFuture<Void> future = new BaseFuture<Void>() {
-                        };
-                        try {
-                            if (randomBoolean()) {
-                                future.get(1L, TimeUnit.SECONDS);
-                            } else {
-                                future.get();
-                            }
-                        } catch (Exception e) {
-                            throw new RuntimeException(e);
-                        } catch (AssertionError e) {
-                            assertionRef.set(e);
-                            latch.countDown();
-                        }
-                    }
-
-                    @Override
-                    public void onFailure(Exception e) {}
-                }
+                update
             );
 
             latch.await();