1
0
Эх сурвалжийг харах

Add global checkpoint tracking on the primary

This commit adds local tracking of the global checkpoints on all shard
copies when a global checkpoint tracker is operating in primary
mode. With this, we relay the global checkpoint on a shard copy back to
the primary shard during replication operations. This serves as another
step towards adding a background sync of the global checkpoint to the
shard copies.

Relates #26666
Jason Tedor 8 жил өмнө
parent
commit
c238b79cf4
17 өөрчлөгдсөн 544 нэмэгдсэн , 210 устгасан
  1. 1 1
      build.gradle
  2. 2 1
      core/src/main/java/org/elasticsearch/action/resync/TransportResyncReplicationAction.java
  3. 24 2
      core/src/main/java/org/elasticsearch/action/support/replication/ReplicationOperation.java
  4. 28 9
      core/src/main/java/org/elasticsearch/action/support/replication/TransportReplicationAction.java
  5. 2 1
      core/src/main/java/org/elasticsearch/index/seqno/GlobalCheckpointSyncAction.java
  6. 243 112
      core/src/main/java/org/elasticsearch/index/seqno/GlobalCheckpointTracker.java
  7. 17 2
      core/src/main/java/org/elasticsearch/index/seqno/SequenceNumbersService.java
  8. 23 4
      core/src/main/java/org/elasticsearch/index/shard/IndexShard.java
  9. 3 1
      core/src/main/java/org/elasticsearch/indices/recovery/RecoverySourceHandler.java
  10. 23 5
      core/src/test/java/org/elasticsearch/action/support/replication/ReplicationOperationTests.java
  11. 2 1
      core/src/test/java/org/elasticsearch/action/support/replication/TransportReplicationActionTests.java
  12. 2 1
      core/src/test/java/org/elasticsearch/action/support/replication/TransportWriteActionTests.java
  13. 2 3
      core/src/test/java/org/elasticsearch/index/engine/InternalEngineTests.java
  14. 6 1
      core/src/test/java/org/elasticsearch/index/replication/ESIndexLevelReplicationTestCase.java
  15. 0 1
      core/src/test/java/org/elasticsearch/index/replication/RecoveryDuringReplicationTests.java
  16. 151 64
      core/src/test/java/org/elasticsearch/index/seqno/GlobalCheckpointTrackerTests.java
  17. 15 1
      core/src/test/java/org/elasticsearch/recovery/RelocationIT.java

+ 1 - 1
build.gradle

@@ -186,7 +186,7 @@ task verifyVersions {
  * after the backport of the backcompat code is complete.
  */
 allprojects {
-  ext.bwc_tests_enabled = true
+  ext.bwc_tests_enabled = false
 }
 
 task verifyBwcTestsEnabled {

+ 2 - 1
core/src/main/java/org/elasticsearch/action/resync/TransportResyncReplicationAction.java

@@ -93,7 +93,8 @@ public class TransportResyncReplicationAction extends TransportWriteAction<Resyn
         if (node.getVersion().onOrAfter(Version.V_6_0_0_alpha1)) {
             super.sendReplicaRequest(replicaRequest, node, listener);
         } else {
-            listener.onResponse(new ReplicaResponse(SequenceNumbersService.PRE_60_NODE_LOCAL_CHECKPOINT));
+            final long pre60NodeCheckpoint = SequenceNumbersService.PRE_60_NODE_CHECKPOINT;
+            listener.onResponse(new ReplicaResponse(pre60NodeCheckpoint, pre60NodeCheckpoint));
         }
     }
 

+ 24 - 2
core/src/main/java/org/elasticsearch/action/support/replication/ReplicationOperation.java

@@ -32,6 +32,7 @@ import org.elasticsearch.cluster.routing.ShardRouting;
 import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.util.set.Sets;
+import org.elasticsearch.index.seqno.SequenceNumbersService;
 import org.elasticsearch.index.shard.ReplicationGroup;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.rest.RestStatus;
@@ -173,6 +174,7 @@ public class ReplicationOperation<
                 successfulShards.incrementAndGet();
                 try {
                     primary.updateLocalCheckpointForShard(shard.allocationId().getId(), response.localCheckpoint());
+                    primary.updateGlobalCheckpointForShard(shard.allocationId().getId(), response.globalCheckpoint());
                 } catch (final AlreadyClosedException e) {
                     // okay, the index was deleted or this shard was never activated after a relocation; fall through and finish normally
                 } catch (final Exception e) {
@@ -315,6 +317,14 @@ public class ReplicationOperation<
          */
         void updateLocalCheckpointForShard(String allocationId, long checkpoint);
 
+        /**
+         * Update the local knowledge of the global checkpoint for the specified allocation ID.
+         *
+         * @param allocationId     the allocation ID to update the global checkpoint for
+         * @param globalCheckpoint the global checkpoint
+         */
+        void updateGlobalCheckpointForShard(String allocationId, long globalCheckpoint);
+
         /**
          * Returns the local checkpoint on the primary shard.
          *
@@ -385,12 +395,24 @@ public class ReplicationOperation<
     }
 
     /**
-     * An interface to encapsulate the metadata needed from replica shards when they respond to operations performed on them
+     * An interface to encapsulate the metadata needed from replica shards when they respond to operations performed on them.
      */
     public interface ReplicaResponse {
 
-        /** the local check point for the shard. see {@link org.elasticsearch.index.seqno.SequenceNumbersService#getLocalCheckpoint()} */
+        /**
+         * The local checkpoint for the shard. See {@link SequenceNumbersService#getLocalCheckpoint()}.
+         *
+         * @return the local checkpoint
+         **/
         long localCheckpoint();
+
+        /**
+         * The global checkpoint for the shard. See {@link SequenceNumbersService#getGlobalCheckpoint()}.
+         *
+         * @return the global checkpoint
+         **/
+        long globalCheckpoint();
+
     }
 
     public static class RetryOnPrimaryException extends ElasticsearchException {

+ 28 - 9
core/src/main/java/org/elasticsearch/action/support/replication/TransportReplicationAction.java

@@ -531,7 +531,8 @@ public abstract class TransportReplicationAction<
             try {
                 final ReplicaResult replicaResult = shardOperationOnReplica(request, replica);
                 releasable.close(); // release shard operation lock before responding to caller
-                final TransportReplicationAction.ReplicaResponse response = new ReplicaResponse(replica.getLocalCheckpoint());
+                final TransportReplicationAction.ReplicaResponse response =
+                        new ReplicaResponse(replica.getLocalCheckpoint(), replica.getGlobalCheckpoint());
                 replicaResult.respond(new ResponseListener(response));
             } catch (final Exception e) {
                 Releasables.closeWhileHandlingException(releasable); // release shard operation lock before responding to caller
@@ -1006,6 +1007,11 @@ public abstract class TransportReplicationAction<
             indexShard.updateLocalCheckpointForShard(allocationId, checkpoint);
         }
 
+        @Override
+        public void updateGlobalCheckpointForShard(final String allocationId, final long globalCheckpoint) {
+            indexShard.updateGlobalCheckpointForShard(allocationId, globalCheckpoint);
+        }
+
         @Override
         public long localCheckpoint() {
             return indexShard.getLocalCheckpoint();
@@ -1025,40 +1031,47 @@ public abstract class TransportReplicationAction<
 
     public static class ReplicaResponse extends ActionResponse implements ReplicationOperation.ReplicaResponse {
         private long localCheckpoint;
+        private long globalCheckpoint;
 
         ReplicaResponse() {
 
         }
 
-        public ReplicaResponse(long localCheckpoint) {
+        public ReplicaResponse(long localCheckpoint, long globalCheckpoint) {
             /*
-             * A replica should always know its own local checkpoint so this should always be a valid sequence number or the pre-6.0 local
+             * A replica should always know its own local checkpoints so this should always be a valid sequence number or the pre-6.0
              * checkpoint value when simulating responses to replication actions that pre-6.0 nodes are not aware of (e.g., the global
              * checkpoint background sync, and the primary/replica resync).
              */
             assert localCheckpoint != SequenceNumbers.UNASSIGNED_SEQ_NO;
             this.localCheckpoint = localCheckpoint;
+            this.globalCheckpoint = globalCheckpoint;
         }
 
         @Override
         public void readFrom(StreamInput in) throws IOException {
+            super.readFrom(in);
             if (in.getVersion().onOrAfter(Version.V_6_0_0_alpha1)) {
-                super.readFrom(in);
                 localCheckpoint = in.readZLong();
             } else {
                 // 5.x used to read empty responses, which don't really read anything off the stream, so just do nothing.
-                localCheckpoint = SequenceNumbersService.PRE_60_NODE_LOCAL_CHECKPOINT;
+                localCheckpoint = SequenceNumbersService.PRE_60_NODE_CHECKPOINT;
+            }
+            if (in.getVersion().onOrAfter(Version.V_6_0_0_rc1)) {
+                globalCheckpoint = in.readZLong();
+            } else {
+                globalCheckpoint = SequenceNumbersService.PRE_60_NODE_CHECKPOINT;
             }
         }
 
         @Override
         public void writeTo(StreamOutput out) throws IOException {
+            super.writeTo(out);
             if (out.getVersion().onOrAfter(Version.V_6_0_0_alpha1)) {
-                super.writeTo(out);
                 out.writeZLong(localCheckpoint);
-            } else {
-                // we use to write empty responses
-                Empty.INSTANCE.writeTo(out);
+            }
+            if (out.getVersion().onOrAfter(Version.V_6_0_0_rc1)) {
+                out.writeZLong(globalCheckpoint);
             }
         }
 
@@ -1066,6 +1079,12 @@ public abstract class TransportReplicationAction<
         public long localCheckpoint() {
             return localCheckpoint;
         }
+
+        @Override
+        public long globalCheckpoint() {
+            return globalCheckpoint;
+        }
+
     }
 
     /**

+ 2 - 1
core/src/main/java/org/elasticsearch/index/seqno/GlobalCheckpointSyncAction.java

@@ -89,7 +89,8 @@ public class GlobalCheckpointSyncAction extends TransportReplicationAction<
         if (node.getVersion().onOrAfter(Version.V_6_0_0_alpha1)) {
             super.sendReplicaRequest(replicaRequest, node, listener);
         } else {
-            listener.onResponse(new ReplicaResponse(SequenceNumbersService.PRE_60_NODE_LOCAL_CHECKPOINT));
+            final long pre60NodeCheckpoint = SequenceNumbersService.PRE_60_NODE_CHECKPOINT;
+            listener.onResponse(new ReplicaResponse(pre60NodeCheckpoint, pre60NodeCheckpoint));
         }
     }
 

+ 243 - 112
core/src/main/java/org/elasticsearch/index/seqno/GlobalCheckpointTracker.java

@@ -19,6 +19,8 @@
 
 package org.elasticsearch.index.seqno;
 
+import com.carrotsearch.hppc.ObjectLongHashMap;
+import com.carrotsearch.hppc.ObjectLongMap;
 import org.elasticsearch.cluster.routing.AllocationId;
 import org.elasticsearch.cluster.routing.IndexShardRoutingTable;
 import org.elasticsearch.cluster.routing.ShardRouting;
@@ -36,8 +38,13 @@ import java.util.Collection;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.Map;
+import java.util.OptionalLong;
 import java.util.Set;
+import java.util.function.Function;
+import java.util.function.LongConsumer;
+import java.util.function.ToLongFunction;
 import java.util.stream.Collectors;
+import java.util.stream.LongStream;
 
 /**
  * This class is responsible of tracking the global checkpoint. The global checkpoint is the highest sequence number for which all lower (or
@@ -50,7 +57,10 @@ import java.util.stream.Collectors;
  */
 public class GlobalCheckpointTracker extends AbstractIndexShardComponent {
 
-    private final String allocationId;
+    /**
+     * The allocation ID for the shard to which this tracker is a component of.
+     */
+    final String shardAllocationId;
 
     /**
      * The global checkpoint tracker can operate in two modes:
@@ -103,9 +113,9 @@ public class GlobalCheckpointTracker extends AbstractIndexShardComponent {
     /**
      * Local checkpoint information for all shard copies that are tracked. Has an entry for all shard copies that are either initializing
      * and / or in-sync, possibly also containing information about unassigned in-sync shard copies. The information that is tracked for
-     * each shard copy is explained in the docs for the {@link LocalCheckpointState} class.
+     * each shard copy is explained in the docs for the {@link CheckpointState} class.
      */
-    final Map<String, LocalCheckpointState> localCheckpoints;
+    final Map<String, CheckpointState> checkpoints;
 
     /**
      * This set contains allocation IDs for which there is a thread actively waiting for the local checkpoint to advance to at least the
@@ -113,60 +123,67 @@ public class GlobalCheckpointTracker extends AbstractIndexShardComponent {
      */
     final Set<String> pendingInSync;
 
-    /**
-     * The global checkpoint:
-     * - computed based on local checkpoints, if the tracker is in primary mode
-     * - received from the primary, if the tracker is in replica mode
-     */
-    volatile long globalCheckpoint;
-
     /**
      * Cached value for the last replication group that was computed
      */
     volatile ReplicationGroup replicationGroup;
 
-    public static class LocalCheckpointState implements Writeable {
+    public static class CheckpointState implements Writeable {
 
         /**
          * the last local checkpoint information that we have for this shard
          */
         long localCheckpoint;
+
+        /**
+         * the last global checkpoint information that we have for this shard. This information is computed for the primary if
+         * the tracker is in primary mode and received from the primary if in replica mode.
+         */
+        long globalCheckpoint;
         /**
          * whether this shard is treated as in-sync and thus contributes to the global checkpoint calculation
          */
         boolean inSync;
 
-        public LocalCheckpointState(long localCheckpoint, boolean inSync) {
+        public CheckpointState(long localCheckpoint, long globalCheckpoint, boolean inSync) {
             this.localCheckpoint = localCheckpoint;
+            this.globalCheckpoint = globalCheckpoint;
             this.inSync = inSync;
         }
 
-        public LocalCheckpointState(StreamInput in) throws IOException {
+        public CheckpointState(StreamInput in) throws IOException {
             this.localCheckpoint = in.readZLong();
+            this.globalCheckpoint = in.readZLong();
             this.inSync = in.readBoolean();
         }
 
         @Override
         public void writeTo(StreamOutput out) throws IOException {
             out.writeZLong(localCheckpoint);
+            out.writeZLong(globalCheckpoint);
             out.writeBoolean(inSync);
         }
 
         /**
          * Returns a full copy of this object
          */
-        public LocalCheckpointState copy() {
-            return new LocalCheckpointState(localCheckpoint, inSync);
+        public CheckpointState copy() {
+            return new CheckpointState(localCheckpoint, globalCheckpoint, inSync);
         }
 
         public long getLocalCheckpoint() {
             return localCheckpoint;
         }
 
+        public long getGlobalCheckpoint() {
+            return globalCheckpoint;
+        }
+
         @Override
         public String toString() {
             return "LocalCheckpointState{" +
                 "localCheckpoint=" + localCheckpoint +
+                ", globalCheckpoint=" + globalCheckpoint +
                 ", inSync=" + inSync +
                 '}';
         }
@@ -176,40 +193,71 @@ public class GlobalCheckpointTracker extends AbstractIndexShardComponent {
             if (this == o) return true;
             if (o == null || getClass() != o.getClass()) return false;
 
-            LocalCheckpointState that = (LocalCheckpointState) o;
+            CheckpointState that = (CheckpointState) o;
 
             if (localCheckpoint != that.localCheckpoint) return false;
+            if (globalCheckpoint != that.globalCheckpoint) return false;
             return inSync == that.inSync;
         }
 
         @Override
         public int hashCode() {
-            int result = (int) (localCheckpoint ^ (localCheckpoint >>> 32));
-            result = 31 * result + (inSync ? 1 : 0);
+            int result = Long.hashCode(localCheckpoint);
+            result = 31 * result + Long.hashCode(globalCheckpoint);
+            result = 31 * result + Boolean.hashCode(inSync);
             return result;
         }
     }
 
+    synchronized ObjectLongMap<String> getGlobalCheckpoints() {
+        assert primaryMode;
+        assert handoffInProgress == false;
+        final ObjectLongMap<String> globalCheckpoints = new ObjectLongHashMap<>(checkpoints.size());
+        for (final Map.Entry<String, CheckpointState> cps : checkpoints.entrySet()) {
+            globalCheckpoints.put(cps.getKey(), cps.getValue().globalCheckpoint);
+        }
+        return globalCheckpoints;
+    }
+
     /**
      * Class invariant that should hold before and after every invocation of public methods on this class. As Java lacks implication
      * as a logical operator, many of the invariants are written under the form (!A || B), they should be read as (A implies B) however.
      */
     private boolean invariant() {
+        assert checkpoints.get(shardAllocationId) != null :
+            "checkpoints map should always have an entry for the current shard";
+
         // local checkpoints only set during primary mode
-        assert primaryMode || localCheckpoints.values().stream()
+        assert primaryMode || checkpoints.values().stream()
             .allMatch(lcps -> lcps.localCheckpoint == SequenceNumbers.UNASSIGNED_SEQ_NO ||
-                lcps.localCheckpoint == SequenceNumbersService.PRE_60_NODE_LOCAL_CHECKPOINT);
+                lcps.localCheckpoint == SequenceNumbersService.PRE_60_NODE_CHECKPOINT);
+
+        // global checkpoints for other shards only set during primary mode
+        assert primaryMode
+                || checkpoints
+                .entrySet()
+                .stream()
+                .filter(e -> e.getKey().equals(shardAllocationId) == false)
+                .map(Map.Entry::getValue)
+                .allMatch(cps ->
+                        (cps.globalCheckpoint == SequenceNumbers.UNASSIGNED_SEQ_NO
+                                || cps.globalCheckpoint == SequenceNumbersService.PRE_60_NODE_CHECKPOINT));
 
         // relocation handoff can only occur in primary mode
         assert !handoffInProgress || primaryMode;
 
-        // there is at least one in-sync shard copy when the global checkpoint tracker operates in primary mode (i.e. the shard itself)
-        assert !primaryMode || localCheckpoints.values().stream().anyMatch(lcps -> lcps.inSync);
+        // the current shard is marked as in-sync when the global checkpoint tracker operates in primary mode
+        assert !primaryMode || checkpoints.get(shardAllocationId).inSync;
 
         // the routing table and replication group is set when the global checkpoint tracker operates in primary mode
         assert !primaryMode || (routingTable != null && replicationGroup != null) :
             "primary mode but routing table is " + routingTable + " and replication group is " + replicationGroup;
 
+        // when in primary mode, the current allocation ID is the allocation ID of the primary or the relocation allocation ID
+        assert !primaryMode
+                || (routingTable.primaryShard().allocationId().getId().equals(shardAllocationId)
+                || routingTable.primaryShard().allocationId().getRelocationId().equals(shardAllocationId));
+
         // during relocation handoff there are no entries blocking global checkpoint advancement
         assert !handoffInProgress || pendingInSync.isEmpty() :
             "entries blocking global checkpoint advancement during relocation handoff: " + pendingInSync;
@@ -218,9 +266,24 @@ public class GlobalCheckpointTracker extends AbstractIndexShardComponent {
         assert pendingInSync.isEmpty() || (primaryMode && !handoffInProgress);
 
         // the computed global checkpoint is always up-to-date
-        assert !primaryMode || globalCheckpoint == computeGlobalCheckpoint(pendingInSync, localCheckpoints.values(), globalCheckpoint) :
-            "global checkpoint is not up-to-date, expected: " +
-                computeGlobalCheckpoint(pendingInSync, localCheckpoints.values(), globalCheckpoint) + " but was: " + globalCheckpoint;
+        assert !primaryMode
+                || getGlobalCheckpoint() == computeGlobalCheckpoint(pendingInSync, checkpoints.values(), getGlobalCheckpoint())
+                : "global checkpoint is not up-to-date, expected: " +
+                computeGlobalCheckpoint(pendingInSync, checkpoints.values(), getGlobalCheckpoint()) + " but was: " + getGlobalCheckpoint();
+
+        // when in primary mode, the global checkpoint is at most the minimum local checkpoint on all in-sync shard copies
+        assert !primaryMode
+                || getGlobalCheckpoint() <= inSyncCheckpointStates(checkpoints, CheckpointState::getLocalCheckpoint, LongStream::min)
+                : "global checkpoint [" + getGlobalCheckpoint() + "] "
+                + "for primary mode allocation ID [" + shardAllocationId + "] "
+                + "more than in-sync local checkpoints [" + checkpoints + "]";
+
+        // when in primary mode, the local knowledge of the global checkpoints on shard copies is bounded by the global checkpoint
+        assert !primaryMode
+                || getGlobalCheckpoint() >= inSyncCheckpointStates(checkpoints, CheckpointState::getGlobalCheckpoint, LongStream::max)
+                : "global checkpoint [" + getGlobalCheckpoint() + "] "
+                + "for primary mode allocation ID [" + shardAllocationId + "] "
+                + "less than in-sync global checkpoints [" + checkpoints + "]";
 
         // we have a routing table iff we have a replication group
         assert (routingTable == null) == (replicationGroup == null) :
@@ -230,10 +293,10 @@ public class GlobalCheckpointTracker extends AbstractIndexShardComponent {
             "cached replication group out of sync: expected: " + calculateReplicationGroup() + " but was: " + replicationGroup;
 
         // all assigned shards from the routing table are tracked
-        assert routingTable == null || localCheckpoints.keySet().containsAll(routingTable.getAllAllocationIds()) :
-            "local checkpoints " + localCheckpoints + " not in-sync with routing table " + routingTable;
+        assert routingTable == null || checkpoints.keySet().containsAll(routingTable.getAllAllocationIds()) :
+            "local checkpoints " + checkpoints + " not in-sync with routing table " + routingTable;
 
-        for (Map.Entry<String, LocalCheckpointState> entry : localCheckpoints.entrySet()) {
+        for (Map.Entry<String, CheckpointState> entry : checkpoints.entrySet()) {
             // blocking global checkpoint advancement only happens for shards that are not in-sync
             assert !pendingInSync.contains(entry.getKey()) || !entry.getValue().inSync :
                 "shard copy " + entry.getKey() + " blocks global checkpoint advancement but is in-sync";
@@ -242,6 +305,21 @@ public class GlobalCheckpointTracker extends AbstractIndexShardComponent {
         return true;
     }
 
+    private static long inSyncCheckpointStates(
+            final Map<String, CheckpointState> checkpoints,
+            ToLongFunction<CheckpointState> function,
+            Function<LongStream, OptionalLong> reducer) {
+        final OptionalLong value =
+                reducer.apply(
+                        checkpoints
+                                .values()
+                                .stream()
+                                .filter(cps -> cps.inSync)
+                                .mapToLong(function)
+                                .filter(v -> v != SequenceNumbers.UNASSIGNED_SEQ_NO));
+        return value.isPresent() ? value.getAsLong() : SequenceNumbers.UNASSIGNED_SEQ_NO;
+    }
+
     /**
      * Initialize the global checkpoint service. The specified global checkpoint should be set to the last known global checkpoint, or
      * {@link SequenceNumbers#UNASSIGNED_SEQ_NO}.
@@ -258,12 +336,12 @@ public class GlobalCheckpointTracker extends AbstractIndexShardComponent {
             final long globalCheckpoint) {
         super(shardId, indexSettings);
         assert globalCheckpoint >= SequenceNumbers.UNASSIGNED_SEQ_NO : "illegal initial global checkpoint: " + globalCheckpoint;
-        this.allocationId = allocationId;
+        this.shardAllocationId = allocationId;
         this.primaryMode = false;
         this.handoffInProgress = false;
         this.appliedClusterStateVersion = -1L;
-        this.globalCheckpoint = globalCheckpoint;
-        this.localCheckpoints = new HashMap<>(1 + indexSettings.getNumberOfReplicas());
+        this.checkpoints = new HashMap<>(1 + indexSettings.getNumberOfReplicas());
+        checkpoints.put(allocationId, new CheckpointState(SequenceNumbers.UNASSIGNED_SEQ_NO, globalCheckpoint, false));
         this.pendingInSync = new HashSet<>();
         this.routingTable = null;
         this.replicationGroup = null;
@@ -282,7 +360,7 @@ public class GlobalCheckpointTracker extends AbstractIndexShardComponent {
 
     private ReplicationGroup calculateReplicationGroup() {
         return new ReplicationGroup(routingTable,
-            localCheckpoints.entrySet().stream().filter(e -> e.getValue().inSync).map(Map.Entry::getKey).collect(Collectors.toSet()));
+            checkpoints.entrySet().stream().filter(e -> e.getValue().inSync).map(Map.Entry::getKey).collect(Collectors.toSet()));
     }
 
     /**
@@ -290,8 +368,10 @@ public class GlobalCheckpointTracker extends AbstractIndexShardComponent {
      *
      * @return the global checkpoint
      */
-    public long getGlobalCheckpoint() {
-        return globalCheckpoint;
+    public synchronized long getGlobalCheckpoint() {
+        final CheckpointState cps = checkpoints.get(shardAllocationId);
+        assert cps != null;
+        return cps.globalCheckpoint;
     }
 
     /**
@@ -306,27 +386,58 @@ public class GlobalCheckpointTracker extends AbstractIndexShardComponent {
         /*
          * The global checkpoint here is a local knowledge which is updated under the mandate of the primary. It can happen that the primary
          * information is lagging compared to a replica (e.g., if a replica is promoted to primary but has stale info relative to other
-         * replica shards). In these cases, the local knowledge of the global checkpoint could be higher than sync from the lagging primary.
+         * replica shards). In these cases, the local knowledge of the global checkpoint could be higher than the sync from the lagging
+         * primary.
          */
-        if (this.globalCheckpoint <= globalCheckpoint) {
-            logger.trace("updating global checkpoint from [{}] to [{}] due to [{}]", this.globalCheckpoint, globalCheckpoint, reason);
-            this.globalCheckpoint = globalCheckpoint;
-        }
+        updateGlobalCheckpoint(
+                shardAllocationId,
+                globalCheckpoint,
+                current -> logger.trace("updating global checkpoint from [{}] to [{}] due to [{}]", current, globalCheckpoint, reason));
+        assert invariant();
+    }
+
+    /**
+     * Update the local knowledge of the global checkpoint for the specified allocation ID.
+     *
+     * @param allocationId     the allocation ID to update the global checkpoint for
+     * @param globalCheckpoint the global checkpoint
+     */
+    public synchronized void updateGlobalCheckpointForShard(final String allocationId, final long globalCheckpoint) {
+        assert primaryMode;
+        assert handoffInProgress == false;
+        assert invariant();
+        updateGlobalCheckpoint(
+                allocationId,
+                globalCheckpoint,
+                current -> logger.trace(
+                        "updating local knowledge for [{}] on the primary of the global checkpoint from [{}] to [{}]",
+                        allocationId,
+                        current,
+                        globalCheckpoint));
         assert invariant();
     }
 
+    private void updateGlobalCheckpoint(final String allocationId, final long globalCheckpoint, LongConsumer ifUpdated) {
+        final CheckpointState cps = checkpoints.get(allocationId);
+        assert !this.shardAllocationId.equals(allocationId) || cps != null;
+        if (cps != null && globalCheckpoint > cps.globalCheckpoint) {
+            ifUpdated.accept(cps.globalCheckpoint);
+            cps.globalCheckpoint = globalCheckpoint;
+        }
+    }
+
     /**
      * Initializes the global checkpoint tracker in primary mode (see {@link #primaryMode}. Called on primary activation or promotion.
      */
     public synchronized void activatePrimaryMode(final long localCheckpoint) {
         assert invariant();
         assert primaryMode == false;
-        assert localCheckpoints.get(allocationId) != null && localCheckpoints.get(allocationId).inSync &&
-            localCheckpoints.get(allocationId).localCheckpoint == SequenceNumbers.UNASSIGNED_SEQ_NO :
-            "expected " + allocationId + " to have initialized entry in " + localCheckpoints + " when activating primary";
+        assert checkpoints.get(shardAllocationId) != null && checkpoints.get(shardAllocationId).inSync &&
+            checkpoints.get(shardAllocationId).localCheckpoint == SequenceNumbers.UNASSIGNED_SEQ_NO :
+            "expected " + shardAllocationId + " to have initialized entry in " + checkpoints + " when activating primary";
         assert localCheckpoint >= SequenceNumbers.NO_OPS_PERFORMED;
         primaryMode = true;
-        updateLocalCheckpoint(allocationId, localCheckpoints.get(allocationId), localCheckpoint);
+        updateLocalCheckpoint(shardAllocationId, checkpoints.get(shardAllocationId), localCheckpoint);
         updateGlobalCheckpointOnPrimary();
         assert invariant();
     }
@@ -345,37 +456,47 @@ public class GlobalCheckpointTracker extends AbstractIndexShardComponent {
         if (applyingClusterStateVersion > appliedClusterStateVersion) {
             // check that the master does not fabricate new in-sync entries out of thin air once we are in primary mode
             assert !primaryMode || inSyncAllocationIds.stream().allMatch(
-                inSyncId -> localCheckpoints.containsKey(inSyncId) && localCheckpoints.get(inSyncId).inSync) :
+                inSyncId -> checkpoints.containsKey(inSyncId) && checkpoints.get(inSyncId).inSync) :
                 "update from master in primary mode contains in-sync ids " + inSyncAllocationIds +
-                    " that have no matching entries in " + localCheckpoints;
+                    " that have no matching entries in " + checkpoints;
             // remove entries which don't exist on master
             Set<String> initializingAllocationIds = routingTable.getAllInitializingShards().stream()
                 .map(ShardRouting::allocationId).map(AllocationId::getId).collect(Collectors.toSet());
-            boolean removedEntries = localCheckpoints.keySet().removeIf(
+            boolean removedEntries = checkpoints.keySet().removeIf(
                 aid -> !inSyncAllocationIds.contains(aid) && !initializingAllocationIds.contains(aid));
 
             if (primaryMode) {
                 // add new initializingIds that are missing locally. These are fresh shard copies - and not in-sync
                 for (String initializingId : initializingAllocationIds) {
-                    if (localCheckpoints.containsKey(initializingId) == false) {
+                    if (checkpoints.containsKey(initializingId) == false) {
                         final boolean inSync = inSyncAllocationIds.contains(initializingId);
                         assert inSync == false : "update from master in primary mode has " + initializingId +
                             " as in-sync but it does not exist locally";
                         final long localCheckpoint = pre60AllocationIds.contains(initializingId) ?
-                            SequenceNumbersService.PRE_60_NODE_LOCAL_CHECKPOINT : SequenceNumbers.UNASSIGNED_SEQ_NO;
-                        localCheckpoints.put(initializingId, new LocalCheckpointState(localCheckpoint, inSync));
+                            SequenceNumbersService.PRE_60_NODE_CHECKPOINT : SequenceNumbers.UNASSIGNED_SEQ_NO;
+                        final long globalCheckpoint = localCheckpoint;
+                        checkpoints.put(initializingId, new CheckpointState(localCheckpoint, globalCheckpoint, inSync));
                     }
                 }
             } else {
                 for (String initializingId : initializingAllocationIds) {
-                    final long localCheckpoint = pre60AllocationIds.contains(initializingId) ?
-                        SequenceNumbersService.PRE_60_NODE_LOCAL_CHECKPOINT : SequenceNumbers.UNASSIGNED_SEQ_NO;
-                    localCheckpoints.put(initializingId, new LocalCheckpointState(localCheckpoint, false));
+                    if (shardAllocationId.equals(initializingId) == false) {
+                        final long localCheckpoint = pre60AllocationIds.contains(initializingId) ?
+                            SequenceNumbersService.PRE_60_NODE_CHECKPOINT : SequenceNumbers.UNASSIGNED_SEQ_NO;
+                        final long globalCheckpoint = localCheckpoint;
+                        checkpoints.put(initializingId, new CheckpointState(localCheckpoint, globalCheckpoint, false));
+                    }
                 }
                 for (String inSyncId : inSyncAllocationIds) {
-                    final long localCheckpoint = pre60AllocationIds.contains(inSyncId) ?
-                        SequenceNumbersService.PRE_60_NODE_LOCAL_CHECKPOINT : SequenceNumbers.UNASSIGNED_SEQ_NO;
-                    localCheckpoints.put(inSyncId, new LocalCheckpointState(localCheckpoint, true));
+                    if (shardAllocationId.equals(inSyncId)) {
+                        // current shard is initially marked as not in-sync because we don't know better at that point
+                        checkpoints.get(shardAllocationId).inSync = true;
+                    } else {
+                        final long localCheckpoint = pre60AllocationIds.contains(inSyncId) ?
+                            SequenceNumbersService.PRE_60_NODE_CHECKPOINT : SequenceNumbers.UNASSIGNED_SEQ_NO;
+                        final long globalCheckpoint = localCheckpoint;
+                        checkpoints.put(inSyncId, new CheckpointState(localCheckpoint, globalCheckpoint, true));
+                    }
                 }
             }
             appliedClusterStateVersion = applyingClusterStateVersion;
@@ -397,8 +518,8 @@ public class GlobalCheckpointTracker extends AbstractIndexShardComponent {
     public synchronized void initiateTracking(final String allocationId) {
         assert invariant();
         assert primaryMode;
-        LocalCheckpointState lcps = localCheckpoints.get(allocationId);
-        if (lcps == null) {
+        CheckpointState cps = checkpoints.get(allocationId);
+        if (cps == null) {
             // can happen if replica was removed from cluster but recovery process is unaware of it yet
             throw new IllegalStateException("no local checkpoint tracking information available");
         }
@@ -416,21 +537,21 @@ public class GlobalCheckpointTracker extends AbstractIndexShardComponent {
         assert invariant();
         assert primaryMode;
         assert handoffInProgress == false;
-        LocalCheckpointState lcps = localCheckpoints.get(allocationId);
-        if (lcps == null) {
+        CheckpointState cps = checkpoints.get(allocationId);
+        if (cps == null) {
             // can happen if replica was removed from cluster but recovery process is unaware of it yet
             throw new IllegalStateException("no local checkpoint tracking information available for " + allocationId);
         }
         assert localCheckpoint >= SequenceNumbers.NO_OPS_PERFORMED :
             "expected known local checkpoint for " + allocationId + " but was " + localCheckpoint;
         assert pendingInSync.contains(allocationId) == false : "shard copy " + allocationId + " is already marked as pending in-sync";
-        updateLocalCheckpoint(allocationId, lcps, localCheckpoint);
+        updateLocalCheckpoint(allocationId, cps, localCheckpoint);
         // if it was already in-sync (because of a previously failed recovery attempt), global checkpoint must have been
         // stuck from advancing
-        assert !lcps.inSync || (lcps.localCheckpoint >= globalCheckpoint) :
-            "shard copy " + allocationId + " that's already in-sync should have a local checkpoint " + lcps.localCheckpoint +
-                " that's above the global checkpoint " + globalCheckpoint;
-        if (lcps.localCheckpoint < globalCheckpoint) {
+        assert !cps.inSync || (cps.localCheckpoint >= getGlobalCheckpoint()) :
+            "shard copy " + allocationId + " that's already in-sync should have a local checkpoint " + cps.localCheckpoint +
+                " that's above the global checkpoint " + getGlobalCheckpoint();
+        if (cps.localCheckpoint < getGlobalCheckpoint()) {
             pendingInSync.add(allocationId);
             try {
                 while (true) {
@@ -444,7 +565,7 @@ public class GlobalCheckpointTracker extends AbstractIndexShardComponent {
                 pendingInSync.remove(allocationId);
             }
         } else {
-            lcps.inSync = true;
+            cps.inSync = true;
             replicationGroup = calculateReplicationGroup();
             logger.trace("marked [{}] as in-sync", allocationId);
             updateGlobalCheckpointOnPrimary();
@@ -453,21 +574,21 @@ public class GlobalCheckpointTracker extends AbstractIndexShardComponent {
         assert invariant();
     }
 
-    private boolean updateLocalCheckpoint(String allocationId, LocalCheckpointState lcps, long localCheckpoint) {
-        // a local checkpoint of PRE_60_NODE_LOCAL_CHECKPOINT cannot be overridden
-        assert lcps.localCheckpoint != SequenceNumbersService.PRE_60_NODE_LOCAL_CHECKPOINT ||
-            localCheckpoint == SequenceNumbersService.PRE_60_NODE_LOCAL_CHECKPOINT :
+    private boolean updateLocalCheckpoint(String allocationId, CheckpointState cps, long localCheckpoint) {
+        // a local checkpoint of PRE_60_NODE_CHECKPOINT cannot be overridden
+        assert cps.localCheckpoint != SequenceNumbersService.PRE_60_NODE_CHECKPOINT ||
+            localCheckpoint == SequenceNumbersService.PRE_60_NODE_CHECKPOINT :
             "pre-6.0 shard copy " + allocationId + " unexpected to send valid local checkpoint " + localCheckpoint;
         // a local checkpoint for a shard copy should be a valid sequence number or the pre-6.0 sequence number indicator
         assert localCheckpoint != SequenceNumbers.UNASSIGNED_SEQ_NO :
                 "invalid local checkpoint for shard copy [" + allocationId + "]";
-        if (localCheckpoint > lcps.localCheckpoint) {
-            logger.trace("updated local checkpoint of [{}] from [{}] to [{}]", allocationId, lcps.localCheckpoint, localCheckpoint);
-            lcps.localCheckpoint = localCheckpoint;
+        if (localCheckpoint > cps.localCheckpoint) {
+            logger.trace("updated local checkpoint of [{}] from [{}] to [{}]", allocationId, cps.localCheckpoint, localCheckpoint);
+            cps.localCheckpoint = localCheckpoint;
             return true;
         } else {
             logger.trace("skipped updating local checkpoint of [{}] from [{}] to [{}], current checkpoint is higher", allocationId,
-                lcps.localCheckpoint, localCheckpoint);
+                cps.localCheckpoint, localCheckpoint);
             return false;
         }
     }
@@ -483,17 +604,17 @@ public class GlobalCheckpointTracker extends AbstractIndexShardComponent {
         assert invariant();
         assert primaryMode;
         assert handoffInProgress == false;
-        LocalCheckpointState lcps = localCheckpoints.get(allocationId);
-        if (lcps == null) {
+        CheckpointState cps = checkpoints.get(allocationId);
+        if (cps == null) {
             // can happen if replica was removed from cluster but replication process is unaware of it yet
             return;
         }
-        boolean increasedLocalCheckpoint = updateLocalCheckpoint(allocationId, lcps, localCheckpoint);
+        boolean increasedLocalCheckpoint = updateLocalCheckpoint(allocationId, cps, localCheckpoint);
         boolean pending = pendingInSync.contains(allocationId);
-        if (pending && lcps.localCheckpoint >= globalCheckpoint) {
+        if (pending && cps.localCheckpoint >= getGlobalCheckpoint()) {
             pendingInSync.remove(allocationId);
             pending = false;
-            lcps.inSync = true;
+            cps.inSync = true;
             replicationGroup = calculateReplicationGroup();
             logger.trace("marked [{}] as in-sync", allocationId);
             notifyAllWaiters();
@@ -508,21 +629,21 @@ public class GlobalCheckpointTracker extends AbstractIndexShardComponent {
      * Computes the global checkpoint based on the given local checkpoints. In case where there are entries preventing the
      * computation to happen (for example due to blocking), it returns the fallback value.
      */
-    private static long computeGlobalCheckpoint(final Set<String> pendingInSync, final Collection<LocalCheckpointState> localCheckpoints,
+    private static long computeGlobalCheckpoint(final Set<String> pendingInSync, final Collection<CheckpointState> localCheckpoints,
                                                 final long fallback) {
         long minLocalCheckpoint = Long.MAX_VALUE;
         if (pendingInSync.isEmpty() == false) {
             return fallback;
         }
-        for (final LocalCheckpointState lcps : localCheckpoints) {
-            if (lcps.inSync) {
-                if (lcps.localCheckpoint == SequenceNumbers.UNASSIGNED_SEQ_NO) {
+        for (final CheckpointState cps : localCheckpoints) {
+            if (cps.inSync) {
+                if (cps.localCheckpoint == SequenceNumbers.UNASSIGNED_SEQ_NO) {
                     // unassigned in-sync replica
                     return fallback;
-                } else if (lcps.localCheckpoint == SequenceNumbersService.PRE_60_NODE_LOCAL_CHECKPOINT) {
+                } else if (cps.localCheckpoint == SequenceNumbersService.PRE_60_NODE_CHECKPOINT) {
                     // 5.x replica, ignore for global checkpoint calculation
                 } else {
-                    minLocalCheckpoint = Math.min(lcps.localCheckpoint, minLocalCheckpoint);
+                    minLocalCheckpoint = Math.min(cps.localCheckpoint, minLocalCheckpoint);
                 }
             }
         }
@@ -535,12 +656,14 @@ public class GlobalCheckpointTracker extends AbstractIndexShardComponent {
      */
     private synchronized void updateGlobalCheckpointOnPrimary() {
         assert primaryMode;
-        final long computedGlobalCheckpoint = computeGlobalCheckpoint(pendingInSync, localCheckpoints.values(), globalCheckpoint);
+        final CheckpointState cps = checkpoints.get(shardAllocationId);
+        final long globalCheckpoint = cps.globalCheckpoint;
+        final long computedGlobalCheckpoint = computeGlobalCheckpoint(pendingInSync, checkpoints.values(), getGlobalCheckpoint());
         assert computedGlobalCheckpoint >= globalCheckpoint : "new global checkpoint [" + computedGlobalCheckpoint +
             "] is lower than previous one [" + globalCheckpoint + "]";
         if (globalCheckpoint != computedGlobalCheckpoint) {
             logger.trace("global checkpoint updated to [{}]", computedGlobalCheckpoint);
-            globalCheckpoint = computedGlobalCheckpoint;
+            cps.globalCheckpoint = computedGlobalCheckpoint;
         }
     }
 
@@ -553,13 +676,13 @@ public class GlobalCheckpointTracker extends AbstractIndexShardComponent {
         assert handoffInProgress == false;
         assert pendingInSync.isEmpty() : "relocation handoff started while there are still shard copies pending in-sync: " + pendingInSync;
         handoffInProgress = true;
-        // copy clusterStateVersion and localCheckpoints and return
-        // all the entries from localCheckpoints that are inSync: the reason we don't need to care about initializing non-insync entries
+        // copy clusterStateVersion and checkpoints and return
+        // all the entries from checkpoints that are inSync: the reason we don't need to care about initializing non-insync entries
         // is that they will have to undergo a recovery attempt on the relocation target, and will hence be supplied by the cluster state
         // update on the relocation target once relocation completes). We could alternatively also copy the map as-is (it’s safe), and it
         // would be cleaned up on the target by cluster state updates.
-        Map<String, LocalCheckpointState> localCheckpointsCopy = new HashMap<>();
-        for (Map.Entry<String, LocalCheckpointState> entry : localCheckpoints.entrySet()) {
+        Map<String, CheckpointState> localCheckpointsCopy = new HashMap<>();
+        for (Map.Entry<String, CheckpointState> entry : checkpoints.entrySet()) {
             localCheckpointsCopy.put(entry.getKey(), entry.getValue().copy());
         }
         assert invariant();
@@ -586,11 +709,19 @@ public class GlobalCheckpointTracker extends AbstractIndexShardComponent {
         assert handoffInProgress;
         primaryMode = false;
         handoffInProgress = false;
-        // forget all checkpoint information
-        localCheckpoints.values().stream().forEach(lcps -> {
-            if (lcps.localCheckpoint != SequenceNumbers.UNASSIGNED_SEQ_NO &&
-                lcps.localCheckpoint != SequenceNumbersService.PRE_60_NODE_LOCAL_CHECKPOINT) {
-                lcps.localCheckpoint = SequenceNumbers.UNASSIGNED_SEQ_NO;
+        // forget all checkpoint information except for global checkpoint of current shard
+        checkpoints.entrySet().stream().forEach(e -> {
+            final CheckpointState cps = e.getValue();
+            if (cps.localCheckpoint != SequenceNumbers.UNASSIGNED_SEQ_NO &&
+                cps.localCheckpoint != SequenceNumbersService.PRE_60_NODE_CHECKPOINT) {
+                cps.localCheckpoint = SequenceNumbers.UNASSIGNED_SEQ_NO;
+            }
+            if (e.getKey().equals(shardAllocationId) == false) {
+                // don't throw global checkpoint information of current shard away
+                if (cps.globalCheckpoint != SequenceNumbers.UNASSIGNED_SEQ_NO &&
+                    cps.globalCheckpoint != SequenceNumbersService.PRE_60_NODE_CHECKPOINT) {
+                    cps.globalCheckpoint = SequenceNumbers.UNASSIGNED_SEQ_NO;
+                }
             }
         });
         assert invariant();
@@ -609,9 +740,9 @@ public class GlobalCheckpointTracker extends AbstractIndexShardComponent {
         primaryMode = true;
         // capture current state to possibly replay missed cluster state update
         appliedClusterStateVersion = primaryContext.clusterStateVersion();
-        localCheckpoints.clear();
-        for (Map.Entry<String, LocalCheckpointState> entry : primaryContext.localCheckpoints.entrySet()) {
-            localCheckpoints.put(entry.getKey(), entry.getValue().copy());
+        checkpoints.clear();
+        for (Map.Entry<String, CheckpointState> entry : primaryContext.checkpoints.entrySet()) {
+            checkpoints.put(entry.getKey(), entry.getValue().copy());
         }
         routingTable = primaryContext.getRoutingTable();
         replicationGroup = calculateReplicationGroup();
@@ -628,11 +759,11 @@ public class GlobalCheckpointTracker extends AbstractIndexShardComponent {
         final long lastAppliedClusterStateVersion = appliedClusterStateVersion;
         final Set<String> inSyncAllocationIds = new HashSet<>();
         final Set<String> pre60AllocationIds = new HashSet<>();
-        localCheckpoints.entrySet().forEach(entry -> {
+        checkpoints.entrySet().forEach(entry -> {
             if (entry.getValue().inSync) {
                 inSyncAllocationIds.add(entry.getKey());
             }
-            if (entry.getValue().getLocalCheckpoint() == SequenceNumbersService.PRE_60_NODE_LOCAL_CHECKPOINT) {
+            if (entry.getValue().getLocalCheckpoint() == SequenceNumbersService.PRE_60_NODE_CHECKPOINT) {
                 pre60AllocationIds.add(entry.getKey());
             }
         });
@@ -651,9 +782,9 @@ public class GlobalCheckpointTracker extends AbstractIndexShardComponent {
     /**
      * Returns the local checkpoint information tracked for a specific shard. Used by tests.
      */
-    public synchronized LocalCheckpointState getTrackedLocalCheckpointForShard(String allocationId) {
+    public synchronized CheckpointState getTrackedLocalCheckpointForShard(String allocationId) {
         assert primaryMode;
-        return localCheckpoints.get(allocationId);
+        return checkpoints.get(allocationId);
     }
 
     /**
@@ -682,19 +813,19 @@ public class GlobalCheckpointTracker extends AbstractIndexShardComponent {
     public static class PrimaryContext implements Writeable {
 
         private final long clusterStateVersion;
-        private final Map<String, LocalCheckpointState> localCheckpoints;
+        private final Map<String, CheckpointState> checkpoints;
         private final IndexShardRoutingTable routingTable;
 
-        public PrimaryContext(long clusterStateVersion, Map<String, LocalCheckpointState> localCheckpoints,
+        public PrimaryContext(long clusterStateVersion, Map<String, CheckpointState> checkpoints,
                               IndexShardRoutingTable routingTable) {
             this.clusterStateVersion = clusterStateVersion;
-            this.localCheckpoints = localCheckpoints;
+            this.checkpoints = checkpoints;
             this.routingTable = routingTable;
         }
 
         public PrimaryContext(StreamInput in) throws IOException {
             clusterStateVersion = in.readVLong();
-            localCheckpoints = in.readMap(StreamInput::readString, LocalCheckpointState::new);
+            checkpoints = in.readMap(StreamInput::readString, CheckpointState::new);
             routingTable = IndexShardRoutingTable.Builder.readFrom(in);
         }
 
@@ -702,8 +833,8 @@ public class GlobalCheckpointTracker extends AbstractIndexShardComponent {
             return clusterStateVersion;
         }
 
-        public Map<String, LocalCheckpointState> getLocalCheckpoints() {
-            return localCheckpoints;
+        public Map<String, CheckpointState> getCheckpointStates() {
+            return checkpoints;
         }
 
         public IndexShardRoutingTable getRoutingTable() {
@@ -713,7 +844,7 @@ public class GlobalCheckpointTracker extends AbstractIndexShardComponent {
         @Override
         public void writeTo(StreamOutput out) throws IOException {
             out.writeVLong(clusterStateVersion);
-            out.writeMap(localCheckpoints, (streamOutput, s) -> out.writeString(s), (streamOutput, lcps) -> lcps.writeTo(out));
+            out.writeMap(checkpoints, (streamOutput, s) -> out.writeString(s), (streamOutput, cps) -> cps.writeTo(out));
             IndexShardRoutingTable.Builder.writeTo(routingTable, out);
         }
 
@@ -721,7 +852,7 @@ public class GlobalCheckpointTracker extends AbstractIndexShardComponent {
         public String toString() {
             return "PrimaryContext{" +
                     "clusterStateVersion=" + clusterStateVersion +
-                    ", localCheckpoints=" + localCheckpoints +
+                    ", checkpoints=" + checkpoints +
                     ", routingTable=" + routingTable +
                     '}';
         }
@@ -740,8 +871,8 @@ public class GlobalCheckpointTracker extends AbstractIndexShardComponent {
 
         @Override
         public int hashCode() {
-            int result = (int) (clusterStateVersion ^ (clusterStateVersion >>> 32));
-            result = 31 * result + localCheckpoints.hashCode();
+            int result = Long.hashCode(clusterStateVersion);
+            result = 31 * result + checkpoints.hashCode();
             result = 31 * result + routingTable.hashCode();
             return result;
         }

+ 17 - 2
core/src/main/java/org/elasticsearch/index/seqno/SequenceNumbersService.java

@@ -19,6 +19,7 @@
 
 package org.elasticsearch.index.seqno;
 
+import com.carrotsearch.hppc.ObjectLongMap;
 import org.elasticsearch.cluster.routing.IndexShardRoutingTable;
 import org.elasticsearch.index.IndexSettings;
 import org.elasticsearch.index.shard.AbstractIndexShardComponent;
@@ -35,7 +36,7 @@ public class SequenceNumbersService extends AbstractIndexShardComponent {
     /**
      * Represents a local checkpoint coming from a pre-6.0 node
      */
-    public static final long PRE_60_NODE_LOCAL_CHECKPOINT = -3L;
+    public static final long PRE_60_NODE_CHECKPOINT = -3L;
 
     private final LocalCheckpointTracker localCheckpointTracker;
     private final GlobalCheckpointTracker globalCheckpointTracker;
@@ -132,6 +133,20 @@ public class SequenceNumbersService extends AbstractIndexShardComponent {
         globalCheckpointTracker.updateLocalCheckpoint(allocationId, checkpoint);
     }
 
+    /**
+     * Update the local knowledge of the global checkpoint for the specified allocation ID.
+     *
+     * @param allocationId     the allocation ID to update the global checkpoint for
+     * @param globalCheckpoint the global checkpoint
+     */
+    public void updateGlobalCheckpointForShard(final String allocationId, final long globalCheckpoint) {
+        globalCheckpointTracker.updateGlobalCheckpointForShard(allocationId, globalCheckpoint);
+    }
+
+    public ObjectLongMap<String> getGlobalCheckpoints() {
+        return globalCheckpointTracker.getGlobalCheckpoints();
+    }
+
     /**
      * Called when the recovery process for a shard is ready to open the engine on the target shard.
      * See {@link GlobalCheckpointTracker#initiateTracking(String)} for details.
@@ -201,7 +216,7 @@ public class SequenceNumbersService extends AbstractIndexShardComponent {
      * Activates the global checkpoint tracker in primary mode (see {@link GlobalCheckpointTracker#primaryMode}.
      * Called on primary activation or promotion.
      */
-    public void activatePrimaryMode(final String allocationId, final long localCheckpoint) {
+    public void activatePrimaryMode(final long localCheckpoint) {
         globalCheckpointTracker.activatePrimaryMode(localCheckpoint);
     }
 

+ 23 - 4
core/src/main/java/org/elasticsearch/index/shard/IndexShard.java

@@ -19,6 +19,7 @@
 
 package org.elasticsearch.index.shard;
 
+import com.carrotsearch.hppc.ObjectLongMap;
 import org.apache.logging.log4j.Logger;
 import org.apache.lucene.index.CheckIndex;
 import org.apache.lucene.index.IndexCommit;
@@ -401,7 +402,7 @@ public class IndexShard extends AbstractIndexShardComponent implements IndicesCl
                     final DiscoveryNode recoverySourceNode = recoveryState.getSourceNode();
                     if (currentRouting.isRelocationTarget() == false || recoverySourceNode.getVersion().before(Version.V_6_0_0_alpha1)) {
                         // there was no primary context hand-off in < 6.0.0, need to manually activate the shard
-                        getEngine().seqNoService().activatePrimaryMode(currentRouting.allocationId().getId(), getEngine().seqNoService().getLocalCheckpoint());
+                        getEngine().seqNoService().activatePrimaryMode(getEngine().seqNoService().getLocalCheckpoint());
                     }
                 }
 
@@ -498,7 +499,7 @@ public class IndexShard extends AbstractIndexShardComponent implements IndicesCl
                             }
                         },
                         e -> failShard("exception during primary term transition", e));
-                    getEngine().seqNoService().activatePrimaryMode(currentRouting.allocationId().getId(), getEngine().seqNoService().getLocalCheckpoint());
+                    getEngine().seqNoService().activatePrimaryMode(getEngine().seqNoService().getLocalCheckpoint());
                     primaryTerm = newPrimaryTerm;
                     latch.countDown();
                 }
@@ -1673,6 +1674,18 @@ public class IndexShard extends AbstractIndexShardComponent implements IndicesCl
         getEngine().seqNoService().updateLocalCheckpointForShard(allocationId, checkpoint);
     }
 
+    /**
+     * Update the local knowledge of the global checkpoint for the specified allocation ID.
+     *
+     * @param allocationId     the allocation ID to update the global checkpoint for
+     * @param globalCheckpoint the global checkpoint
+     */
+    public void updateGlobalCheckpointForShard(final String allocationId, final long globalCheckpoint) {
+        verifyPrimary();
+        verifyNotClosed();
+        getEngine().seqNoService().updateGlobalCheckpointForShard(allocationId, globalCheckpoint);
+    }
+
     /**
      * Waits for all operations up to the provided sequence number to complete.
      *
@@ -1735,6 +1748,12 @@ public class IndexShard extends AbstractIndexShardComponent implements IndicesCl
         return getEngine().seqNoService().getGlobalCheckpoint();
     }
 
+    public ObjectLongMap<String> getGlobalCheckpoints() {
+        verifyPrimary();
+        verifyNotClosed();
+        return getEngine().seqNoService().getGlobalCheckpoints();
+    }
+
     /**
      * Returns the current replication group for the shard.
      *
@@ -1783,9 +1802,9 @@ public class IndexShard extends AbstractIndexShardComponent implements IndicesCl
     public void activateWithPrimaryContext(final GlobalCheckpointTracker.PrimaryContext primaryContext) {
         verifyPrimary();
         assert shardRouting.isRelocationTarget() : "only relocation target can update allocation IDs from primary context: " + shardRouting;
-        assert primaryContext.getLocalCheckpoints().containsKey(routingEntry().allocationId().getId()) &&
+        assert primaryContext.getCheckpointStates().containsKey(routingEntry().allocationId().getId()) &&
             getEngine().seqNoService().getLocalCheckpoint() ==
-                primaryContext.getLocalCheckpoints().get(routingEntry().allocationId().getId()).getLocalCheckpoint();
+                primaryContext.getCheckpointStates().get(routingEntry().allocationId().getId()).getLocalCheckpoint();
         getEngine().seqNoService().activateWithPrimaryContext(primaryContext);
     }
 

+ 3 - 1
core/src/main/java/org/elasticsearch/indices/recovery/RecoverySourceHandler.java

@@ -475,7 +475,9 @@ public class RecoverySourceHandler {
          * the permit then the state of the shard will be relocated and this recovery will fail.
          */
         runUnderPrimaryPermit(() -> shard.markAllocationIdAsInSync(request.targetAllocationId(), targetLocalCheckpoint));
-        cancellableThreads.execute(() -> recoveryTarget.finalizeRecovery(shard.getGlobalCheckpoint()));
+        final long globalCheckpoint = shard.getGlobalCheckpoint();
+        cancellableThreads.execute(() -> recoveryTarget.finalizeRecovery(globalCheckpoint));
+        shard.updateGlobalCheckpointForShard(request.targetAllocationId(), globalCheckpoint);
 
         if (request.isPrimaryRelocation()) {
             logger.trace("performing relocation hand-off");

+ 23 - 5
core/src/test/java/org/elasticsearch/action/support/replication/ReplicationOperationTests.java

@@ -131,6 +131,7 @@ public class ReplicationOperationTests extends ESTestCase {
 
         assertThat(primary.knownLocalCheckpoints.remove(primaryShard.allocationId().getId()), equalTo(primary.localCheckpoint));
         assertThat(primary.knownLocalCheckpoints, equalTo(replicasProxy.generatedLocalCheckpoints));
+        assertThat(primary.knownGlobalCheckpoints, equalTo(replicasProxy.generatedGlobalCheckpoints));
     }
 
     public void testDemotedPrimary() throws Exception {
@@ -380,6 +381,7 @@ public class ReplicationOperationTests extends ESTestCase {
         final long globalCheckpoint;
         final Supplier<ClusterState> clusterStateSupplier;
         final Map<String, Long> knownLocalCheckpoints = new HashMap<>();
+        final Map<String, Long> knownGlobalCheckpoints = new HashMap<>();
 
         TestPrimary(ShardRouting routing, Supplier<ClusterState> clusterStateSupplier) {
             this.routing = routing;
@@ -434,6 +436,11 @@ public class ReplicationOperationTests extends ESTestCase {
             knownLocalCheckpoints.put(allocationId, checkpoint);
         }
 
+        @Override
+        public void updateGlobalCheckpointForShard(String allocationId, long globalCheckpoint) {
+            knownGlobalCheckpoints.put(allocationId, globalCheckpoint);
+        }
+
         @Override
         public long localCheckpoint() {
             return localCheckpoint;
@@ -455,15 +462,23 @@ public class ReplicationOperationTests extends ESTestCase {
 
     static class ReplicaResponse implements ReplicationOperation.ReplicaResponse {
         final long localCheckpoint;
+        final long globalCheckpoint;
 
-        ReplicaResponse(long localCheckpoint) {
+        ReplicaResponse(long localCheckpoint, long globalCheckpoint) {
             this.localCheckpoint = localCheckpoint;
+            this.globalCheckpoint = globalCheckpoint;
         }
 
         @Override
         public long localCheckpoint() {
             return localCheckpoint;
         }
+
+        @Override
+        public long globalCheckpoint() {
+            return globalCheckpoint;
+        }
+
     }
 
     static class TestReplicaProxy implements ReplicationOperation.Replicas<Request> {
@@ -474,6 +489,8 @@ public class ReplicationOperationTests extends ESTestCase {
 
         final Map<String, Long> generatedLocalCheckpoints = ConcurrentCollections.newConcurrentMap();
 
+        final Map<String, Long> generatedGlobalCheckpoints = ConcurrentCollections.newConcurrentMap();
+
         final Set<String> markedAsStaleCopies = ConcurrentCollections.newConcurrentSet();
 
         final long primaryTerm;
@@ -497,11 +514,12 @@ public class ReplicationOperationTests extends ESTestCase {
             if (opFailures.containsKey(replica)) {
                 listener.onFailure(opFailures.get(replica));
             } else {
-                final long checkpoint = random().nextLong();
+                final long generatedLocalCheckpoint = random().nextLong();
+                final long generatedGlobalCheckpoint = random().nextLong();
                 final String allocationId = replica.allocationId().getId();
-                Long existing = generatedLocalCheckpoints.put(allocationId, checkpoint);
-                assertNull(existing);
-                listener.onResponse(new ReplicaResponse(checkpoint));
+                assertNull(generatedLocalCheckpoints.put(allocationId, generatedLocalCheckpoint));
+                assertNull(generatedGlobalCheckpoints.put(allocationId, generatedGlobalCheckpoint));
+                listener.onResponse(new ReplicaResponse(generatedLocalCheckpoint, generatedGlobalCheckpoint));
             }
         }
 

+ 2 - 1
core/src/test/java/org/elasticsearch/action/support/replication/TransportReplicationActionTests.java

@@ -639,7 +639,8 @@ public class TransportReplicationActionTests extends ESTestCase {
         CapturingTransport.CapturedRequest[] captures = transport.getCapturedRequestsAndClear();
         assertThat(captures, arrayWithSize(1));
         if (randomBoolean()) {
-            final TransportReplicationAction.ReplicaResponse response = new TransportReplicationAction.ReplicaResponse(randomLong());
+            final TransportReplicationAction.ReplicaResponse response =
+                    new TransportReplicationAction.ReplicaResponse(randomLong(), randomLong());
             transport.handleResponse(captures[0].requestId, response);
             assertTrue(listener.isDone());
             assertThat(listener.get(), equalTo(response));

+ 2 - 1
core/src/test/java/org/elasticsearch/action/support/replication/TransportWriteActionTests.java

@@ -289,7 +289,8 @@ public class TransportWriteActionTests extends ESTestCase {
         CapturingTransport.CapturedRequest[] captures = transport.getCapturedRequestsAndClear();
         assertThat(captures, arrayWithSize(1));
         if (randomBoolean()) {
-            final TransportReplicationAction.ReplicaResponse response = new TransportReplicationAction.ReplicaResponse(randomLong());
+            final TransportReplicationAction.ReplicaResponse response =
+                    new TransportReplicationAction.ReplicaResponse(randomLong(), randomLong());
             transport.handleResponse(captures[0].requestId, response);
             assertTrue(listener.isDone());
             assertThat(listener.get(), equalTo(response));

+ 2 - 3
core/src/test/java/org/elasticsearch/index/engine/InternalEngineTests.java

@@ -82,7 +82,6 @@ import org.elasticsearch.cluster.routing.TestShardRouting;
 import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.Randomness;
 import org.elasticsearch.common.Strings;
-import org.elasticsearch.common.UUIDs;
 import org.elasticsearch.common.bytes.BytesArray;
 import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.common.collect.Tuple;
@@ -2028,7 +2027,7 @@ public class InternalEngineTests extends ESTestCase {
         final Set<String> indexedIds = new HashSet<>();
         long localCheckpoint = SequenceNumbers.NO_OPS_PERFORMED;
         long replicaLocalCheckpoint = SequenceNumbers.NO_OPS_PERFORMED;
-        long globalCheckpoint = SequenceNumbers.UNASSIGNED_SEQ_NO;
+        final long globalCheckpoint;
         long maxSeqNo = SequenceNumbers.NO_OPS_PERFORMED;
         InternalEngine initialEngine = null;
 
@@ -2039,7 +2038,7 @@ public class InternalEngineTests extends ESTestCase {
             initialEngine.seqNoService().updateAllocationIdsFromMaster(1L, new HashSet<>(Arrays.asList(primary.allocationId().getId(),
                 replica.allocationId().getId())),
                 new IndexShardRoutingTable.Builder(shardId).addShard(primary).addShard(replica).build(), Collections.emptySet());
-            initialEngine.seqNoService().activatePrimaryMode(primary.allocationId().getId(), primarySeqNo);
+            initialEngine.seqNoService().activatePrimaryMode(primarySeqNo);
             for (int op = 0; op < opCount; op++) {
                 final String id;
                 // mostly index, sometimes delete

+ 6 - 1
core/src/test/java/org/elasticsearch/index/replication/ESIndexLevelReplicationTestCase.java

@@ -481,6 +481,11 @@ public abstract class ESIndexLevelReplicationTestCase extends IndexShardTestCase
                 replicationGroup.getPrimary().updateLocalCheckpointForShard(allocationId, checkpoint);
             }
 
+            @Override
+            public void updateGlobalCheckpointForShard(String allocationId, long globalCheckpoint) {
+                replicationGroup.getPrimary().updateGlobalCheckpointForShard(allocationId, globalCheckpoint);
+            }
+
             @Override
             public long localCheckpoint() {
                 return replicationGroup.getPrimary().getLocalCheckpoint();
@@ -518,7 +523,7 @@ public abstract class ESIndexLevelReplicationTestCase extends IndexShardTestCase
                                 try {
                                     performOnReplica(request, replica);
                                     releasable.close();
-                                    listener.onResponse(new ReplicaResponse(replica.getLocalCheckpoint()));
+                                    listener.onResponse(new ReplicaResponse(replica.getLocalCheckpoint(), replica.getGlobalCheckpoint()));
                                 } catch (final Exception e) {
                                     Releasables.closeWhileHandlingException(releasable);
                                     listener.onFailure(e);

+ 0 - 1
core/src/test/java/org/elasticsearch/index/replication/RecoveryDuringReplicationTests.java

@@ -293,7 +293,6 @@ public class RecoveryDuringReplicationTests extends ESIndexLevelReplicationTestC
 
             final IndexShard oldPrimary = shards.getPrimary();
             final IndexShard newPrimary = shards.getReplicas().get(0);
-            final IndexShard otherReplica = shards.getReplicas().get(1);
 
             // simulate docs that were inflight when primary failed
             final int extraDocs = randomIntBetween(0, 5);

+ 151 - 64
core/src/test/java/org/elasticsearch/index/seqno/GlobalCheckpointTrackerTests.java

@@ -25,7 +25,6 @@ import org.elasticsearch.cluster.routing.ShardRouting;
 import org.elasticsearch.cluster.routing.ShardRoutingState;
 import org.elasticsearch.cluster.routing.TestShardRouting;
 import org.elasticsearch.common.Randomness;
-import org.elasticsearch.common.UUIDs;
 import org.elasticsearch.common.collect.Tuple;
 import org.elasticsearch.common.io.stream.BytesStreamOutput;
 import org.elasticsearch.common.io.stream.StreamInput;
@@ -35,10 +34,10 @@ import org.elasticsearch.index.IndexSettings;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.IndexSettingsModule;
-import org.junit.Before;
 
 import java.io.IOException;
 import java.util.Arrays;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
@@ -75,13 +74,24 @@ public class GlobalCheckpointTrackerTests extends ESTestCase {
         return allocations;
     }
 
-    private static IndexShardRoutingTable routingTable(Set<AllocationId> initializingIds) {
+    private static IndexShardRoutingTable routingTable(final Set<AllocationId> initializingIds, final AllocationId primaryId) {
+        final ShardId shardId = new ShardId("test", "_na_", 0);
+        final ShardRouting primaryShard =
+                TestShardRouting.newShardRouting(shardId, randomAlphaOfLength(10), null, true, ShardRoutingState.STARTED, primaryId);
+        return routingTable(initializingIds, primaryShard);
+    }
+
+    private static IndexShardRoutingTable routingTable(final Set<AllocationId> initializingIds, final ShardRouting primaryShard) {
+        assert !initializingIds.contains(primaryShard.allocationId());
         ShardId shardId = new ShardId("test", "_na_", 0);
         IndexShardRoutingTable.Builder builder = new IndexShardRoutingTable.Builder(shardId);
         for (AllocationId initializingId : initializingIds) {
-            builder.addShard(TestShardRouting.newShardRouting(shardId, randomAlphaOfLength(10), null, false, ShardRoutingState.INITIALIZING,
-                initializingId));
+            builder.addShard(TestShardRouting.newShardRouting(
+                    shardId, randomAlphaOfLength(10), null, false, ShardRoutingState.INITIALIZING, initializingId));
         }
+
+        builder.addShard(primaryShard);
+
         return builder.build();
     }
 
@@ -104,7 +114,9 @@ public class GlobalCheckpointTrackerTests extends ESTestCase {
         // it is however nice not to assume this on this level and check we do the right thing.
         final long minLocalCheckpoint = allocations.values().stream().min(Long::compare).orElse(UNASSIGNED_SEQ_NO);
 
-        final GlobalCheckpointTracker tracker = newTracker(active.iterator().next());
+
+        final AllocationId primaryId = active.iterator().next();
+        final GlobalCheckpointTracker tracker = newTracker(primaryId);
         assertThat(tracker.getGlobalCheckpoint(), equalTo(UNASSIGNED_SEQ_NO));
 
         logger.info("--> using allocations");
@@ -120,7 +132,7 @@ public class GlobalCheckpointTrackerTests extends ESTestCase {
             logger.info("  - [{}], local checkpoint [{}], [{}]", aId, allocations.get(aId), type);
         });
 
-        tracker.updateFromMaster(initialClusterStateVersion, ids(active), routingTable(initializing), emptySet());
+        tracker.updateFromMaster(initialClusterStateVersion, ids(active), routingTable(initializing, primaryId), emptySet());
         tracker.activatePrimaryMode(NO_OPS_PERFORMED);
         initializing.forEach(aId -> markAllocationIdAsInSyncQuietly(tracker, aId.getId(), NO_OPS_PERFORMED));
         allocations.keySet().forEach(aId -> tracker.updateLocalCheckpoint(aId.getId(), allocations.get(aId)));
@@ -140,12 +152,12 @@ public class GlobalCheckpointTrackerTests extends ESTestCase {
 
         // first check that adding it without the master blessing doesn't change anything.
         tracker.updateLocalCheckpoint(extraId.getId(), minLocalCheckpointAfterUpdates + 1 + randomInt(4));
-        assertNull(tracker.localCheckpoints.get(extraId));
+        assertNull(tracker.checkpoints.get(extraId));
         expectThrows(IllegalStateException.class, () -> tracker.initiateTracking(extraId.getId()));
 
         Set<AllocationId> newInitializing = new HashSet<>(initializing);
         newInitializing.add(extraId);
-        tracker.updateFromMaster(initialClusterStateVersion + 1, ids(active), routingTable(newInitializing), emptySet());
+        tracker.updateFromMaster(initialClusterStateVersion + 1, ids(active), routingTable(newInitializing, primaryId), emptySet());
 
         tracker.initiateTracking(extraId.getId());
 
@@ -167,9 +179,9 @@ public class GlobalCheckpointTrackerTests extends ESTestCase {
         final Map<AllocationId, Long> assigned = new HashMap<>();
         assigned.putAll(active);
         assigned.putAll(initializing);
-        AllocationId primary = active.keySet().iterator().next();
-        final GlobalCheckpointTracker tracker = newTracker(primary);
-        tracker.updateFromMaster(randomNonNegativeLong(), ids(active.keySet()), routingTable(initializing.keySet()), emptySet());
+        AllocationId primaryId = active.keySet().iterator().next();
+        final GlobalCheckpointTracker tracker = newTracker(primaryId);
+        tracker.updateFromMaster(randomNonNegativeLong(), ids(active.keySet()), routingTable(initializing.keySet(), primaryId), emptySet());
         tracker.activatePrimaryMode(NO_OPS_PERFORMED);
         randomSubsetOf(initializing.keySet()).forEach(k -> markAllocationIdAsInSyncQuietly(tracker, k.getId(), NO_OPS_PERFORMED));
         final AllocationId missingActiveID = randomFrom(active.keySet());
@@ -179,7 +191,7 @@ public class GlobalCheckpointTrackerTests extends ESTestCase {
                 .filter(e -> !e.getKey().equals(missingActiveID))
                 .forEach(e -> tracker.updateLocalCheckpoint(e.getKey().getId(), e.getValue()));
 
-        if (missingActiveID.equals(primary) == false) {
+        if (missingActiveID.equals(primaryId) == false) {
             assertThat(tracker.getGlobalCheckpoint(), equalTo(UNASSIGNED_SEQ_NO));
         }
         // now update all knowledge of all shards
@@ -192,9 +204,9 @@ public class GlobalCheckpointTrackerTests extends ESTestCase {
         final Map<AllocationId, Long> initializing = randomAllocationsWithLocalCheckpoints(2, 5);
         logger.info("active: {}, initializing: {}", active, initializing);
 
-        AllocationId primary = active.keySet().iterator().next();
-        final GlobalCheckpointTracker tracker = newTracker(primary);
-        tracker.updateFromMaster(randomNonNegativeLong(), ids(active.keySet()), routingTable(initializing.keySet()), emptySet());
+        AllocationId primaryId = active.keySet().iterator().next();
+        final GlobalCheckpointTracker tracker = newTracker(primaryId);
+        tracker.updateFromMaster(randomNonNegativeLong(), ids(active.keySet()), routingTable(initializing.keySet(), primaryId), emptySet());
         tracker.activatePrimaryMode(NO_OPS_PERFORMED);
         randomSubsetOf(randomIntBetween(1, initializing.size() - 1),
             initializing.keySet()).forEach(aId -> markAllocationIdAsInSyncQuietly(tracker, aId.getId(), NO_OPS_PERFORMED));
@@ -212,8 +224,9 @@ public class GlobalCheckpointTrackerTests extends ESTestCase {
         final Map<AllocationId, Long> active = randomAllocationsWithLocalCheckpoints(1, 5);
         final Map<AllocationId, Long> initializing = randomAllocationsWithLocalCheckpoints(1, 5);
         final Map<AllocationId, Long> nonApproved = randomAllocationsWithLocalCheckpoints(1, 5);
-        final GlobalCheckpointTracker tracker = newTracker(active.keySet().iterator().next());
-        tracker.updateFromMaster(randomNonNegativeLong(), ids(active.keySet()), routingTable(initializing.keySet()), emptySet());
+        final AllocationId primaryId = active.keySet().iterator().next();
+        final GlobalCheckpointTracker tracker = newTracker(primaryId);
+        tracker.updateFromMaster(randomNonNegativeLong(), ids(active.keySet()), routingTable(initializing.keySet(), primaryId), emptySet());
         tracker.activatePrimaryMode(NO_OPS_PERFORMED);
         initializing.keySet().forEach(k -> markAllocationIdAsInSyncQuietly(tracker, k.getId(), NO_OPS_PERFORMED));
         nonApproved.keySet().forEach(k ->
@@ -235,6 +248,10 @@ public class GlobalCheckpointTrackerTests extends ESTestCase {
         final Set<AllocationId> active = Sets.union(activeToStay.keySet(), activeToBeRemoved.keySet());
         final Set<AllocationId> initializing = Sets.union(initializingToStay.keySet(), initializingToBeRemoved.keySet());
         final Map<AllocationId, Long> allocations = new HashMap<>();
+        final AllocationId primaryId = active.iterator().next();
+        if (activeToBeRemoved.containsKey(primaryId)) {
+            activeToStay.put(primaryId, activeToBeRemoved.remove(primaryId));
+        }
         allocations.putAll(activeToStay);
         if (randomBoolean()) {
             allocations.putAll(activeToBeRemoved);
@@ -243,8 +260,8 @@ public class GlobalCheckpointTrackerTests extends ESTestCase {
         if (randomBoolean()) {
             allocations.putAll(initializingToBeRemoved);
         }
-        final GlobalCheckpointTracker tracker = newTracker(active.iterator().next());
-        tracker.updateFromMaster(initialClusterStateVersion, ids(active), routingTable(initializing), emptySet());
+        final GlobalCheckpointTracker tracker = newTracker(primaryId);
+        tracker.updateFromMaster(initialClusterStateVersion, ids(active), routingTable(initializing, primaryId), emptySet());
         tracker.activatePrimaryMode(NO_OPS_PERFORMED);
         if (randomBoolean()) {
             initializingToStay.keySet().forEach(k -> markAllocationIdAsInSyncQuietly(tracker, k.getId(), NO_OPS_PERFORMED));
@@ -257,13 +274,19 @@ public class GlobalCheckpointTrackerTests extends ESTestCase {
 
         // now remove shards
         if (randomBoolean()) {
-            tracker.updateFromMaster(initialClusterStateVersion + 1, ids(activeToStay.keySet()), routingTable(initializingToStay.keySet()),
-                emptySet());
+            tracker.updateFromMaster(
+                    initialClusterStateVersion + 1,
+                    ids(activeToStay.keySet()),
+                    routingTable(initializingToStay.keySet(), primaryId),
+                    emptySet());
             allocations.forEach((aid, ckp) -> tracker.updateLocalCheckpoint(aid.getId(), ckp + 10L));
         } else {
             allocations.forEach((aid, ckp) -> tracker.updateLocalCheckpoint(aid.getId(), ckp + 10L));
-            tracker.updateFromMaster(initialClusterStateVersion + 2, ids(activeToStay.keySet()), routingTable(initializingToStay.keySet()),
-                emptySet());
+            tracker.updateFromMaster(
+                    initialClusterStateVersion + 2,
+                    ids(activeToStay.keySet()),
+                    routingTable(initializingToStay.keySet(), primaryId),
+                    emptySet());
         }
 
         final long checkpoint = Stream.concat(activeToStay.values().stream(), initializingToStay.values().stream())
@@ -281,7 +304,7 @@ public class GlobalCheckpointTrackerTests extends ESTestCase {
         final AllocationId trackingAllocationId = AllocationId.newInitializing();
         final GlobalCheckpointTracker tracker = newTracker(inSyncAllocationId);
         tracker.updateFromMaster(randomNonNegativeLong(), Collections.singleton(inSyncAllocationId.getId()),
-            routingTable(Collections.singleton(trackingAllocationId)), emptySet());
+            routingTable(Collections.singleton(trackingAllocationId), inSyncAllocationId), emptySet());
         tracker.activatePrimaryMode(globalCheckpoint);
         final Thread thread = new Thread(() -> {
             try {
@@ -337,7 +360,7 @@ public class GlobalCheckpointTrackerTests extends ESTestCase {
         final AllocationId trackingAllocationId = AllocationId.newInitializing();
         final GlobalCheckpointTracker tracker = newTracker(inSyncAllocationId);
         tracker.updateFromMaster(randomNonNegativeLong(), Collections.singleton(inSyncAllocationId.getId()),
-            routingTable(Collections.singleton(trackingAllocationId)), emptySet());
+            routingTable(Collections.singleton(trackingAllocationId), inSyncAllocationId), emptySet());
         tracker.activatePrimaryMode(globalCheckpoint);
         final Thread thread = new Thread(() -> {
             try {
@@ -382,8 +405,8 @@ public class GlobalCheckpointTrackerTests extends ESTestCase {
                 randomActiveAndInitializingAllocationIds(numberOfActiveAllocationsIds, numberOfInitializingIds);
         final Set<AllocationId> activeAllocationIds = activeAndInitializingAllocationIds.v1();
         final Set<AllocationId> initializingIds = activeAndInitializingAllocationIds.v2();
-        IndexShardRoutingTable routingTable = routingTable(initializingIds);
         AllocationId primaryId = activeAllocationIds.iterator().next();
+        IndexShardRoutingTable routingTable = routingTable(initializingIds, primaryId);
         final GlobalCheckpointTracker tracker = newTracker(primaryId);
         tracker.updateFromMaster(initialClusterStateVersion, ids(activeAllocationIds), routingTable, emptySet());
         tracker.activatePrimaryMode(NO_OPS_PERFORMED);
@@ -408,14 +431,14 @@ public class GlobalCheckpointTrackerTests extends ESTestCase {
 
         // now we will remove some allocation IDs from these and ensure that they propagate through
         final Set<AllocationId> removingActiveAllocationIds = new HashSet<>(randomSubsetOf(activeAllocationIds));
+        removingActiveAllocationIds.remove(primaryId);
         final Set<AllocationId> newActiveAllocationIds =
                 activeAllocationIds.stream().filter(a -> !removingActiveAllocationIds.contains(a)).collect(Collectors.toSet());
         final List<AllocationId> removingInitializingAllocationIds = randomSubsetOf(initializingIds);
         final Set<AllocationId> newInitializingAllocationIds =
                 initializingIds.stream().filter(a -> !removingInitializingAllocationIds.contains(a)).collect(Collectors.toSet());
-        routingTable = routingTable(newInitializingAllocationIds);
-        tracker.updateFromMaster(initialClusterStateVersion + 1, ids(newActiveAllocationIds), routingTable,
-            emptySet());
+        routingTable = routingTable(newInitializingAllocationIds, primaryId);
+        tracker.updateFromMaster(initialClusterStateVersion + 1, ids(newActiveAllocationIds), routingTable, emptySet());
         assertTrue(newActiveAllocationIds.stream().allMatch(a -> tracker.getTrackedLocalCheckpointForShard(a.getId()).inSync));
         assertTrue(removingActiveAllocationIds.stream().allMatch(a -> tracker.getTrackedLocalCheckpointForShard(a.getId()) == null));
         assertTrue(newInitializingAllocationIds.stream().noneMatch(a -> tracker.getTrackedLocalCheckpointForShard(a.getId()).inSync));
@@ -429,8 +452,11 @@ public class GlobalCheckpointTrackerTests extends ESTestCase {
          * than we have been using above ensures that we can not collide with a previous allocation ID
          */
         newInitializingAllocationIds.add(AllocationId.newInitializing());
-        tracker.updateFromMaster(initialClusterStateVersion + 2, ids(newActiveAllocationIds), routingTable(newInitializingAllocationIds),
-            emptySet());
+        tracker.updateFromMaster(
+                initialClusterStateVersion + 2,
+                ids(newActiveAllocationIds),
+                routingTable(newInitializingAllocationIds, primaryId),
+                emptySet());
         assertTrue(newActiveAllocationIds.stream().allMatch(a -> tracker.getTrackedLocalCheckpointForShard(a.getId()).inSync));
         assertTrue(
                 newActiveAllocationIds
@@ -473,8 +499,11 @@ public class GlobalCheckpointTrackerTests extends ESTestCase {
         // using a different length than we have been using above ensures that we can not collide with a previous allocation ID
         final AllocationId newSyncingAllocationId = AllocationId.newInitializing();
         newInitializingAllocationIds.add(newSyncingAllocationId);
-        tracker.updateFromMaster(initialClusterStateVersion + 3, ids(newActiveAllocationIds), routingTable(newInitializingAllocationIds),
-            emptySet());
+        tracker.updateFromMaster(
+                initialClusterStateVersion + 3,
+                ids(newActiveAllocationIds),
+                routingTable(newInitializingAllocationIds, primaryId),
+                emptySet());
         final CyclicBarrier barrier = new CyclicBarrier(2);
         final Thread thread = new Thread(() -> {
             try {
@@ -508,8 +537,11 @@ public class GlobalCheckpointTrackerTests extends ESTestCase {
          * the in-sync set even if we receive a cluster state update that does not reflect this.
          *
          */
-        tracker.updateFromMaster(initialClusterStateVersion + 4, ids(newActiveAllocationIds), routingTable(newInitializingAllocationIds),
-            emptySet());
+        tracker.updateFromMaster(
+                initialClusterStateVersion + 4,
+                ids(newActiveAllocationIds),
+                routingTable(newInitializingAllocationIds, primaryId),
+                emptySet());
         assertTrue(tracker.getTrackedLocalCheckpointForShard(newSyncingAllocationId.getId()).inSync);
         assertFalse(tracker.pendingInSync.contains(newSyncingAllocationId.getId()));
     }
@@ -534,8 +566,11 @@ public class GlobalCheckpointTrackerTests extends ESTestCase {
 
         final int activeLocalCheckpoint = randomIntBetween(0, Integer.MAX_VALUE - 1);
         final GlobalCheckpointTracker tracker = newTracker(active);
-        tracker.updateFromMaster(randomNonNegativeLong(), Collections.singleton(active.getId()),
-            routingTable(Collections.singleton(initializing)), emptySet());
+        tracker.updateFromMaster(
+                randomNonNegativeLong(),
+                Collections.singleton(active.getId()),
+                routingTable(Collections.singleton(initializing), active),
+                emptySet());
         tracker.activatePrimaryMode(activeLocalCheckpoint);
         final int nextActiveLocalCheckpoint = randomIntBetween(activeLocalCheckpoint + 1, Integer.MAX_VALUE);
         final Thread activeThread = new Thread(() -> {
@@ -583,20 +618,23 @@ public class GlobalCheckpointTrackerTests extends ESTestCase {
         final ShardId shardId = new ShardId("test", "_na_", 0);
 
         FakeClusterState clusterState = initialState();
+        final AllocationId primaryAllocationId = clusterState.routingTable.primaryShard().allocationId();
         GlobalCheckpointTracker oldPrimary =
-                new GlobalCheckpointTracker(shardId, randomFrom(ids(clusterState.inSyncIds)), indexSettings, UNASSIGNED_SEQ_NO);
+                new GlobalCheckpointTracker(shardId, primaryAllocationId.getId(), indexSettings, UNASSIGNED_SEQ_NO);
         GlobalCheckpointTracker newPrimary =
-                new GlobalCheckpointTracker(shardId, UUIDs.randomBase64UUID(random()), indexSettings, UNASSIGNED_SEQ_NO);
+                new GlobalCheckpointTracker(shardId, primaryAllocationId.getRelocationId(), indexSettings, UNASSIGNED_SEQ_NO);
+
+        Set<String> allocationIds = new HashSet<>(Arrays.asList(oldPrimary.shardAllocationId, newPrimary.shardAllocationId));
 
         clusterState.apply(oldPrimary);
         clusterState.apply(newPrimary);
 
-        activatePrimary(clusterState, oldPrimary);
+        activatePrimary(oldPrimary);
 
         final int numUpdates = randomInt(10);
         for (int i = 0; i < numUpdates; i++) {
             if (rarely()) {
-                clusterState = randomUpdateClusterState(clusterState);
+                clusterState = randomUpdateClusterState(allocationIds, clusterState);
                 clusterState.apply(oldPrimary);
                 clusterState.apply(newPrimary);
             }
@@ -608,12 +646,18 @@ public class GlobalCheckpointTrackerTests extends ESTestCase {
             }
         }
 
+        // simulate transferring the global checkpoint to the new primary after finalizing recovery before the handoff
+        markAllocationIdAsInSyncQuietly(
+                oldPrimary,
+                newPrimary.shardAllocationId,
+                Math.max(SequenceNumbers.NO_OPS_PERFORMED, oldPrimary.getGlobalCheckpoint() + randomInt(5)));
+        oldPrimary.updateGlobalCheckpointForShard(newPrimary.shardAllocationId, oldPrimary.getGlobalCheckpoint());
         GlobalCheckpointTracker.PrimaryContext primaryContext = oldPrimary.startRelocationHandoff();
 
         if (randomBoolean()) {
             // cluster state update after primary context handoff
             if (randomBoolean()) {
-                clusterState = randomUpdateClusterState(clusterState);
+                clusterState = randomUpdateClusterState(allocationIds, clusterState);
                 clusterState.apply(oldPrimary);
                 clusterState.apply(newPrimary);
             }
@@ -622,7 +666,7 @@ public class GlobalCheckpointTrackerTests extends ESTestCase {
             oldPrimary.abortRelocationHandoff();
 
             if (rarely()) {
-                clusterState = randomUpdateClusterState(clusterState);
+                clusterState = randomUpdateClusterState(allocationIds, clusterState);
                 clusterState.apply(oldPrimary);
                 clusterState.apply(newPrimary);
             }
@@ -642,11 +686,10 @@ public class GlobalCheckpointTrackerTests extends ESTestCase {
         primaryContext.writeTo(output);
         StreamInput streamInput = output.bytes().streamInput();
         primaryContext = new GlobalCheckpointTracker.PrimaryContext(streamInput);
-
         switch (randomInt(3)) {
             case 0: {
                 // apply cluster state update on old primary while primary context is being transferred
-                clusterState = randomUpdateClusterState(clusterState);
+                clusterState = randomUpdateClusterState(allocationIds, clusterState);
                 clusterState.apply(oldPrimary);
                 // activate new primary
                 newPrimary.activateWithPrimaryContext(primaryContext);
@@ -656,7 +699,7 @@ public class GlobalCheckpointTrackerTests extends ESTestCase {
             }
             case 1: {
                 // apply cluster state update on new primary while primary context is being transferred
-                clusterState = randomUpdateClusterState(clusterState);
+                clusterState = randomUpdateClusterState(allocationIds, clusterState);
                 clusterState.apply(newPrimary);
                 // activate new primary
                 newPrimary.activateWithPrimaryContext(primaryContext);
@@ -666,7 +709,7 @@ public class GlobalCheckpointTrackerTests extends ESTestCase {
             }
             case 2: {
                 // apply cluster state update on both copies while primary context is being transferred
-                clusterState = randomUpdateClusterState(clusterState);
+                clusterState = randomUpdateClusterState(allocationIds, clusterState);
                 clusterState.apply(oldPrimary);
                 clusterState.apply(newPrimary);
                 newPrimary.activateWithPrimaryContext(primaryContext);
@@ -682,8 +725,32 @@ public class GlobalCheckpointTrackerTests extends ESTestCase {
         assertTrue(oldPrimary.primaryMode);
         assertTrue(newPrimary.primaryMode);
         assertThat(newPrimary.appliedClusterStateVersion, equalTo(oldPrimary.appliedClusterStateVersion));
-        assertThat(newPrimary.localCheckpoints, equalTo(oldPrimary.localCheckpoints));
-        assertThat(newPrimary.globalCheckpoint, equalTo(oldPrimary.globalCheckpoint));
+        /*
+         * We can not assert on shared knowledge of the global checkpoint between the old primary and the new primary as the new primary
+         * will update its global checkpoint state without the old primary learning of it, and the old primary could have updated its
+         * global checkpoint state after the primary context was transferred.
+         */
+        Map<String, GlobalCheckpointTracker.CheckpointState> oldPrimaryCheckpointsCopy = new HashMap<>(oldPrimary.checkpoints);
+        oldPrimaryCheckpointsCopy.remove(oldPrimary.shardAllocationId);
+        oldPrimaryCheckpointsCopy.remove(newPrimary.shardAllocationId);
+        Map<String, GlobalCheckpointTracker.CheckpointState> newPrimaryCheckpointsCopy = new HashMap<>(newPrimary.checkpoints);
+        newPrimaryCheckpointsCopy.remove(oldPrimary.shardAllocationId);
+        newPrimaryCheckpointsCopy.remove(newPrimary.shardAllocationId);
+        assertThat(newPrimaryCheckpointsCopy, equalTo(oldPrimaryCheckpointsCopy));
+        // we can however assert that shared knowledge of the local checkpoint and in-sync status is equal
+        assertThat(
+                oldPrimary.checkpoints.get(oldPrimary.shardAllocationId).localCheckpoint,
+                equalTo(newPrimary.checkpoints.get(oldPrimary.shardAllocationId).localCheckpoint));
+        assertThat(
+                oldPrimary.checkpoints.get(newPrimary.shardAllocationId).localCheckpoint,
+                equalTo(newPrimary.checkpoints.get(newPrimary.shardAllocationId).localCheckpoint));
+        assertThat(
+                oldPrimary.checkpoints.get(oldPrimary.shardAllocationId).inSync,
+                equalTo(newPrimary.checkpoints.get(oldPrimary.shardAllocationId).inSync));
+        assertThat(
+                oldPrimary.checkpoints.get(newPrimary.shardAllocationId).inSync,
+                equalTo(newPrimary.checkpoints.get(newPrimary.shardAllocationId).inSync));
+        assertThat(newPrimary.getGlobalCheckpoint(), equalTo(oldPrimary.getGlobalCheckpoint()));
         assertThat(newPrimary.routingTable, equalTo(oldPrimary.routingTable));
         assertThat(newPrimary.replicationGroup, equalTo(oldPrimary.replicationGroup));
 
@@ -696,7 +763,7 @@ public class GlobalCheckpointTrackerTests extends ESTestCase {
         final AllocationId initializing = AllocationId.newInitializing();
         final GlobalCheckpointTracker tracker = newTracker(active);
         tracker.updateFromMaster(randomNonNegativeLong(), Collections.singleton(active.getId()),
-            routingTable(Collections.singleton(initializing)), emptySet());
+            routingTable(Collections.singleton(initializing), active), emptySet());
         tracker.activatePrimaryMode(NO_OPS_PERFORMED);
 
         expectThrows(IllegalStateException.class, () -> tracker.initiateTracking(randomAlphaOfLength(10)));
@@ -733,38 +800,58 @@ public class GlobalCheckpointTrackerTests extends ESTestCase {
         final int numberOfActiveAllocationsIds = randomIntBetween(1, 8);
         final int numberOfInitializingIds = randomIntBetween(0, 8);
         final Tuple<Set<AllocationId>, Set<AllocationId>> activeAndInitializingAllocationIds =
-            randomActiveAndInitializingAllocationIds(numberOfActiveAllocationsIds, numberOfInitializingIds);
+                randomActiveAndInitializingAllocationIds(numberOfActiveAllocationsIds, numberOfInitializingIds);
         final Set<AllocationId> activeAllocationIds = activeAndInitializingAllocationIds.v1();
         final Set<AllocationId> initializingAllocationIds = activeAndInitializingAllocationIds.v2();
-        return new FakeClusterState(initialClusterStateVersion, activeAllocationIds, routingTable(initializingAllocationIds));
+        final AllocationId primaryId = randomFrom(activeAllocationIds);
+        final AllocationId relocatingId = AllocationId.newRelocation(primaryId);
+        activeAllocationIds.remove(primaryId);
+        activeAllocationIds.add(relocatingId);
+        final ShardId shardId = new ShardId("test", "_na_", 0);
+        final ShardRouting primaryShard =
+                TestShardRouting.newShardRouting(
+                        shardId, randomAlphaOfLength(10), randomAlphaOfLength(10), true, ShardRoutingState.RELOCATING, relocatingId);
+
+        return new FakeClusterState(
+                initialClusterStateVersion,
+                activeAllocationIds,
+                routingTable(initializingAllocationIds, primaryShard));
     }
 
-    private static void activatePrimary(FakeClusterState clusterState, GlobalCheckpointTracker gcp) {
+    private static void activatePrimary(GlobalCheckpointTracker gcp) {
         gcp.activatePrimaryMode(randomIntBetween(Math.toIntExact(NO_OPS_PERFORMED), 10));
     }
 
     private static void randomLocalCheckpointUpdate(GlobalCheckpointTracker gcp) {
-        String allocationId = randomFrom(gcp.localCheckpoints.keySet());
-        long currentLocalCheckpoint = gcp.localCheckpoints.get(allocationId).getLocalCheckpoint();
+        String allocationId = randomFrom(gcp.checkpoints.keySet());
+        long currentLocalCheckpoint = gcp.checkpoints.get(allocationId).getLocalCheckpoint();
         gcp.updateLocalCheckpoint(allocationId, Math.max(SequenceNumbers.NO_OPS_PERFORMED, currentLocalCheckpoint + randomInt(5)));
     }
 
     private static void randomMarkInSync(GlobalCheckpointTracker gcp) {
-        String allocationId = randomFrom(gcp.localCheckpoints.keySet());
+        String allocationId = randomFrom(gcp.checkpoints.keySet());
         long newLocalCheckpoint = Math.max(NO_OPS_PERFORMED, gcp.getGlobalCheckpoint() + randomInt(5));
         markAllocationIdAsInSyncQuietly(gcp, allocationId, newLocalCheckpoint);
     }
 
-    private static FakeClusterState randomUpdateClusterState(FakeClusterState clusterState) {
-        final Set<AllocationId> initializingIdsToAdd = randomAllocationIdsExcludingExistingIds(clusterState.allIds(), randomInt(2));
+    private static FakeClusterState randomUpdateClusterState(Set<String> allocationIds, FakeClusterState clusterState) {
+        final Set<AllocationId> initializingIdsToAdd =
+                randomAllocationIdsExcludingExistingIds(exclude(clusterState.allIds(), allocationIds), randomInt(2));
         final Set<AllocationId> initializingIdsToRemove = new HashSet<>(
-            randomSubsetOf(randomInt(clusterState.initializingIds().size()), clusterState.initializingIds()));
+            exclude(randomSubsetOf(randomInt(clusterState.initializingIds().size()), clusterState.initializingIds()), allocationIds));
         final Set<AllocationId> inSyncIdsToRemove = new HashSet<>(
-            randomSubsetOf(randomInt(clusterState.inSyncIds.size()), clusterState.inSyncIds));
+            exclude(randomSubsetOf(randomInt(clusterState.inSyncIds.size()), clusterState.inSyncIds), allocationIds));
         final Set<AllocationId> remainingInSyncIds = Sets.difference(clusterState.inSyncIds, inSyncIdsToRemove);
-        return new FakeClusterState(clusterState.version + randomIntBetween(1, 5),
-            remainingInSyncIds.isEmpty() ? clusterState.inSyncIds : remainingInSyncIds,
-            routingTable(Sets.difference(Sets.union(clusterState.initializingIds(), initializingIdsToAdd), initializingIdsToRemove)));
+        return new FakeClusterState(
+                clusterState.version + randomIntBetween(1, 5),
+                remainingInSyncIds.isEmpty() ? clusterState.inSyncIds : remainingInSyncIds,
+                routingTable(
+                        Sets.difference(Sets.union(clusterState.initializingIds(), initializingIdsToAdd), initializingIdsToRemove),
+                        clusterState.routingTable.primaryShard()));
+    }
+
+    private static Set<AllocationId> exclude(Collection<AllocationId> allocationIds, Set<String> excludeIds) {
+        return allocationIds.stream().filter(aId -> !excludeIds.contains(aId.getId())).collect(Collectors.toSet());
     }
 
     private static Tuple<Set<AllocationId>, Set<AllocationId>> randomActiveAndInitializingAllocationIds(

+ 15 - 1
core/src/test/java/org/elasticsearch/recovery/RelocationIT.java

@@ -20,6 +20,7 @@
 package org.elasticsearch.recovery;
 
 import com.carrotsearch.hppc.IntHashSet;
+import com.carrotsearch.hppc.ObjectLongMap;
 import com.carrotsearch.hppc.procedures.IntProcedure;
 import org.apache.lucene.index.IndexFileNames;
 import org.apache.lucene.util.English;
@@ -34,6 +35,7 @@ import org.elasticsearch.client.Client;
 import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.cluster.metadata.IndexMetaData;
 import org.elasticsearch.cluster.node.DiscoveryNode;
+import org.elasticsearch.cluster.routing.ShardRouting;
 import org.elasticsearch.cluster.routing.ShardRoutingState;
 import org.elasticsearch.cluster.routing.allocation.command.MoveAllocationCommand;
 import org.elasticsearch.cluster.routing.allocation.decider.EnableAllocationDecider;
@@ -44,12 +46,14 @@ import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.unit.TimeValue;
 import org.elasticsearch.common.xcontent.XContentType;
 import org.elasticsearch.env.NodeEnvironment;
+import org.elasticsearch.index.IndexService;
 import org.elasticsearch.index.seqno.SeqNoStats;
 import org.elasticsearch.index.seqno.SequenceNumbers;
 import org.elasticsearch.index.shard.IndexEventListener;
 import org.elasticsearch.index.shard.IndexShard;
 import org.elasticsearch.index.shard.IndexShardState;
 import org.elasticsearch.index.shard.ShardId;
+import org.elasticsearch.indices.IndicesService;
 import org.elasticsearch.indices.recovery.PeerRecoveryTargetService;
 import org.elasticsearch.indices.recovery.RecoveryFileChunkRequest;
 import org.elasticsearch.plugins.Plugin;
@@ -118,8 +122,14 @@ public class RelocationIT extends ESIntegTestCase {
                     }
                     ShardStats primary = maybePrimary.get();
                     final SeqNoStats primarySeqNoStats = primary.getSeqNoStats();
-                    assertThat(primary.getShardRouting() + " should have set the global checkpoint",
+                    final ShardRouting primaryShardRouting = primary.getShardRouting();
+                    assertThat(primaryShardRouting + " should have set the global checkpoint",
                         primarySeqNoStats.getGlobalCheckpoint(), not(equalTo(SequenceNumbers.UNASSIGNED_SEQ_NO)));
+                    final DiscoveryNode node = clusterService().state().nodes().get(primaryShardRouting.currentNodeId());
+                    final IndicesService indicesService =
+                            internalCluster().getInstance(IndicesService.class, node.getName());
+                    final IndexShard indexShard = indicesService.getShardOrNull(primaryShardRouting.shardId());
+                    final ObjectLongMap<String> globalCheckpoints = indexShard.getGlobalCheckpoints();
                     for (ShardStats shardStats : indexShardStats) {
                         final SeqNoStats seqNoStats = shardStats.getSeqNoStats();
                         assertThat(shardStats.getShardRouting() + " local checkpoint mismatch",
@@ -128,6 +138,10 @@ public class RelocationIT extends ESIntegTestCase {
                             seqNoStats.getGlobalCheckpoint(), equalTo(primarySeqNoStats.getGlobalCheckpoint()));
                         assertThat(shardStats.getShardRouting() + " max seq no mismatch",
                             seqNoStats.getMaxSeqNo(), equalTo(primarySeqNoStats.getMaxSeqNo()));
+                        // the local knowledge on the primary of the global checkpoint equals the global checkpoint on the shard
+                        assertThat(
+                                seqNoStats.getGlobalCheckpoint(),
+                                equalTo(globalCheckpoints.get(shardStats.getShardRouting().allocationId().getId())));
                     }
                 }
             }