Explorar o código

Further unblocking of recovery (#95270)

Follow-on from #95115 using the new async `runUnderPrimaryPermit` to
remove almost all the remaining blocking in `RecoverySourceHandler`.
David Turner %!s(int64=2) %!d(string=hai) anos
pai
achega
da1404ea8c

+ 212 - 200
server/src/main/java/org/elasticsearch/indices/recovery/RecoverySourceHandler.java

@@ -17,13 +17,11 @@ import org.apache.lucene.store.IOContext;
 import org.apache.lucene.store.IndexInput;
 import org.apache.lucene.store.RateLimiter;
 import org.apache.lucene.util.ArrayUtil;
-import org.apache.lucene.util.SetOnce;
 import org.elasticsearch.ExceptionsHelper;
 import org.elasticsearch.Version;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.StepListener;
 import org.elasticsearch.action.support.ListenableActionFuture;
-import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.action.support.SubscribableListener;
 import org.elasticsearch.action.support.ThreadedActionListener;
 import org.elasticsearch.action.support.replication.ReplicationResponse;
@@ -37,9 +35,9 @@ import org.elasticsearch.common.logging.Loggers;
 import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.common.util.CancellableThreads;
 import org.elasticsearch.common.util.concurrent.CountDown;
-import org.elasticsearch.common.util.concurrent.FutureUtils;
 import org.elasticsearch.common.util.concurrent.ListenableFuture;
 import org.elasticsearch.common.util.set.Sets;
+import org.elasticsearch.core.CheckedRunnable;
 import org.elasticsearch.core.IOUtils;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.Releasable;
@@ -183,9 +181,7 @@ public class RecoverySourceHandler {
                 IOUtils.closeWhileHandlingException(releaseResources, () -> future.onFailure(e));
             };
 
-            final SetOnce<RetentionLease> retentionLeaseRef = new SetOnce<>();
-
-            runUnderPrimaryPermit(() -> {
+            runUnderPrimaryPermit(retentionLeaseListener -> {
                 final IndexShardRoutingTable routingTable = shard.getReplicationGroup().getRoutingTable();
                 ShardRouting targetShardRouting = routingTable.getByAllocationId(request.targetAllocationId());
                 if (targetShardRouting == null) {
@@ -197,174 +193,185 @@ public class RecoverySourceHandler {
                     throw new DelayRecoveryException("source node does not have the shard listed in its state as allocated on the node");
                 }
                 assert targetShardRouting.initializing() : "expected recovery target to be initializing but was " + targetShardRouting;
-                retentionLeaseRef.set(
+                retentionLeaseListener.onResponse(
                     shard.getRetentionLeases().get(ReplicationTracker.getPeerRecoveryRetentionLeaseId(targetShardRouting))
                 );
-            }, shard, cancellableThreads);
-            final Closeable retentionLock = shard.acquireHistoryRetentionLock();
-            resources.add(retentionLock);
-            final long startingSeqNo;
-            final boolean isSequenceNumberBasedRecovery = request.startingSeqNo() != SequenceNumbers.UNASSIGNED_SEQ_NO
-                && isTargetSameHistory()
-                && shard.hasCompleteHistoryOperations("peer-recovery", request.startingSeqNo())
-                && ((retentionLeaseRef.get() == null && shard.useRetentionLeasesInPeerRecovery() == false)
-                    || (retentionLeaseRef.get() != null && retentionLeaseRef.get().retainingSequenceNumber() <= request.startingSeqNo()));
-            // NB check hasCompleteHistoryOperations when computing isSequenceNumberBasedRecovery, even if there is a retention lease,
-            // because when doing a rolling upgrade from earlier than 7.4 we may create some leases that are initially unsatisfied. It's
-            // possible there are other cases where we cannot satisfy all leases, because that's not a property we currently expect to hold.
-            // Also it's pretty cheap when soft deletes are enabled, and it'd be a disaster if we tried a sequence-number-based recovery
-            // without having a complete history.
-
-            if (isSequenceNumberBasedRecovery && retentionLeaseRef.get() != null) {
-                // all the history we need is retained by an existing retention lease, so we do not need a separate retention lock
-                retentionLock.close();
-                logger.trace("history is retained by {}", retentionLeaseRef.get());
-            } else {
-                // all the history we need is retained by the retention lock, obtained before calling shard.hasCompleteHistoryOperations()
-                // and before acquiring the safe commit we'll be using, so we can be certain that all operations after the safe commit's
-                // local checkpoint will be retained for the duration of this recovery.
-                logger.trace("history is retained by retention lock");
-            }
-
-            final StepListener<SendFileResult> sendFileStep = new StepListener<>();
-            final StepListener<TimeValue> prepareEngineStep = new StepListener<>();
-            final StepListener<SendSnapshotResult> sendSnapshotStep = new StepListener<>();
-            final StepListener<Void> finalizeStep = new StepListener<>();
-
-            if (isSequenceNumberBasedRecovery) {
-                logger.trace("performing sequence numbers based recovery. starting at [{}]", request.startingSeqNo());
-                startingSeqNo = request.startingSeqNo();
-                if (retentionLeaseRef.get() == null) {
-                    createRetentionLease(startingSeqNo, sendFileStep.map(ignored -> SendFileResult.EMPTY));
-                } else {
-                    sendFileStep.onResponse(SendFileResult.EMPTY);
-                }
-            } else {
-                final Engine.IndexCommitRef safeCommitRef;
-                try {
-                    safeCommitRef = acquireSafeCommit(shard);
-                    resources.add(safeCommitRef);
-                } catch (final Exception e) {
-                    throw new RecoveryEngineException(shard.shardId(), 1, "snapshot failed", e);
-                }
-
-                // Try and copy enough operations to the recovering peer so that if it is promoted to primary then it has a chance of being
-                // able to recover other replicas using operations-based recoveries. If we are not using retention leases then we
-                // conservatively copy all available operations. If we are using retention leases then "enough operations" is just the
-                // operations from the local checkpoint of the safe commit onwards, because when using soft deletes the safe commit retains
-                // at least as much history as anything else. The safe commit will often contain all the history retained by the current set
-                // of retention leases, but this is not guaranteed: an earlier peer recovery from a different primary might have created a
-                // retention lease for some history that this primary already discarded, since we discard history when the global checkpoint
-                // advances and not when creating a new safe commit. In any case this is a best-effort thing since future recoveries can
-                // always fall back to file-based ones, and only really presents a problem if this primary fails before things have settled
-                // down.
-                startingSeqNo = Long.parseLong(safeCommitRef.getIndexCommit().getUserData().get(SequenceNumbers.LOCAL_CHECKPOINT_KEY)) + 1L;
-                logger.trace("performing file-based recovery followed by history replay starting at [{}]", startingSeqNo);
+            },
+                shard,
+                cancellableThreads,
+                ActionListener.wrap((RetentionLease retentionLease) -> recoverToTarget(retentionLease, onFailure), onFailure)
+            );
+        } catch (Exception e) {
+            IOUtils.closeWhileHandlingException(releaseResources, () -> future.onFailure(e));
+        }
+    }
 
-                try {
-                    final int estimateNumOps = estimateNumberOfHistoryOperations(startingSeqNo);
-                    final Releasable releaseStore = acquireStore(shard.store());
-                    resources.add(releaseStore);
-                    sendFileStep.whenComplete(r -> IOUtils.close(safeCommitRef, releaseStore), e -> {
-                        try {
-                            IOUtils.close(safeCommitRef, releaseStore);
-                        } catch (Exception ex) {
-                            logger.warn("releasing snapshot caused exception", ex);
-                        }
-                    });
+    private void recoverToTarget(RetentionLease retentionLease, Consumer<Exception> onFailure) throws IOException {
+        final Closeable retentionLock = shard.acquireHistoryRetentionLock();
+        resources.add(retentionLock);
+        final long startingSeqNo;
+        final boolean isSequenceNumberBasedRecovery = request.startingSeqNo() != SequenceNumbers.UNASSIGNED_SEQ_NO
+            && isTargetSameHistory()
+            && shard.hasCompleteHistoryOperations("peer-recovery", request.startingSeqNo())
+            && ((retentionLease == null && shard.useRetentionLeasesInPeerRecovery() == false)
+                || (retentionLease != null && retentionLease.retainingSequenceNumber() <= request.startingSeqNo()));
+        // NB check hasCompleteHistoryOperations when computing isSequenceNumberBasedRecovery, even if there is a retention lease,
+        // because when doing a rolling upgrade from earlier than 7.4 we may create some leases that are initially unsatisfied. It's
+        // possible there are other cases where we cannot satisfy all leases, because that's not a property we currently expect to hold.
+        // Also it's pretty cheap when soft deletes are enabled, and it'd be a disaster if we tried a sequence-number-based recovery
+        // without having a complete history.
+
+        if (isSequenceNumberBasedRecovery && retentionLease != null) {
+            // all the history we need is retained by an existing retention lease, so we do not need a separate retention lock
+            retentionLock.close();
+            logger.trace("history is retained by {}", retentionLease);
+        } else {
+            // all the history we need is retained by the retention lock, obtained before calling shard.hasCompleteHistoryOperations()
+            // and before acquiring the safe commit we'll be using, so we can be certain that all operations after the safe commit's
+            // local checkpoint will be retained for the duration of this recovery.
+            logger.trace("history is retained by retention lock");
+        }
 
-                    // If the target previously had a copy of this shard then a file-based recovery might move its global checkpoint
-                    // backwards. We must therefore remove any existing retention lease so that we can create a new one later on in the
-                    // recovery.
-                    deleteRetentionLease(ActionListener.wrap(ignored -> {
-                        assert Transports.assertNotTransportThread(RecoverySourceHandler.this + "[phase1]");
-                        phase1(safeCommitRef.getIndexCommit(), startingSeqNo, () -> estimateNumOps, sendFileStep);
-                    }, onFailure));
+        final StepListener<SendFileResult> sendFileStep = new StepListener<>();
+        final StepListener<TimeValue> prepareEngineStep = new StepListener<>();
+        final StepListener<SendSnapshotResult> sendSnapshotStep = new StepListener<>();
+        final StepListener<Void> finalizeStep = new StepListener<>();
 
-                } catch (final Exception e) {
-                    throw new RecoveryEngineException(shard.shardId(), 1, "sendFileStep failed", e);
-                }
+        if (isSequenceNumberBasedRecovery) {
+            logger.trace("performing sequence numbers based recovery. starting at [{}]", request.startingSeqNo());
+            startingSeqNo = request.startingSeqNo();
+            if (retentionLease == null) {
+                createRetentionLease(startingSeqNo, sendFileStep.map(ignored -> SendFileResult.EMPTY));
+            } else {
+                sendFileStep.onResponse(SendFileResult.EMPTY);
+            }
+        } else {
+            final Engine.IndexCommitRef safeCommitRef;
+            try {
+                safeCommitRef = acquireSafeCommit(shard);
+                resources.add(safeCommitRef);
+            } catch (final Exception e) {
+                throw new RecoveryEngineException(shard.shardId(), 1, "snapshot failed", e);
             }
-            assert startingSeqNo >= 0 : "startingSeqNo must be non negative. got: " + startingSeqNo;
 
-            sendFileStep.whenComplete(r -> {
-                assert Transports.assertNotTransportThread(RecoverySourceHandler.this + "[prepareTargetForTranslog]");
-                // For a sequence based recovery, the target can keep its local translog
-                prepareTargetForTranslog(estimateNumberOfHistoryOperations(startingSeqNo), prepareEngineStep);
-            }, onFailure);
+            // Try and copy enough operations to the recovering peer so that if it is promoted to primary then it has a chance of being
+            // able to recover other replicas using operations-based recoveries. If we are not using retention leases then we
+            // conservatively copy all available operations. If we are using retention leases then "enough operations" is just the
+            // operations from the local checkpoint of the safe commit onwards, because when using soft deletes the safe commit retains
+            // at least as much history as anything else. The safe commit will often contain all the history retained by the current set
+            // of retention leases, but this is not guaranteed: an earlier peer recovery from a different primary might have created a
+            // retention lease for some history that this primary already discarded, since we discard history when the global checkpoint
+            // advances and not when creating a new safe commit. In any case this is a best-effort thing since future recoveries can
+            // always fall back to file-based ones, and only really presents a problem if this primary fails before things have settled
+            // down.
+            startingSeqNo = Long.parseLong(safeCommitRef.getIndexCommit().getUserData().get(SequenceNumbers.LOCAL_CHECKPOINT_KEY)) + 1L;
+            logger.trace("performing file-based recovery followed by history replay starting at [{}]", startingSeqNo);
 
-            prepareEngineStep.whenComplete(prepareEngineTime -> {
-                assert Transports.assertNotTransportThread(RecoverySourceHandler.this + "[phase2]");
-                /*
-                 * add shard to replication group (shard will receive replication requests from this point on) now that engine is open.
-                 * This means that any document indexed into the primary after this will be replicated to this replica as well
-                 * make sure to do this before sampling the max sequence number in the next step, to ensure that we send
-                 * all documents up to maxSeqNo in phase2.
-                 */
-                runUnderPrimaryPermit(() -> shard.initiateTracking(request.targetAllocationId()), shard, cancellableThreads);
+            try {
+                final int estimateNumOps = estimateNumberOfHistoryOperations(startingSeqNo);
+                final Releasable releaseStore = acquireStore(shard.store());
+                resources.add(releaseStore);
+                sendFileStep.whenComplete(r -> IOUtils.close(safeCommitRef, releaseStore), e -> {
+                    try {
+                        IOUtils.close(safeCommitRef, releaseStore);
+                    } catch (Exception ex) {
+                        logger.warn("releasing snapshot caused exception", ex);
+                    }
+                });
 
-                final long endingSeqNo = shard.seqNoStats().getMaxSeqNo();
-                logger.trace("snapshot for recovery; current size is [{}]", estimateNumberOfHistoryOperations(startingSeqNo));
-                final Translog.Snapshot phase2Snapshot = shard.newChangesSnapshot(
-                    "peer-recovery",
-                    startingSeqNo,
-                    Long.MAX_VALUE,
-                    false,
-                    false,
-                    true
-                );
-                resources.add(phase2Snapshot);
-                retentionLock.close();
-
-                // we have to capture the max_seen_auto_id_timestamp and the max_seq_no_of_updates to make sure that these values
-                // are at least as high as the corresponding values on the primary when any of these operations were executed on it.
-                final long maxSeenAutoIdTimestamp = shard.getMaxSeenAutoIdTimestamp();
-                final long maxSeqNoOfUpdatesOrDeletes = shard.getMaxSeqNoOfUpdatesOrDeletes();
-                final RetentionLeases retentionLeases = shard.getRetentionLeases();
-                final long mappingVersionOnPrimary = shard.indexSettings().getIndexMetadata().getMappingVersion();
-                phase2(
-                    startingSeqNo,
-                    endingSeqNo,
-                    phase2Snapshot,
-                    maxSeenAutoIdTimestamp,
-                    maxSeqNoOfUpdatesOrDeletes,
-                    retentionLeases,
-                    mappingVersionOnPrimary,
-                    sendSnapshotStep
-                );
+                // If the target previously had a copy of this shard then a file-based recovery might move its global checkpoint
+                // backwards. We must therefore remove any existing retention lease so that we can create a new one later on in the
+                // recovery.
+                deleteRetentionLease(ActionListener.wrap(ignored -> {
+                    assert Transports.assertNotTransportThread(RecoverySourceHandler.this + "[phase1]");
+                    phase1(safeCommitRef.getIndexCommit(), startingSeqNo, () -> estimateNumOps, sendFileStep);
+                }, onFailure));
 
-            }, onFailure);
-
-            // Recovery target can trim all operations >= startingSeqNo as we have sent all these operations in the phase 2
-            final long trimAboveSeqNo = startingSeqNo - 1;
-            sendSnapshotStep.whenComplete(r -> finalizeRecovery(r.targetLocalCheckpoint, trimAboveSeqNo, finalizeStep), onFailure);
-
-            finalizeStep.whenComplete(r -> {
-                final long phase1ThrottlingWaitTime = 0L; // TODO: return the actual throttle time
-                final SendSnapshotResult sendSnapshotResult = sendSnapshotStep.result();
-                final SendFileResult sendFileResult = sendFileStep.result();
-                final RecoveryResponse response = new RecoveryResponse(
-                    sendFileResult.phase1FileNames,
-                    sendFileResult.phase1FileSizes,
-                    sendFileResult.phase1ExistingFileNames,
-                    sendFileResult.phase1ExistingFileSizes,
-                    sendFileResult.totalSize,
-                    sendFileResult.existingTotalSize,
-                    sendFileResult.took.millis(),
-                    phase1ThrottlingWaitTime,
-                    prepareEngineStep.result().millis(),
-                    sendSnapshotResult.sentOperations,
-                    sendSnapshotResult.tookTime.millis()
-                );
-                try {
-                    future.onResponse(response);
-                } finally {
-                    IOUtils.close(resources);
-                }
-            }, onFailure);
-        } catch (Exception e) {
-            IOUtils.closeWhileHandlingException(releaseResources, () -> future.onFailure(e));
+            } catch (final Exception e) {
+                throw new RecoveryEngineException(shard.shardId(), 1, "sendFileStep failed", e);
+            }
         }
+        assert startingSeqNo >= 0 : "startingSeqNo must be non negative. got: " + startingSeqNo;
+
+        sendFileStep.whenComplete(r -> {
+            assert Transports.assertNotTransportThread(RecoverySourceHandler.this + "[prepareTargetForTranslog]");
+            // For a sequence based recovery, the target can keep its local translog
+            prepareTargetForTranslog(estimateNumberOfHistoryOperations(startingSeqNo), prepareEngineStep);
+        }, onFailure);
+
+        prepareEngineStep.whenComplete(prepareEngineTime -> {
+            assert Transports.assertNotTransportThread(RecoverySourceHandler.this + "[phase2]");
+            /*
+             * add shard to replication group (shard will receive replication requests from this point on) now that engine is open.
+             * This means that any document indexed into the primary after this will be replicated to this replica as well
+             * make sure to do this before sampling the max sequence number in the next step, to ensure that we send
+             * all documents up to maxSeqNo in phase2.
+             */
+            runUnderPrimaryPermit(
+                () -> shard.initiateTracking(request.targetAllocationId()),
+                shard,
+                cancellableThreads,
+                ActionListener.wrap(ignored -> {
+                    final long endingSeqNo = shard.seqNoStats().getMaxSeqNo();
+                    logger.trace("snapshot for recovery; current size is [{}]", estimateNumberOfHistoryOperations(startingSeqNo));
+                    final Translog.Snapshot phase2Snapshot = shard.newChangesSnapshot(
+                        "peer-recovery",
+                        startingSeqNo,
+                        Long.MAX_VALUE,
+                        false,
+                        false,
+                        true
+                    );
+                    resources.add(phase2Snapshot);
+                    retentionLock.close();
+
+                    // we have to capture the max_seen_auto_id_timestamp and the max_seq_no_of_updates to make sure that these values
+                    // are at least as high as the corresponding values on the primary when any of these operations were executed on it.
+                    final long maxSeenAutoIdTimestamp = shard.getMaxSeenAutoIdTimestamp();
+                    final long maxSeqNoOfUpdatesOrDeletes = shard.getMaxSeqNoOfUpdatesOrDeletes();
+                    final RetentionLeases retentionLeases = shard.getRetentionLeases();
+                    final long mappingVersionOnPrimary = shard.indexSettings().getIndexMetadata().getMappingVersion();
+                    phase2(
+                        startingSeqNo,
+                        endingSeqNo,
+                        phase2Snapshot,
+                        maxSeenAutoIdTimestamp,
+                        maxSeqNoOfUpdatesOrDeletes,
+                        retentionLeases,
+                        mappingVersionOnPrimary,
+                        sendSnapshotStep
+                    );
+                }, onFailure)
+            );
+        }, onFailure);
+
+        // Recovery target can trim all operations >= startingSeqNo as we have sent all these operations in the phase 2
+        final long trimAboveSeqNo = startingSeqNo - 1;
+        sendSnapshotStep.whenComplete(r -> finalizeRecovery(r.targetLocalCheckpoint, trimAboveSeqNo, finalizeStep), onFailure);
+
+        finalizeStep.whenComplete(r -> {
+            final long phase1ThrottlingWaitTime = 0L; // TODO: return the actual throttle time
+            final SendSnapshotResult sendSnapshotResult = sendSnapshotStep.result();
+            final SendFileResult sendFileResult = sendFileStep.result();
+            final RecoveryResponse response = new RecoveryResponse(
+                sendFileResult.phase1FileNames,
+                sendFileResult.phase1FileSizes,
+                sendFileResult.phase1ExistingFileNames,
+                sendFileResult.phase1ExistingFileSizes,
+                sendFileResult.totalSize,
+                sendFileResult.existingTotalSize,
+                sendFileResult.took.millis(),
+                phase1ThrottlingWaitTime,
+                prepareEngineStep.result().millis(),
+                sendSnapshotResult.sentOperations,
+                sendSnapshotResult.tookTime.millis()
+            );
+            try {
+                future.onResponse(response);
+            } finally {
+                IOUtils.close(resources);
+            }
+        }, onFailure);
     }
 
     private boolean isTargetSameHistory() {
@@ -377,27 +384,6 @@ public class RecoverySourceHandler {
         return shard.countChanges("peer-recovery", startingSeqNo, Long.MAX_VALUE);
     }
 
-    static void runUnderPrimaryPermit(
-        CancellableThreads.Interruptible runnable,
-        IndexShard primary,
-        CancellableThreads cancellableThreads
-    ) {
-        cancellableThreads.execute(() -> {
-            final var listener = new ListenableFuture<Releasable>();
-            final var future = new PlainActionFuture<Releasable>();
-            listener.addListener(future);
-
-            primary.acquirePrimaryOperationPermit(listener, ThreadPool.Names.SAME);
-            try (var ignored = FutureUtils.get(future)) {
-                ensureNotRelocatedPrimary(primary);
-                runnable.run();
-            } finally {
-                // add a listener to release the permit because we might have been interrupted while waiting (double-releasing is ok)
-                listener.addListener(ActionListener.wrap(Releasable::close, e -> {}));
-            }
-        });
-    }
-
     /**
      * Run {@code action} while holding a primary permit, checking for cancellation both before and after. Completing the listener passed to
      * {@code action} releases the permit before passing the result through to {@code outerListener}.
@@ -434,6 +420,18 @@ public class RecoverySourceHandler {
         })), ThreadPool.Names.GENERIC);
     }
 
+    static void runUnderPrimaryPermit(
+        CheckedRunnable<Exception> action,
+        IndexShard primary,
+        CancellableThreads cancellableThreads,
+        ActionListener<Void> listener
+    ) {
+        runUnderPrimaryPermit(l -> ActionListener.completeWith(l, () -> {
+            action.run();
+            return null;
+        }), primary, cancellableThreads, listener);
+    }
+
     private static void ensureNotRelocatedPrimary(IndexShard indexShard) {
         // check that the IndexShard still has the primary authority. This needs to be checked under operation permit to prevent
         // races, as IndexShard will switch its authority only when it holds all operation permits, see IndexShard.relocated()
@@ -1250,38 +1248,52 @@ public class RecoverySourceHandler {
          * marking the shard as in-sync. If the relocation handoff holds all the permits then after the handoff completes and we acquire
          * the permit then the state of the shard will be relocated and this recovery will fail.
          */
+        final StepListener<Void> markInSyncStep = new StepListener<>();
         runUnderPrimaryPermit(
             () -> shard.markAllocationIdAsInSync(request.targetAllocationId(), targetLocalCheckpoint),
             shard,
-            cancellableThreads
+            cancellableThreads,
+            markInSyncStep
         );
-        final long globalCheckpoint = shard.getLastKnownGlobalCheckpoint(); // this global checkpoint is persisted in finalizeRecovery
-        final StepListener<Void> finalizeListener = new StepListener<>();
-        cancellableThreads.checkForCancel();
-        recoveryTarget.finalizeRecovery(globalCheckpoint, trimAboveSeqNo, finalizeListener);
-        finalizeListener.whenComplete(r -> {
+
+        final StepListener<Long> finalizeListener = new StepListener<>();
+        markInSyncStep.whenComplete(ignored -> {
+            final long globalCheckpoint = shard.getLastKnownGlobalCheckpoint(); // this global checkpoint is persisted in finalizeRecovery
+            cancellableThreads.checkForCancel();
+            recoveryTarget.finalizeRecovery(globalCheckpoint, trimAboveSeqNo, finalizeListener.map(ignored2 -> globalCheckpoint));
+        }, listener::onFailure);
+
+        final StepListener<Void> updateGlobalCheckpointStep = new StepListener<>();
+        finalizeListener.whenComplete(globalCheckpoint -> {
             runUnderPrimaryPermit(
                 () -> shard.updateGlobalCheckpointForShard(request.targetAllocationId(), globalCheckpoint),
                 shard,
-                cancellableThreads
+                cancellableThreads,
+                updateGlobalCheckpointStep
             );
+        }, listener::onFailure);
 
-            if (request.isPrimaryRelocation()) {
+        final StepListener<Void> finalStep;
+        if (request.isPrimaryRelocation()) {
+            finalStep = new StepListener<>();
+            updateGlobalCheckpointStep.whenComplete(ignored -> {
                 logger.trace("performing relocation hand-off");
-                // this acquires all IndexShard operation permits and will thus delay new recoveries until it is done
                 cancellableThreads.execute(
-                    () -> shard.relocated(request.targetAllocationId(), recoveryTarget::handoffPrimaryContext, ActionListener.wrap(v -> {
-                        cancellableThreads.checkForCancel();
-                        completeFinalizationListener(listener, stopWatch);
-                    }, listener::onFailure))
+                    // this acquires all IndexShard operation permits and will thus delay new recoveries until it is done
+                    () -> shard.relocated(request.targetAllocationId(), recoveryTarget::handoffPrimaryContext, finalStep)
                 );
                 /*
                  * if the recovery process fails after disabling primary mode on the source shard, both relocation source and
                  * target are failed (see {@link IndexShard#updateRoutingEntry}).
                  */
-            } else {
-                completeFinalizationListener(listener, stopWatch);
-            }
+            }, listener::onFailure);
+        } else {
+            finalStep = updateGlobalCheckpointStep;
+        }
+
+        finalStep.whenComplete(ignored -> {
+            cancellableThreads.checkForCancel();
+            completeFinalizationListener(listener, stopWatch);
         }, listener::onFailure);
     }
 

+ 11 - 24
server/src/test/java/org/elasticsearch/indices/recovery/RecoverySourceHandlerTests.java

@@ -106,7 +106,6 @@ import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicLong;
 import java.util.concurrent.atomic.AtomicReference;
-import java.util.function.BiConsumer;
 import java.util.function.Function;
 import java.util.function.IntSupplier;
 import java.util.stream.Collectors;
@@ -785,29 +784,8 @@ public class RecoverySourceHandlerTests extends MapperServiceTestCase {
         assertFalse(phase2Called.get());
     }
 
-    public void testCancellationsDoesNotLeakPrimaryPermits() throws Exception {
-        runPrimaryPermitsLeakTest((shard, cancellableThreads) -> {
-            RecoverySourceHandler.runUnderPrimaryPermit(() -> {}, shard, cancellableThreads);
-        });
-    }
-
-    public void testCancellationsDoesNotLeakPrimaryPermitsAsync() throws Exception {
-        runPrimaryPermitsLeakTest((shard, cancellableThreads) -> {
-            PlainActionFuture.<Void, RuntimeException>get(
-                future -> RecoverySourceHandler.runUnderPrimaryPermit(
-                    listener -> listener.onResponse(null),
-                    shard,
-                    cancellableThreads,
-                    future
-                ),
-                10,
-                TimeUnit.SECONDS
-            );
-        });
-    }
-
     @SuppressWarnings("unchecked")
-    private static void runPrimaryPermitsLeakTest(BiConsumer<IndexShard, CancellableThreads> acquireAndReleasePermit) throws Exception {
+    public void testCancellationsDoesNotLeakPrimaryPermits() throws Exception {
         final CancellableThreads cancellableThreads = new CancellableThreads();
         final IndexShard shard = mock(IndexShard.class);
         final AtomicBoolean freed = new AtomicBoolean(true);
@@ -821,7 +799,16 @@ public class RecoverySourceHandlerTests extends MapperServiceTestCase {
         Thread cancelingThread = new Thread(() -> cancellableThreads.cancel("test"));
         cancelingThread.start();
         try {
-            acquireAndReleasePermit.accept(shard, cancellableThreads);
+            PlainActionFuture.<Void, RuntimeException>get(
+                future -> RecoverySourceHandler.runUnderPrimaryPermit(
+                    listener -> listener.onResponse(null),
+                    shard,
+                    cancellableThreads,
+                    future
+                ),
+                10,
+                TimeUnit.SECONDS
+            );
         } catch (CancellableThreads.ExecutionCancelledException e) {
             // expected.
         }