浏览代码

Long balance computation should not delay new index primary assignment (#115511) (#116316)

A long desired balance computation could delay a newly created index shard from being assigned since first the computation has to finish for the assignments to be published and the shards getting assigned. With this change we add a new setting which allows setting a maximum time for a computation in case there are unassigned primary shards. Note that this is similar to how a new cluster state causes early publishing of the desired balance.

Closes ES-9616

Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
Pooya Salehi 11 月之前
父节点
当前提交
69df7fbfe1

+ 11 - 1
server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/ContinuousComputation.java

@@ -49,6 +49,16 @@ public abstract class ContinuousComputation<T> {
         }
     }
 
+    /**
+     * enqueues {@code input} if {@code expectedLatestKnownInput} is the latest known input.
+     * Neither of the parameters can be null.
+     */
+    protected boolean compareAndEnqueue(T expectedLatestKnownInput, T input) {
+        assert expectedLatestKnownInput != null;
+        assert input != null;
+        return enqueuedInput.compareAndSet(Objects.requireNonNull(expectedLatestKnownInput), Objects.requireNonNull(input));
+    }
+
     /**
      * @return {@code false} iff there are no active/enqueued computations
      */
@@ -67,7 +77,7 @@ public abstract class ContinuousComputation<T> {
     /**
      * Process the given input.
      *
-     * @param input the value that was last received by {@link #onNewInput} before invocation.
+     * @param input the value that was last received by {@link #onNewInput} or {@link #compareAndEnqueue} before invocation.
      */
     protected abstract void processInput(T input);
 

+ 11 - 1
server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalance.java

@@ -20,7 +20,17 @@ import java.util.Objects;
  *
  * @param assignments a set of the (persistent) node IDs to which each {@link ShardId} should be allocated
  */
-public record DesiredBalance(long lastConvergedIndex, Map<ShardId, ShardAssignment> assignments) {
+public record DesiredBalance(long lastConvergedIndex, Map<ShardId, ShardAssignment> assignments, ComputationFinishReason finishReason) {
+
+    enum ComputationFinishReason {
+        CONVERGED,
+        YIELD_TO_NEW_INPUT,
+        STOP_EARLY
+    }
+
+    public DesiredBalance(long lastConvergedIndex, Map<ShardId, ShardAssignment> assignments) {
+        this(lastConvergedIndex, assignments, ComputationFinishReason.CONVERGED);
+    }
 
     public static final DesiredBalance INITIAL = new DesiredBalance(-1, Map.of());
 

+ 57 - 9
server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceComputer.java

@@ -38,6 +38,7 @@ import java.util.Queue;
 import java.util.Set;
 import java.util.TreeMap;
 import java.util.TreeSet;
+import java.util.function.LongSupplier;
 import java.util.function.Predicate;
 
 import static java.util.stream.Collectors.toUnmodifiableSet;
@@ -49,8 +50,8 @@ public class DesiredBalanceComputer {
 
     private static final Logger logger = LogManager.getLogger(DesiredBalanceComputer.class);
 
-    private final ThreadPool threadPool;
     private final ShardsAllocator delegateAllocator;
+    private final LongSupplier timeSupplierMillis;
 
     // stats
     protected final MeanMetric iterations = new MeanMetric();
@@ -63,12 +64,28 @@ public class DesiredBalanceComputer {
         Setting.Property.NodeScope
     );
 
+    public static final Setting<TimeValue> MAX_BALANCE_COMPUTATION_TIME_DURING_INDEX_CREATION_SETTING = Setting.timeSetting(
+        "cluster.routing.allocation.desired_balance.max_balance_computation_time_during_index_creation",
+        TimeValue.timeValueSeconds(1),
+        Setting.Property.Dynamic,
+        Setting.Property.NodeScope
+    );
+
     private TimeValue progressLogInterval;
+    private long maxBalanceComputationTimeDuringIndexCreationMillis;
 
     public DesiredBalanceComputer(ClusterSettings clusterSettings, ThreadPool threadPool, ShardsAllocator delegateAllocator) {
-        this.threadPool = threadPool;
+        this(clusterSettings, delegateAllocator, threadPool::relativeTimeInMillis);
+    }
+
+    DesiredBalanceComputer(ClusterSettings clusterSettings, ShardsAllocator delegateAllocator, LongSupplier timeSupplierMillis) {
         this.delegateAllocator = delegateAllocator;
+        this.timeSupplierMillis = timeSupplierMillis;
         clusterSettings.initializeAndWatch(PROGRESS_LOG_INTERVAL_SETTING, value -> this.progressLogInterval = value);
+        clusterSettings.initializeAndWatch(
+            MAX_BALANCE_COMPUTATION_TIME_DURING_INDEX_CREATION_SETTING,
+            value -> this.maxBalanceComputationTimeDuringIndexCreationMillis = value.millis()
+        );
     }
 
     public DesiredBalance compute(
@@ -77,7 +94,6 @@ public class DesiredBalanceComputer {
         Queue<List<MoveAllocationCommand>> pendingDesiredBalanceMoves,
         Predicate<DesiredBalanceInput> isFresh
     ) {
-
         if (logger.isTraceEnabled()) {
             logger.trace(
                 "Recomputing desired balance for [{}]: {}, {}, {}, {}",
@@ -97,9 +113,10 @@ public class DesiredBalanceComputer {
         final var changes = routingAllocation.changes();
         final var ignoredShards = getIgnoredShardsWithDiscardedAllocationStatus(desiredBalanceInput.ignoredShards());
         final var clusterInfoSimulator = new ClusterInfoSimulator(routingAllocation);
+        DesiredBalance.ComputationFinishReason finishReason = DesiredBalance.ComputationFinishReason.CONVERGED;
 
         if (routingNodes.size() == 0) {
-            return new DesiredBalance(desiredBalanceInput.index(), Map.of());
+            return new DesiredBalance(desiredBalanceInput.index(), Map.of(), finishReason);
         }
 
         // we assume that all ongoing recoveries will complete
@@ -263,11 +280,12 @@ public class DesiredBalanceComputer {
 
         final int iterationCountReportInterval = computeIterationCountReportInterval(routingAllocation);
         final long timeWarningInterval = progressLogInterval.millis();
-        final long computationStartedTime = threadPool.relativeTimeInMillis();
+        final long computationStartedTime = timeSupplierMillis.getAsLong();
         long nextReportTime = computationStartedTime + timeWarningInterval;
 
         int i = 0;
         boolean hasChanges = false;
+        boolean assignedNewlyCreatedPrimaryShards = false;
         while (true) {
             if (hasChanges) {
                 // Not the first iteration, so every remaining unassigned shard has been ignored, perhaps due to throttling. We must bring
@@ -293,6 +311,15 @@ public class DesiredBalanceComputer {
                 for (final var shardRouting : routingNode) {
                     if (shardRouting.initializing()) {
                         hasChanges = true;
+                        if (shardRouting.primary()
+                            && shardRouting.unassignedInfo() != null
+                            && shardRouting.unassignedInfo().reason() == UnassignedInfo.Reason.INDEX_CREATED) {
+                            // TODO: we could include more cases that would cause early publishing of desired balance in case of a long
+                            // computation. e.g.:
+                            // - unassigned search replicas in case the shard has no assigned shard replicas
+                            // - other reasons for an unassigned shard such as NEW_INDEX_RESTORED
+                            assignedNewlyCreatedPrimaryShards = true;
+                        }
                         clusterInfoSimulator.simulateShardStarted(shardRouting);
                         routingNodes.startShard(shardRouting, changes, 0L);
                     }
@@ -301,14 +328,14 @@ public class DesiredBalanceComputer {
 
             i++;
             final int iterations = i;
-            final long currentTime = threadPool.relativeTimeInMillis();
+            final long currentTime = timeSupplierMillis.getAsLong();
             final boolean reportByTime = nextReportTime <= currentTime;
             final boolean reportByIterationCount = i % iterationCountReportInterval == 0;
             if (reportByTime || reportByIterationCount) {
                 nextReportTime = currentTime + timeWarningInterval;
             }
 
-            if (hasChanges == false) {
+            if (hasComputationConverged(hasChanges, i)) {
                 logger.debug(
                     "Desired balance computation for [{}] converged after [{}] and [{}] iterations",
                     desiredBalanceInput.index(),
@@ -324,9 +351,25 @@ public class DesiredBalanceComputer {
                     "Desired balance computation for [{}] interrupted after [{}] and [{}] iterations as newer cluster state received. "
                         + "Publishing intermediate desired balance and restarting computation",
                     desiredBalanceInput.index(),
+                    TimeValue.timeValueMillis(currentTime - computationStartedTime).toString(),
+                    i
+                );
+                finishReason = DesiredBalance.ComputationFinishReason.YIELD_TO_NEW_INPUT;
+                break;
+            }
+
+            if (assignedNewlyCreatedPrimaryShards
+                && currentTime - computationStartedTime >= maxBalanceComputationTimeDuringIndexCreationMillis) {
+                logger.info(
+                    "Desired balance computation for [{}] interrupted after [{}] and [{}] iterations "
+                        + "in order to not delay assignment of newly created index shards for more than [{}]. "
+                        + "Publishing intermediate desired balance and restarting computation",
+                    desiredBalanceInput.index(),
+                    TimeValue.timeValueMillis(currentTime - computationStartedTime).toString(),
                     i,
-                    TimeValue.timeValueMillis(currentTime - computationStartedTime).toString()
+                    TimeValue.timeValueMillis(maxBalanceComputationTimeDuringIndexCreationMillis).toString()
                 );
+                finishReason = DesiredBalance.ComputationFinishReason.STOP_EARLY;
                 break;
             }
 
@@ -368,7 +411,12 @@ public class DesiredBalanceComputer {
         }
 
         long lastConvergedIndex = hasChanges ? previousDesiredBalance.lastConvergedIndex() : desiredBalanceInput.index();
-        return new DesiredBalance(lastConvergedIndex, assignments);
+        return new DesiredBalance(lastConvergedIndex, assignments, finishReason);
+    }
+
+    // visible for testing
+    boolean hasComputationConverged(boolean hasRoutingChanges, int currentIteration) {
+        return hasRoutingChanges == false;
     }
 
     private static Map<ShardId, ShardAssignment> collectShardAssignments(RoutingNodes routingNodes) {

+ 10 - 1
server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceShardsAllocator.java

@@ -134,7 +134,16 @@ public class DesiredBalanceShardsAllocator implements ShardsAllocator {
                     )
                 );
                 computationsExecuted.inc();
-                if (isFresh(desiredBalanceInput)) {
+
+                if (currentDesiredBalance.finishReason() == DesiredBalance.ComputationFinishReason.STOP_EARLY) {
+                    logger.debug(
+                        "Desired balance computation for [{}] terminated early with partial result, scheduling reconciliation",
+                        index
+                    );
+                    submitReconcileTask(currentDesiredBalance);
+                    var newInput = DesiredBalanceInput.create(indexGenerator.incrementAndGet(), desiredBalanceInput.routingAllocation());
+                    desiredBalanceComputation.compareAndEnqueue(desiredBalanceInput, newInput);
+                } else if (isFresh(desiredBalanceInput)) {
                     logger.debug("Desired balance computation for [{}] is completed, scheduling reconciliation", index);
                     computationsConverged.inc();
                     submitReconcileTask(currentDesiredBalance);

+ 1 - 0
server/src/main/java/org/elasticsearch/common/settings/ClusterSettings.java

@@ -219,6 +219,7 @@ public final class ClusterSettings extends AbstractScopedSettings {
         DataStreamAutoShardingService.CLUSTER_AUTO_SHARDING_MAX_WRITE_THREADS,
         DataStreamAutoShardingService.CLUSTER_AUTO_SHARDING_MIN_WRITE_THREADS,
         DesiredBalanceComputer.PROGRESS_LOG_INTERVAL_SETTING,
+        DesiredBalanceComputer.MAX_BALANCE_COMPUTATION_TIME_DURING_INDEX_CREATION_SETTING,
         DesiredBalanceReconciler.UNDESIRED_ALLOCATIONS_LOG_INTERVAL_SETTING,
         DesiredBalanceReconciler.UNDESIRED_ALLOCATIONS_LOG_THRESHOLD_SETTING,
         BreakerSettings.CIRCUIT_BREAKER_LIMIT_SETTING,

+ 63 - 0
server/src/test/java/org/elasticsearch/cluster/routing/allocation/allocator/ContinuousComputationTests.java

@@ -21,6 +21,7 @@ import java.util.Arrays;
 import java.util.concurrent.CyclicBarrier;
 import java.util.concurrent.Semaphore;
 import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicReference;
 
@@ -73,6 +74,68 @@ public class ContinuousComputationTests extends ESTestCase {
         assertTrue(Arrays.toString(valuePerThread) + " vs " + result.get(), Arrays.stream(valuePerThread).anyMatch(i -> i == result.get()));
     }
 
+    public void testCompareAndEnqueue() throws Exception {
+        final var initialInput = new Object();
+        final var compareAndEnqueueCount = between(1, 10);
+        final var remaining = new AtomicInteger(compareAndEnqueueCount);
+        final var computationsExecuted = new AtomicInteger();
+        final var result = new AtomicReference<>();
+        final var computation = new ContinuousComputation<>(threadPool.generic()) {
+            @Override
+            protected void processInput(Object input) {
+                result.set(input);
+                if (remaining.decrementAndGet() >= 0) {
+                    compareAndEnqueue(input, new Object());
+                }
+                computationsExecuted.incrementAndGet();
+            }
+        };
+        computation.onNewInput(initialInput);
+        assertBusy(() -> assertFalse(computation.isActive()));
+        assertNotEquals(result.get(), initialInput);
+        assertEquals(computationsExecuted.get(), 1 + compareAndEnqueueCount);
+    }
+
+    public void testCompareAndEnqueueSkipped() throws Exception {
+        final var barrier = new CyclicBarrier(2);
+        final var computationsExecuted = new AtomicInteger();
+        final var initialInput = new Object();
+        final var conditionalInput = new Object();
+        final var newInput = new Object();
+        final var submitConditional = new AtomicBoolean(true);
+        final var result = new AtomicReference<>();
+
+        final var computation = new ContinuousComputation<>(threadPool.generic()) {
+            @Override
+            protected void processInput(Object input) {
+                assertNotEquals(input, conditionalInput);
+                safeAwait(barrier);  // start
+                safeAwait(barrier);  // continue
+                if (submitConditional.getAndSet(false)) {
+                    compareAndEnqueue(input, conditionalInput);
+                }
+                result.set(input);
+                safeAwait(barrier);  // finished
+                computationsExecuted.incrementAndGet();
+            }
+        };
+        computation.onNewInput(initialInput);
+
+        safeAwait(barrier);  // start
+        computation.onNewInput(newInput);
+        safeAwait(barrier);  // continue
+        safeAwait(barrier);  // finished
+        assertEquals(result.get(), initialInput);
+
+        safeAwait(barrier);  // start
+        safeAwait(barrier);  // continue
+        safeAwait(barrier);  // finished
+
+        assertBusy(() -> assertFalse(computation.isActive()));
+        assertEquals(result.get(), newInput);
+        assertEquals(computationsExecuted.get(), 2);
+    }
+
     public void testSkipsObsoleteValues() throws Exception {
         final var barrier = new CyclicBarrier(2);
         final Runnable await = () -> safeAwait(barrier);

+ 6 - 1
server/src/test/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceComputerTests.java

@@ -1210,7 +1210,12 @@ public class DesiredBalanceComputerTests extends ESAllocationTestCase {
         var currentTime = new AtomicLong(0L);
         when(mockThreadPool.relativeTimeInMillis()).thenAnswer(invocation -> currentTime.addAndGet(eachIterationDuration));
 
-        var desiredBalanceComputer = new DesiredBalanceComputer(createBuiltInClusterSettings(), mockThreadPool, new ShardsAllocator() {
+        // Some runs of this test try to simulate a long desired balance computation. Setting a high value on the following setting
+        // prevents interrupting a long computation.
+        var clusterSettings = createBuiltInClusterSettings(
+            Settings.builder().put(DesiredBalanceComputer.MAX_BALANCE_COMPUTATION_TIME_DURING_INDEX_CREATION_SETTING.getKey(), "2m").build()
+        );
+        var desiredBalanceComputer = new DesiredBalanceComputer(clusterSettings, mockThreadPool, new ShardsAllocator() {
             @Override
             public void allocate(RoutingAllocation allocation) {
                 final var unassignedIterator = allocation.routingNodes().unassigned().iterator();

+ 187 - 14
server/src/test/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceShardsAllocatorTests.java

@@ -9,6 +9,7 @@
 
 package org.elasticsearch.cluster.routing.allocation.allocator;
 
+import org.apache.logging.log4j.Level;
 import org.apache.lucene.util.SetOnce;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.support.ActionTestUtils;
@@ -52,6 +53,7 @@ import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.snapshots.SnapshotShardSizeInfo;
 import org.elasticsearch.telemetry.TelemetryProvider;
 import org.elasticsearch.test.ClusterServiceUtils;
+import org.elasticsearch.test.MockLog;
 import org.elasticsearch.threadpool.TestThreadPool;
 
 import java.util.List;
@@ -59,11 +61,12 @@ import java.util.Map;
 import java.util.Queue;
 import java.util.Set;
 import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.CyclicBarrier;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicLong;
 import java.util.concurrent.atomic.AtomicReference;
-import java.util.function.BiConsumer;
 import java.util.function.Consumer;
 import java.util.function.Predicate;
 
@@ -85,14 +88,19 @@ public class DesiredBalanceShardsAllocatorTests extends ESAllocationTestCase {
     public void testGatewayAllocatorPreemptsAllocation() {
         final var nodeId = randomFrom(LOCAL_NODE_ID, OTHER_NODE_ID);
         testAllocate(
-            (allocation, unassignedAllocationHandler) -> unassignedAllocationHandler.initialize(nodeId, null, 0L, allocation.changes()),
+            (shardRouting, allocation, unassignedAllocationHandler) -> unassignedAllocationHandler.initialize(
+                nodeId,
+                null,
+                0L,
+                allocation.changes()
+            ),
             routingTable -> assertEquals(nodeId, routingTable.index("test-index").shard(0).primaryShard().currentNodeId())
         );
     }
 
     public void testGatewayAllocatorStillFetching() {
         testAllocate(
-            (allocation, unassignedAllocationHandler) -> unassignedAllocationHandler.removeAndIgnore(
+            (shardRouting, allocation, unassignedAllocationHandler) -> unassignedAllocationHandler.removeAndIgnore(
                 UnassignedInfo.AllocationStatus.FETCHING_SHARD_DATA,
                 allocation.changes()
             ),
@@ -108,17 +116,14 @@ public class DesiredBalanceShardsAllocatorTests extends ESAllocationTestCase {
     }
 
     public void testGatewayAllocatorDoesNothing() {
-        testAllocate((allocation, unassignedAllocationHandler) -> {}, routingTable -> {
+        testAllocate((shardRouting, allocation, unassignedAllocationHandler) -> {}, routingTable -> {
             var shardRouting = routingTable.shardRoutingTable("test-index", 0).primaryShard();
             assertTrue(shardRouting.assignedToNode());// assigned by a followup reconciliation
             assertThat(shardRouting.unassignedInfo().lastAllocationStatus(), equalTo(UnassignedInfo.AllocationStatus.NO_ATTEMPT));
         });
     }
 
-    public void testAllocate(
-        BiConsumer<RoutingAllocation, ExistingShardsAllocator.UnassignedAllocationHandler> allocateUnassigned,
-        Consumer<RoutingTable> verifier
-    ) {
+    public void testAllocate(AllocateUnassignedHandler allocateUnassigned, Consumer<RoutingTable> verifier) {
         var deterministicTaskQueue = new DeterministicTaskQueue();
         var threadPool = deterministicTaskQueue.getThreadPool();
 
@@ -295,7 +300,7 @@ public class DesiredBalanceShardsAllocatorTests extends ESAllocationTestCase {
         var allocationService = new AllocationService(
             new AllocationDeciders(List.of()),
             createGatewayAllocator(
-                (allocation, unassignedAllocationHandler) -> unassignedAllocationHandler.removeAndIgnore(
+                (shardRouting, allocation, unassignedAllocationHandler) -> unassignedAllocationHandler.removeAndIgnore(
                     UnassignedInfo.AllocationStatus.NO_ATTEMPT,
                     allocation.changes()
                 )
@@ -336,6 +341,157 @@ public class DesiredBalanceShardsAllocatorTests extends ESAllocationTestCase {
         }
     }
 
+    public void testIndexCreationInterruptsLongDesiredBalanceComputation() throws Exception {
+        var discoveryNode = newNode("node-0");
+        var initialState = ClusterState.builder(ClusterName.DEFAULT)
+            .nodes(DiscoveryNodes.builder().add(discoveryNode).localNodeId(discoveryNode.getId()).masterNodeId(discoveryNode.getId()))
+            .build();
+        final var ignoredIndexName = "index-ignored";
+
+        var threadPool = new TestThreadPool(getTestName());
+        var time = new AtomicLong(threadPool.relativeTimeInMillis());
+        var clusterService = ClusterServiceUtils.createClusterService(initialState, threadPool);
+        var allocationServiceRef = new SetOnce<AllocationService>();
+        var reconcileAction = new DesiredBalanceReconcilerAction() {
+            @Override
+            public ClusterState apply(ClusterState clusterState, RerouteStrategy routingAllocationAction) {
+                return allocationServiceRef.get().executeWithRoutingAllocation(clusterState, "reconcile", routingAllocationAction);
+            }
+        };
+
+        var gatewayAllocator = createGatewayAllocator((shardRouting, allocation, unassignedAllocationHandler) -> {
+            if (shardRouting.getIndexName().equals(ignoredIndexName)) {
+                unassignedAllocationHandler.removeAndIgnore(UnassignedInfo.AllocationStatus.NO_ATTEMPT, allocation.changes());
+            }
+        });
+        var shardsAllocator = new ShardsAllocator() {
+            @Override
+            public void allocate(RoutingAllocation allocation) {
+                // simulate long computation
+                time.addAndGet(1_000);
+                var dataNodeId = allocation.nodes().getDataNodes().values().iterator().next().getId();
+                var unassignedIterator = allocation.routingNodes().unassigned().iterator();
+                while (unassignedIterator.hasNext()) {
+                    unassignedIterator.next();
+                    unassignedIterator.initialize(dataNodeId, null, 0L, allocation.changes());
+                }
+            }
+
+            @Override
+            public ShardAllocationDecision decideShardAllocation(ShardRouting shard, RoutingAllocation allocation) {
+                throw new AssertionError("only used for allocation explain");
+            }
+        };
+
+        // Make sure the computation takes at least a few iterations, where each iteration takes 1s (see {@code #shardsAllocator.allocate}).
+        // By setting the following setting we ensure the desired balance computation will be interrupted early to not delay assigning
+        // newly created primary shards. This ensures that we hit a desired balance computation (3s) which is longer than the configured
+        // setting below.
+        var clusterSettings = createBuiltInClusterSettings(
+            Settings.builder().put(DesiredBalanceComputer.MAX_BALANCE_COMPUTATION_TIME_DURING_INDEX_CREATION_SETTING.getKey(), "2s").build()
+        );
+        final int minIterations = between(3, 10);
+        var desiredBalanceShardsAllocator = new DesiredBalanceShardsAllocator(
+            shardsAllocator,
+            threadPool,
+            clusterService,
+            new DesiredBalanceComputer(clusterSettings, shardsAllocator, time::get) {
+                @Override
+                public DesiredBalance compute(
+                    DesiredBalance previousDesiredBalance,
+                    DesiredBalanceInput desiredBalanceInput,
+                    Queue<List<MoveAllocationCommand>> pendingDesiredBalanceMoves,
+                    Predicate<DesiredBalanceInput> isFresh
+                ) {
+                    return super.compute(previousDesiredBalance, desiredBalanceInput, pendingDesiredBalanceMoves, isFresh);
+                }
+
+                @Override
+                boolean hasComputationConverged(boolean hasRoutingChanges, int currentIteration) {
+                    return super.hasComputationConverged(hasRoutingChanges, currentIteration) && currentIteration >= minIterations;
+                }
+            },
+            reconcileAction,
+            TelemetryProvider.NOOP
+        );
+        var allocationService = createAllocationService(desiredBalanceShardsAllocator, gatewayAllocator);
+        allocationServiceRef.set(allocationService);
+
+        var rerouteFinished = new CyclicBarrier(2);
+        // A mock cluster state update task for creating an index
+        class CreateIndexTask extends ClusterStateUpdateTask {
+            private final String indexName;
+
+            private CreateIndexTask(String indexName) {
+                this.indexName = indexName;
+            }
+
+            @Override
+            public ClusterState execute(ClusterState currentState) throws Exception {
+                var indexMetadata = createIndex(indexName);
+                var newState = ClusterState.builder(currentState)
+                    .metadata(Metadata.builder(currentState.metadata()).put(indexMetadata, true))
+                    .routingTable(
+                        RoutingTable.builder(TestShardRoutingRoleStrategies.DEFAULT_ROLE_ONLY, currentState.routingTable())
+                            .addAsNew(indexMetadata)
+                    )
+                    .build();
+                return allocationService.reroute(
+                    newState,
+                    "test",
+                    ActionTestUtils.assertNoFailureListener(response -> safeAwait(rerouteFinished))
+                );
+            }
+
+            @Override
+            public void onFailure(Exception e) {
+                throw new AssertionError(e);
+            }
+        }
+
+        final var computationInterruptedMessage =
+            "Desired balance computation for * interrupted * in order to not delay assignment of newly created index shards *";
+        try {
+            // Create a new index which is not ignored and therefore must be considered when a desired balance
+            // computation takes longer than 2s.
+            assertThat(desiredBalanceShardsAllocator.getStats().computationExecuted(), equalTo(0L));
+            MockLog.assertThatLogger(() -> {
+                clusterService.submitUnbatchedStateUpdateTask("test", new CreateIndexTask("index-1"));
+                safeAwait(rerouteFinished);
+                assertThat(clusterService.state().getRoutingTable().index("index-1").primaryShardsUnassigned(), equalTo(0));
+            },
+                DesiredBalanceComputer.class,
+                new MockLog.SeenEventExpectation(
+                    "Should log interrupted computation",
+                    DesiredBalanceComputer.class.getCanonicalName(),
+                    Level.INFO,
+                    computationInterruptedMessage
+                )
+            );
+            assertBusy(() -> assertFalse(desiredBalanceShardsAllocator.getStats().computationActive()));
+            assertThat(desiredBalanceShardsAllocator.getStats().computationExecuted(), equalTo(2L));
+            // The computation should not get interrupted when the newly created index shard stays unassigned.
+            MockLog.assertThatLogger(() -> {
+                clusterService.submitUnbatchedStateUpdateTask("test", new CreateIndexTask(ignoredIndexName));
+                safeAwait(rerouteFinished);
+                assertThat(clusterService.state().getRoutingTable().index(ignoredIndexName).primaryShardsUnassigned(), equalTo(1));
+            },
+                DesiredBalanceComputer.class,
+                new MockLog.UnseenEventExpectation(
+                    "Should log interrupted computation",
+                    DesiredBalanceComputer.class.getCanonicalName(),
+                    Level.INFO,
+                    computationInterruptedMessage
+                )
+            );
+            assertBusy(() -> assertFalse(desiredBalanceShardsAllocator.getStats().computationActive()));
+            assertThat(desiredBalanceShardsAllocator.getStats().computationExecuted(), equalTo(3L));
+        } finally {
+            clusterService.close();
+            terminate(threadPool);
+        }
+    }
+
     public void testCallListenersOnlyAfterProducingFreshInput() throws InterruptedException {
 
         var reconciliations = new AtomicInteger(0);
@@ -772,13 +928,30 @@ public class DesiredBalanceShardsAllocatorTests extends ESAllocationTestCase {
         return createGatewayAllocator(DesiredBalanceShardsAllocatorTests::initialize);
     }
 
-    private static void initialize(RoutingAllocation allocation, ExistingShardsAllocator.UnassignedAllocationHandler handler) {
+    private static void initialize(
+        ShardRouting shardRouting,
+        RoutingAllocation allocation,
+        ExistingShardsAllocator.UnassignedAllocationHandler handler
+    ) {
         handler.initialize(allocation.nodes().getLocalNodeId(), null, 0L, allocation.changes());
     }
 
-    private static GatewayAllocator createGatewayAllocator(
-        BiConsumer<RoutingAllocation, ExistingShardsAllocator.UnassignedAllocationHandler> allocateUnassigned
-    ) {
+    /**
+     * A helper interface to simplify creating a GatewayAllocator in the tests by only requiring
+     * an implementation for {@link org.elasticsearch.cluster.routing.allocation.ExistingShardsAllocator#allocateUnassigned}.
+     */
+    interface AllocateUnassignedHandler {
+        void handle(
+            ShardRouting shardRouting,
+            RoutingAllocation allocation,
+            ExistingShardsAllocator.UnassignedAllocationHandler unassignedAllocationHandler
+        );
+    }
+
+    /**
+     * Creates an implementation of GatewayAllocator that delegates its logic for allocating unassigned shards to the provided handler.
+     */
+    private static GatewayAllocator createGatewayAllocator(AllocateUnassignedHandler allocateUnassigned) {
         return new GatewayAllocator() {
 
             @Override
@@ -790,7 +963,7 @@ public class DesiredBalanceShardsAllocatorTests extends ESAllocationTestCase {
                 RoutingAllocation allocation,
                 UnassignedAllocationHandler unassignedAllocationHandler
             ) {
-                allocateUnassigned.accept(allocation, unassignedAllocationHandler);
+                allocateUnassigned.handle(shardRouting, allocation, unassignedAllocationHandler);
             }
 
             @Override