Browse Source

Harden global checkpoint tracker

This commit refactors the global checkpont tracker to make it more
resilient. The main idea is to make it more explicit what state is
actually captured and how that state is updated through
replication/cluster state updates etc. It also fixes the issue where the
local checkpoint information is not being updated when a shard becomes
primary. The primary relocation handoff becomes very simple too, we can
just verbatim copy over the internal state.

Relates #25468
Yannick Welsch 8 years ago
parent
commit
baa87db5d1
23 changed files with 1150 additions and 988 deletions
  1. 3 1
      core/src/main/java/org/elasticsearch/action/support/replication/TransportReplicationAction.java
  2. 512 384
      core/src/main/java/org/elasticsearch/index/seqno/GlobalCheckpointTracker.java
  3. 54 14
      core/src/main/java/org/elasticsearch/index/seqno/SequenceNumbersService.java
  4. 147 127
      core/src/main/java/org/elasticsearch/index/shard/IndexShard.java
  5. 0 105
      core/src/main/java/org/elasticsearch/index/shard/PrimaryContext.java
  6. 21 25
      core/src/main/java/org/elasticsearch/indices/cluster/IndicesClusterStateService.java
  7. 6 5
      core/src/main/java/org/elasticsearch/indices/recovery/RecoveryHandoffPrimaryContextRequest.java
  8. 14 12
      core/src/main/java/org/elasticsearch/indices/recovery/RecoverySourceHandler.java
  9. 3 3
      core/src/main/java/org/elasticsearch/indices/recovery/RecoveryTarget.java
  10. 2 2
      core/src/main/java/org/elasticsearch/indices/recovery/RecoveryTargetHandler.java
  11. 2 2
      core/src/main/java/org/elasticsearch/indices/recovery/RemoteRecoveryTargetHandler.java
  12. 0 1
      core/src/test/java/org/elasticsearch/cluster/MinimumMasterNodesIT.java
  13. 3 6
      core/src/test/java/org/elasticsearch/index/engine/InternalEngineTests.java
  14. 7 18
      core/src/test/java/org/elasticsearch/index/replication/ESIndexLevelReplicationTestCase.java
  15. 13 8
      core/src/test/java/org/elasticsearch/index/replication/IndexLevelReplicationTests.java
  16. 7 0
      core/src/test/java/org/elasticsearch/index/replication/RecoveryDuringReplicationTests.java
  17. 256 248
      core/src/test/java/org/elasticsearch/index/seqno/GlobalCheckpointTrackerTests.java
  18. 39 10
      core/src/test/java/org/elasticsearch/index/shard/IndexShardTests.java
  19. 1 1
      core/src/test/java/org/elasticsearch/index/shard/PrimaryReplicaSyncerTests.java
  20. 9 9
      core/src/test/java/org/elasticsearch/indices/cluster/AbstractIndicesClusterStateServiceTestCase.java
  21. 6 1
      core/src/test/java/org/elasticsearch/indices/recovery/RecoverySourceHandlerTests.java
  22. 0 1
      core/src/test/java/org/elasticsearch/recovery/FullRollingRestartIT.java
  23. 45 5
      test/framework/src/main/java/org/elasticsearch/index/shard/IndexShardTestCase.java

+ 3 - 1
core/src/main/java/org/elasticsearch/action/support/replication/TransportReplicationAction.java

@@ -1027,7 +1027,7 @@ public abstract class TransportReplicationAction<
                 localCheckpoint = in.readZLong();
             } else {
                 // 5.x used to read empty responses, which don't really read anything off the stream, so just do nothing.
-                localCheckpoint = SequenceNumbersService.UNASSIGNED_SEQ_NO;
+                localCheckpoint = SequenceNumbersService.PRE_60_NODE_LOCAL_CHECKPOINT;
             }
         }
 
@@ -1202,6 +1202,8 @@ public abstract class TransportReplicationAction<
             super.readFrom(in);
             if (in.getVersion().onOrAfter(Version.V_6_0_0_alpha1)) {
                 globalCheckpoint = in.readZLong();
+            } else {
+                globalCheckpoint = SequenceNumbersService.UNASSIGNED_SEQ_NO;
             }
         }
 

+ 512 - 384
core/src/main/java/org/elasticsearch/index/seqno/GlobalCheckpointTracker.java

@@ -19,23 +19,20 @@
 
 package org.elasticsearch.index.seqno;
 
-import com.carrotsearch.hppc.ObjectLongHashMap;
-import com.carrotsearch.hppc.ObjectLongMap;
-import com.carrotsearch.hppc.cursors.ObjectLongCursor;
 import org.elasticsearch.common.SuppressForbidden;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.index.IndexSettings;
 import org.elasticsearch.index.shard.AbstractIndexShardComponent;
-import org.elasticsearch.index.shard.PrimaryContext;
 import org.elasticsearch.index.shard.ShardId;
 
-import java.util.Arrays;
-import java.util.Comparator;
+import java.io.IOException;
+import java.util.Collection;
+import java.util.HashMap;
 import java.util.HashSet;
-import java.util.List;
-import java.util.Locale;
+import java.util.Map;
 import java.util.Set;
-import java.util.stream.Collectors;
-import java.util.stream.StreamSupport;
 
 /**
  * This class is responsible of tracking the global checkpoint. The global checkpoint is the highest sequence number for which all lower (or
@@ -48,154 +45,191 @@ import java.util.stream.StreamSupport;
  */
 public class GlobalCheckpointTracker extends AbstractIndexShardComponent {
 
-    long appliedClusterStateVersion;
+    /**
+     * The global checkpoint tracker can operate in two modes:
+     * - primary: this shard is in charge of collecting local checkpoint information from all shard copies and computing the global
+     *            checkpoint based on the local checkpoints of all in-sync shard copies.
+     * - replica: this shard receives global checkpoint information from the primary (see {@link #updateGlobalCheckpointOnReplica}).
+     *
+     * When a shard is initialized (be it a primary or replica), it initially operates in replica mode. The global checkpoint tracker is
+     * then switched to primary mode in the following three scenarios:
+     *
+     * - An initializing primary shard that is not a relocation target is moved to primary mode (using {@link #activatePrimaryMode}) once
+     *   the shard becomes active.
+     * - An active replica shard is moved to primary mode (using {@link #activatePrimaryMode}) once it is promoted to primary.
+     * - A primary relocation target is moved to primary mode (using {@link #activateWithPrimaryContext}) during the primary relocation
+     *   handoff. If the target shard is successfully initialized in primary mode, the source shard of a primary relocation is then moved
+     *   to replica mode (using {@link #completeRelocationHandoff}), as the relocation target will be in charge of the global checkpoint
+     *   computation from that point on.
+     */
+    boolean primaryMode;
+    /**
+     * Boolean flag that indicates if a relocation handoff is in progress. A handoff is started by calling {@link #startRelocationHandoff}
+     * and is finished by either calling {@link #completeRelocationHandoff} or {@link #abortRelocationHandoff}, depending on whether the
+     * handoff was successful or not. During the handoff, which has as main objective to transfer the internal state of the global
+     * checkpoint tracker from the relocation source to the target, the list of in-sync shard copies cannot grow, otherwise the relocation
+     * target might miss this information and increase the global checkpoint to eagerly. As consequence, some of the methods in this class
+     * are not allowed to be called while a handoff is in progress, in particular {@link #markAllocationIdAsInSync}.
+     *
+     * A notable exception to this is the method {@link #updateFromMaster}, which is still allowed to be called during a relocation handoff.
+     * The reason for this is that the handoff might fail and can be aborted (using {@link #abortRelocationHandoff}), in which case
+     * it is important that the global checkpoint tracker does not miss any state updates that might happened during the handoff attempt.
+     * This means, however, that the global checkpoint can still advance after the primary relocation handoff has been initiated, but only
+     * because the master could have failed some of the in-sync shard copies and marked them as stale. That is ok though, as this
+     * information is conveyed through cluster state updates, and the new primary relocation target will also eventually learn about those.
+     */
+    boolean handoffInProgress;
 
-    /*
-     * This map holds the last known local checkpoint for every active shard and initializing shard copies that has been brought up to speed
-     * through recovery. These shards are treated as valid copies and participate in determining the global checkpoint. This map is keyed by
-     * allocation IDs. All accesses to this set are guarded by a lock on this.
+    /**
+     * The global checkpoint tracker relies on the property that cluster state updates are applied in-order. After transferring a primary
+     * context from the primary relocation source to the target and initializing the target, it is possible for the target to apply a
+     * cluster state that is older than the one upon which the primary context was based. If we allowed this old cluster state
+     * to influence the list of in-sync shard copies here, this could possibly remove such an in-sync copy from the internal structures
+     * until the newer cluster state were to be applied, which would unsafely advance the global checkpoint. This field thus captures
+     * the version of the last applied cluster state to ensure in-order updates.
      */
-    final ObjectLongMap<String> inSyncLocalCheckpoints;
+    long appliedClusterStateVersion;
 
-    /*
-     * This map holds the last known local checkpoint for initializing shards that are undergoing recovery. Such shards do not participate
-     * in determining the global checkpoint. We must track these local checkpoints so that when a shard is activated we use the highest
-     * known checkpoint.
+    /**
+     * Local checkpoint information for all shard copies that are tracked. Has an entry for all shard copies that are either initializing
+     * and / or in-sync, possibly also containing information about unassigned in-sync shard copies. The information that is tracked for
+     * each shard copy is explained in the docs for the {@link LocalCheckpointState} class.
      */
-    final ObjectLongMap<String> trackingLocalCheckpoints;
+    final Map<String, LocalCheckpointState> localCheckpoints;
 
-    /*
+    /**
      * This set contains allocation IDs for which there is a thread actively waiting for the local checkpoint to advance to at least the
      * current global checkpoint.
      */
     final Set<String> pendingInSync;
 
-    /*
-     * The current global checkpoint for this shard. Note that this field is guarded by a lock on this and thus this field does not need to
-     * be volatile.
+    /**
+     * The global checkpoint:
+     * - computed based on local checkpoints, if the tracker is in primary mode
+     * - received from the primary, if the tracker is in replica mode
      */
-    private long globalCheckpoint;
+    long globalCheckpoint;
 
-    /*
-     * During relocation handoff, the state of the global checkpoint tracker is sampled. After sampling, there should be no additional
-     * mutations to this tracker until the handoff has completed.
-     */
-    private boolean sealed = false;
+    public static class LocalCheckpointState implements Writeable {
 
-    /**
-     * Initialize the global checkpoint service. The specified global checkpoint should be set to the last known global checkpoint, or
-     * {@link SequenceNumbersService#UNASSIGNED_SEQ_NO}.
-     *
-     * @param shardId          the shard ID
-     * @param indexSettings    the index settings
-     * @param globalCheckpoint the last known global checkpoint for this shard, or {@link SequenceNumbersService#UNASSIGNED_SEQ_NO}
-     */
-    GlobalCheckpointTracker(final ShardId shardId, final IndexSettings indexSettings, final long globalCheckpoint) {
-        super(shardId, indexSettings);
-        assert globalCheckpoint >= SequenceNumbersService.UNASSIGNED_SEQ_NO : "illegal initial global checkpoint: " + globalCheckpoint;
-        this.inSyncLocalCheckpoints = new ObjectLongHashMap<>(1 + indexSettings.getNumberOfReplicas());
-        this.trackingLocalCheckpoints = new ObjectLongHashMap<>(indexSettings.getNumberOfReplicas());
-        this.globalCheckpoint = globalCheckpoint;
-        this.pendingInSync = new HashSet<>();
-    }
+        /**
+         * the last local checkpoint information that we have for this shard
+         */
+        long localCheckpoint;
+        /**
+         * whether this shard is treated as in-sync and thus contributes to the global checkpoint calculation
+         */
+        boolean inSync;
 
-    /**
-     * Notifies the service to update the local checkpoint for the shard with the provided allocation ID. If the checkpoint is lower than
-     * the currently known one, this is a no-op. If the allocation ID is not tracked, it is ignored. This is to prevent late arrivals from
-     * shards that are removed to be re-added.
-     *
-     * @param allocationId    the allocation ID of the shard to update the local checkpoint for
-     * @param localCheckpoint the local checkpoint for the shard
-     */
-    public synchronized void updateLocalCheckpoint(final String allocationId, final long localCheckpoint) {
-        if (sealed) {
-            throw new IllegalStateException("global checkpoint tracker is sealed");
+        public LocalCheckpointState(long localCheckpoint, boolean inSync) {
+            this.localCheckpoint = localCheckpoint;
+            this.inSync = inSync;
         }
-        final boolean updated;
-        if (updateLocalCheckpoint(allocationId, localCheckpoint, inSyncLocalCheckpoints, "in-sync")) {
-            updated = true;
-            updateGlobalCheckpointOnPrimary();
-        } else if (updateLocalCheckpoint(allocationId, localCheckpoint, trackingLocalCheckpoints, "tracking")) {
-            updated = true;
-        } else {
-            logger.trace("ignored local checkpoint [{}] of [{}], allocation ID is not tracked", localCheckpoint, allocationId);
-            updated = false;
+
+        public LocalCheckpointState(StreamInput in) throws IOException {
+            this.localCheckpoint = in.readZLong();
+            this.inSync = in.readBoolean();
         }
-        if (updated) {
-            notifyAllWaiters();
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            out.writeZLong(localCheckpoint);
+            out.writeBoolean(inSync);
         }
-    }
 
-    /**
-     * Notify all threads waiting on the monitor on this tracker. These threads should be waiting for the local checkpoint on a specific
-     * allocation ID to catch up to the global checkpoint.
-     */
-    @SuppressForbidden(reason = "Object#notifyAll waiters for local checkpoint advancement")
-    private synchronized void notifyAllWaiters() {
-        this.notifyAll();
+        /**
+         * Returns a full copy of this object
+         */
+        public LocalCheckpointState copy() {
+            return new LocalCheckpointState(localCheckpoint, inSync);
+        }
+
+        public long getLocalCheckpoint() {
+            return localCheckpoint;
+        }
+
+        @Override
+        public String toString() {
+            return "LocalCheckpointState{" +
+                "localCheckpoint=" + localCheckpoint +
+                ", inSync=" + inSync +
+                '}';
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+
+            LocalCheckpointState that = (LocalCheckpointState) o;
+
+            if (localCheckpoint != that.localCheckpoint) return false;
+            return inSync == that.inSync;
+        }
+
+        @Override
+        public int hashCode() {
+            int result = (int) (localCheckpoint ^ (localCheckpoint >>> 32));
+            result = 31 * result + (inSync ? 1 : 0);
+            return result;
+        }
     }
 
     /**
-     * Update the local checkpoint for the specified allocation ID in the specified tracking map. If the checkpoint is lower than the
-     * currently known one, this is a no-op. If the allocation ID is not tracked, it is ignored.
-     *
-     * @param allocationId the allocation ID of the shard to update the local checkpoint for
-     * @param localCheckpoint the local checkpoint for the shard
-     * @param map the tracking map
-     * @param reason the reason for the update (used for logging)
-     * @return {@code true} if the local checkpoint was updated, otherwise {@code false} if this was a no-op
+     * Class invariant that should hold before and after every invocation of public methods on this class. As Java lacks implication
+     * as a logical operator, many of the invariants are written under the form (!A || B), they should be read as (A implies B) however.
      */
-    private boolean updateLocalCheckpoint(
-            final String allocationId, final long localCheckpoint, ObjectLongMap<String> map, final String reason) {
-        final int index = map.indexOf(allocationId);
-        if (index >= 0) {
-            final long current = map.indexGet(index);
-            if (current < localCheckpoint) {
-                map.indexReplace(index, localCheckpoint);
-                logger.trace("updated local checkpoint of [{}] in [{}] from [{}] to [{}]", allocationId, reason, current, localCheckpoint);
-            } else {
-                logger.trace(
-                        "skipped updating local checkpoint of [{}] in [{}] from [{}] to [{}], current checkpoint is higher",
-                        allocationId,
-                        reason,
-                        current,
-                        localCheckpoint);
-            }
-            return true;
-        } else {
-            return false;
+    private boolean invariant() {
+        // local checkpoints only set during primary mode
+        assert primaryMode || localCheckpoints.values().stream()
+            .allMatch(lcps -> lcps.localCheckpoint == SequenceNumbersService.UNASSIGNED_SEQ_NO ||
+                lcps.localCheckpoint == SequenceNumbersService.PRE_60_NODE_LOCAL_CHECKPOINT);
+
+        // relocation handoff can only occur in primary mode
+        assert !handoffInProgress || primaryMode;
+
+        // there is at least one in-sync shard copy when the global checkpoint tracker operates in primary mode (i.e. the shard itself)
+        assert !primaryMode || localCheckpoints.values().stream().anyMatch(lcps -> lcps.inSync);
+
+        // during relocation handoff there are no entries blocking global checkpoint advancement
+        assert !handoffInProgress || pendingInSync.isEmpty() :
+            "entries blocking global checkpoint advancement during relocation handoff: " + pendingInSync;
+
+        // entries blocking global checkpoint advancement can only exist in primary mode and when not having a relocation handoff
+        assert pendingInSync.isEmpty() || (primaryMode && !handoffInProgress);
+
+        // the computed global checkpoint is always up-to-date
+        assert !primaryMode || globalCheckpoint == computeGlobalCheckpoint(pendingInSync, localCheckpoints.values(), globalCheckpoint) :
+            "global checkpoint is not up-to-date, expected: " +
+                computeGlobalCheckpoint(pendingInSync, localCheckpoints.values(), globalCheckpoint) + " but was: " + globalCheckpoint;
+
+        for (Map.Entry<String, LocalCheckpointState> entry : localCheckpoints.entrySet()) {
+            // blocking global checkpoint advancement only happens for shards that are not in-sync
+            assert !pendingInSync.contains(entry.getKey()) || !entry.getValue().inSync :
+                "shard copy " + entry.getKey() + " blocks global checkpoint advancement but is in-sync";
         }
+
+        return true;
     }
 
     /**
-     * Scans through the currently known local checkpoint and updates the global checkpoint accordingly.
+     * Initialize the global checkpoint service. The specified global checkpoint should be set to the last known global checkpoint, or
+     * {@link SequenceNumbersService#UNASSIGNED_SEQ_NO}.
+     *
+     * @param shardId          the shard ID
+     * @param indexSettings    the index settings
+     * @param globalCheckpoint the last known global checkpoint for this shard, or {@link SequenceNumbersService#UNASSIGNED_SEQ_NO}
      */
-    private synchronized void updateGlobalCheckpointOnPrimary() {
-        long minLocalCheckpoint = Long.MAX_VALUE;
-        if (inSyncLocalCheckpoints.isEmpty() || !pendingInSync.isEmpty()) {
-            return;
-        }
-        for (final ObjectLongCursor<String> localCheckpoint : inSyncLocalCheckpoints) {
-            if (localCheckpoint.value == SequenceNumbersService.UNASSIGNED_SEQ_NO) {
-                logger.trace("unknown local checkpoint for active allocation ID [{}], requesting a sync", localCheckpoint.key);
-                return;
-            }
-            minLocalCheckpoint = Math.min(localCheckpoint.value, minLocalCheckpoint);
-        }
-        assert minLocalCheckpoint != SequenceNumbersService.UNASSIGNED_SEQ_NO : "new global checkpoint must be assigned";
-        if (minLocalCheckpoint < globalCheckpoint) {
-            final String message =
-                    String.format(
-                            Locale.ROOT,
-                            "new global checkpoint [%d] is lower than previous one [%d]",
-                            minLocalCheckpoint,
-                            globalCheckpoint);
-            throw new IllegalStateException(message);
-        }
-        if (globalCheckpoint != minLocalCheckpoint) {
-            logger.trace("global checkpoint updated to [{}]", minLocalCheckpoint);
-            globalCheckpoint = minLocalCheckpoint;
-        }
+    GlobalCheckpointTracker(final ShardId shardId, final IndexSettings indexSettings, final long globalCheckpoint) {
+        super(shardId, indexSettings);
+        assert globalCheckpoint >= SequenceNumbersService.UNASSIGNED_SEQ_NO : "illegal initial global checkpoint: " + globalCheckpoint;
+        this.primaryMode = false;
+        this.handoffInProgress = false;
+        this.appliedClusterStateVersion = -1L;
+        this.globalCheckpoint = globalCheckpoint;
+        this.localCheckpoints = new HashMap<>(1 + indexSettings.getNumberOfReplicas());
+        this.pendingInSync = new HashSet<>();
+        assert invariant();
     }
 
     /**
@@ -212,7 +246,9 @@ public class GlobalCheckpointTracker extends AbstractIndexShardComponent {
      *
      * @param globalCheckpoint the global checkpoint
      */
-    synchronized void updateGlobalCheckpointOnReplica(final long globalCheckpoint) {
+    public synchronized void updateGlobalCheckpointOnReplica(final long globalCheckpoint) {
+        assert invariant();
+        assert primaryMode == false;
         /*
          * The global checkpoint here is a local knowledge which is updated under the mandate of the primary. It can happen that the primary
          * information is lagging compared to a replica (e.g., if a replica is promoted to primary but has stale info relative to other
@@ -222,325 +258,417 @@ public class GlobalCheckpointTracker extends AbstractIndexShardComponent {
             this.globalCheckpoint = globalCheckpoint;
             logger.trace("global checkpoint updated from primary to [{}]", globalCheckpoint);
         }
+        assert invariant();
+    }
+
+    /**
+     * Initializes the global checkpoint tracker in primary mode (see {@link #primaryMode}. Called on primary activation or promotion.
+     */
+    public synchronized void activatePrimaryMode(final String allocationId, final long localCheckpoint) {
+        assert invariant();
+        assert primaryMode == false;
+        assert localCheckpoints.get(allocationId) != null && localCheckpoints.get(allocationId).inSync &&
+            localCheckpoints.get(allocationId).localCheckpoint == SequenceNumbersService.UNASSIGNED_SEQ_NO :
+            "expected " + allocationId + " to have initialized entry in " + localCheckpoints + " when activating primary";
+        assert localCheckpoint >= SequenceNumbersService.NO_OPS_PERFORMED;
+        primaryMode = true;
+        updateLocalCheckpoint(allocationId, localCheckpoints.get(allocationId), localCheckpoint);
+        updateGlobalCheckpointOnPrimary();
+        assert invariant();
     }
 
     /**
-     * Notifies the service of the current allocation ids in the cluster state. This method trims any shards that have been removed.
+     * Notifies the tracker of the current allocation IDs in the cluster state.
      *
      * @param applyingClusterStateVersion the cluster state version being applied when updating the allocation IDs from the master
-     * @param activeAllocationIds         the allocation IDs of the currently active shard copies
+     * @param inSyncAllocationIds         the allocation IDs of the currently in-sync shard copies
      * @param initializingAllocationIds   the allocation IDs of the currently initializing shard copies
+     * @param pre60AllocationIds          the allocation IDs of shards that are allocated to pre-6.0 nodes
      */
-    public synchronized void updateAllocationIdsFromMaster(
-            final long applyingClusterStateVersion, final Set<String> activeAllocationIds, final Set<String> initializingAllocationIds) {
-        if (applyingClusterStateVersion < appliedClusterStateVersion) {
-            return;
-        }
-
-        appliedClusterStateVersion = applyingClusterStateVersion;
-
-        // remove shards whose allocation ID no longer exists
-        inSyncLocalCheckpoints.removeAll(a -> !activeAllocationIds.contains(a) && !initializingAllocationIds.contains(a));
-
-        // add any new active allocation IDs
-        for (final String a : activeAllocationIds) {
-            if (!inSyncLocalCheckpoints.containsKey(a)) {
-                final long localCheckpoint = trackingLocalCheckpoints.getOrDefault(a, SequenceNumbersService.UNASSIGNED_SEQ_NO);
-                inSyncLocalCheckpoints.put(a, localCheckpoint);
-                logger.trace("marked [{}] as in-sync with local checkpoint [{}] via cluster state update from master", a, localCheckpoint);
-            }
-        }
-
-        trackingLocalCheckpoints.removeAll(a -> !initializingAllocationIds.contains(a));
-        for (final String a : initializingAllocationIds) {
-            if (inSyncLocalCheckpoints.containsKey(a)) {
-                /*
-                 * This can happen if we mark the allocation ID as in sync at the end of recovery before seeing a cluster state update from
-                 * marking the shard as active.
-                 */
-                continue;
+    public synchronized void updateFromMaster(final long applyingClusterStateVersion, final Set<String> inSyncAllocationIds,
+                                              final Set<String> initializingAllocationIds, final Set<String> pre60AllocationIds) {
+        assert invariant();
+        if (applyingClusterStateVersion > appliedClusterStateVersion) {
+            // check that the master does not fabricate new in-sync entries out of thin air once we are in primary mode
+            assert !primaryMode || inSyncAllocationIds.stream().allMatch(
+                inSyncId -> localCheckpoints.containsKey(inSyncId) && localCheckpoints.get(inSyncId).inSync) :
+                "update from master in primary mode contains in-sync ids " + inSyncAllocationIds +
+                    " that have no matching entries in " + localCheckpoints;
+            // remove entries which don't exist on master
+            boolean removedEntries = localCheckpoints.keySet().removeIf(
+                aid -> !inSyncAllocationIds.contains(aid) && !initializingAllocationIds.contains(aid));
+
+            if (primaryMode) {
+                // add new initializingIds that are missing locally. These are fresh shard copies - and not in-sync
+                for (String initializingId : initializingAllocationIds) {
+                    if (localCheckpoints.containsKey(initializingId) == false) {
+                        final boolean inSync = inSyncAllocationIds.contains(initializingId);
+                        assert inSync == false : "update from master in primary mode has " + initializingId +
+                            " as in-sync but it does not exist locally";
+                        final long localCheckpoint = pre60AllocationIds.contains(initializingId) ?
+                            SequenceNumbersService.PRE_60_NODE_LOCAL_CHECKPOINT : SequenceNumbersService.UNASSIGNED_SEQ_NO;
+                        localCheckpoints.put(initializingId, new LocalCheckpointState(localCheckpoint, inSync));
+                    }
+                }
+            } else {
+                for (String initializingId : initializingAllocationIds) {
+                    final long localCheckpoint = pre60AllocationIds.contains(initializingId) ?
+                        SequenceNumbersService.PRE_60_NODE_LOCAL_CHECKPOINT : SequenceNumbersService.UNASSIGNED_SEQ_NO;
+                    localCheckpoints.put(initializingId, new LocalCheckpointState(localCheckpoint, false));
+                }
+                for (String inSyncId : inSyncAllocationIds) {
+                    final long localCheckpoint = pre60AllocationIds.contains(inSyncId) ?
+                        SequenceNumbersService.PRE_60_NODE_LOCAL_CHECKPOINT : SequenceNumbersService.UNASSIGNED_SEQ_NO;
+                    localCheckpoints.put(inSyncId, new LocalCheckpointState(localCheckpoint, true));
+                }
             }
-            if (trackingLocalCheckpoints.containsKey(a)) {
-                // we are already tracking this allocation ID
-                continue;
+            appliedClusterStateVersion = applyingClusterStateVersion;
+            if (primaryMode && removedEntries) {
+                updateGlobalCheckpointOnPrimary();
             }
-            // this is a new allocation ID
-            trackingLocalCheckpoints.put(a, SequenceNumbersService.UNASSIGNED_SEQ_NO);
-            logger.trace("tracking [{}] via cluster state update from master", a);
         }
-
-        updateGlobalCheckpointOnPrimary();
+        assert invariant();
     }
 
     /**
-     * Get the primary context for the shard. This includes the state of the global checkpoint tracker.
+     * Called when the recovery process for a shard is ready to open the engine on the target shard. Ensures that the right data structures
+     * have been set up locally to track local checkpoint information for the shard.
      *
-     * @return the primary context
+     * @param allocationId  the allocation ID of the shard for which recovery was initiated
      */
-    synchronized PrimaryContext primaryContext() {
-        if (sealed) {
-            throw new IllegalStateException("global checkpoint tracker is sealed");
+    public synchronized void initiateTracking(final String allocationId) {
+        assert invariant();
+        assert primaryMode;
+        LocalCheckpointState lcps = localCheckpoints.get(allocationId);
+        if (lcps == null) {
+            // can happen if replica was removed from cluster but recovery process is unaware of it yet
+            throw new IllegalStateException("no local checkpoint tracking information available");
         }
-        sealed = true;
-        final ObjectLongMap<String> inSyncLocalCheckpoints = new ObjectLongHashMap<>(this.inSyncLocalCheckpoints);
-        final ObjectLongMap<String> trackingLocalCheckpoints = new ObjectLongHashMap<>(this.trackingLocalCheckpoints);
-        return new PrimaryContext(appliedClusterStateVersion, inSyncLocalCheckpoints, trackingLocalCheckpoints);
+        assert invariant();
     }
 
     /**
-     * Releases a previously acquired primary context.
+     * Marks the shard with the provided allocation ID as in-sync with the primary shard. This method will block until the local checkpoint
+     * on the specified shard advances above the current global checkpoint.
+     *
+     * @param allocationId    the allocation ID of the shard to mark as in-sync
+     * @param localCheckpoint the current local checkpoint on the shard
      */
-    synchronized void releasePrimaryContext() {
-        assert sealed;
-        sealed = false;
+    public synchronized void markAllocationIdAsInSync(final String allocationId, final long localCheckpoint) throws InterruptedException {
+        assert invariant();
+        assert primaryMode;
+        assert handoffInProgress == false;
+        LocalCheckpointState lcps = localCheckpoints.get(allocationId);
+        if (lcps == null) {
+            // can happen if replica was removed from cluster but recovery process is unaware of it yet
+            throw new IllegalStateException("no local checkpoint tracking information available for " + allocationId);
+        }
+        assert localCheckpoint >= SequenceNumbersService.NO_OPS_PERFORMED :
+            "expected known local checkpoint for " + allocationId + " but was " + localCheckpoint;
+        assert pendingInSync.contains(allocationId) == false : "shard copy " + allocationId + " is already marked as pending in-sync";
+        updateLocalCheckpoint(allocationId, lcps, localCheckpoint);
+        // if it was already in-sync (because of a previously failed recovery attempt), global checkpoint must have been
+        // stuck from advancing
+        assert !lcps.inSync || (lcps.localCheckpoint >= globalCheckpoint) :
+            "shard copy " + allocationId + " that's already in-sync should have a local checkpoint " + lcps.localCheckpoint +
+                " that's above the global checkpoint " + globalCheckpoint;
+        if (lcps.localCheckpoint < globalCheckpoint) {
+            pendingInSync.add(allocationId);
+            try {
+                while (true) {
+                    if (pendingInSync.contains(allocationId)) {
+                        waitForLocalCheckpointToAdvance();
+                    } else {
+                        break;
+                    }
+                }
+            } finally {
+                pendingInSync.remove(allocationId);
+            }
+        } else {
+            lcps.inSync = true;
+            logger.trace("marked [{}] as in-sync", allocationId);
+            updateGlobalCheckpointOnPrimary();
+        }
+
+        assert invariant();
+    }
+
+    private boolean updateLocalCheckpoint(String allocationId, LocalCheckpointState lcps, long localCheckpoint) {
+        // a local checkpoint of PRE_60_NODE_LOCAL_CHECKPOINT cannot be overridden
+        assert lcps.localCheckpoint != SequenceNumbersService.PRE_60_NODE_LOCAL_CHECKPOINT ||
+            localCheckpoint == SequenceNumbersService.PRE_60_NODE_LOCAL_CHECKPOINT :
+            "pre-6.0 shard copy " + allocationId + " unexpected to send valid local checkpoint " + localCheckpoint;
+        if (localCheckpoint > lcps.localCheckpoint) {
+            logger.trace("updated local checkpoint of [{}] from [{}] to [{}]", allocationId, lcps.localCheckpoint, localCheckpoint);
+            lcps.localCheckpoint = localCheckpoint;
+            return true;
+        } else {
+            logger.trace("skipped updating local checkpoint of [{}] from [{}] to [{}], current checkpoint is higher", allocationId,
+                lcps.localCheckpoint, localCheckpoint);
+            return false;
+        }
     }
 
     /**
-     * Updates the known allocation IDs and the local checkpoints for the corresponding allocations from a primary relocation source.
+     * Notifies the service to update the local checkpoint for the shard with the provided allocation ID. If the checkpoint is lower than
+     * the currently known one, this is a no-op. If the allocation ID is not tracked, it is ignored.
      *
-     * @param primaryContext the primary context
+     * @param allocationId    the allocation ID of the shard to update the local checkpoint for
+     * @param localCheckpoint the local checkpoint for the shard
      */
-    synchronized void updateAllocationIdsFromPrimaryContext(final PrimaryContext primaryContext) {
-        if (sealed) {
-            throw new IllegalStateException("global checkpoint tracker is sealed");
+    public synchronized void updateLocalCheckpoint(final String allocationId, final long localCheckpoint) {
+        assert invariant();
+        assert primaryMode;
+        assert handoffInProgress == false;
+        LocalCheckpointState lcps = localCheckpoints.get(allocationId);
+        if (lcps == null) {
+            // can happen if replica was removed from cluster but replication process is unaware of it yet
+            return;
         }
-        /*
-         * We are gathered here today to witness the relocation handoff transferring knowledge from the relocation source to the relocation
-         * target. We need to consider the possibility that the version of the cluster state on the relocation source when the primary
-         * context was sampled is different than the version of the cluster state on the relocation target at this exact moment. We define
-         * the following values:
-         *  - version(source) = the cluster state version on the relocation source used to ensure a minimum cluster state version on the
-         *    relocation target
-         *  - version(context) = the cluster state version on the relocation source when the primary context was sampled
-         *  - version(target) = the current cluster state version on the relocation target
-         *
-         * We know that version(source) <= version(target) and version(context) < version(target), version(context) = version(target), and
-         * version(target) < version(context) are all possibilities.
-         *
-         * The case of version(context) = version(target) causes no issues as in this case the knowledge of the in-sync and initializing
-         * shards the target receives from the master will be equal to the knowledge of the in-sync and initializing shards the target
-         * receives from the relocation source via the primary context.
-         *
-         * Let us now consider the case that version(context) < version(target). In this case, the active allocation IDs in the primary
-         * context can be a superset of the active allocation IDs contained in the applied cluster state. This is because no new shards can
-         * have been started as marking a shard as in-sync is blocked during relocation handoff. Note however that the relocation target
-         * itself will have been marked in-sync during recovery and therefore is an active allocation ID from the perspective of the primary
-         * context.
-         *
-         * Finally, we consider the case that version(target) < version(context). In this case, the active allocation IDs in the primary
-         * context can be a subset of the active allocation IDs contained the applied cluster state. This is again because no new shards can
-         * have been started. Moreover, existing active allocation IDs could have been removed from the cluster state.
-         *
-         * In each of these latter two cases, consider initializing shards that are contained in the primary context but not contained in
-         * the cluster state applied on the target.
-         *
-         * If version(context) < version(target) it means that the shard has been removed by a later cluster state update that is already
-         * applied on the target and we only need to ensure that we do not add it to the tracking map on the target. The call to
-         * GlobalCheckpointTracker#updateLocalCheckpoint(String, long) is a no-op for such shards and this is safe.
-         *
-         * If version(target) < version(context) it means that the shard has started initializing by a later cluster state update has not
-         * yet arrived on the target. However, there is a delay on recoveries before we ensure that version(source) <= version(target).
-         * Therefore, such a shard can never initialize from the relocation source and will have to await the handoff completing. As such,
-         * these shards are not problematic.
-         *
-         * Lastly, again in these two cases, what about initializing shards that are contained in cluster state applied on the target but
-         * not contained in the cluster state applied on the target.
-         *
-         * If version(context) < version(target) it means that a shard has started initializing by a later cluster state that is applied on
-         * the target but not yet known to what would be the relocation source. As recoveries are delayed at this time, these shards can not
-         * cause a problem and we do not mutate remove these shards from the tracking map, so we are safe here.
-         *
-         * If version(target) < version(context) it means that a shard has started initializing but was removed by a later cluster state. In
-         * this case, as the cluster state version on the primary context exceeds the applied cluster state version, we replace the tracking
-         * map and are safe here too.
-         */
-
-        assert StreamSupport
-                .stream(inSyncLocalCheckpoints.spliterator(), false)
-                .allMatch(e -> e.value == SequenceNumbersService.UNASSIGNED_SEQ_NO) : inSyncLocalCheckpoints;
-        assert StreamSupport
-                .stream(trackingLocalCheckpoints.spliterator(), false)
-                .allMatch(e -> e.value == SequenceNumbersService.UNASSIGNED_SEQ_NO) : trackingLocalCheckpoints;
-        assert pendingInSync.isEmpty() : pendingInSync;
-
-        if (primaryContext.clusterStateVersion() > appliedClusterStateVersion) {
-            final Set<String> activeAllocationIds =
-                    new HashSet<>(Arrays.asList(primaryContext.inSyncLocalCheckpoints().keys().toArray(String.class)));
-            final Set<String> initializingAllocationIds =
-                    new HashSet<>(Arrays.asList(primaryContext.trackingLocalCheckpoints().keys().toArray(String.class)));
-            updateAllocationIdsFromMaster(primaryContext.clusterStateVersion(), activeAllocationIds, initializingAllocationIds);
+        boolean increasedLocalCheckpoint = updateLocalCheckpoint(allocationId, lcps, localCheckpoint);
+        boolean pending = pendingInSync.contains(allocationId);
+        if (pending && lcps.localCheckpoint >= globalCheckpoint) {
+            pendingInSync.remove(allocationId);
+            pending = false;
+            lcps.inSync = true;
+            logger.trace("marked [{}] as in-sync", allocationId);
+            notifyAllWaiters();
         }
-
-        /*
-         * As we are updating the local checkpoints for the in-sync allocation IDs, the global checkpoint will advance in place; this means
-         * that we have to sort the incoming local checkpoints from smallest to largest lest we violate that the global checkpoint does not
-         * regress.
-         */
-
-        class AllocationIdLocalCheckpointPair {
-
-            private final String allocationId;
-
-            public String allocationId() {
-                return allocationId;
-            }
-
-            private final long localCheckpoint;
-
-            public long localCheckpoint() {
-                return localCheckpoint;
-            }
-
-            private AllocationIdLocalCheckpointPair(final String allocationId, final long localCheckpoint) {
-                this.allocationId = allocationId;
-                this.localCheckpoint = localCheckpoint;
-            }
-
+        if (increasedLocalCheckpoint && pending == false) {
+            updateGlobalCheckpointOnPrimary();
         }
+        assert invariant();
+    }
 
-        final List<AllocationIdLocalCheckpointPair> inSync =
-                StreamSupport
-                        .stream(primaryContext.inSyncLocalCheckpoints().spliterator(), false)
-                        .map(e -> new AllocationIdLocalCheckpointPair(e.key, e.value))
-                        .collect(Collectors.toList());
-        inSync.sort(Comparator.comparingLong(AllocationIdLocalCheckpointPair::localCheckpoint));
-
-        for (final AllocationIdLocalCheckpointPair cursor : inSync) {
-            assert cursor.localCheckpoint() >= globalCheckpoint
-                    : "local checkpoint [" + cursor.localCheckpoint() + "] "
-                    + "for allocation ID [" + cursor.allocationId() + "] "
-                    + "violates being at least the global checkpoint [" + globalCheckpoint + "]";
-            updateLocalCheckpoint(cursor.allocationId(), cursor.localCheckpoint());
-            if (trackingLocalCheckpoints.containsKey(cursor.allocationId())) {
-                moveAllocationIdFromTrackingToInSync(cursor.allocationId(), "relocation");
-                updateGlobalCheckpointOnPrimary();
+    /**
+     * Computes the global checkpoint based on the given local checkpoints. In case where there are entries preventing the
+     * computation to happen (for example due to blocking), it returns the fallback value.
+     */
+    private static long computeGlobalCheckpoint(final Set<String> pendingInSync, final Collection<LocalCheckpointState> localCheckpoints,
+                                                final long fallback) {
+        long minLocalCheckpoint = Long.MAX_VALUE;
+        if (pendingInSync.isEmpty() == false) {
+            return fallback;
+        }
+        for (final LocalCheckpointState lcps : localCheckpoints) {
+            if (lcps.inSync) {
+                if (lcps.localCheckpoint == SequenceNumbersService.UNASSIGNED_SEQ_NO) {
+                    // unassigned in-sync replica
+                    return fallback;
+                } else if (lcps.localCheckpoint == SequenceNumbersService.PRE_60_NODE_LOCAL_CHECKPOINT) {
+                    // 5.x replica, ignore for global checkpoint calculation
+                } else {
+                    minLocalCheckpoint = Math.min(lcps.localCheckpoint, minLocalCheckpoint);
+                }
             }
         }
+        assert minLocalCheckpoint != Long.MAX_VALUE;
+        return minLocalCheckpoint;
+    }
 
-        for (final ObjectLongCursor<String> cursor : primaryContext.trackingLocalCheckpoints()) {
-            updateLocalCheckpoint(cursor.key, cursor.value);
+    /**
+     * Scans through the currently known local checkpoint and updates the global checkpoint accordingly.
+     */
+    private synchronized void updateGlobalCheckpointOnPrimary() {
+        assert primaryMode;
+        final long computedGlobalCheckpoint = computeGlobalCheckpoint(pendingInSync, localCheckpoints.values(), globalCheckpoint);
+        assert computedGlobalCheckpoint >= globalCheckpoint : "new global checkpoint [" + computedGlobalCheckpoint +
+            "] is lower than previous one [" + globalCheckpoint + "]";
+        if (globalCheckpoint != computedGlobalCheckpoint) {
+            logger.trace("global checkpoint updated to [{}]", computedGlobalCheckpoint);
+            globalCheckpoint = computedGlobalCheckpoint;
         }
     }
 
     /**
-     * Marks the shard with the provided allocation ID as in-sync with the primary shard. This method will block until the local checkpoint
-     * on the specified shard advances above the current global checkpoint.
-     *
-     * @param allocationId    the allocation ID of the shard to mark as in-sync
-     * @param localCheckpoint the current local checkpoint on the shard
-     *
-     * @throws InterruptedException if the thread is interrupted waiting for the local checkpoint on the shard to advance
+     * Initiates a relocation handoff and returns the corresponding primary context.
      */
-    public synchronized void markAllocationIdAsInSync(final String allocationId, final long localCheckpoint) throws InterruptedException {
-        if (sealed) {
-            throw new IllegalStateException("global checkpoint tracker is sealed");
-        }
-        if (!trackingLocalCheckpoints.containsKey(allocationId)) {
-            /*
-             * This can happen if the recovery target has been failed and the cluster state update from the master has triggered removing
-             * this allocation ID from the tracking map but this recovery thread has not yet been made aware that the recovery is
-             * cancelled.
-             */
-            return;
+    public synchronized PrimaryContext startRelocationHandoff() {
+        assert invariant();
+        assert primaryMode;
+        assert handoffInProgress == false;
+        assert pendingInSync.isEmpty() : "relocation handoff started while there are still shard copies pending in-sync: " + pendingInSync;
+        handoffInProgress = true;
+        // copy clusterStateVersion and localCheckpoints and return
+        // all the entries from localCheckpoints that are inSync: the reason we don't need to care about initializing non-insync entries
+        // is that they will have to undergo a recovery attempt on the relocation target, and will hence be supplied by the cluster state
+        // update on the relocation target once relocation completes). We could alternatively also copy the map as-is (it’s safe), and it
+        // would be cleaned up on the target by cluster state updates.
+        Map<String, LocalCheckpointState> localCheckpointsCopy = new HashMap<>();
+        for (Map.Entry<String, LocalCheckpointState> entry : localCheckpoints.entrySet()) {
+            localCheckpointsCopy.put(entry.getKey(), entry.getValue().copy());
         }
+        assert invariant();
+        return new PrimaryContext(appliedClusterStateVersion, localCheckpointsCopy);
+    }
 
-        updateLocalCheckpoint(allocationId, localCheckpoint, trackingLocalCheckpoints, "tracking");
-        if (!pendingInSync.add(allocationId)) {
-            throw new IllegalStateException("there is already a pending sync in progress for allocation ID [" + allocationId + "]");
-        }
-        try {
-            waitForAllocationIdToBeInSync(allocationId);
-        } finally {
-            pendingInSync.remove(allocationId);
-            updateGlobalCheckpointOnPrimary();
-        }
+    /**
+     * Fails a relocation handoff attempt.
+     */
+    public synchronized void abortRelocationHandoff() {
+        assert invariant();
+        assert primaryMode;
+        assert handoffInProgress;
+        handoffInProgress = false;
+        assert invariant();
+    }
+
+    /**
+     * Marks a relocation handoff attempt as successful. Moves the tracker into replica mode.
+     */
+    public synchronized void completeRelocationHandoff() {
+        assert invariant();
+        assert primaryMode;
+        assert handoffInProgress;
+        primaryMode = false;
+        handoffInProgress = false;
+        // forget all checkpoint information
+        localCheckpoints.values().stream().forEach(lcps -> {
+            if (lcps.localCheckpoint != SequenceNumbersService.UNASSIGNED_SEQ_NO &&
+                lcps.localCheckpoint != SequenceNumbersService.PRE_60_NODE_LOCAL_CHECKPOINT) {
+                lcps.localCheckpoint = SequenceNumbersService.UNASSIGNED_SEQ_NO;
+            }
+        });
+        assert invariant();
     }
 
     /**
-     * Wait for knowledge of the local checkpoint for the specified allocation ID to advance to the global checkpoint. Global checkpoint
-     * advancement is blocked while there are any allocation IDs waiting to catch up to the global checkpoint.
+     * Activates the global checkpoint tracker in primary mode (see {@link #primaryMode}. Called on primary relocation target during
+     * primary relocation handoff.
      *
-     * @param allocationId the allocation ID
-     * @throws InterruptedException if this thread was interrupted before of during waiting
+     * @param primaryContext the primary context used to initialize the state
      */
-    private synchronized void waitForAllocationIdToBeInSync(final String allocationId) throws InterruptedException {
-        while (true) {
-            /*
-             * If the allocation has been cancelled and so removed from the tracking map from a cluster state update from the master it
-             * means that this recovery will be cancelled; we are here on a cancellable recovery thread and so this thread will throw an
-             * interrupted exception as soon as it tries to wait on the monitor.
-             */
-            final long current = trackingLocalCheckpoints.getOrDefault(allocationId, Long.MIN_VALUE);
-            if (current >= globalCheckpoint) {
-                /*
-                 * This is prematurely adding the allocation ID to the in-sync map as at this point recovery is not yet finished and could
-                 * still abort. At this point we will end up with a shard in the in-sync map holding back the global checkpoint because the
-                 * shard never recovered and we would have to wait until either the recovery retries and completes successfully, or the
-                 * master fails the shard and issues a cluster state update that removes the shard from the set of active allocation IDs.
-                 */
-                moveAllocationIdFromTrackingToInSync(allocationId, "recovery");
-                break;
+    public synchronized void activateWithPrimaryContext(PrimaryContext primaryContext) {
+        assert invariant();
+        assert primaryMode == false;
+        final Runnable runAfter = getMasterUpdateOperationFromCurrentState();
+        primaryMode = true;
+        // capture current state to possibly replay missed cluster state update
+        appliedClusterStateVersion = primaryContext.clusterStateVersion();
+        localCheckpoints.clear();
+        for (Map.Entry<String, LocalCheckpointState> entry : primaryContext.localCheckpoints.entrySet()) {
+            localCheckpoints.put(entry.getKey(), entry.getValue().copy());
+        }
+        updateGlobalCheckpointOnPrimary();
+        // reapply missed cluster state update
+        // note that if there was no cluster state update between start of the engine of this shard and the call to
+        // initializeWithPrimaryContext, we might still have missed a cluster state update. This is best effort.
+        runAfter.run();
+        assert invariant();
+    }
+
+    private Runnable getMasterUpdateOperationFromCurrentState() {
+        assert primaryMode == false;
+        final long lastAppliedClusterStateVersion = appliedClusterStateVersion;
+        final Set<String> inSyncAllocationIds = new HashSet<>();
+        final Set<String> initializingAllocationIds = new HashSet<>();
+        final Set<String> pre60AllocationIds = new HashSet<>();
+        localCheckpoints.entrySet().forEach(entry -> {
+            if (entry.getValue().inSync) {
+                inSyncAllocationIds.add(entry.getKey());
             } else {
-                waitForLocalCheckpointToAdvance();
+                initializingAllocationIds.add(entry.getKey());
             }
-        }
+            if (entry.getValue().getLocalCheckpoint() == SequenceNumbersService.PRE_60_NODE_LOCAL_CHECKPOINT) {
+                pre60AllocationIds.add(entry.getKey());
+            }
+        });
+        return () -> updateFromMaster(lastAppliedClusterStateVersion, inSyncAllocationIds, initializingAllocationIds, pre60AllocationIds);
     }
 
     /**
-     * Moves a tracking allocation ID to be in-sync. This can occur when a shard is recovering from the primary and its local checkpoint has
-     * advanced past the global checkpoint, or during relocation hand-off when the relocation target learns of an in-sync shard from the
-     * relocation source.
-     *
-     * @param allocationId the allocation ID to move
-     * @param reason       the reason for the transition
+     * Whether the are shards blocking global checkpoint advancement. Used by tests.
      */
-    private synchronized void moveAllocationIdFromTrackingToInSync(final String allocationId, final String reason) {
-        assert trackingLocalCheckpoints.containsKey(allocationId);
-        final long current = trackingLocalCheckpoints.remove(allocationId);
-        inSyncLocalCheckpoints.put(allocationId, current);
-        logger.trace("marked [{}] as in-sync with local checkpoint [{}] due to [{}]", allocationId, current, reason);
+    public synchronized boolean pendingInSync() {
+        assert primaryMode;
+        return pendingInSync.isEmpty() == false;
     }
 
     /**
-     * Wait for the local checkpoint to advance to the global checkpoint.
-     *
-     * @throws InterruptedException if this thread was interrupted before of during waiting
+     * Returns the local checkpoint information tracked for a specific shard. Used by tests.
      */
-    @SuppressForbidden(reason = "Object#wait for local checkpoint advancement")
-    private synchronized void waitForLocalCheckpointToAdvance() throws InterruptedException {
-        this.wait();
+    public synchronized LocalCheckpointState getTrackedLocalCheckpointForShard(String allocationId) {
+        assert primaryMode;
+        return localCheckpoints.get(allocationId);
     }
 
     /**
-     * Check if there are any recoveries pending in-sync.
-     *
-     * @return true if there is at least one shard pending in-sync, otherwise false
+     * Notify all threads waiting on the monitor on this tracker. These threads should be waiting for the local checkpoint on a specific
+     * allocation ID to catch up to the global checkpoint.
      */
-    boolean pendingInSync() {
-        return !pendingInSync.isEmpty();
+    @SuppressForbidden(reason = "Object#notifyAll waiters for local checkpoint advancement")
+    private synchronized void notifyAllWaiters() {
+        this.notifyAll();
     }
 
     /**
-     * Check if the tracker is sealed.
+     * Wait for the local checkpoint to advance to the global checkpoint.
      *
-     * @return true if the tracker is sealed, otherwise false.
+     * @throws InterruptedException if this thread was interrupted before of during waiting
      */
-    boolean sealed() {
-        return sealed;
+    @SuppressForbidden(reason = "Object#wait for local checkpoint advancement")
+    private synchronized void waitForLocalCheckpointToAdvance() throws InterruptedException {
+        this.wait();
     }
 
     /**
-     * Returns the local checkpoint for the shard with the specified allocation ID, or {@link SequenceNumbersService#UNASSIGNED_SEQ_NO} if
-     * the shard is not in-sync.
-     *
-     * @param allocationId the allocation ID of the shard to obtain the local checkpoint for
-     * @return the local checkpoint, or {@link SequenceNumbersService#UNASSIGNED_SEQ_NO}
+     * Represents the sequence number component of the primary context. This is the knowledge on the primary of the in-sync and initializing
+     * shards and their local checkpoints.
      */
-    synchronized long getLocalCheckpointForAllocationId(final String allocationId) {
-        if (inSyncLocalCheckpoints.containsKey(allocationId)) {
-            return inSyncLocalCheckpoints.get(allocationId);
+    public static class PrimaryContext implements Writeable {
+
+        private final long clusterStateVersion;
+        private final Map<String, LocalCheckpointState> localCheckpoints;
+
+        public PrimaryContext(long clusterStateVersion, Map<String, LocalCheckpointState> localCheckpoints) {
+            this.clusterStateVersion = clusterStateVersion;
+            this.localCheckpoints = localCheckpoints;
+        }
+
+        public PrimaryContext(StreamInput in) throws IOException {
+            clusterStateVersion = in.readVLong();
+            localCheckpoints = in.readMap(StreamInput::readString, LocalCheckpointState::new);
         }
-        return SequenceNumbersService.UNASSIGNED_SEQ_NO;
-    }
 
+        public long clusterStateVersion() {
+            return clusterStateVersion;
+        }
+
+        public Map<String, LocalCheckpointState> getLocalCheckpoints() {
+            return localCheckpoints;
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            out.writeVLong(clusterStateVersion);
+            out.writeMap(localCheckpoints, (streamOutput, s) -> out.writeString(s), (streamOutput, lcps) -> lcps.writeTo(out));
+        }
+
+        @Override
+        public String toString() {
+            return "PrimaryContext{" +
+                    "clusterStateVersion=" + clusterStateVersion +
+                    ", localCheckpoints=" + localCheckpoints +
+                    '}';
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+
+            PrimaryContext that = (PrimaryContext) o;
+
+            if (clusterStateVersion != that.clusterStateVersion) return false;
+            return localCheckpoints.equals(that.localCheckpoints);
+        }
+
+        @Override
+        public int hashCode() {
+            int result = (int) (clusterStateVersion ^ (clusterStateVersion >>> 32));
+            result = 31 * result + localCheckpoints.hashCode();
+            return result;
+        }
+    }
 }

+ 54 - 14
core/src/main/java/org/elasticsearch/index/seqno/SequenceNumbersService.java

@@ -21,7 +21,6 @@ package org.elasticsearch.index.seqno;
 
 import org.elasticsearch.index.IndexSettings;
 import org.elasticsearch.index.shard.AbstractIndexShardComponent;
-import org.elasticsearch.index.shard.PrimaryContext;
 import org.elasticsearch.index.shard.ShardId;
 
 import java.util.Set;
@@ -41,6 +40,11 @@ public class SequenceNumbersService extends AbstractIndexShardComponent {
      */
     public static final long NO_OPS_PERFORMED = -1L;
 
+    /**
+     * Represents a local checkpoint coming from a pre-6.0 node
+     */
+    public static final long PRE_60_NODE_LOCAL_CHECKPOINT = -3L;
+
     private final LocalCheckpointTracker localCheckpointTracker;
     private final GlobalCheckpointTracker globalCheckpointTracker;
 
@@ -135,6 +139,16 @@ public class SequenceNumbersService extends AbstractIndexShardComponent {
         globalCheckpointTracker.updateLocalCheckpoint(allocationId, checkpoint);
     }
 
+    /**
+     * Called when the recovery process for a shard is ready to open the engine on the target shard.
+     * See {@link GlobalCheckpointTracker#initiateTracking(String)} for details.
+     *
+     * @param allocationId  the allocation ID of the shard for which recovery was initiated
+     */
+    public void initiateTracking(final String allocationId) {
+        globalCheckpointTracker.initiateTracking(allocationId);
+    }
+
     /**
      * Marks the shard with the provided allocation ID as in-sync with the primary shard. See
      * {@link GlobalCheckpointTracker#markAllocationIdAsInSync(String, long)} for additional details.
@@ -173,26 +187,45 @@ public class SequenceNumbersService extends AbstractIndexShardComponent {
         globalCheckpointTracker.updateGlobalCheckpointOnReplica(globalCheckpoint);
     }
 
+    /**
+     * Returns the local checkpoint information tracked for a specific shard. Used by tests.
+     */
+    public synchronized long getTrackedLocalCheckpointForShard(final String allocationId) {
+        return globalCheckpointTracker.getTrackedLocalCheckpointForShard(allocationId).getLocalCheckpoint();
+    }
+
+    /**
+     * Activates the global checkpoint tracker in primary mode (see {@link GlobalCheckpointTracker#primaryMode}.
+     * Called on primary activation or promotion.
+     */
+    public void activatePrimaryMode(final String allocationId, final long localCheckpoint) {
+        globalCheckpointTracker.activatePrimaryMode(allocationId, localCheckpoint);
+    }
+
     /**
      * Notifies the service of the current allocation IDs in the cluster state. See
-     * {@link GlobalCheckpointTracker#updateAllocationIdsFromMaster(long, Set, Set)} for details.
+     * {@link GlobalCheckpointTracker#updateFromMaster(long, Set, Set, Set)} for details.
      *
      * @param applyingClusterStateVersion the cluster state version being applied when updating the allocation IDs from the master
-     * @param activeAllocationIds         the allocation IDs of the currently active shard copies
+     * @param inSyncAllocationIds         the allocation IDs of the currently in-sync shard copies
      * @param initializingAllocationIds   the allocation IDs of the currently initializing shard copies
+     * @param pre60AllocationIds          the allocation IDs of shards that are allocated to pre-6.0 nodes
      */
     public void updateAllocationIdsFromMaster(
-            final long applyingClusterStateVersion, final Set<String> activeAllocationIds, final Set<String> initializingAllocationIds) {
-        globalCheckpointTracker.updateAllocationIdsFromMaster(applyingClusterStateVersion, activeAllocationIds, initializingAllocationIds);
+            final long applyingClusterStateVersion, final Set<String> inSyncAllocationIds, final Set<String> initializingAllocationIds,
+            final Set<String> pre60AllocationIds) {
+        globalCheckpointTracker.updateFromMaster(applyingClusterStateVersion, inSyncAllocationIds, initializingAllocationIds,
+            pre60AllocationIds);
     }
 
     /**
-     * Updates the known allocation IDs and the local checkpoints for the corresponding allocations from a primary relocation source.
+     * Activates the global checkpoint tracker in primary mode (see {@link GlobalCheckpointTracker#primaryMode}.
+     * Called on primary relocation target during primary relocation handoff.
      *
-     * @param primaryContext the sequence number context
+     * @param primaryContext the primary context used to initialize the state
      */
-    public void updateAllocationIdsFromPrimaryContext(final PrimaryContext primaryContext) {
-        globalCheckpointTracker.updateAllocationIdsFromPrimaryContext(primaryContext);
+    public void activateWithPrimaryContext(final GlobalCheckpointTracker.PrimaryContext primaryContext) {
+        globalCheckpointTracker.activateWithPrimaryContext(primaryContext);
     }
 
     /**
@@ -209,15 +242,22 @@ public class SequenceNumbersService extends AbstractIndexShardComponent {
      *
      * @return the primary context
      */
-    public PrimaryContext primaryContext() {
-        return globalCheckpointTracker.primaryContext();
+    public GlobalCheckpointTracker.PrimaryContext startRelocationHandoff() {
+        return globalCheckpointTracker.startRelocationHandoff();
+    }
+
+    /**
+     * Marks a relocation handoff attempt as successful. Moves the tracker into replica mode.
+     */
+    public void completeRelocationHandoff() {
+        globalCheckpointTracker.completeRelocationHandoff();
     }
 
     /**
-     * Releases a previously acquired primary context.
+     * Fails a relocation handoff attempt.
      */
-    public void releasePrimaryContext() {
-        globalCheckpointTracker.releasePrimaryContext();
+    public void abortRelocationHandoff() {
+        globalCheckpointTracker.abortRelocationHandoff();
     }
 
 }

+ 147 - 127
core/src/main/java/org/elasticsearch/index/shard/IndexShard.java

@@ -43,6 +43,7 @@ import org.elasticsearch.action.admin.indices.forcemerge.ForceMergeRequest;
 import org.elasticsearch.action.admin.indices.upgrade.post.UpgradeRequest;
 import org.elasticsearch.cluster.metadata.IndexMetaData;
 import org.elasticsearch.cluster.metadata.MappingMetaData;
+import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.routing.RecoverySource;
 import org.elasticsearch.cluster.routing.RecoverySource.SnapshotRecoverySource;
 import org.elasticsearch.cluster.routing.ShardRouting;
@@ -102,6 +103,7 @@ import org.elasticsearch.index.recovery.RecoveryStats;
 import org.elasticsearch.index.refresh.RefreshStats;
 import org.elasticsearch.index.search.stats.SearchStats;
 import org.elasticsearch.index.search.stats.ShardSearchStats;
+import org.elasticsearch.index.seqno.GlobalCheckpointTracker;
 import org.elasticsearch.index.seqno.SeqNoStats;
 import org.elasticsearch.index.seqno.SequenceNumbersService;
 import org.elasticsearch.index.shard.PrimaryReplicaSyncer.ResyncTask;
@@ -362,140 +364,141 @@ public class IndexShard extends AbstractIndexShardComponent implements IndicesCl
                                  final long newPrimaryTerm,
                                  final CheckedBiConsumer<IndexShard, ActionListener<ResyncTask>, IOException> primaryReplicaSyncer,
                                  final long applyingClusterStateVersion,
-                                 final Set<String> activeAllocationIds,
-                                 final Set<String> initializingAllocationIds) throws IOException {
+                                 final Set<String> inSyncAllocationIds,
+                                 final Set<String> initializingAllocationIds,
+                                 final Set<String> pre60AllocationIds) throws IOException {
         final ShardRouting currentRouting;
         synchronized (mutex) {
             currentRouting = this.shardRouting;
-            updateRoutingEntry(newRouting);
 
-            if (shardRouting.primary()) {
-                updatePrimaryTerm(newPrimaryTerm, primaryReplicaSyncer);
+            if (!newRouting.shardId().equals(shardId())) {
+                throw new IllegalArgumentException("Trying to set a routing entry with shardId " + newRouting.shardId() + " on a shard with shardId " + shardId());
+            }
+            if ((currentRouting == null || newRouting.isSameAllocation(currentRouting)) == false) {
+                throw new IllegalArgumentException("Trying to set a routing entry with a different allocation. Current " + currentRouting + ", new " + newRouting);
+            }
+            if (currentRouting != null && currentRouting.primary() && newRouting.primary() == false) {
+                throw new IllegalArgumentException("illegal state: trying to move shard from primary mode to replica mode. Current "
+                    + currentRouting + ", new " + newRouting);
+            }
 
+            if (newRouting.primary()) {
                 final Engine engine = getEngineOrNull();
-                // if the engine is not yet started, we are not ready yet and can just ignore this
                 if (engine != null) {
-                    engine.seqNoService().updateAllocationIdsFromMaster(
-                            applyingClusterStateVersion, activeAllocationIds, initializingAllocationIds);
+                    engine.seqNoService().updateAllocationIdsFromMaster(applyingClusterStateVersion, inSyncAllocationIds, initializingAllocationIds, pre60AllocationIds);
                 }
             }
-        }
-        if (currentRouting != null && currentRouting.active() == false && newRouting.active()) {
-            indexEventListener.afterIndexShardStarted(this);
-        }
-        if (newRouting.equals(currentRouting) == false) {
-            indexEventListener.shardRoutingChanged(this, currentRouting, newRouting);
-        }
-    }
 
-    private void updateRoutingEntry(ShardRouting newRouting) throws IOException {
-        assert Thread.holdsLock(mutex);
-        final ShardRouting currentRouting = this.shardRouting;
+            if (state == IndexShardState.POST_RECOVERY && newRouting.active()) {
+                assert currentRouting.active() == false : "we are in POST_RECOVERY, but our shard routing is active " + currentRouting;
+                // we want to refresh *before* we move to internal STARTED state
+                try {
+                    getEngine().refresh("cluster_state_started");
+                } catch (Exception e) {
+                    logger.debug("failed to refresh due to move to cluster wide started", e);
+                }
 
-        if (!newRouting.shardId().equals(shardId())) {
-            throw new IllegalArgumentException("Trying to set a routing entry with shardId " + newRouting.shardId() + " on a shard with shardId " + shardId());
-        }
-        if ((currentRouting == null || newRouting.isSameAllocation(currentRouting)) == false) {
-            throw new IllegalArgumentException("Trying to set a routing entry with a different allocation. Current " + currentRouting + ", new " + newRouting);
-        }
-        if (currentRouting != null && currentRouting.primary() && newRouting.primary() == false) {
-            throw new IllegalArgumentException("illegal state: trying to move shard from primary mode to replica mode. Current "
-                + currentRouting + ", new " + newRouting);
-        }
+                if (newRouting.primary()) {
+                    final DiscoveryNode recoverySourceNode = recoveryState.getSourceNode();
+                    if (currentRouting.isRelocationTarget() == false || recoverySourceNode.getVersion().before(Version.V_6_0_0_alpha1)) {
+                        // there was no primary context hand-off in < 6.0.0, need to manually activate the shard
+                        getEngine().seqNoService().activatePrimaryMode(currentRouting.allocationId().getId(), getEngine().seqNoService().getLocalCheckpoint());
+                    }
+                }
 
-        if (state == IndexShardState.POST_RECOVERY && newRouting.active()) {
-            assert currentRouting.active() == false : "we are in POST_RECOVERY, but our shard routing is active " + currentRouting;
-            // we want to refresh *before* we move to internal STARTED state
-            try {
-                getEngine().refresh("cluster_state_started");
-            } catch (Exception e) {
-                logger.debug("failed to refresh due to move to cluster wide started", e);
+                changeState(IndexShardState.STARTED, "global state is [" + newRouting.state() + "]");
+            } else if (state == IndexShardState.RELOCATED &&
+                (newRouting.relocating() == false || newRouting.equalsIgnoringMetaData(currentRouting) == false)) {
+                // if the shard is marked as RELOCATED we have to fail when any changes in shard routing occur (e.g. due to recovery
+                // failure / cancellation). The reason is that at the moment we cannot safely move back to STARTED without risking two
+                // active primaries.
+                throw new IndexShardRelocatedException(shardId(), "Shard is marked as relocated, cannot safely move to state " + newRouting.state());
             }
-            changeState(IndexShardState.STARTED, "global state is [" + newRouting.state() + "]");
-        } else if (state == IndexShardState.RELOCATED &&
-            (newRouting.relocating() == false || newRouting.equalsIgnoringMetaData(currentRouting) == false)) {
-            // if the shard is marked as RELOCATED we have to fail when any changes in shard routing occur (e.g. due to recovery
-            // failure / cancellation). The reason is that at the moment we cannot safely move back to STARTED without risking two
-            // active primaries.
-            throw new IndexShardRelocatedException(shardId(), "Shard is marked as relocated, cannot safely move to state " + newRouting.state());
-        }
-        assert newRouting.active() == false || state == IndexShardState.STARTED || state == IndexShardState.RELOCATED ||
-            state == IndexShardState.CLOSED :
-            "routing is active, but local shard state isn't. routing: " + newRouting + ", local state: " + state;
-        this.shardRouting = newRouting;
-        persistMetadata(path, indexSettings, newRouting, currentRouting, logger);
-    }
-
-    private void updatePrimaryTerm(
-            final long newPrimaryTerm, final CheckedBiConsumer<IndexShard, ActionListener<ResyncTask>, IOException> primaryReplicaSyncer) {
-        assert Thread.holdsLock(mutex);
-        assert shardRouting.primary() : "primary term can only be explicitly updated on a primary shard";
-        if (newPrimaryTerm != primaryTerm) {
-            /* Note that due to cluster state batching an initializing primary shard term can failed and re-assigned
-             * in one state causing it's term to be incremented. Note that if both current shard state and new
-             * shard state are initializing, we could replace the current shard and reinitialize it. It is however
-             * possible that this shard is being started. This can happen if:
-             * 1) Shard is post recovery and sends shard started to the master
-             * 2) Node gets disconnected and rejoins
-             * 3) Master assigns the shard back to the node
-             * 4) Master processes the shard started and starts the shard
-             * 5) The node process the cluster state where the shard is both started and primary term is incremented.
-             *
-             * We could fail the shard in that case, but this will cause it to be removed from the insync allocations list
-             * potentially preventing re-allocation.
-             */
-            assert shardRouting.initializing() == false :
-                "a started primary shard should never update its term; "
-                    + "shard " + shardRouting + ", "
-                    + "current term [" + primaryTerm + "], "
-                    + "new term [" + newPrimaryTerm + "]";
-            assert newPrimaryTerm > primaryTerm :
-                "primary terms can only go up; current term [" + primaryTerm + "], new term [" + newPrimaryTerm + "]";
-            /*
-             * Before this call returns, we are guaranteed that all future operations are delayed and so this happens before we
-             * increment the primary term. The latch is needed to ensure that we do not unblock operations before the primary term is
-             * incremented.
-             */
-            final CountDownLatch latch = new CountDownLatch(1);
-            // to prevent primary relocation handoff while resync is not completed
-            boolean resyncStarted = primaryReplicaResyncInProgress.compareAndSet(false, true);
-            if (resyncStarted == false) {
-                throw new IllegalStateException("cannot start resync while it's already in progress");
-            }
-            indexShardOperationPermits.asyncBlockOperations(
-                30,
-                TimeUnit.MINUTES,
-                () -> {
-                    latch.await();
-                    try {
-                        getEngine().fillSeqNoGaps(newPrimaryTerm);
-                        primaryReplicaSyncer.accept(IndexShard.this, new ActionListener<ResyncTask>() {
-                            @Override
-                            public void onResponse(ResyncTask resyncTask) {
-                                logger.info("primary-replica resync completed with {} operations",
-                                    resyncTask.getResyncedOperations());
-                                boolean resyncCompleted = primaryReplicaResyncInProgress.compareAndSet(true, false);
-                                assert resyncCompleted : "primary-replica resync finished but was not started";
-                            }
+            assert newRouting.active() == false || state == IndexShardState.STARTED || state == IndexShardState.RELOCATED ||
+                state == IndexShardState.CLOSED :
+                "routing is active, but local shard state isn't. routing: " + newRouting + ", local state: " + state;
+            this.shardRouting = newRouting;
+            persistMetadata(path, indexSettings, newRouting, currentRouting, logger);
 
-                            @Override
-                            public void onFailure(Exception e) {
-                                boolean resyncCompleted = primaryReplicaResyncInProgress.compareAndSet(true, false);
-                                assert resyncCompleted : "primary-replica resync finished but was not started";
-                                if (state == IndexShardState.CLOSED) {
-                                    // ignore, shutting down
-                                } else {
-                                    failShard("exception during primary-replica resync", e);
-                                }
-                            }
-                        });
-                    } catch (final AlreadyClosedException e) {
-                        // okay, the index was deleted
+            if (shardRouting.primary()) {
+                if (newPrimaryTerm != primaryTerm) {
+                    assert currentRouting.primary() == false : "term is only increased as part of primary promotion";
+                    /* Note that due to cluster state batching an initializing primary shard term can failed and re-assigned
+                     * in one state causing it's term to be incremented. Note that if both current shard state and new
+                     * shard state are initializing, we could replace the current shard and reinitialize it. It is however
+                     * possible that this shard is being started. This can happen if:
+                     * 1) Shard is post recovery and sends shard started to the master
+                     * 2) Node gets disconnected and rejoins
+                     * 3) Master assigns the shard back to the node
+                     * 4) Master processes the shard started and starts the shard
+                     * 5) The node process the cluster state where the shard is both started and primary term is incremented.
+                     *
+                     * We could fail the shard in that case, but this will cause it to be removed from the insync allocations list
+                     * potentially preventing re-allocation.
+                     */
+                    assert shardRouting.initializing() == false :
+                        "a started primary shard should never update its term; "
+                            + "shard " + shardRouting + ", "
+                            + "current term [" + primaryTerm + "], "
+                            + "new term [" + newPrimaryTerm + "]";
+                    assert newPrimaryTerm > primaryTerm :
+                        "primary terms can only go up; current term [" + primaryTerm + "], new term [" + newPrimaryTerm + "]";
+                    /*
+                     * Before this call returns, we are guaranteed that all future operations are delayed and so this happens before we
+                     * increment the primary term. The latch is needed to ensure that we do not unblock operations before the primary term is
+                     * incremented.
+                     */
+                    final CountDownLatch latch = new CountDownLatch(1);
+                    // to prevent primary relocation handoff while resync is not completed
+                    boolean resyncStarted = primaryReplicaResyncInProgress.compareAndSet(false, true);
+                    if (resyncStarted == false) {
+                        throw new IllegalStateException("cannot start resync while it's already in progress");
                     }
-                },
-                e -> failShard("exception during primary term transition", e));
-            primaryTerm = newPrimaryTerm;
-            latch.countDown();
+                    indexShardOperationPermits.asyncBlockOperations(
+                        30,
+                        TimeUnit.MINUTES,
+                        () -> {
+                            latch.await();
+                            try {
+                                getEngine().fillSeqNoGaps(newPrimaryTerm);
+                                updateLocalCheckpointForShard(currentRouting.allocationId().getId(),
+                                    getEngine().seqNoService().getLocalCheckpoint());
+                                primaryReplicaSyncer.accept(this, new ActionListener<ResyncTask>() {
+                                    @Override
+                                    public void onResponse(ResyncTask resyncTask) {
+                                        logger.info("primary-replica resync completed with {} operations",
+                                            resyncTask.getResyncedOperations());
+                                        boolean resyncCompleted = primaryReplicaResyncInProgress.compareAndSet(true, false);
+                                        assert resyncCompleted : "primary-replica resync finished but was not started";
+                                    }
+
+                                    @Override
+                                    public void onFailure(Exception e) {
+                                        boolean resyncCompleted = primaryReplicaResyncInProgress.compareAndSet(true, false);
+                                        assert resyncCompleted : "primary-replica resync finished but was not started";
+                                        if (state == IndexShardState.CLOSED) {
+                                            // ignore, shutting down
+                                        } else {
+                                            failShard("exception during primary-replica resync", e);
+                                        }
+                                    }
+                                });
+                            } catch (final AlreadyClosedException e) {
+                                // okay, the index was deleted
+                            }
+                        },
+                        e -> failShard("exception during primary term transition", e));
+                    getEngine().seqNoService().activatePrimaryMode(currentRouting.allocationId().getId(), getEngine().seqNoService().getLocalCheckpoint());
+                    primaryTerm = newPrimaryTerm;
+                    latch.countDown();
+                }
+            }
+        }
+        if (currentRouting != null && currentRouting.active() == false && newRouting.active()) {
+            indexEventListener.afterIndexShardStarted(this);
+        }
+        if (newRouting.equals(currentRouting) == false) {
+            indexEventListener.shardRoutingChanged(this, currentRouting, newRouting);
         }
     }
 
@@ -537,7 +540,7 @@ public class IndexShard extends AbstractIndexShardComponent implements IndicesCl
      * @throws InterruptedException            if blocking operations is interrupted
      */
     public void relocated(
-            final String reason, final Consumer<PrimaryContext> consumer) throws IllegalIndexShardStateException, InterruptedException {
+            final String reason, final Consumer<GlobalCheckpointTracker.PrimaryContext> consumer) throws IllegalIndexShardStateException, InterruptedException {
         assert shardRouting.primary() : "only primaries can be marked as relocated: " + shardRouting;
         try {
             indexShardOperationPermits.blockOperations(30, TimeUnit.MINUTES, () -> {
@@ -549,16 +552,17 @@ public class IndexShard extends AbstractIndexShardComponent implements IndicesCl
                  * network operation. Doing this under the mutex can implicitly block the cluster state update thread on network operations.
                  */
                 verifyRelocatingState();
-                final PrimaryContext primaryContext = getEngine().seqNoService().primaryContext();
+                final GlobalCheckpointTracker.PrimaryContext primaryContext = getEngine().seqNoService().startRelocationHandoff();
                 try {
                     consumer.accept(primaryContext);
                     synchronized (mutex) {
                         verifyRelocatingState();
                         changeState(IndexShardState.RELOCATED, reason);
                     }
+                    getEngine().seqNoService().completeRelocationHandoff();
                 } catch (final Exception e) {
                     try {
-                        getEngine().seqNoService().releasePrimaryContext();
+                        getEngine().seqNoService().abortRelocationHandoff();
                     } catch (final Exception inner) {
                         e.addSuppressed(inner);
                     }
@@ -1644,6 +1648,22 @@ public class IndexShard extends AbstractIndexShardComponent implements IndicesCl
         getEngine().seqNoService().waitForOpsToComplete(seqNo);
     }
 
+    /**
+     * Called when the recovery process for a shard is ready to open the engine on the target shard.
+     * See {@link GlobalCheckpointTracker#initiateTracking(String)} for details.
+     *
+     * @param allocationId  the allocation ID of the shard for which recovery was initiated
+     */
+    public void initiateTracking(final String allocationId) {
+        verifyPrimary();
+        getEngine().seqNoService().initiateTracking(allocationId);
+        /*
+         * We could have blocked so long waiting for the replica to catch up that we fell idle and there will not be a background sync to
+         * the replica; mark our self as active to force a future background sync.
+         */
+        active.compareAndSet(false, true);
+    }
+
     /**
      * Marks the shard with the provided allocation ID as in-sync with the primary shard. See
      * {@link org.elasticsearch.index.seqno.GlobalCheckpointTracker#markAllocationIdAsInSync(String, long)}
@@ -1710,13 +1730,13 @@ public class IndexShard extends AbstractIndexShardComponent implements IndicesCl
      *
      * @param primaryContext the sequence number context
      */
-    public void updateAllocationIdsFromPrimaryContext(final PrimaryContext primaryContext) {
+    public void activateWithPrimaryContext(final GlobalCheckpointTracker.PrimaryContext primaryContext) {
         verifyPrimary();
         assert shardRouting.isRelocationTarget() : "only relocation target can update allocation IDs from primary context: " + shardRouting;
-        final Engine engine = getEngineOrNull();
-        if (engine != null) {
-            engine.seqNoService().updateAllocationIdsFromPrimaryContext(primaryContext);
-        }
+        assert primaryContext.getLocalCheckpoints().containsKey(routingEntry().allocationId().getId()) &&
+            getEngine().seqNoService().getLocalCheckpoint() ==
+                primaryContext.getLocalCheckpoints().get(routingEntry().allocationId().getId()).getLocalCheckpoint();
+        getEngine().seqNoService().activateWithPrimaryContext(primaryContext);
     }
 
     /**

+ 0 - 105
core/src/main/java/org/elasticsearch/index/shard/PrimaryContext.java

@@ -1,105 +0,0 @@
-/*
- * Licensed to Elasticsearch under one or more contributor
- * license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright
- * ownership. Elasticsearch licenses this file to you under
- * the Apache License, Version 2.0 (the "License"); you may
- * not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.elasticsearch.index.shard;
-
-import com.carrotsearch.hppc.ObjectLongHashMap;
-import com.carrotsearch.hppc.ObjectLongMap;
-import com.carrotsearch.hppc.cursors.ObjectLongCursor;
-import org.elasticsearch.common.io.stream.StreamInput;
-import org.elasticsearch.common.io.stream.StreamOutput;
-import org.elasticsearch.common.io.stream.Writeable;
-
-import java.io.IOException;
-
-/**
- * Represents the sequence number component of the primary context. This is the knowledge on the primary of the in-sync and initializing
- * shards and their local checkpoints.
- */
-public class PrimaryContext implements Writeable {
-
-    private long clusterStateVersion;
-
-    public long clusterStateVersion() {
-        return clusterStateVersion;
-    }
-
-    private ObjectLongMap<String> inSyncLocalCheckpoints;
-
-    public ObjectLongMap<String> inSyncLocalCheckpoints() {
-        return inSyncLocalCheckpoints;
-    }
-
-    private ObjectLongMap<String> trackingLocalCheckpoints;
-
-    public ObjectLongMap<String> trackingLocalCheckpoints() {
-        return trackingLocalCheckpoints;
-    }
-
-    public PrimaryContext(
-            final long clusterStateVersion,
-            final ObjectLongMap<String> inSyncLocalCheckpoints,
-            final ObjectLongMap<String> trackingLocalCheckpoints) {
-        this.clusterStateVersion = clusterStateVersion;
-        this.inSyncLocalCheckpoints = inSyncLocalCheckpoints;
-        this.trackingLocalCheckpoints = trackingLocalCheckpoints;
-    }
-
-    public PrimaryContext(final StreamInput in) throws IOException {
-        clusterStateVersion = in.readVLong();
-        inSyncLocalCheckpoints = readMap(in);
-        trackingLocalCheckpoints = readMap(in);
-    }
-
-    private static ObjectLongMap<String> readMap(final StreamInput in) throws IOException {
-        final int length = in.readVInt();
-        final ObjectLongMap<String> map = new ObjectLongHashMap<>(length);
-        for (int i = 0; i < length; i++) {
-            final String key = in.readString();
-            final long value = in.readZLong();
-            map.addTo(key, value);
-        }
-        return map;
-    }
-
-    @Override
-    public void writeTo(final StreamOutput out) throws IOException {
-        out.writeVLong(clusterStateVersion);
-        writeMap(out, inSyncLocalCheckpoints);
-        writeMap(out, trackingLocalCheckpoints);
-    }
-
-    private static void writeMap(final StreamOutput out, final ObjectLongMap<String> map) throws IOException {
-        out.writeVInt(map.size());
-        for (ObjectLongCursor<String> cursor : map) {
-            out.writeString(cursor.key);
-            out.writeZLong(cursor.value);
-        }
-    }
-
-    @Override
-    public String toString() {
-        return "PrimaryContext{" +
-                "clusterStateVersion=" + clusterStateVersion +
-                ", inSyncLocalCheckpoints=" + inSyncLocalCheckpoints +
-                ", trackingLocalCheckpoints=" + trackingLocalCheckpoints +
-                '}';
-    }
-
-}

+ 21 - 25
core/src/main/java/org/elasticsearch/indices/cluster/IndicesClusterStateService.java

@@ -554,17 +554,23 @@ public class IndicesClusterStateService extends AbstractLifecycleComponent imple
                 + "cluster state: " + shardRouting + " local: " + currentRoutingEntry;
 
         try {
-            final long primaryTerm = clusterState.metaData().index(shard.shardId().getIndex()).primaryTerm(shard.shardId().id());
+            final IndexMetaData indexMetaData = clusterState.metaData().index(shard.shardId().getIndex());
+            final long primaryTerm = indexMetaData.primaryTerm(shard.shardId().id());
+            final Set<String> inSyncIds = indexMetaData.inSyncAllocationIds(shard.shardId().id());
             final IndexShardRoutingTable indexShardRoutingTable = routingTable.shardRoutingTable(shardRouting.shardId());
-            /*
-             * Filter to shards that track sequence numbers and should be taken into consideration for checkpoint tracking. Shards on old
-             * nodes will go through a file-based recovery which will also transfer sequence number information.
-             */
-            final Set<String> activeIds = allocationIdsForShardsOnNodesThatUnderstandSeqNos(indexShardRoutingTable.activeShards(), nodes);
-            final Set<String> initializingIds =
-                    allocationIdsForShardsOnNodesThatUnderstandSeqNos(indexShardRoutingTable.getAllInitializingShards(), nodes);
-            shard.updateShardState(
-                    shardRouting, primaryTerm, primaryReplicaSyncer::resync, clusterState.version(), activeIds, initializingIds);
+            final Set<String> initializingIds = indexShardRoutingTable.getAllInitializingShards()
+                .stream()
+                .map(ShardRouting::allocationId)
+                .map(AllocationId::getId)
+                .collect(Collectors.toSet());
+            final Set<String> pre60AllocationIds = indexShardRoutingTable.assignedShards()
+                .stream()
+                .filter(shr -> nodes.get(shr.currentNodeId()).getVersion().before(Version.V_6_0_0_alpha1))
+                .map(ShardRouting::allocationId)
+                .map(AllocationId::getId)
+                .collect(Collectors.toSet());
+            shard.updateShardState(shardRouting, primaryTerm, primaryReplicaSyncer::resync, clusterState.version(),
+                inSyncIds, initializingIds, pre60AllocationIds);
         } catch (Exception e) {
             failAndRemoveShard(shardRouting, true, "failed updating shard routing entry", e, clusterState);
             return;
@@ -587,17 +593,6 @@ public class IndicesClusterStateService extends AbstractLifecycleComponent imple
         }
     }
 
-    private Set<String> allocationIdsForShardsOnNodesThatUnderstandSeqNos(
-            final List<ShardRouting> shardRoutings,
-            final DiscoveryNodes nodes) {
-        return shardRoutings
-                .stream()
-                .filter(sr -> nodes.get(sr.currentNodeId()).getVersion().onOrAfter(Version.V_6_0_0_alpha1))
-                .map(ShardRouting::allocationId)
-                .map(AllocationId::getId)
-                .collect(Collectors.toSet());
-    }
-
     /**
      * Finds the routing source node for peer recovery, return null if its not found. Note, this method expects the shard
      * routing to *require* peer recovery, use {@link ShardRouting#recoverySource()} to check if its needed or not.
@@ -735,13 +730,13 @@ public class IndicesClusterStateService extends AbstractLifecycleComponent imple
          * - Updates and persists the new routing value.
          * - Updates the primary term if this shard is a primary.
          * - Updates the allocation ids that are tracked by the shard if it is a primary.
-         * See {@link GlobalCheckpointTracker#updateAllocationIdsFromMaster(long, Set, Set)} for details.
+         *   See {@link GlobalCheckpointTracker#updateFromMaster(long, Set, Set, Set)} for details.
          *
          * @param shardRouting                the new routing entry
          * @param primaryTerm                 the new primary term
          * @param primaryReplicaSyncer        the primary-replica resync action to trigger when a term is increased on a primary
          * @param applyingClusterStateVersion the cluster state version being applied when updating the allocation IDs from the master
-         * @param activeAllocationIds         the allocation ids of the currently active shard copies
+         * @param inSyncAllocationIds         the allocation ids of the currently in-sync shard copies
          * @param initializingAllocationIds   the allocation ids of the currently initializing shard copies
          * @throws IndexShardRelocatedException if shard is marked as relocated and relocation aborted
          * @throws IOException                  if shard state could not be persisted
@@ -750,8 +745,9 @@ public class IndicesClusterStateService extends AbstractLifecycleComponent imple
                               long primaryTerm,
                               CheckedBiConsumer<IndexShard, ActionListener<ResyncTask>, IOException> primaryReplicaSyncer,
                               long applyingClusterStateVersion,
-                              Set<String> activeAllocationIds,
-                              Set<String> initializingAllocationIds) throws IOException;
+                              Set<String> inSyncAllocationIds,
+                              Set<String> initializingAllocationIds,
+                              Set<String> pre60AllocationIds) throws IOException;
     }
 
     public interface AllocatedIndex<T extends Shard> extends Iterable<T>, IndexComponent {

+ 6 - 5
core/src/main/java/org/elasticsearch/indices/recovery/RecoveryHandoffPrimaryContextRequest.java

@@ -21,7 +21,7 @@ package org.elasticsearch.indices.recovery;
 
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
-import org.elasticsearch.index.shard.PrimaryContext;
+import org.elasticsearch.index.seqno.GlobalCheckpointTracker;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.transport.TransportRequest;
 
@@ -34,7 +34,7 @@ class RecoveryHandoffPrimaryContextRequest extends TransportRequest {
 
     private long recoveryId;
     private ShardId shardId;
-    private PrimaryContext primaryContext;
+    private GlobalCheckpointTracker.PrimaryContext primaryContext;
 
     /**
      * Initialize an empty request (used to serialize into when reading from a stream).
@@ -49,7 +49,8 @@ class RecoveryHandoffPrimaryContextRequest extends TransportRequest {
      * @param shardId        the shard ID of the relocation
      * @param primaryContext the primary context
      */
-    RecoveryHandoffPrimaryContextRequest(final long recoveryId, final ShardId shardId, final PrimaryContext primaryContext) {
+    RecoveryHandoffPrimaryContextRequest(final long recoveryId, final ShardId shardId,
+                                         final GlobalCheckpointTracker.PrimaryContext primaryContext) {
         this.recoveryId = recoveryId;
         this.shardId = shardId;
         this.primaryContext = primaryContext;
@@ -63,7 +64,7 @@ class RecoveryHandoffPrimaryContextRequest extends TransportRequest {
         return shardId;
     }
 
-    PrimaryContext primaryContext() {
+    GlobalCheckpointTracker.PrimaryContext primaryContext() {
         return primaryContext;
     }
 
@@ -72,7 +73,7 @@ class RecoveryHandoffPrimaryContextRequest extends TransportRequest {
         super.readFrom(in);
         recoveryId = in.readLong();
         shardId = ShardId.readShardId(in);
-        primaryContext = new PrimaryContext(in);
+        primaryContext = new GlobalCheckpointTracker.PrimaryContext(in);
     }
 
     @Override

+ 14 - 12
core/src/main/java/org/elasticsearch/indices/recovery/RecoverySourceHandler.java

@@ -42,7 +42,6 @@ import org.elasticsearch.common.lucene.store.InputStreamIndexInput;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.common.util.CancellableThreads;
-import org.elasticsearch.common.util.concurrent.CountDown;
 import org.elasticsearch.index.engine.Engine;
 import org.elasticsearch.index.engine.RecoveryEngineException;
 import org.elasticsearch.index.seqno.LocalCheckpointTracker;
@@ -63,9 +62,7 @@ import java.io.OutputStream;
 import java.util.ArrayList;
 import java.util.Comparator;
 import java.util.List;
-import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.atomic.AtomicLong;
-import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.Function;
 import java.util.function.Supplier;
 import java.util.stream.StreamSupport;
@@ -171,6 +168,8 @@ public class RecoverySourceHandler {
                 }
             }
 
+            cancellableThreads.execute(() -> runUnderOperationPermit(() -> shard.initiateTracking(request.targetAllocationId())));
+
             try {
                 prepareTargetForTranslog(translogView.estimateTotalOperations(startingSeqNo));
             } catch (final Exception e) {
@@ -208,6 +207,17 @@ public class RecoverySourceHandler {
         return response;
     }
 
+    private void runUnderOperationPermit(CancellableThreads.Interruptable runnable) throws InterruptedException {
+        final PlainActionFuture<Releasable> onAcquired = new PlainActionFuture<>();
+        shard.acquirePrimaryOperationPermit(onAcquired, ThreadPool.Names.SAME);
+        try (Releasable ignored = onAcquired.actionGet()) {
+            if (shard.state() == IndexShardState.RELOCATED) {
+                throw new IndexShardRelocatedException(shard.shardId());
+            }
+            runnable.run();
+        }
+    }
+
     /**
      * Determines if the source translog is ready for a sequence-number-based peer recovery. The main condition here is that the source
      * translog contains all operations between the local checkpoint on the target and the current maximum sequence number on the source.
@@ -465,15 +475,7 @@ public class RecoverySourceHandler {
              * marking the shard as in-sync. If the relocation handoff holds all the permits then after the handoff completes and we acquire
              * the permit then the state of the shard will be relocated and this recovery will fail.
              */
-            final PlainActionFuture<Releasable> onAcquired = new PlainActionFuture<>();
-            shard.acquirePrimaryOperationPermit(onAcquired, ThreadPool.Names.SAME);
-            try (Releasable ignored = onAcquired.actionGet()) {
-                if (shard.state() == IndexShardState.RELOCATED) {
-                    throw new IndexShardRelocatedException(shard.shardId());
-                }
-                shard.markAllocationIdAsInSync(request.targetAllocationId(), targetLocalCheckpoint);
-            }
-
+            runUnderOperationPermit(() -> shard.markAllocationIdAsInSync(request.targetAllocationId(), targetLocalCheckpoint));
             recoveryTarget.finalizeRecovery(shard.getGlobalCheckpoint());
         });
 

+ 3 - 3
core/src/main/java/org/elasticsearch/indices/recovery/RecoveryTarget.java

@@ -41,10 +41,10 @@ import org.elasticsearch.common.util.concurrent.AbstractRefCounted;
 import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
 import org.elasticsearch.index.engine.Engine;
 import org.elasticsearch.index.mapper.MapperException;
+import org.elasticsearch.index.seqno.GlobalCheckpointTracker;
 import org.elasticsearch.index.shard.IndexShard;
 import org.elasticsearch.index.shard.IndexShardNotRecoveringException;
 import org.elasticsearch.index.shard.IndexShardState;
-import org.elasticsearch.index.shard.PrimaryContext;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.index.store.Store;
 import org.elasticsearch.index.store.StoreFileMetaData;
@@ -379,8 +379,8 @@ public class RecoveryTarget extends AbstractRefCounted implements RecoveryTarget
     }
 
     @Override
-    public void handoffPrimaryContext(final PrimaryContext primaryContext) {
-        indexShard.updateAllocationIdsFromPrimaryContext(primaryContext);
+    public void handoffPrimaryContext(final GlobalCheckpointTracker.PrimaryContext primaryContext) {
+        indexShard.activateWithPrimaryContext(primaryContext);
     }
 
     @Override

+ 2 - 2
core/src/main/java/org/elasticsearch/indices/recovery/RecoveryTargetHandler.java

@@ -19,7 +19,7 @@
 package org.elasticsearch.indices.recovery;
 
 import org.elasticsearch.common.bytes.BytesReference;
-import org.elasticsearch.index.shard.PrimaryContext;
+import org.elasticsearch.index.seqno.GlobalCheckpointTracker;
 import org.elasticsearch.index.store.Store;
 import org.elasticsearch.index.store.StoreFileMetaData;
 import org.elasticsearch.index.translog.Translog;
@@ -55,7 +55,7 @@ public interface RecoveryTargetHandler {
      *
      * @param primaryContext the primary context from the relocation source
      */
-    void handoffPrimaryContext(PrimaryContext primaryContext);
+    void handoffPrimaryContext(GlobalCheckpointTracker.PrimaryContext primaryContext);
 
     /**
      * Index a set of translog operations on the target

+ 2 - 2
core/src/main/java/org/elasticsearch/indices/recovery/RemoteRecoveryTargetHandler.java

@@ -23,7 +23,7 @@ import org.apache.lucene.store.RateLimiter;
 import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.common.bytes.BytesReference;
-import org.elasticsearch.index.shard.PrimaryContext;
+import org.elasticsearch.index.seqno.GlobalCheckpointTracker;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.index.store.Store;
 import org.elasticsearch.index.store.StoreFileMetaData;
@@ -100,7 +100,7 @@ public class RemoteRecoveryTargetHandler implements RecoveryTargetHandler {
     }
 
     @Override
-    public void handoffPrimaryContext(final PrimaryContext primaryContext) {
+    public void handoffPrimaryContext(final GlobalCheckpointTracker.PrimaryContext primaryContext) {
         transportService.submitRequest(
                 targetNode,
                 PeerRecoveryTargetService.Actions.HANDOFF_PRIMARY_CONTEXT,

+ 0 - 1
core/src/test/java/org/elasticsearch/cluster/MinimumMasterNodesIT.java

@@ -188,7 +188,6 @@ public class MinimumMasterNodesIT extends ESIntegTestCase {
         }
     }
 
-    @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/25415")
     public void testMultipleNodesShutdownNonMasterNodes() throws Exception {
         Settings settings = Settings.builder()
                 .put("discovery.zen.minimum_master_nodes", 3)

+ 3 - 6
core/src/test/java/org/elasticsearch/index/engine/InternalEngineTests.java

@@ -2020,12 +2020,9 @@ public class InternalEngineTests extends ESTestCase {
 
         try {
             initialEngine = engine;
-            initialEngine
-                .seqNoService()
-                .updateAllocationIdsFromMaster(
-                        randomNonNegativeLong(),
-                        new HashSet<>(Arrays.asList("primary", "replica")),
-                        Collections.emptySet());
+            initialEngine.seqNoService().updateAllocationIdsFromMaster(1L, new HashSet<>(Arrays.asList("primary", "replica")),
+                Collections.emptySet(), Collections.emptySet());
+            initialEngine.seqNoService().activatePrimaryMode("primary", primarySeqNo);
             for (int op = 0; op < opCount; op++) {
                 final String id;
                 // mostly index, sometimes delete

+ 7 - 18
core/src/test/java/org/elasticsearch/index/replication/ESIndexLevelReplicationTestCase.java

@@ -123,7 +123,6 @@ public abstract class ESIndexLevelReplicationTestCase extends IndexShardTestCase
     }
 
     protected class ReplicationGroup implements AutoCloseable, Iterable<IndexShard> {
-        private long clusterStateVersion;
         private IndexShard primary;
         private IndexMetaData indexMetaData;
         private final List<IndexShard> replicas;
@@ -144,7 +143,6 @@ public abstract class ESIndexLevelReplicationTestCase extends IndexShardTestCase
             primary = newShard(primaryRouting, indexMetaData, null, getEngineFactory(primaryRouting));
             replicas = new ArrayList<>();
             this.indexMetaData = indexMetaData;
-            clusterStateVersion = 1;
             updateAllocationIDsOnPrimary();
             for (int i = 0; i < indexMetaData.getNumberOfReplicas(); i++) {
                 addReplica();
@@ -231,7 +229,7 @@ public abstract class ESIndexLevelReplicationTestCase extends IndexShardTestCase
             initializingIds.addAll(initializingIds());
             initializingIds.remove(primary.routingEntry().allocationId().getId());
             primary.updateShardState(ShardRoutingHelper.moveToStarted(primary.routingEntry()), primary.getPrimaryTerm(), null,
-                ++clusterStateVersion, activeIds, initializingIds);
+                currentClusterStateVersion.incrementAndGet(), activeIds, initializingIds, Collections.emptySet());
             for (final IndexShard replica : replicas) {
                 recoverReplica(replica);
             }
@@ -250,7 +248,6 @@ public abstract class ESIndexLevelReplicationTestCase extends IndexShardTestCase
                 .filter(shardRouting -> shardRouting.isSameAllocation(replica.routingEntry())).findFirst().isPresent() == false :
                 "replica with aId [" + replica.routingEntry().allocationId() + "] already exists";
             replicas.add(replica);
-            clusterStateVersion++;
             updateAllocationIDsOnPrimary();
         }
 
@@ -265,7 +262,6 @@ public abstract class ESIndexLevelReplicationTestCase extends IndexShardTestCase
             final IndexShard newReplica = newShard(shardRouting, shardPath, indexMetaData, null,
                     getEngineFactory(shardRouting));
             replicas.add(newReplica);
-            clusterStateVersion++;
             updateAllocationIDsOnPrimary();
             return newReplica;
         }
@@ -284,13 +280,8 @@ public abstract class ESIndexLevelReplicationTestCase extends IndexShardTestCase
             assertTrue(replicas.remove(replica));
             closeShards(primary);
             primary = replica;
+            assert primary.routingEntry().active() : "only active replicas can be promoted to primary: " + primary.routingEntry();
             PlainActionFuture<PrimaryReplicaSyncer.ResyncTask> fut = new PlainActionFuture<>();
-            HashSet<String> activeIds = new HashSet<>();
-            activeIds.addAll(activeIds());
-            activeIds.add(replica.routingEntry().allocationId().getId());
-            HashSet<String> initializingIds = new HashSet<>();
-            initializingIds.addAll(initializingIds());
-            initializingIds.remove(replica.routingEntry().allocationId().getId());
             primary.updateShardState(replica.routingEntry().moveActiveReplicaToPrimary(),
                 newTerm, (shard, listener) -> primaryReplicaSyncer.resync(shard,
                     new ActionListener<PrimaryReplicaSyncer.ResyncTask>() {
@@ -305,7 +296,7 @@ public abstract class ESIndexLevelReplicationTestCase extends IndexShardTestCase
                             listener.onFailure(e);
                             fut.onFailure(e);
                         }
-                    }), ++clusterStateVersion, activeIds, initializingIds);
+                    }), currentClusterStateVersion.incrementAndGet(), activeIds(), initializingIds(), Collections.emptySet());
 
             return fut;
         }
@@ -323,7 +314,6 @@ public abstract class ESIndexLevelReplicationTestCase extends IndexShardTestCase
         synchronized boolean removeReplica(IndexShard replica) throws IOException {
             final boolean removed = replicas.remove(replica);
             if (removed) {
-                clusterStateVersion++;
                 updateAllocationIDsOnPrimary();
             }
             return removed;
@@ -342,9 +332,8 @@ public abstract class ESIndexLevelReplicationTestCase extends IndexShardTestCase
             IndexShard replica,
             BiFunction<IndexShard, DiscoveryNode, RecoveryTarget> targetSupplier,
             boolean markAsRecovering) throws IOException {
-            ESIndexLevelReplicationTestCase.this.recoverReplica(replica, primary, targetSupplier, markAsRecovering);
-            clusterStateVersion++;
-            updateAllocationIDsOnPrimary();
+            ESIndexLevelReplicationTestCase.this.recoverReplica(replica, primary, targetSupplier, markAsRecovering, activeIds(),
+                initializingIds());
         }
 
         public synchronized DiscoveryNode getPrimaryNode() {
@@ -422,8 +411,8 @@ public abstract class ESIndexLevelReplicationTestCase extends IndexShardTestCase
         }
 
         private void updateAllocationIDsOnPrimary() throws IOException {
-            primary.updateShardState(primary.routingEntry(), primary.getPrimaryTerm(), null, clusterStateVersion,
-                activeIds(), initializingIds());
+            primary.updateShardState(primary.routingEntry(), primary.getPrimaryTerm(), null, currentClusterStateVersion.incrementAndGet(),
+                activeIds(), initializingIds(), Collections.emptySet());
         }
     }
 

+ 13 - 8
core/src/test/java/org/elasticsearch/index/replication/IndexLevelReplicationTests.java

@@ -25,6 +25,7 @@ import org.apache.lucene.search.TermQuery;
 import org.apache.lucene.search.TopDocs;
 import org.elasticsearch.action.DocWriteResponse;
 import org.elasticsearch.action.bulk.BulkItemResponse;
+import org.elasticsearch.action.bulk.BulkShardRequest;
 import org.elasticsearch.action.index.IndexRequest;
 import org.elasticsearch.cluster.routing.ShardRouting;
 import org.elasticsearch.common.xcontent.XContentType;
@@ -46,6 +47,7 @@ import org.hamcrest.Matcher;
 
 import java.io.IOException;
 import java.util.Collections;
+import java.util.List;
 import java.util.Map;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.Future;
@@ -149,7 +151,6 @@ public class IndexLevelReplicationTests extends ESIndexLevelReplicationTestCase
                 startedShards = shards.startReplicas(randomIntBetween(1, 2));
             } while (startedShards > 0);
 
-            final long unassignedSeqNo = SequenceNumbersService.UNASSIGNED_SEQ_NO;
             for (IndexShard shard : shards) {
                 final SeqNoStats shardStats = shard.seqNoStats();
                 final ShardRouting shardRouting = shard.routingEntry();
@@ -164,9 +165,10 @@ public class IndexLevelReplicationTests extends ESIndexLevelReplicationTestCase
                  */
                 final Matcher<Long> globalCheckpointMatcher;
                 if (shardRouting.primary()) {
-                    globalCheckpointMatcher = numDocs == 0 ? equalTo(unassignedSeqNo) : equalTo(numDocs - 1L);
+                    globalCheckpointMatcher = numDocs == 0 ? equalTo(SequenceNumbersService.NO_OPS_PERFORMED) : equalTo(numDocs - 1L);
                 } else {
-                    globalCheckpointMatcher = numDocs == 0 ? equalTo(unassignedSeqNo) : anyOf(equalTo(numDocs - 1L), equalTo(numDocs - 2L));
+                    globalCheckpointMatcher = numDocs == 0 ? equalTo(SequenceNumbersService.NO_OPS_PERFORMED)
+                        : anyOf(equalTo(numDocs - 1L), equalTo(numDocs - 2L));
                 }
                 assertThat(shardRouting + " global checkpoint mismatch", shardStats.getGlobalCheckpoint(), globalCheckpointMatcher);
                 assertThat(shardRouting + " max seq no mismatch", shardStats.getMaxSeqNo(), equalTo(numDocs - 1L));
@@ -194,12 +196,15 @@ public class IndexLevelReplicationTests extends ESIndexLevelReplicationTestCase
                 Collections.singletonMap("type", "{ \"type\": { \"properties\": { \"f\": { \"type\": \"keyword\"} }}}");
         try (ReplicationGroup shards = new ReplicationGroup(buildIndexMetaData(2, mappings))) {
             shards.startAll();
-            IndexShard replica1 = shards.getReplicas().get(0);
-            logger.info("--> isolated replica " + replica1.routingEntry());
-            shards.removeReplica(replica1);
+            List<IndexShard> replicas = shards.getReplicas();
+            IndexShard replica1 = replicas.get(0);
             IndexRequest indexRequest = new IndexRequest(index.getName(), "type", "1").source("{ \"f\": \"1\"}", XContentType.JSON);
-            shards.index(indexRequest);
-            shards.addReplica(replica1);
+            logger.info("--> isolated replica " + replica1.routingEntry());
+            BulkShardRequest replicationRequest = indexOnPrimary(indexRequest, shards.getPrimary());
+            for (int i = 1; i < replicas.size(); i++) {
+                indexOnReplica(replicationRequest, replicas.get(i));
+            }
+
             logger.info("--> promoting replica to primary " + replica1.routingEntry());
             shards.promoteReplicaToPrimary(replica1);
             indexRequest = new IndexRequest(index.getName(), "type", "1").source("{ \"f\": \"2\"}", XContentType.JSON);

+ 7 - 0
core/src/test/java/org/elasticsearch/index/replication/RecoveryDuringReplicationTests.java

@@ -39,6 +39,7 @@ import org.elasticsearch.index.engine.EngineFactory;
 import org.elasticsearch.index.engine.InternalEngineTests;
 import org.elasticsearch.index.mapper.SourceToParse;
 import org.elasticsearch.index.shard.IndexShard;
+import org.elasticsearch.index.shard.IndexShardTestCase;
 import org.elasticsearch.index.shard.PrimaryReplicaSyncer;
 import org.elasticsearch.index.store.Store;
 import org.elasticsearch.index.translog.Translog;
@@ -245,6 +246,12 @@ public class RecoveryDuringReplicationTests extends ESIndexLevelReplicationTestC
             }
 
             shards.promoteReplicaToPrimary(newPrimary);
+
+            // check that local checkpoint of new primary is properly tracked after primary promotion
+            assertThat(newPrimary.getLocalCheckpoint(), equalTo(totalDocs - 1L));
+            assertThat(IndexShardTestCase.getEngine(newPrimary).seqNoService()
+                .getTrackedLocalCheckpointForShard(newPrimary.routingEntry().allocationId().getId()), equalTo(totalDocs - 1L));
+
             // index some more
             totalDocs += shards.indexDocs(randomIntBetween(0, 5));
 

+ 256 - 248
core/src/test/java/org/elasticsearch/index/seqno/GlobalCheckpointTrackerTests.java

@@ -19,18 +19,18 @@
 
 package org.elasticsearch.index.seqno;
 
-import com.carrotsearch.hppc.ObjectLongHashMap;
-import com.carrotsearch.hppc.ObjectLongMap;
 import org.elasticsearch.common.Randomness;
 import org.elasticsearch.common.collect.Tuple;
+import org.elasticsearch.common.io.stream.BytesStreamOutput;
+import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.util.set.Sets;
-import org.elasticsearch.index.shard.PrimaryContext;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.IndexSettingsModule;
 import org.junit.Before;
 
+import java.io.IOException;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashMap;
@@ -46,15 +46,13 @@ import java.util.function.Function;
 import java.util.stream.Collectors;
 import java.util.stream.IntStream;
 import java.util.stream.Stream;
-import java.util.stream.StreamSupport;
 
+import static java.util.Collections.emptySet;
+import static org.elasticsearch.index.seqno.SequenceNumbersService.NO_OPS_PERFORMED;
 import static org.elasticsearch.index.seqno.SequenceNumbersService.UNASSIGNED_SEQ_NO;
-import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.greaterThan;
-import static org.hamcrest.Matchers.hasToString;
 import static org.hamcrest.Matchers.not;
-import static org.mockito.Mockito.mock;
 
 public class GlobalCheckpointTrackerTests extends ESTestCase {
 
@@ -88,7 +86,7 @@ public class GlobalCheckpointTrackerTests extends ESTestCase {
     public void testGlobalCheckpointUpdate() {
         final long initialClusterStateVersion = randomNonNegativeLong();
         Map<String, Long> allocations = new HashMap<>();
-        Map<String, Long> activeWithCheckpoints = randomAllocationsWithLocalCheckpoints(0, 5);
+        Map<String, Long> activeWithCheckpoints = randomAllocationsWithLocalCheckpoints(1, 5);
         Set<String> active = new HashSet<>(activeWithCheckpoints.keySet());
         allocations.putAll(activeWithCheckpoints);
         Map<String, Long> initializingWithCheckpoints = randomAllocationsWithLocalCheckpoints(0, 5);
@@ -115,8 +113,9 @@ public class GlobalCheckpointTrackerTests extends ESTestCase {
             logger.info("  - [{}], local checkpoint [{}], [{}]", aId, allocations.get(aId), type);
         });
 
-        tracker.updateAllocationIdsFromMaster(initialClusterStateVersion, active, initializing);
-        initializing.forEach(aId -> markAllocationIdAsInSyncQuietly(tracker, aId, tracker.getGlobalCheckpoint()));
+        tracker.updateFromMaster(initialClusterStateVersion, active, initializing, emptySet());
+        tracker.activatePrimaryMode(active.iterator().next(), NO_OPS_PERFORMED);
+        initializing.forEach(aId -> markAllocationIdAsInSyncQuietly(tracker, aId, NO_OPS_PERFORMED));
         allocations.keySet().forEach(aId -> tracker.updateLocalCheckpoint(aId, allocations.get(aId)));
 
         assertThat(tracker.getGlobalCheckpoint(), equalTo(minLocalCheckpoint));
@@ -134,30 +133,37 @@ public class GlobalCheckpointTrackerTests extends ESTestCase {
 
         // first check that adding it without the master blessing doesn't change anything.
         tracker.updateLocalCheckpoint(extraId, minLocalCheckpointAfterUpdates + 1 + randomInt(4));
-        assertThat(tracker.getLocalCheckpointForAllocationId(extraId), equalTo(UNASSIGNED_SEQ_NO));
+        assertNull(tracker.localCheckpoints.get(extraId));
+        expectThrows(IllegalStateException.class, () -> tracker.initiateTracking(extraId));
 
-        Set<String> newActive = new HashSet<>(active);
-        newActive.add(extraId);
-        tracker.updateAllocationIdsFromMaster(initialClusterStateVersion + 1, newActive, initializing);
+        Set<String> newInitializing = new HashSet<>(initializing);
+        newInitializing.add(extraId);
+        tracker.updateFromMaster(initialClusterStateVersion + 1, active, newInitializing, emptySet());
+
+        tracker.initiateTracking(extraId);
 
         // now notify for the new id
-        tracker.updateLocalCheckpoint(extraId, minLocalCheckpointAfterUpdates + 1 + randomInt(4));
+        if (randomBoolean()) {
+            tracker.updateLocalCheckpoint(extraId, minLocalCheckpointAfterUpdates + 1 + randomInt(4));
+            markAllocationIdAsInSyncQuietly(tracker, extraId, randomInt((int) minLocalCheckpointAfterUpdates));
+        } else {
+            markAllocationIdAsInSyncQuietly(tracker, extraId, minLocalCheckpointAfterUpdates + 1 + randomInt(4));
+        }
 
         // now it should be incremented
         assertThat(tracker.getGlobalCheckpoint(), greaterThan(minLocalCheckpoint));
     }
 
     public void testMissingActiveIdsPreventAdvance() {
-        final Map<String, Long> active = randomAllocationsWithLocalCheckpoints(1, 5);
+        final Map<String, Long> active = randomAllocationsWithLocalCheckpoints(2, 5);
         final Map<String, Long> initializing = randomAllocationsWithLocalCheckpoints(0, 5);
         final Map<String, Long> assigned = new HashMap<>();
         assigned.putAll(active);
         assigned.putAll(initializing);
-        tracker.updateAllocationIdsFromMaster(
-                randomNonNegativeLong(),
-                active.keySet(),
-                initializing.keySet());
-        randomSubsetOf(initializing.keySet()).forEach(k -> markAllocationIdAsInSyncQuietly(tracker, k, tracker.getGlobalCheckpoint()));
+        tracker.updateFromMaster(randomNonNegativeLong(), active.keySet(), initializing.keySet(), emptySet());
+        String primary = active.keySet().iterator().next();
+        tracker.activatePrimaryMode(primary, NO_OPS_PERFORMED);
+        randomSubsetOf(initializing.keySet()).forEach(k -> markAllocationIdAsInSyncQuietly(tracker, k, NO_OPS_PERFORMED));
         final String missingActiveID = randomFrom(active.keySet());
         assigned
                 .entrySet()
@@ -165,24 +171,27 @@ public class GlobalCheckpointTrackerTests extends ESTestCase {
                 .filter(e -> !e.getKey().equals(missingActiveID))
                 .forEach(e -> tracker.updateLocalCheckpoint(e.getKey(), e.getValue()));
 
-        assertThat(tracker.getGlobalCheckpoint(), equalTo(UNASSIGNED_SEQ_NO));
-
+        if (missingActiveID.equals(primary) == false) {
+            assertThat(tracker.getGlobalCheckpoint(), equalTo(UNASSIGNED_SEQ_NO));
+        }
         // now update all knowledge of all shards
         assigned.forEach(tracker::updateLocalCheckpoint);
         assertThat(tracker.getGlobalCheckpoint(), not(equalTo(UNASSIGNED_SEQ_NO)));
     }
 
     public void testMissingInSyncIdsPreventAdvance() {
-        final Map<String, Long> active = randomAllocationsWithLocalCheckpoints(0, 5);
-        final Map<String, Long> initializing = randomAllocationsWithLocalCheckpoints(1, 5);
-        tracker.updateAllocationIdsFromMaster(randomNonNegativeLong(), active.keySet(), initializing.keySet());
-        initializing.keySet().forEach(k -> markAllocationIdAsInSyncQuietly(tracker, k, tracker.getGlobalCheckpoint()));
-        randomSubsetOf(randomInt(initializing.size() - 1),
-            initializing.keySet()).forEach(aId -> tracker.updateLocalCheckpoint(aId, initializing.get(aId)));
+        final Map<String, Long> active = randomAllocationsWithLocalCheckpoints(1, 5);
+        final Map<String, Long> initializing = randomAllocationsWithLocalCheckpoints(2, 5);
+        logger.info("active: {}, initializing: {}", active, initializing);
+        tracker.updateFromMaster(randomNonNegativeLong(), active.keySet(), initializing.keySet(), emptySet());
+        String primary = active.keySet().iterator().next();
+        tracker.activatePrimaryMode(primary, NO_OPS_PERFORMED);
+        randomSubsetOf(randomIntBetween(1, initializing.size() - 1),
+            initializing.keySet()).forEach(aId -> markAllocationIdAsInSyncQuietly(tracker, aId, NO_OPS_PERFORMED));
 
         active.forEach(tracker::updateLocalCheckpoint);
 
-        assertThat(tracker.getGlobalCheckpoint(), equalTo(UNASSIGNED_SEQ_NO));
+        assertThat(tracker.getGlobalCheckpoint(), equalTo(NO_OPS_PERFORMED));
 
         // update again
         initializing.forEach(tracker::updateLocalCheckpoint);
@@ -193,9 +202,11 @@ public class GlobalCheckpointTrackerTests extends ESTestCase {
         final Map<String, Long> active = randomAllocationsWithLocalCheckpoints(1, 5);
         final Map<String, Long> initializing = randomAllocationsWithLocalCheckpoints(1, 5);
         final Map<String, Long> nonApproved = randomAllocationsWithLocalCheckpoints(1, 5);
-        tracker.updateAllocationIdsFromMaster(randomNonNegativeLong(), active.keySet(), initializing.keySet());
-        initializing.keySet().forEach(k -> markAllocationIdAsInSyncQuietly(tracker, k, tracker.getGlobalCheckpoint()));
-        nonApproved.keySet().forEach(k -> markAllocationIdAsInSyncQuietly(tracker, k, tracker.getGlobalCheckpoint()));
+        tracker.updateFromMaster(randomNonNegativeLong(), active.keySet(), initializing.keySet(), emptySet());
+        tracker.activatePrimaryMode(active.keySet().iterator().next(), NO_OPS_PERFORMED);
+        initializing.keySet().forEach(k -> markAllocationIdAsInSyncQuietly(tracker, k, NO_OPS_PERFORMED));
+        nonApproved.keySet().forEach(k ->
+            expectThrows(IllegalStateException.class, () -> markAllocationIdAsInSyncQuietly(tracker, k, NO_OPS_PERFORMED)));
 
         List<Map<String, Long>> allocations = Arrays.asList(active, initializing, nonApproved);
         Collections.shuffle(allocations, random());
@@ -221,11 +232,12 @@ public class GlobalCheckpointTrackerTests extends ESTestCase {
         if (randomBoolean()) {
             allocations.putAll(initializingToBeRemoved);
         }
-        tracker.updateAllocationIdsFromMaster(initialClusterStateVersion, active, initializing);
+        tracker.updateFromMaster(initialClusterStateVersion, active, initializing, emptySet());
+        tracker.activatePrimaryMode(active.iterator().next(), NO_OPS_PERFORMED);
         if (randomBoolean()) {
-            initializingToStay.keySet().forEach(k -> markAllocationIdAsInSyncQuietly(tracker, k, tracker.getGlobalCheckpoint()));
+            initializingToStay.keySet().forEach(k -> markAllocationIdAsInSyncQuietly(tracker, k, NO_OPS_PERFORMED));
         } else {
-            initializing.forEach(k -> markAllocationIdAsInSyncQuietly(tracker, k, tracker.getGlobalCheckpoint()));
+            initializing.forEach(k -> markAllocationIdAsInSyncQuietly(tracker, k, NO_OPS_PERFORMED));
         }
         if (randomBoolean()) {
             allocations.forEach(tracker::updateLocalCheckpoint);
@@ -233,11 +245,13 @@ public class GlobalCheckpointTrackerTests extends ESTestCase {
 
         // now remove shards
         if (randomBoolean()) {
-            tracker.updateAllocationIdsFromMaster(initialClusterStateVersion + 1, activeToStay.keySet(), initializingToStay.keySet());
+            tracker.updateFromMaster(initialClusterStateVersion + 1, activeToStay.keySet(), initializingToStay.keySet(),
+                emptySet());
             allocations.forEach((aid, ckp) -> tracker.updateLocalCheckpoint(aid, ckp + 10L));
         } else {
             allocations.forEach((aid, ckp) -> tracker.updateLocalCheckpoint(aid, ckp + 10L));
-            tracker.updateAllocationIdsFromMaster(initialClusterStateVersion + 2, activeToStay.keySet(), initializingToStay.keySet());
+            tracker.updateFromMaster(initialClusterStateVersion + 2, activeToStay.keySet(), initializingToStay.keySet(),
+                emptySet());
         }
 
         final long checkpoint = Stream.concat(activeToStay.values().stream(), initializingToStay.values().stream())
@@ -246,16 +260,16 @@ public class GlobalCheckpointTrackerTests extends ESTestCase {
         assertThat(tracker.getGlobalCheckpoint(), equalTo(checkpoint));
     }
 
-    public void testWaitForAllocationIdToBeInSync() throws BrokenBarrierException, InterruptedException {
+    public void testWaitForAllocationIdToBeInSync() throws Exception {
         final int localCheckpoint = randomIntBetween(1, 32);
         final int globalCheckpoint = randomIntBetween(localCheckpoint + 1, 64);
         final CyclicBarrier barrier = new CyclicBarrier(2);
         final AtomicBoolean complete = new AtomicBoolean();
         final String inSyncAllocationId =randomAlphaOfLength(16);
         final String trackingAllocationId = randomAlphaOfLength(16);
-        tracker.updateAllocationIdsFromMaster(
-                randomNonNegativeLong(), Collections.singleton(inSyncAllocationId), Collections.singleton(trackingAllocationId));
-        tracker.updateLocalCheckpoint(inSyncAllocationId, globalCheckpoint);
+        tracker.updateFromMaster(randomNonNegativeLong(), Collections.singleton(inSyncAllocationId),
+            Collections.singleton(trackingAllocationId), emptySet());
+        tracker.activatePrimaryMode(inSyncAllocationId, globalCheckpoint);
         final Thread thread = new Thread(() -> {
             try {
                 // synchronize starting with the test thread
@@ -279,18 +293,16 @@ public class GlobalCheckpointTrackerTests extends ESTestCase {
         for (int i = 0; i < elements.size(); i++) {
             tracker.updateLocalCheckpoint(trackingAllocationId, elements.get(i));
             assertFalse(complete.get());
-            assertTrue(awaitBusy(() -> tracker.trackingLocalCheckpoints.containsKey(trackingAllocationId)));
-            assertTrue(awaitBusy(() -> tracker.pendingInSync.contains(trackingAllocationId)));
-            assertFalse(tracker.inSyncLocalCheckpoints.containsKey(trackingAllocationId));
+            assertFalse(tracker.getTrackedLocalCheckpointForShard(trackingAllocationId).inSync);
+            assertBusy(() -> assertTrue(tracker.pendingInSync.contains(trackingAllocationId)));
         }
 
         tracker.updateLocalCheckpoint(trackingAllocationId, randomIntBetween(globalCheckpoint, 64));
         // synchronize with the waiting thread to mark that it is complete
         barrier.await();
         assertTrue(complete.get());
-        assertTrue(tracker.trackingLocalCheckpoints.isEmpty());
-        assertTrue(tracker.pendingInSync.isEmpty());
-        assertTrue(tracker.inSyncLocalCheckpoints.containsKey(trackingAllocationId));
+        assertTrue(tracker.getTrackedLocalCheckpointForShard(trackingAllocationId).inSync);
+        assertFalse(tracker.pendingInSync.contains(trackingAllocationId));
 
         thread.join();
     }
@@ -302,9 +314,9 @@ public class GlobalCheckpointTrackerTests extends ESTestCase {
         final AtomicBoolean interrupted = new AtomicBoolean();
         final String inSyncAllocationId = randomAlphaOfLength(16);
         final String trackingAllocationId = randomAlphaOfLength(32);
-        tracker.updateAllocationIdsFromMaster(
-                randomNonNegativeLong(), Collections.singleton(inSyncAllocationId), Collections.singleton(trackingAllocationId));
-        tracker.updateLocalCheckpoint(inSyncAllocationId, globalCheckpoint);
+        tracker.updateFromMaster(randomNonNegativeLong(), Collections.singleton(inSyncAllocationId),
+            Collections.singleton(trackingAllocationId), emptySet());
+        tracker.activatePrimaryMode(inSyncAllocationId, globalCheckpoint);
         final Thread thread = new Thread(() -> {
             try {
                 // synchronize starting with the test thread
@@ -348,19 +360,25 @@ public class GlobalCheckpointTrackerTests extends ESTestCase {
                 randomActiveAndInitializingAllocationIds(numberOfActiveAllocationsIds, numberOfInitializingIds);
         final Set<String> activeAllocationIds = activeAndInitializingAllocationIds.v1();
         final Set<String> initializingIds = activeAndInitializingAllocationIds.v2();
-        tracker.updateAllocationIdsFromMaster(initialClusterStateVersion, activeAllocationIds, initializingIds);
+        tracker.updateFromMaster(initialClusterStateVersion, activeAllocationIds, initializingIds, emptySet());
+        String primaryId = activeAllocationIds.iterator().next();
+        tracker.activatePrimaryMode(primaryId, NO_OPS_PERFORMED);
 
         // first we assert that the in-sync and tracking sets are set up correctly
-        assertTrue(activeAllocationIds.stream().allMatch(a -> tracker.inSyncLocalCheckpoints.containsKey(a)));
+        assertTrue(activeAllocationIds.stream().allMatch(a -> tracker.getTrackedLocalCheckpointForShard(a).inSync));
         assertTrue(
                 activeAllocationIds
                         .stream()
-                        .allMatch(a -> tracker.inSyncLocalCheckpoints.get(a) == SequenceNumbersService.UNASSIGNED_SEQ_NO));
-        assertTrue(initializingIds.stream().allMatch(a -> tracker.trackingLocalCheckpoints.containsKey(a)));
+                        .filter(a -> a.equals(primaryId) == false)
+                        .allMatch(a -> tracker.getTrackedLocalCheckpointForShard(a).getLocalCheckpoint()
+                            == SequenceNumbersService.UNASSIGNED_SEQ_NO));
+        assertTrue(initializingIds.stream().noneMatch(a -> tracker.getTrackedLocalCheckpointForShard(a).inSync));
         assertTrue(
                 initializingIds
                         .stream()
-                        .allMatch(a -> tracker.trackingLocalCheckpoints.get(a) == SequenceNumbersService.UNASSIGNED_SEQ_NO));
+                        .filter(a -> a.equals(primaryId) == false)
+                        .allMatch(a -> tracker.getTrackedLocalCheckpointForShard(a).getLocalCheckpoint()
+                            == SequenceNumbersService.UNASSIGNED_SEQ_NO));
 
         // now we will remove some allocation IDs from these and ensure that they propagate through
         final List<String> removingActiveAllocationIds = randomSubsetOf(activeAllocationIds);
@@ -369,29 +387,32 @@ public class GlobalCheckpointTrackerTests extends ESTestCase {
         final List<String> removingInitializingAllocationIds = randomSubsetOf(initializingIds);
         final Set<String> newInitializingAllocationIds =
                 initializingIds.stream().filter(a -> !removingInitializingAllocationIds.contains(a)).collect(Collectors.toSet());
-        tracker.updateAllocationIdsFromMaster(initialClusterStateVersion + 1, newActiveAllocationIds, newInitializingAllocationIds);
-        assertTrue(newActiveAllocationIds.stream().allMatch(a -> tracker.inSyncLocalCheckpoints.containsKey(a)));
-        assertTrue(removingActiveAllocationIds.stream().noneMatch(a -> tracker.inSyncLocalCheckpoints.containsKey(a)));
-        assertTrue(newInitializingAllocationIds.stream().allMatch(a -> tracker.trackingLocalCheckpoints.containsKey(a)));
-        assertTrue(removingInitializingAllocationIds.stream().noneMatch(a -> tracker.trackingLocalCheckpoints.containsKey(a)));
+        tracker.updateFromMaster(initialClusterStateVersion + 1, newActiveAllocationIds, newInitializingAllocationIds,
+            emptySet());
+        assertTrue(newActiveAllocationIds.stream().allMatch(a -> tracker.getTrackedLocalCheckpointForShard(a).inSync));
+        assertTrue(removingActiveAllocationIds.stream().allMatch(a -> tracker.getTrackedLocalCheckpointForShard(a) == null));
+        assertTrue(newInitializingAllocationIds.stream().noneMatch(a -> tracker.getTrackedLocalCheckpointForShard(a).inSync));
+        assertTrue(removingInitializingAllocationIds.stream().allMatch(a -> tracker.getTrackedLocalCheckpointForShard(a) == null));
 
         /*
          * Now we will add an allocation ID to each of active and initializing and ensure they propagate through. Using different lengths
          * than we have been using above ensures that we can not collide with a previous allocation ID
          */
-        newActiveAllocationIds.add(randomAlphaOfLength(32));
         newInitializingAllocationIds.add(randomAlphaOfLength(64));
-        tracker.updateAllocationIdsFromMaster(initialClusterStateVersion + 2, newActiveAllocationIds, newInitializingAllocationIds);
-        assertTrue(newActiveAllocationIds.stream().allMatch(a -> tracker.inSyncLocalCheckpoints.containsKey(a)));
+        tracker.updateFromMaster(initialClusterStateVersion + 2, newActiveAllocationIds, newInitializingAllocationIds, emptySet());
+        assertTrue(newActiveAllocationIds.stream().allMatch(a -> tracker.getTrackedLocalCheckpointForShard(a).inSync));
         assertTrue(
                 newActiveAllocationIds
                         .stream()
-                        .allMatch(a -> tracker.inSyncLocalCheckpoints.get(a) == SequenceNumbersService.UNASSIGNED_SEQ_NO));
-        assertTrue(newInitializingAllocationIds.stream().allMatch(a -> tracker.trackingLocalCheckpoints.containsKey(a)));
+                        .filter(a -> a.equals(primaryId) == false)
+                        .allMatch(a -> tracker.getTrackedLocalCheckpointForShard(a).getLocalCheckpoint()
+                            == SequenceNumbersService.UNASSIGNED_SEQ_NO));
+        assertTrue(newInitializingAllocationIds.stream().noneMatch(a -> tracker.getTrackedLocalCheckpointForShard(a).inSync));
         assertTrue(
                 newInitializingAllocationIds
                         .stream()
-                        .allMatch(a -> tracker.trackingLocalCheckpoints.get(a) == SequenceNumbersService.UNASSIGNED_SEQ_NO));
+                        .allMatch(a -> tracker.getTrackedLocalCheckpointForShard(a).getLocalCheckpoint()
+                            == SequenceNumbersService.UNASSIGNED_SEQ_NO));
 
         // the tracking allocation IDs should play no role in determining the global checkpoint
         final Map<String, Integer> activeLocalCheckpoints =
@@ -404,12 +425,12 @@ public class GlobalCheckpointTrackerTests extends ESTestCase {
                 activeLocalCheckpoints
                         .entrySet()
                         .stream()
-                        .allMatch(e -> tracker.getLocalCheckpointForAllocationId(e.getKey()) == e.getValue()));
+                        .allMatch(e -> tracker.getTrackedLocalCheckpointForShard(e.getKey()).getLocalCheckpoint() == e.getValue()));
         assertTrue(
                 initializingLocalCheckpoints
                         .entrySet()
                         .stream()
-                        .allMatch(e -> tracker.trackingLocalCheckpoints.get(e.getKey()) == e.getValue()));
+                        .allMatch(e -> tracker.getTrackedLocalCheckpointForShard(e.getKey()).getLocalCheckpoint() == e.getValue()));
         final long minimumActiveLocalCheckpoint = (long) activeLocalCheckpoints.values().stream().min(Integer::compareTo).get();
         assertThat(tracker.getGlobalCheckpoint(), equalTo(minimumActiveLocalCheckpoint));
         final long minimumInitailizingLocalCheckpoint = (long) initializingLocalCheckpoints.values().stream().min(Integer::compareTo).get();
@@ -421,7 +442,7 @@ public class GlobalCheckpointTrackerTests extends ESTestCase {
         // using a different length than we have been using above ensures that we can not collide with a previous allocation ID
         final String newSyncingAllocationId = randomAlphaOfLength(128);
         newInitializingAllocationIds.add(newSyncingAllocationId);
-        tracker.updateAllocationIdsFromMaster(initialClusterStateVersion + 3, newActiveAllocationIds, newInitializingAllocationIds);
+        tracker.updateFromMaster(initialClusterStateVersion + 3, newActiveAllocationIds, newInitializingAllocationIds, emptySet());
         final CyclicBarrier barrier = new CyclicBarrier(2);
         final Thread thread = new Thread(() -> {
             try {
@@ -439,7 +460,7 @@ public class GlobalCheckpointTrackerTests extends ESTestCase {
 
         assertBusy(() -> {
             assertTrue(tracker.pendingInSync.contains(newSyncingAllocationId));
-            assertTrue(tracker.trackingLocalCheckpoints.containsKey(newSyncingAllocationId));
+            assertFalse(tracker.getTrackedLocalCheckpointForShard(newSyncingAllocationId).inSync);
         });
 
         tracker.updateLocalCheckpoint(newSyncingAllocationId, randomIntBetween(Math.toIntExact(minimumActiveLocalCheckpoint), 1024));
@@ -447,17 +468,16 @@ public class GlobalCheckpointTrackerTests extends ESTestCase {
         barrier.await();
 
         assertFalse(tracker.pendingInSync.contains(newSyncingAllocationId));
-        assertFalse(tracker.trackingLocalCheckpoints.containsKey(newSyncingAllocationId));
-        assertTrue(tracker.inSyncLocalCheckpoints.containsKey(newSyncingAllocationId));
+        assertTrue(tracker.getTrackedLocalCheckpointForShard(newSyncingAllocationId).inSync);
 
         /*
          * The new in-sync allocation ID is in the in-sync set now yet the master does not know this; the allocation ID should still be in
          * the in-sync set even if we receive a cluster state update that does not reflect this.
          *
          */
-        tracker.updateAllocationIdsFromMaster(initialClusterStateVersion + 4, newActiveAllocationIds, newInitializingAllocationIds);
-        assertFalse(tracker.trackingLocalCheckpoints.containsKey(newSyncingAllocationId));
-        assertTrue(tracker.inSyncLocalCheckpoints.containsKey(newSyncingAllocationId));
+        tracker.updateFromMaster(initialClusterStateVersion + 4, newActiveAllocationIds, newInitializingAllocationIds, emptySet());
+        assertTrue(tracker.getTrackedLocalCheckpointForShard(newSyncingAllocationId).inSync);
+        assertFalse(tracker.pendingInSync.contains(newSyncingAllocationId));
     }
 
     /**
@@ -476,12 +496,11 @@ public class GlobalCheckpointTrackerTests extends ESTestCase {
 
         final String active = randomAlphaOfLength(16);
         final String initializing = randomAlphaOfLength(32);
-        tracker.updateAllocationIdsFromMaster(randomNonNegativeLong(), Collections.singleton(active), Collections.singleton(initializing));
-
         final CyclicBarrier barrier = new CyclicBarrier(4);
 
         final int activeLocalCheckpoint = randomIntBetween(0, Integer.MAX_VALUE - 1);
-        tracker.updateLocalCheckpoint(active, activeLocalCheckpoint);
+        tracker.updateFromMaster(randomNonNegativeLong(), Collections.singleton(active), Collections.singleton(initializing), emptySet());
+        tracker.activatePrimaryMode(active, activeLocalCheckpoint);
         final int nextActiveLocalCheckpoint = randomIntBetween(activeLocalCheckpoint + 1, Integer.MAX_VALUE);
         final Thread activeThread = new Thread(() -> {
             try {
@@ -523,205 +542,194 @@ public class GlobalCheckpointTrackerTests extends ESTestCase {
         assertThat(tracker.getGlobalCheckpoint(), equalTo((long) nextActiveLocalCheckpoint));
     }
 
-    public void testPrimaryContextOlderThanAppliedClusterState() {
-        final long initialClusterStateVersion = randomIntBetween(0, Integer.MAX_VALUE - 1) + 1;
-        final int numberOfActiveAllocationsIds = randomIntBetween(0, 8);
-        final int numberOfInitializingIds = randomIntBetween(0, 8);
-        final Tuple<Set<String>, Set<String>> activeAndInitializingAllocationIds =
-                randomActiveAndInitializingAllocationIds(numberOfActiveAllocationsIds, numberOfInitializingIds);
-        final Set<String> activeAllocationIds = activeAndInitializingAllocationIds.v1();
-        final Set<String> initializingAllocationIds = activeAndInitializingAllocationIds.v2();
-        tracker.updateAllocationIdsFromMaster(initialClusterStateVersion, activeAllocationIds, initializingAllocationIds);
+    public void testPrimaryContextHandoff() throws IOException {
+        GlobalCheckpointTracker oldPrimary = new GlobalCheckpointTracker(new ShardId("test", "_na_", 0),
+            IndexSettingsModule.newIndexSettings("test", Settings.EMPTY), UNASSIGNED_SEQ_NO);
+        GlobalCheckpointTracker newPrimary = new GlobalCheckpointTracker(new ShardId("test", "_na_", 0),
+            IndexSettingsModule.newIndexSettings("test", Settings.EMPTY), UNASSIGNED_SEQ_NO);
 
-        /*
-         * We are going to establish a primary context from a cluster state version older than the applied cluster state version on the
-         * tracker. Because of recovery barriers established during relocation handoff, we know that the set of active allocation IDs in the
-         * newer cluster state is a superset of the allocation IDs in the applied cluster state with the caveat that an existing
-         * initializing allocation ID could have moved to an in-sync allocation ID within the tracker due to recovery finalization, and the
-         * set of initializing allocation IDs is otherwise arbitrary.
-         */
-        final int numberOfAdditionalInitializingAllocationIds = randomIntBetween(0, 8);
-        final Set<String> initializedAllocationIds = new HashSet<>(randomSubsetOf(initializingAllocationIds));
-        final Set<String> newInitializingAllocationIds =
-                randomAllocationIdsExcludingExistingIds(
-                        Sets.union(activeAllocationIds, initializingAllocationIds), numberOfAdditionalInitializingAllocationIds);
-        final Set<String> contextInitializingIds = Sets.union(
-                new HashSet<>(randomSubsetOf(Sets.difference(initializingAllocationIds, initializedAllocationIds))),
-                newInitializingAllocationIds);
-
-        final int numberOfAdditionalActiveAllocationIds = randomIntBetween(0, 8);
-        final Set<String> contextActiveAllocationIds = Sets.union(
-                Sets.union(
-                        activeAllocationIds,
-                        randomAllocationIdsExcludingExistingIds(activeAllocationIds, numberOfAdditionalActiveAllocationIds)),
-                initializedAllocationIds);
-
-        final ObjectLongMap<String> activeAllocationIdsLocalCheckpoints = new ObjectLongHashMap<>();
-        for (final String allocationId : contextActiveAllocationIds) {
-            activeAllocationIdsLocalCheckpoints.put(allocationId, randomNonNegativeLong());
-        }
-        final ObjectLongMap<String> initializingAllocationIdsLocalCheckpoints = new ObjectLongHashMap<>();
-        for (final String allocationId : contextInitializingIds) {
-            initializingAllocationIdsLocalCheckpoints.put(allocationId, randomNonNegativeLong());
+        FakeClusterState clusterState = initialState();
+        clusterState.apply(oldPrimary);
+        clusterState.apply(newPrimary);
+
+        activatePrimary(clusterState, oldPrimary);
+
+        final int numUpdates = randomInt(10);
+        for (int i = 0; i < numUpdates; i++) {
+            if (rarely()) {
+                clusterState = randomUpdateClusterState(clusterState);
+                clusterState.apply(oldPrimary);
+                clusterState.apply(newPrimary);
+            }
+            if (randomBoolean()) {
+                randomLocalCheckpointUpdate(oldPrimary);
+            }
+            if (randomBoolean()) {
+                randomMarkInSync(oldPrimary);
+            }
         }
 
-        final PrimaryContext primaryContext = new PrimaryContext(
-                initialClusterStateVersion - randomIntBetween(0, Math.toIntExact(initialClusterStateVersion) - 1),
-                activeAllocationIdsLocalCheckpoints,
-                initializingAllocationIdsLocalCheckpoints);
-
-        tracker.updateAllocationIdsFromPrimaryContext(primaryContext);
-
-        // the primary context carries an older cluster state version
-        assertThat(tracker.appliedClusterStateVersion, equalTo(initialClusterStateVersion));
-
-        // only existing active allocation IDs and initializing allocation IDs that moved to initialized should be in-sync
-        assertThat(
-                Sets.union(activeAllocationIds, initializedAllocationIds),
-                equalTo(
-                        StreamSupport
-                                .stream(tracker.inSyncLocalCheckpoints.keys().spliterator(), false)
-                                .map(e -> e.value)
-                                .collect(Collectors.toSet())));
-
-        // the local checkpoints known to the tracker for in-sync shards should match what is known in the primary context
-        for (final String allocationId : Sets.union(activeAllocationIds, initializedAllocationIds)) {
-            assertThat(
-                    tracker.inSyncLocalCheckpoints.get(allocationId), equalTo(primaryContext.inSyncLocalCheckpoints().get(allocationId)));
+        GlobalCheckpointTracker.PrimaryContext primaryContext = oldPrimary.startRelocationHandoff();
+
+        if (randomBoolean()) {
+            // cluster state update after primary context handoff
+            if (randomBoolean()) {
+                clusterState = randomUpdateClusterState(clusterState);
+                clusterState.apply(oldPrimary);
+                clusterState.apply(newPrimary);
+            }
+
+            // abort handoff, check that we can continue updates and retry handoff
+            oldPrimary.abortRelocationHandoff();
+
+            if (rarely()) {
+                clusterState = randomUpdateClusterState(clusterState);
+                clusterState.apply(oldPrimary);
+                clusterState.apply(newPrimary);
+            }
+            if (randomBoolean()) {
+                randomLocalCheckpointUpdate(oldPrimary);
+            }
+            if (randomBoolean()) {
+                randomMarkInSync(oldPrimary);
+            }
+
+            // do another handoff
+            primaryContext = oldPrimary.startRelocationHandoff();
         }
 
-        // only existing initializing allocation IDs that did not moved to initialized should be tracked
-        assertThat(
-                Sets.difference(initializingAllocationIds, initializedAllocationIds),
-                equalTo(
-                        StreamSupport
-                        .stream(tracker.trackingLocalCheckpoints.keys().spliterator(), false)
-                        .map(e -> e.value)
-                        .collect(Collectors.toSet())));
-
-        // the local checkpoints known to the tracker for initializing shards should match what is known in the primary context
-        for (final String allocationId : Sets.difference(initializingAllocationIds, initializedAllocationIds)) {
-            if (primaryContext.trackingLocalCheckpoints().containsKey(allocationId)) {
-                assertThat(
-                        tracker.trackingLocalCheckpoints.get(allocationId),
-                        equalTo(primaryContext.trackingLocalCheckpoints().get(allocationId)));
-            } else {
-                assertThat(tracker.trackingLocalCheckpoints.get(allocationId), equalTo(SequenceNumbersService.UNASSIGNED_SEQ_NO));
+        // send primary context through the wire
+        BytesStreamOutput output = new BytesStreamOutput();
+        primaryContext.writeTo(output);
+        StreamInput streamInput = output.bytes().streamInput();
+        primaryContext = new GlobalCheckpointTracker.PrimaryContext(streamInput);
+
+        switch (randomInt(3)) {
+            case 0: {
+                // apply cluster state update on old primary while primary context is being transferred
+                clusterState = randomUpdateClusterState(clusterState);
+                clusterState.apply(oldPrimary);
+                // activate new primary
+                newPrimary.activateWithPrimaryContext(primaryContext);
+                // apply cluster state update on new primary so that the states on old and new primary are comparable
+                clusterState.apply(newPrimary);
+                break;
+            }
+            case 1: {
+                // apply cluster state update on new primary while primary context is being transferred
+                clusterState = randomUpdateClusterState(clusterState);
+                clusterState.apply(newPrimary);
+                // activate new primary
+                newPrimary.activateWithPrimaryContext(primaryContext);
+                // apply cluster state update on old primary so that the states on old and new primary are comparable
+                clusterState.apply(oldPrimary);
+                break;
+            }
+            case 2: {
+                // apply cluster state update on both copies while primary context is being transferred
+                clusterState = randomUpdateClusterState(clusterState);
+                clusterState.apply(oldPrimary);
+                clusterState.apply(newPrimary);
+                newPrimary.activateWithPrimaryContext(primaryContext);
+                break;
+            }
+            case 3: {
+                // no cluster state update
+                newPrimary.activateWithPrimaryContext(primaryContext);
+                break;
             }
         }
 
-        // the global checkpoint can only be computed from active allocation IDs and initializing allocation IDs that moved to initializing
-        final long globalCheckpoint =
-                StreamSupport
-                        .stream(activeAllocationIdsLocalCheckpoints.spliterator(), false)
-                        .filter(e -> tracker.inSyncLocalCheckpoints.containsKey(e.key) || initializedAllocationIds.contains(e.key))
-                        .mapToLong(e -> e.value)
-                        .min()
-                        .orElse(SequenceNumbersService.UNASSIGNED_SEQ_NO);
-        assertThat(tracker.getGlobalCheckpoint(), equalTo(globalCheckpoint));
+        assertTrue(oldPrimary.primaryMode);
+        assertTrue(newPrimary.primaryMode);
+        assertThat(newPrimary.appliedClusterStateVersion, equalTo(oldPrimary.appliedClusterStateVersion));
+        assertThat(newPrimary.localCheckpoints, equalTo(oldPrimary.localCheckpoints));
+        assertThat(newPrimary.globalCheckpoint, equalTo(oldPrimary.globalCheckpoint));
+
+        oldPrimary.completeRelocationHandoff();
+        assertFalse(oldPrimary.primaryMode);
     }
 
-    public void testPrimaryContextNewerThanAppliedClusterState() {
-        final long initialClusterStateVersion = randomIntBetween(0, Integer.MAX_VALUE);
-        final int numberOfActiveAllocationsIds = randomIntBetween(0, 8);
-        final int numberOfInitializingIds = randomIntBetween(0, 8);
-        final Tuple<Set<String>, Set<String>> activeAndInitializingAllocationIds =
-                randomActiveAndInitializingAllocationIds(numberOfActiveAllocationsIds, numberOfInitializingIds);
-        final Set<String> activeAllocationIds = activeAndInitializingAllocationIds.v1();
-        final Set<String> initializingAllocationIds = activeAndInitializingAllocationIds.v2();
-        tracker.updateAllocationIdsFromMaster(initialClusterStateVersion, activeAllocationIds, initializingAllocationIds);
+    public void testIllegalStateExceptionIfUnknownAllocationId() {
+        final String active = randomAlphaOfLength(16);
+        final String initializing = randomAlphaOfLength(32);
+        tracker.updateFromMaster(randomNonNegativeLong(), Collections.singleton(active), Collections.singleton(initializing), emptySet());
+        tracker.activatePrimaryMode(active, NO_OPS_PERFORMED);
 
-        /*
-         * We are going to establish a primary context from a cluster state version older than the applied cluster state version on the
-         * tracker. Because of recovery barriers established during relocation handoff, we know that the set of active allocation IDs in the
-         * newer cluster state is a subset of the allocation IDs in the applied cluster state with the caveat that an existing initializing
-         * allocation ID could have moved to an in-sync allocation ID within the tracker due to recovery finalization, and the set of
-         * initializing allocation IDs is otherwise arbitrary.
-         */
-        final int numberOfNewInitializingAllocationIds = randomIntBetween(0, 8);
-        final Set<String> initializedAllocationIds = new HashSet<>(randomSubsetOf(initializingAllocationIds));
-        final Set<String> newInitializingAllocationIds =
-                randomAllocationIdsExcludingExistingIds(
-                        Sets.union(activeAllocationIds, initializingAllocationIds), numberOfNewInitializingAllocationIds);
+        expectThrows(IllegalStateException.class, () -> tracker.initiateTracking(randomAlphaOfLength(10)));
+        expectThrows(IllegalStateException.class, () -> tracker.markAllocationIdAsInSync(randomAlphaOfLength(10), randomNonNegativeLong()));
+    }
 
-        final ObjectLongMap<String> activeAllocationIdsLocalCheckpoints = new ObjectLongHashMap<>();
-        for (final String allocationId : Sets.union(new HashSet<>(randomSubsetOf(activeAllocationIds)), initializedAllocationIds)) {
-            activeAllocationIdsLocalCheckpoints.put(allocationId, randomNonNegativeLong());
-        }
-        final ObjectLongMap<String> initializingIdsLocalCheckpoints = new ObjectLongHashMap<>();
-        final Set<String> contextInitializingAllocationIds = Sets.union(
-                new HashSet<>(randomSubsetOf(Sets.difference(initializingAllocationIds, initializedAllocationIds))),
-                newInitializingAllocationIds);
-        for (final String allocationId : contextInitializingAllocationIds) {
-            initializingIdsLocalCheckpoints.put(allocationId, randomNonNegativeLong());
-        }
+    private static class FakeClusterState {
+        final long version;
+        final Set<String> inSyncIds;
+        final Set<String> initializingIds;
 
-        final PrimaryContext primaryContext =
-                new PrimaryContext(
-                        initialClusterStateVersion + randomIntBetween(0, Integer.MAX_VALUE) + 1,
-                        activeAllocationIdsLocalCheckpoints,
-                        initializingIdsLocalCheckpoints);
+        private FakeClusterState(long version, Set<String> inSyncIds, Set<String> initializingIds) {
+            this.version = version;
+            this.inSyncIds = Collections.unmodifiableSet(inSyncIds);
+            this.initializingIds = Collections.unmodifiableSet(initializingIds);
+        }
 
-        tracker.updateAllocationIdsFromPrimaryContext(primaryContext);
+        public Set<String> allIds() {
+            return Sets.union(initializingIds, inSyncIds);
+        }
 
-        final PrimaryContext trackerPrimaryContext = tracker.primaryContext();
-        try {
-            assertTrue(tracker.sealed());
-            final long globalCheckpoint =
-                    StreamSupport
-                            .stream(activeAllocationIdsLocalCheckpoints.values().spliterator(), false)
-                            .mapToLong(e -> e.value)
-                            .min()
-                            .orElse(SequenceNumbersService.UNASSIGNED_SEQ_NO);
-
-            // the primary context contains knowledge of the state of the entire universe
-            assertThat(primaryContext.clusterStateVersion(), equalTo(trackerPrimaryContext.clusterStateVersion()));
-            assertThat(primaryContext.inSyncLocalCheckpoints(), equalTo(trackerPrimaryContext.inSyncLocalCheckpoints()));
-            assertThat(primaryContext.trackingLocalCheckpoints(), equalTo(trackerPrimaryContext.trackingLocalCheckpoints()));
-            assertThat(tracker.getGlobalCheckpoint(), equalTo(globalCheckpoint));
-        } finally {
-            tracker.releasePrimaryContext();
-            assertFalse(tracker.sealed());
+        public void apply(GlobalCheckpointTracker gcp) {
+            gcp.updateFromMaster(version, inSyncIds, initializingIds, Collections.emptySet());
         }
     }
 
-    public void testPrimaryContextSealing() {
-        // the tracker should start in the state of not being sealed
-        assertFalse(tracker.sealed());
+    private static FakeClusterState initialState() {
+        final long initialClusterStateVersion = randomIntBetween(1, Integer.MAX_VALUE);
+        final int numberOfActiveAllocationsIds = randomIntBetween(1, 8);
+        final int numberOfInitializingIds = randomIntBetween(0, 8);
+        final Tuple<Set<String>, Set<String>> activeAndInitializingAllocationIds =
+            randomActiveAndInitializingAllocationIds(numberOfActiveAllocationsIds, numberOfInitializingIds);
+        final Set<String> activeAllocationIds = activeAndInitializingAllocationIds.v1();
+        final Set<String> initializingAllocationIds = activeAndInitializingAllocationIds.v2();
+        return new FakeClusterState(initialClusterStateVersion, activeAllocationIds, initializingAllocationIds);
+    }
+
+    private static void activatePrimary(FakeClusterState clusterState, GlobalCheckpointTracker gcp) {
+        gcp.activatePrimaryMode(randomFrom(clusterState.inSyncIds), randomIntBetween(Math.toIntExact(NO_OPS_PERFORMED), 10));
+    }
 
-        // sampling the primary context should seal the tracker
-        tracker.primaryContext();
-        assertTrue(tracker.sealed());
+    private static void randomLocalCheckpointUpdate(GlobalCheckpointTracker gcp) {
+        String allocationId = randomFrom(gcp.localCheckpoints.keySet());
+        long currentLocalCheckpoint = gcp.localCheckpoints.get(allocationId).getLocalCheckpoint();
+        gcp.updateLocalCheckpoint(allocationId, currentLocalCheckpoint + randomInt(5));
+    }
 
-        /*
-         * Invoking methods that mutates the state of the tracker should fail (with the exception of updating allocation IDs and updating
-         * global checkpoint on replica which can happen on the relocation source).
-         */
-        assertIllegalStateExceptionWhenSealed(() -> tracker.updateLocalCheckpoint(randomAlphaOfLength(16), randomNonNegativeLong()));
-        assertIllegalStateExceptionWhenSealed(() -> tracker.updateAllocationIdsFromPrimaryContext(mock(PrimaryContext.class)));
-        assertIllegalStateExceptionWhenSealed(() -> tracker.primaryContext());
-        assertIllegalStateExceptionWhenSealed(() -> tracker.markAllocationIdAsInSync(randomAlphaOfLength(16), randomNonNegativeLong()));
-
-        // closing the releasable should unseal the tracker
-        tracker.releasePrimaryContext();
-        assertFalse(tracker.sealed());
+    private static void randomMarkInSync(GlobalCheckpointTracker gcp) {
+        String allocationId = randomFrom(gcp.localCheckpoints.keySet());
+        long newLocalCheckpoint = Math.max(NO_OPS_PERFORMED, gcp.getGlobalCheckpoint() + randomInt(5));
+        markAllocationIdAsInSyncQuietly(gcp, allocationId, newLocalCheckpoint);
     }
 
-    private void assertIllegalStateExceptionWhenSealed(final ThrowingRunnable runnable) {
-        final IllegalStateException e = expectThrows(IllegalStateException.class, runnable);
-        assertThat(e, hasToString(containsString("global checkpoint tracker is sealed")));
+    private static FakeClusterState randomUpdateClusterState(FakeClusterState clusterState) {
+        final Set<String> initializingIdsToAdd = randomAllocationIdsExcludingExistingIds(clusterState.allIds(), randomInt(2));
+        final Set<String> initializingIdsToRemove = new HashSet<>(
+            randomSubsetOf(randomInt(clusterState.initializingIds.size()), clusterState.initializingIds));
+        final Set<String> inSyncIdsToRemove = new HashSet<>(
+            randomSubsetOf(randomInt(clusterState.inSyncIds.size()), clusterState.inSyncIds));
+        final Set<String> remainingInSyncIds = Sets.difference(clusterState.inSyncIds, inSyncIdsToRemove);
+        return new FakeClusterState(clusterState.version + randomIntBetween(1, 5),
+            remainingInSyncIds.isEmpty() ? clusterState.inSyncIds : remainingInSyncIds,
+            Sets.difference(Sets.union(clusterState.initializingIds, initializingIdsToAdd), initializingIdsToRemove));
     }
 
-    private Tuple<Set<String>, Set<String>> randomActiveAndInitializingAllocationIds(
+    private static Tuple<Set<String>, Set<String>> randomActiveAndInitializingAllocationIds(
             final int numberOfActiveAllocationsIds,
             final int numberOfInitializingIds) {
         final Set<String> activeAllocationIds =
-                IntStream.range(0, numberOfActiveAllocationsIds).mapToObj(i -> randomAlphaOfLength(16) + i).collect(Collectors.toSet());
+            IntStream.range(0, numberOfActiveAllocationsIds).mapToObj(i -> randomAlphaOfLength(16) + i).collect(Collectors.toSet());
         final Set<String> initializingIds = randomAllocationIdsExcludingExistingIds(activeAllocationIds, numberOfInitializingIds);
         return Tuple.tuple(activeAllocationIds, initializingIds);
     }
 
-    private Set<String> randomAllocationIdsExcludingExistingIds(final Set<String> existingAllocationIds, final int numberOfAllocationIds) {
+    private static Set<String> randomAllocationIdsExcludingExistingIds(final Set<String> existingAllocationIds,
+                                                                       final int numberOfAllocationIds) {
         return IntStream.range(0, numberOfAllocationIds).mapToObj(i -> {
             do {
                 final String newAllocationId = randomAlphaOfLength(16);
@@ -733,7 +741,7 @@ public class GlobalCheckpointTrackerTests extends ESTestCase {
         }).collect(Collectors.toSet());
     }
 
-    private void markAllocationIdAsInSyncQuietly(
+    private static void markAllocationIdAsInSyncQuietly(
             final GlobalCheckpointTracker tracker, final String allocationId, final long localCheckpoint) {
         try {
             tracker.markAllocationIdAsInSync(allocationId, localCheckpoint);

+ 39 - 10
core/src/test/java/org/elasticsearch/index/shard/IndexShardTests.java

@@ -347,7 +347,7 @@ public class IndexShardTests extends IndexShardTestCase {
                         ShardRoutingState.STARTED,
                         replicaRouting.allocationId());
         indexShard.updateShardState(primaryRouting, indexShard.getPrimaryTerm() + 1, (shard, listener) -> {},
-            0L, Collections.emptySet(), Collections.emptySet());
+            0L, Collections.singleton(primaryRouting.allocationId().getId()), Collections.emptySet(), Collections.emptySet());
 
         final int delayedOperations = scaledRandomIntBetween(1, 64);
         final CyclicBarrier delayedOperationsBarrier = new CyclicBarrier(1 + delayedOperations);
@@ -422,7 +422,7 @@ public class IndexShardTests extends IndexShardTestCase {
                         ShardRoutingState.STARTED,
                         replicaRouting.allocationId());
         indexShard.updateShardState(primaryRouting, indexShard.getPrimaryTerm() + 1, (shard, listener) -> {},
-            0L, Collections.emptySet(), Collections.emptySet());
+            0L, Collections.singleton(primaryRouting.allocationId().getId()), Collections.emptySet(), Collections.emptySet());
 
         /*
          * This operation completing means that the delay operation executed as part of increasing the primary term has completed and the
@@ -463,8 +463,8 @@ public class IndexShardTests extends IndexShardTestCase {
             ShardRouting replicaRouting = indexShard.routingEntry();
             ShardRouting primaryRouting = TestShardRouting.newShardRouting(replicaRouting.shardId(), replicaRouting.currentNodeId(), null,
                 true, ShardRoutingState.STARTED, replicaRouting.allocationId());
-            indexShard.updateShardState(primaryRouting, indexShard.getPrimaryTerm() + 1, (shard, listener) -> {},
-                0L, Collections.emptySet(), Collections.emptySet());
+            indexShard.updateShardState(primaryRouting, indexShard.getPrimaryTerm() + 1, (shard, listener) -> {}, 0L,
+                Collections.singleton(indexShard.routingEntry().allocationId().getId()), Collections.emptySet(), Collections.emptySet());
         } else {
             indexShard = newStartedShard(true);
         }
@@ -598,11 +598,11 @@ public class IndexShardTests extends IndexShardTestCase {
             final long newPrimaryTerm = primaryTerm + 1 + randomInt(20);
             if (engineClosed == false) {
                 assertThat(indexShard.getLocalCheckpoint(), equalTo(SequenceNumbersService.NO_OPS_PERFORMED));
-                assertThat(indexShard.getGlobalCheckpoint(), equalTo(SequenceNumbersService.UNASSIGNED_SEQ_NO));
+                assertThat(indexShard.getGlobalCheckpoint(), equalTo(SequenceNumbersService.NO_OPS_PERFORMED));
             }
             final long newGlobalCheckPoint;
             if (engineClosed || randomBoolean()) {
-                newGlobalCheckPoint = SequenceNumbersService.UNASSIGNED_SEQ_NO;
+                newGlobalCheckPoint = SequenceNumbersService.NO_OPS_PERFORMED;
             } else {
                 long localCheckPoint = indexShard.getGlobalCheckpoint() + randomInt(100);
                 // advance local checkpoint
@@ -1267,7 +1267,6 @@ public class IndexShardTests extends IndexShardTestCase {
         closeShards(shard);
     }
 
-    @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/25419")
     public void testRelocatedShardCanNotBeRevivedConcurrently() throws IOException, InterruptedException, BrokenBarrierException {
         final IndexShard shard = newStartedShard(true);
         final ShardRouting originalRouting = shard.routingEntry();
@@ -1321,8 +1320,11 @@ public class IndexShardTests extends IndexShardTestCase {
 
     public void testRecoverFromStore() throws IOException {
         final IndexShard shard = newStartedShard(true);
-        int translogOps = 1;
-        indexDoc(shard, "test", "0");
+        int totalOps = randomInt(10);
+        int translogOps = totalOps;
+        for (int i = 0; i < totalOps; i++) {
+            indexDoc(shard, "test", Integer.toString(i));
+        }
         if (randomBoolean()) {
             flushShard(shard);
             translogOps = 0;
@@ -1336,10 +1338,33 @@ public class IndexShardTests extends IndexShardTestCase {
         assertEquals(translogOps, newShard.recoveryState().getTranslog().totalOperationsOnStart());
         assertEquals(100.0f, newShard.recoveryState().getTranslog().recoveredPercent(), 0.01f);
         IndexShardTestCase.updateRoutingEntry(newShard, newShard.routingEntry().moveToStarted());
-        assertDocCount(newShard, 1);
+        // check that local checkpoint of new primary is properly tracked after recovery
+        assertThat(newShard.getLocalCheckpoint(), equalTo(totalOps - 1L));
+        assertThat(IndexShardTestCase.getEngine(newShard).seqNoService()
+            .getTrackedLocalCheckpointForShard(newShard.routingEntry().allocationId().getId()), equalTo(totalOps - 1L));
+        assertDocCount(newShard, totalOps);
         closeShards(newShard);
     }
 
+    public void testPrimaryHandOffUpdatesLocalCheckpoint() throws IOException {
+        final IndexShard primarySource = newStartedShard(true);
+        int totalOps = randomInt(10);
+        for (int i = 0; i < totalOps; i++) {
+            indexDoc(primarySource, "test", Integer.toString(i));
+        }
+        IndexShardTestCase.updateRoutingEntry(primarySource, primarySource.routingEntry().relocate("n2", -1));
+        final IndexShard primaryTarget = newShard(primarySource.routingEntry().getTargetRelocatingShard());
+        updateMappings(primaryTarget, primarySource.indexSettings().getIndexMetaData());
+        recoverReplica(primaryTarget, primarySource);
+
+        // check that local checkpoint of new primary is properly tracked after primary relocation
+        assertThat(primaryTarget.getLocalCheckpoint(), equalTo(totalOps - 1L));
+        assertThat(IndexShardTestCase.getEngine(primaryTarget).seqNoService()
+            .getTrackedLocalCheckpointForShard(primaryTarget.routingEntry().allocationId().getId()), equalTo(totalOps - 1L));
+        assertDocCount(primaryTarget, totalOps);
+        closeShards(primarySource, primaryTarget);
+    }
+
     /* This test just verifies that we fill up local checkpoint up to max seen seqID on primary recovery */
     public void testRecoverFromStoreWithNoOps() throws IOException {
         final IndexShard shard = newStartedShard(true);
@@ -1888,6 +1913,10 @@ public class IndexShardTests extends IndexShardTestCase {
                 }
             }
             IndexShardTestCase.updateRoutingEntry(targetShard, ShardRoutingHelper.moveToStarted(targetShard.routingEntry()));
+            // check that local checkpoint of new primary is properly tracked after recovery
+            assertThat(targetShard.getLocalCheckpoint(), equalTo(1L));
+            assertThat(IndexShardTestCase.getEngine(targetShard).seqNoService()
+                .getTrackedLocalCheckpointForShard(targetShard.routingEntry().allocationId().getId()), equalTo(1L));
             assertDocCount(targetShard, 2);
         }
         // now check that it's persistent ie. that the added shards are committed

+ 1 - 1
core/src/test/java/org/elasticsearch/index/shard/PrimaryReplicaSyncerTests.java

@@ -63,7 +63,7 @@ public class PrimaryReplicaSyncerTests extends IndexShardTestCase {
 
         String allocationId = shard.routingEntry().allocationId().getId();
         shard.updateShardState(shard.routingEntry(), shard.getPrimaryTerm(), null, 1000L, Collections.singleton(allocationId),
-            Collections.emptySet());
+            Collections.emptySet(), Collections.emptySet());
         shard.updateLocalCheckpointForShard(allocationId, globalCheckPoint);
         assertEquals(globalCheckPoint, shard.getGlobalCheckpoint());
 

+ 9 - 9
core/src/test/java/org/elasticsearch/indices/cluster/AbstractIndicesClusterStateServiceTestCase.java

@@ -35,7 +35,6 @@ import org.elasticsearch.index.IndexSettings;
 import org.elasticsearch.index.shard.IndexEventListener;
 import org.elasticsearch.index.shard.IndexShard;
 import org.elasticsearch.index.shard.IndexShardState;
-import org.elasticsearch.index.shard.PrimaryReplicaSyncer;
 import org.elasticsearch.index.shard.PrimaryReplicaSyncer.ResyncTask;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.indices.IndicesService;
@@ -141,12 +140,12 @@ public abstract class AbstractIndicesClusterStateServiceTestCase extends ESTestC
 
                         if (shard.routingEntry().primary() && shard.routingEntry().active()) {
                             IndexShardRoutingTable shardRoutingTable = state.routingTable().shardRoutingTable(shard.shardId());
-                            Set<String> activeIds = shardRoutingTable.activeShards().stream()
-                                .map(r -> r.allocationId().getId()).collect(Collectors.toSet());
+                            Set<String> inSyncIds = state.metaData().index(shard.shardId().getIndex())
+                                .inSyncAllocationIds(shard.shardId().id());
                             Set<String> initializingIds = shardRoutingTable.getAllInitializingShards().stream()
                                 .map(r -> r.allocationId().getId()).collect(Collectors.toSet());
-                            assertThat(shard.routingEntry() + " isn't updated with active aIDs", shard.activeAllocationIds,
-                                equalTo(activeIds));
+                            assertThat(shard.routingEntry() + " isn't updated with in-sync aIDs", shard.inSyncAllocationIds,
+                                equalTo(inSyncIds));
                             assertThat(shard.routingEntry() + " isn't updated with init aIDs", shard.initializingAllocationIds,
                                 equalTo(initializingIds));
                         }
@@ -326,7 +325,7 @@ public abstract class AbstractIndicesClusterStateServiceTestCase extends ESTestC
         private volatile long clusterStateVersion;
         private volatile ShardRouting shardRouting;
         private volatile RecoveryState recoveryState;
-        private volatile Set<String> activeAllocationIds;
+        private volatile Set<String> inSyncAllocationIds;
         private volatile Set<String> initializingAllocationIds;
         private volatile long term;
 
@@ -350,8 +349,9 @@ public abstract class AbstractIndicesClusterStateServiceTestCase extends ESTestC
                                      long newPrimaryTerm,
                                      CheckedBiConsumer<IndexShard, ActionListener<ResyncTask>, IOException> primaryReplicaSyncer,
                                      long applyingClusterStateVersion,
-                                     Set<String> activeAllocationIds,
-                                     Set<String> initializingAllocationIds) throws IOException {
+                                     Set<String> inSyncAllocationIds,
+                                     Set<String> initializingAllocationIds,
+                                     Set<String> pre60AllocationIds) throws IOException {
             failRandomly();
             assertThat(this.shardId(), equalTo(shardRouting.shardId()));
             assertTrue("current: " + this.shardRouting + ", got: " + shardRouting, this.shardRouting.isSameAllocation(shardRouting));
@@ -363,7 +363,7 @@ public abstract class AbstractIndicesClusterStateServiceTestCase extends ESTestC
             if (shardRouting.primary()) {
                 term = newPrimaryTerm;
                 this.clusterStateVersion = applyingClusterStateVersion;
-                this.activeAllocationIds = activeAllocationIds;
+                this.inSyncAllocationIds = inSyncAllocationIds;
                 this.initializingAllocationIds = initializingAllocationIds;
             }
         }

+ 6 - 1
core/src/test/java/org/elasticsearch/indices/recovery/RecoverySourceHandlerTests.java

@@ -86,6 +86,7 @@ import static java.util.Collections.emptySet;
 import static org.hamcrest.Matchers.equalTo;
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.anyBoolean;
+import static org.mockito.Matchers.anyString;
 import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
@@ -378,6 +379,10 @@ public class RecoverySourceHandlerTests extends ESTestCase {
         when(shard.acquireTranslogView()).thenReturn(translogView);
         when(shard.state()).thenReturn(IndexShardState.RELOCATED);
         when(shard.acquireIndexCommit(anyBoolean())).thenReturn(mock(Engine.IndexCommitRef.class));
+        doAnswer(invocation -> {
+            ((ActionListener<Releasable>)invocation.getArguments()[0]).onResponse(() -> {});
+            return null;
+        }).when(shard).acquirePrimaryOperationPermit(any(), anyString());
         final AtomicBoolean phase1Called = new AtomicBoolean();
 //        final Engine.IndexCommitRef indexCommitRef = mock(Engine.IndexCommitRef.class);
 //        when(shard.acquireIndexCommit(anyBoolean())).thenReturn(indexCommitRef);
@@ -420,7 +425,7 @@ public class RecoverySourceHandlerTests extends ESTestCase {
         expectThrows(IndexShardRelocatedException.class, handler::recoverToTarget);
         // phase1 should only be attempted if we are not doing a sequence-number-based recovery
         assertThat(phase1Called.get(), equalTo(!isTranslogReadyForSequenceNumberBasedRecovery));
-        assertTrue(prepareTargetForTranslogCalled.get());
+        assertFalse(prepareTargetForTranslogCalled.get());
         assertFalse(phase2Called.get());
     }
 

+ 0 - 1
core/src/test/java/org/elasticsearch/recovery/FullRollingRestartIT.java

@@ -54,7 +54,6 @@ public class FullRollingRestartIT extends ESIntegTestCase {
         return 1;
     }
 
-    @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/25420")
     public void testFullRollingRestart() throws Exception {
         Settings settings = Settings.builder().put(ZenDiscovery.JOIN_TIMEOUT_SETTING.getKey(), "30s").build();
         internalCluster().startNode(settings);

+ 45 - 5
test/framework/src/main/java/org/elasticsearch/index/shard/IndexShardTestCase.java

@@ -76,6 +76,7 @@ import org.elasticsearch.test.DummyShardLock;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.threadpool.TestThreadPool;
 import org.elasticsearch.threadpool.ThreadPool;
+import org.junit.Before;
 
 import java.io.IOException;
 import java.util.Arrays;
@@ -84,6 +85,7 @@ import java.util.EnumSet;
 import java.util.HashSet;
 import java.util.Set;
 import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicLong;
 import java.util.function.BiFunction;
 import java.util.function.Consumer;
 
@@ -110,11 +112,13 @@ public abstract class IndexShardTestCase extends ESTestCase {
     };
 
     protected ThreadPool threadPool;
+    private long primaryTerm;
 
     @Override
     public void setUp() throws Exception {
         super.setUp();
         threadPool = new TestThreadPool(getClass().getName());
+        primaryTerm = randomIntBetween(1, 100); // use random but fixed term for creating shards
     }
 
     @Override
@@ -164,7 +168,7 @@ public abstract class IndexShardTestCase extends ESTestCase {
             .build();
         IndexMetaData.Builder metaData = IndexMetaData.builder(shardRouting.getIndexName())
             .settings(settings)
-            .primaryTerm(0, randomIntBetween(1, 100));
+            .primaryTerm(0, primaryTerm);
         return newShard(shardRouting, metaData.build(), listeners);
     }
 
@@ -360,8 +364,15 @@ public abstract class IndexShardTestCase extends ESTestCase {
         updateRoutingEntry(primary, ShardRoutingHelper.moveToStarted(primary.routingEntry()));
     }
 
+    protected static AtomicLong currentClusterStateVersion = new AtomicLong();
+
     public static void updateRoutingEntry(IndexShard shard, ShardRouting shardRouting) throws IOException {
-        shard.updateShardState(shardRouting, shard.getPrimaryTerm(), null, 0L, Collections.emptySet(), Collections.emptySet());
+        Set<String> inSyncIds =
+            shardRouting.active() ? Collections.singleton(shardRouting.allocationId().getId()) : Collections.emptySet();
+        Set<String> initializingIds =
+            shardRouting.initializing() ? Collections.singleton(shardRouting.allocationId().getId()) : Collections.emptySet();
+        shard.updateShardState(shardRouting, shard.getPrimaryTerm(), null, currentClusterStateVersion.incrementAndGet(),
+            inSyncIds, initializingIds, Collections.emptySet());
     }
 
     protected void recoveryEmptyReplica(IndexShard replica) throws IOException {
@@ -387,6 +398,16 @@ public abstract class IndexShardTestCase extends ESTestCase {
             true);
     }
 
+    /** recovers a replica from the given primary **/
+    protected void recoverReplica(final IndexShard replica,
+                                  final IndexShard primary,
+                                  final BiFunction<IndexShard, DiscoveryNode, RecoveryTarget> targetSupplier,
+                                  final boolean markAsRecovering) throws IOException {
+        recoverReplica(replica, primary, targetSupplier, markAsRecovering,
+            Collections.singleton(primary.routingEntry().allocationId().getId()),
+            Collections.singleton(replica.routingEntry().allocationId().getId()));
+    }
+
     /**
      * Recovers a replica from the give primary, allow the user to supply a custom recovery target. A typical usage of a custom recovery
      * target is to assert things in the various stages of recovery.
@@ -398,7 +419,9 @@ public abstract class IndexShardTestCase extends ESTestCase {
     protected final void recoverReplica(final IndexShard replica,
                                         final IndexShard primary,
                                         final BiFunction<IndexShard, DiscoveryNode, RecoveryTarget> targetSupplier,
-                                        final boolean markAsRecovering) throws IOException {
+                                        final boolean markAsRecovering,
+                                        final Set<String> inSyncIds,
+                                        final Set<String> initializingIds) throws IOException {
         final DiscoveryNode pNode = getFakeDiscoNode(primary.routingEntry().currentNodeId());
         final DiscoveryNode rNode = getFakeDiscoNode(replica.routingEntry().currentNodeId());
         if (markAsRecovering) {
@@ -419,7 +442,7 @@ public abstract class IndexShardTestCase extends ESTestCase {
         }
 
         final StartRecoveryRequest request = new StartRecoveryRequest(replica.shardId(), targetAllocationId,
-            pNode, rNode, snapshot, false, 0, startingSeqNo);
+            pNode, rNode, snapshot, replica.routingEntry().primary(), 0, startingSeqNo);
         final RecoverySourceHandler recovery = new RecoverySourceHandler(
             primary,
             recoveryTarget,
@@ -428,9 +451,19 @@ public abstract class IndexShardTestCase extends ESTestCase {
             e -> () -> {},
             (int) ByteSizeUnit.MB.toBytes(1),
             Settings.builder().put(Node.NODE_NAME_SETTING.getKey(), pNode.getName()).build());
+        primary.updateShardState(primary.routingEntry(), primary.getPrimaryTerm(), null, currentClusterStateVersion.incrementAndGet(),
+            inSyncIds, initializingIds, Collections.emptySet());
         recovery.recoverToTarget();
         recoveryTarget.markAsDone();
-        updateRoutingEntry(replica, ShardRoutingHelper.moveToStarted(replica.routingEntry()));
+        Set<String> initializingIdsWithoutReplica = new HashSet<>(initializingIds);
+        initializingIdsWithoutReplica.remove(replica.routingEntry().allocationId().getId());
+        Set<String> inSyncIdsWithReplica = new HashSet<>(inSyncIds);
+        inSyncIdsWithReplica.add(replica.routingEntry().allocationId().getId());
+        // update both primary and replica shard state
+        primary.updateShardState(primary.routingEntry(), primary.getPrimaryTerm(), null, currentClusterStateVersion.incrementAndGet(),
+            inSyncIdsWithReplica, initializingIdsWithoutReplica, Collections.emptySet());
+        replica.updateShardState(replica.routingEntry().moveToStarted(), replica.getPrimaryTerm(), null,
+            currentClusterStateVersion.get(), inSyncIdsWithReplica, initializingIdsWithoutReplica, Collections.emptySet());
     }
 
     private Store.MetadataSnapshot getMetadataSnapshotOrEmpty(IndexShard replica) throws IOException {
@@ -529,4 +562,11 @@ public abstract class IndexShardTestCase extends ESTestCase {
     protected void flushShard(IndexShard shard, boolean force) {
         shard.flush(new FlushRequest(shard.shardId().getIndexName()).force(force));
     }
+
+    /**
+     * Helper method to access (package-protected) engine from tests
+     */
+    public static Engine getEngine(IndexShard indexShard) {
+        return indexShard.getEngine();
+    }
 }