1
0
Эх сурвалжийг харах

Make TaskBatcher Less Lock-Heavy (#82227)

In many shards benchmarks we see a lot of contention when submitting tasks.
This is obvious when working with lots of large task batches and doing long iterations.

We don't need to lock in `runIfNotProcessed` past removign the task set for a key
and can be a little more efficient when it comes to creating the new tasks set in
`submitTasks` as well.
Also, we don't need to use a fully locking map as we often have operations for different
batching keys interleaved so moving to CHM as well.

This change is particularly relevant for stability because we often submit tasks from
network threads directly where grinding through a e.g. bunch of shard state updates and
having to lock on the map over and over while e.g. a huge batch of index create or so
was iterated over in `runIfNotProcessed` caused very visible latency.
Armin Braun 3 жил өмнө
parent
commit
dc27d1d37b

+ 44 - 38
server/src/main/java/org/elasticsearch/cluster/service/TaskBatcher.java

@@ -23,6 +23,8 @@ import java.util.IdentityHashMap;
 import java.util.LinkedHashSet;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.function.Function;
 import java.util.stream.Collectors;
@@ -36,7 +38,7 @@ public abstract class TaskBatcher {
     private final Logger logger;
     private final PrioritizedEsThreadPoolExecutor threadExecutor;
     // package visible for tests
-    final Map<Object, LinkedHashSet<BatchedTask>> tasksPerBatchingKey = new HashMap<>();
+    final Map<Object, Set<BatchedTask>> tasksPerBatchingKey = new ConcurrentHashMap<>();
 
     public TaskBatcher(Logger logger, PrioritizedEsThreadPoolExecutor threadExecutor) {
         this.logger = logger;
@@ -51,42 +53,47 @@ public abstract class TaskBatcher {
         assert tasks.stream().allMatch(t -> t.batchingKey == firstTask.batchingKey)
             : "tasks submitted in a batch should share the same batching key: " + tasks;
         // convert to an identity map to check for dups based on task identity
+
+        tasksPerBatchingKey.compute(firstTask.batchingKey, (k, existingTasks) -> {
+            assert assertNoDuplicateTasks(tasks, existingTasks);
+            if (existingTasks == null) {
+                return Collections.synchronizedSet(new LinkedHashSet<>(tasks));
+            }
+            existingTasks.addAll(tasks);
+            return existingTasks;
+        });
+
+        if (timeout != null) {
+            threadExecutor.execute(firstTask, timeout, () -> onTimeoutInternal(tasks, timeout));
+        } else {
+            threadExecutor.execute(firstTask);
+        }
+    }
+
+    private static boolean assertNoDuplicateTasks(List<? extends BatchedTask> tasks, Set<BatchedTask> existingTasks) {
         final Map<Object, BatchedTask> tasksIdentity = tasks.stream()
             .collect(
                 Collectors.toMap(
                     BatchedTask::getTask,
                     Function.identity(),
-                    (a, b) -> { throw new IllegalStateException("cannot add duplicate task: " + a); },
+                    (a, b) -> { throw new AssertionError("cannot add duplicate task: " + a); },
                     IdentityHashMap::new
                 )
             );
-
-        synchronized (tasksPerBatchingKey) {
-            LinkedHashSet<BatchedTask> existingTasks = tasksPerBatchingKey.computeIfAbsent(
-                firstTask.batchingKey,
-                k -> new LinkedHashSet<>(tasks.size())
-            );
-            for (BatchedTask existing : existingTasks) {
-                // check that there won't be two tasks with the same identity for the same batching key
-                BatchedTask duplicateTask = tasksIdentity.get(existing.getTask());
-                if (duplicateTask != null) {
-                    throw new IllegalStateException(
-                        "task ["
-                            + duplicateTask.describeTasks(Collections.singletonList(existing))
-                            + "] with source ["
-                            + duplicateTask.source
-                            + "] is already queued"
-                    );
-                }
-            }
-            existingTasks.addAll(tasks);
+        if (existingTasks == null) {
+            return true;
         }
-
-        if (timeout != null) {
-            threadExecutor.execute(firstTask, timeout, () -> onTimeoutInternal(tasks, timeout));
-        } else {
-            threadExecutor.execute(firstTask);
+        for (BatchedTask existing : existingTasks) {
+            // check that there won't be two tasks with the same identity for the same batching key
+            BatchedTask duplicateTask = tasksIdentity.get(existing.getTask());
+            assert duplicateTask == null
+                : "task ["
+                    + duplicateTask.describeTasks(Collections.singletonList(existing))
+                    + "] with source ["
+                    + duplicateTask.source
+                    + "] is already queued";
         }
+        return true;
     }
 
     private void onTimeoutInternal(List<? extends BatchedTask> tasks, TimeValue timeout) {
@@ -102,15 +109,13 @@ public abstract class TaskBatcher {
             Object batchingKey = firstTask.batchingKey;
             assert tasks.stream().allMatch(t -> t.batchingKey == batchingKey)
                 : "tasks submitted in a batch should share the same batching key: " + tasks;
-            synchronized (tasksPerBatchingKey) {
-                LinkedHashSet<BatchedTask> existingTasks = tasksPerBatchingKey.get(batchingKey);
-                if (existingTasks != null) {
-                    existingTasks.removeAll(toRemove);
-                    if (existingTasks.isEmpty()) {
-                        tasksPerBatchingKey.remove(batchingKey);
-                    }
+            tasksPerBatchingKey.computeIfPresent(batchingKey, (key, existingTasks) -> {
+                toRemove.forEach(existingTasks::remove);
+                if (existingTasks.isEmpty()) {
+                    return null;
                 }
-            }
+                return existingTasks;
+            });
             onTimeout(toRemove, timeout);
         }
     }
@@ -127,9 +132,10 @@ public abstract class TaskBatcher {
         if (updateTask.processed.get() == false) {
             final List<BatchedTask> toExecute = new ArrayList<>();
             final Map<String, List<BatchedTask>> processTasksBySource = new HashMap<>();
-            synchronized (tasksPerBatchingKey) {
-                LinkedHashSet<BatchedTask> pending = tasksPerBatchingKey.remove(updateTask.batchingKey);
-                if (pending != null) {
+            final Set<BatchedTask> pending = tasksPerBatchingKey.remove(updateTask.batchingKey);
+            if (pending != null) {
+                // pending is a java.util.Collections.SynchronizedSet so we can safely iterate holding its mutex
+                synchronized (pending) {
                     for (BatchedTask task : pending) {
                         if (task.processed.getAndSet(true) == false) {
                             logger.trace("will process {}", task);

+ 3 - 5
server/src/test/java/org/elasticsearch/cluster/service/TaskBatcherTests.java

@@ -118,9 +118,7 @@ public class TaskBatcherTests extends TaskExecutorTests {
     @Override
     public void testTimedOutTaskCleanedUp() throws Exception {
         super.testTimedOutTaskCleanedUp();
-        synchronized (taskBatcher.tasksPerBatchingKey) {
-            assertTrue("expected empty map but was " + taskBatcher.tasksPerBatchingKey, taskBatcher.tasksPerBatchingKey.isEmpty());
-        }
+        assertTrue("expected empty map but was " + taskBatcher.tasksPerBatchingKey, taskBatcher.tasksPerBatchingKey.isEmpty());
     }
 
     public void testOneExecutorDoesntStarveAnother() throws InterruptedException {
@@ -309,8 +307,8 @@ public class TaskBatcherTests extends TaskExecutorTests {
 
             submitTask("first time", task, ClusterStateTaskConfig.build(Priority.NORMAL), executor, listener);
 
-            final IllegalStateException e = expectThrows(
-                IllegalStateException.class,
+            final AssertionError e = expectThrows(
+                AssertionError.class,
                 () -> submitTask("second time", task, ClusterStateTaskConfig.build(Priority.NORMAL), executor, listener)
             );
             assertThat(e, hasToString(containsString("task [1] with source [second time] is already queued")));