Browse Source

Make SnapshotsInProgress Diffable (#89619)

This is very important for #77466. Profiling showed that serializing snapshots-in-progress when there's a few snapshots with high shard count running takes a significant amount of CPU and heap for sending the full data structure over an over.
This PR adds diffing in the simplest way I could think of on top of the existing data structure.

closes #88732
Armin Braun 3 years ago
parent
commit
b69d1bd15a

+ 6 - 0
docs/changelog/89619.yaml

@@ -0,0 +1,6 @@
+pr: 89619
+summary: Make `SnapshotsInProgress` Diffable
+area: Snapshot/Restore
+type: enhancement
+issues:
+ - 88732

+ 404 - 18
server/src/main/java/org/elasticsearch/cluster/SnapshotsInProgress.java

@@ -56,13 +56,13 @@ public class SnapshotsInProgress extends AbstractNamedDiffable<Custom> implement
     public static final String ABORTED_FAILURE_TEXT = "Snapshot was aborted by deletion";
 
     // keyed by repository name
-    private final Map<String, List<Entry>> entries;
+    private final Map<String, ByRepo> entries;
 
     public SnapshotsInProgress(StreamInput in) throws IOException {
         this(collectByRepo(in));
     }
 
-    private static Map<String, List<Entry>> collectByRepo(StreamInput in) throws IOException {
+    private static Map<String, ByRepo> collectByRepo(StreamInput in) throws IOException {
         final int count = in.readVInt();
         if (count == 0) {
             return Map.of();
@@ -72,13 +72,14 @@ public class SnapshotsInProgress extends AbstractNamedDiffable<Custom> implement
             final Entry entry = Entry.readFrom(in);
             entriesByRepo.computeIfAbsent(entry.repository(), repo -> new ArrayList<>()).add(entry);
         }
+        final Map<String, ByRepo> res = Maps.newMapWithExpectedSize(entriesByRepo.size());
         for (Map.Entry<String, List<Entry>> entryForRepo : entriesByRepo.entrySet()) {
-            entryForRepo.setValue(List.copyOf(entryForRepo.getValue()));
+            res.put(entryForRepo.getKey(), new ByRepo(entryForRepo.getValue()));
         }
-        return entriesByRepo;
+        return res;
     }
 
-    private SnapshotsInProgress(Map<String, List<Entry>> entries) {
+    private SnapshotsInProgress(Map<String, ByRepo> entries) {
         this.entries = Map.copyOf(entries);
         assert assertConsistentEntries(this.entries);
     }
@@ -87,26 +88,26 @@ public class SnapshotsInProgress extends AbstractNamedDiffable<Custom> implement
         if (updatedEntries.equals(forRepo(repository))) {
             return this;
         }
-        final Map<String, List<Entry>> copy = new HashMap<>(this.entries);
+        final Map<String, ByRepo> copy = new HashMap<>(this.entries);
         if (updatedEntries.isEmpty()) {
             copy.remove(repository);
             if (copy.isEmpty()) {
                 return EMPTY;
             }
         } else {
-            copy.put(repository, List.copyOf(updatedEntries));
+            copy.put(repository, new ByRepo(updatedEntries));
         }
         return new SnapshotsInProgress(copy);
     }
 
     public SnapshotsInProgress withAddedEntry(Entry entry) {
-        final List<Entry> forRepo = new ArrayList<>(entries.getOrDefault(entry.repository(), List.of()));
+        final List<Entry> forRepo = new ArrayList<>(forRepo(entry.repository()));
         forRepo.add(entry);
         return withUpdatedEntriesForRepo(entry.repository(), forRepo);
     }
 
     public List<Entry> forRepo(String repository) {
-        return entries.getOrDefault(repository, List.of());
+        return entries.getOrDefault(repository, ByRepo.EMPTY).entries;
     }
 
     public boolean isEmpty() {
@@ -115,18 +116,18 @@ public class SnapshotsInProgress extends AbstractNamedDiffable<Custom> implement
 
     public int count() {
         int count = 0;
-        for (List<Entry> list : entries.values()) {
-            count += list.size();
+        for (ByRepo byRepo : entries.values()) {
+            count += byRepo.entries.size();
         }
         return count;
     }
 
-    public Collection<List<Entry>> entriesByRepo() {
-        return entries.values();
+    public Iterable<List<Entry>> entriesByRepo() {
+        return () -> entries.values().stream().map(byRepo -> byRepo.entries).iterator();
     }
 
     public Stream<Entry> asStream() {
-        return entries.values().stream().flatMap(Collection::stream);
+        return entries.values().stream().flatMap(t -> t.entries.stream());
     }
 
     @Nullable
@@ -187,10 +188,20 @@ public class SnapshotsInProgress extends AbstractNamedDiffable<Custom> implement
         return Version.CURRENT.minimumCompatibilityVersion();
     }
 
+    private static final Version DIFFABLE_VERSION = Version.V_8_5_0;
+
     public static NamedDiff<Custom> readDiffFrom(StreamInput in) throws IOException {
+        if (in.getVersion().onOrAfter(DIFFABLE_VERSION)) {
+            return new SnapshotInProgressDiff(in);
+        }
         return readDiffFrom(Custom.class, TYPE, in);
     }
 
+    @Override
+    public Diff<Custom> diff(Custom previousState) {
+        return new SnapshotInProgressDiff((SnapshotsInProgress) previousState, this);
+    }
+
     @Override
     public void writeTo(StreamOutput out) throws IOException {
         out.writeVInt(count());
@@ -318,11 +329,11 @@ public class SnapshotsInProgress extends AbstractNamedDiffable<Custom> implement
         return false;
     }
 
-    private static boolean assertConsistentEntries(Map<String, List<Entry>> entries) {
-        for (Map.Entry<String, List<Entry>> repoEntries : entries.entrySet()) {
+    private static boolean assertConsistentEntries(Map<String, ByRepo> entries) {
+        for (Map.Entry<String, ByRepo> repoEntries : entries.entrySet()) {
             final Set<Tuple<String, Integer>> assignedShards = new HashSet<>();
             final Set<Tuple<String, Integer>> queuedShards = new HashSet<>();
-            final List<Entry> entriesForRepository = repoEntries.getValue();
+            final List<Entry> entriesForRepository = repoEntries.getValue().entries;
             final String repository = repoEntries.getKey();
             assert entriesForRepository.isEmpty() == false : "found empty list of snapshots for " + repository + " in " + entries;
             for (Entry entry : entriesForRepository) {
@@ -637,7 +648,7 @@ public class SnapshotsInProgress extends AbstractNamedDiffable<Custom> implement
         }
     }
 
-    public static class Entry implements Writeable, ToXContent, RepositoryOperation {
+    public static class Entry implements Writeable, ToXContent, RepositoryOperation, Diffable<Entry> {
         private final State state;
         private final Snapshot snapshot;
         private final boolean includeGlobalState;
@@ -1312,5 +1323,380 @@ public class SnapshotsInProgress extends AbstractNamedDiffable<Custom> implement
         public boolean isFragment() {
             return false;
         }
+
+        @Override
+        public Diff<Entry> diff(Entry previousState) {
+            return new EntryDiff(previousState, this);
+        }
+    }
+
+    private static final class EntryDiff implements Diff<Entry> {
+
+        private static final DiffableUtils.NonDiffableValueSerializer<String, IndexId> INDEX_ID_VALUE_SERIALIZER =
+            new DiffableUtils.NonDiffableValueSerializer<>() {
+                @Override
+                public void write(IndexId value, StreamOutput out) throws IOException {
+                    out.writeString(value.getId());
+                }
+
+                @Override
+                public IndexId read(StreamInput in, String key) throws IOException {
+                    return new IndexId(key, in.readString());
+                }
+            };
+
+        private static final DiffableUtils.NonDiffableValueSerializer<?, ShardSnapshotStatus> SHARD_SNAPSHOT_STATUS_VALUE_SERIALIZER =
+            new DiffableUtils.NonDiffableValueSerializer<>() {
+                @Override
+                public void write(ShardSnapshotStatus value, StreamOutput out) throws IOException {
+                    value.writeTo(out);
+                }
+
+                @Override
+                public ShardSnapshotStatus read(StreamInput in, Object key) throws IOException {
+                    return ShardSnapshotStatus.readFrom(in);
+                }
+            };
+
+        private static final DiffableUtils.KeySerializer<ShardId> SHARD_ID_KEY_SERIALIZER = new DiffableUtils.KeySerializer<>() {
+            @Override
+            public void writeKey(ShardId key, StreamOutput out) throws IOException {
+                key.writeTo(out);
+            }
+
+            @Override
+            public ShardId readKey(StreamInput in) throws IOException {
+                return new ShardId(in);
+            }
+        };
+
+        private static final DiffableUtils.KeySerializer<RepositoryShardId> REPO_SHARD_ID_KEY_SERIALIZER =
+            new DiffableUtils.KeySerializer<>() {
+                @Override
+                public void writeKey(RepositoryShardId key, StreamOutput out) throws IOException {
+                    key.writeTo(out);
+                }
+
+                @Override
+                public RepositoryShardId readKey(StreamInput in) throws IOException {
+                    return new RepositoryShardId(in);
+                }
+            };
+
+        private final DiffableUtils.MapDiff<String, IndexId, Map<String, IndexId>> indexByIndexNameDiff;
+
+        private final DiffableUtils.MapDiff<ShardId, ShardSnapshotStatus, Map<ShardId, ShardSnapshotStatus>> shardsByShardIdDiff;
+
+        @Nullable
+        private final DiffableUtils.MapDiff<
+            RepositoryShardId,
+            ShardSnapshotStatus,
+            Map<RepositoryShardId, ShardSnapshotStatus>> shardsByRepoShardIdDiff;
+
+        @Nullable
+        private final List<String> updatedDataStreams;
+
+        @Nullable
+        private final String updatedFailure;
+
+        private final long updatedRepositoryStateId;
+
+        private final State updatedState;
+
+        @SuppressWarnings("unchecked")
+        EntryDiff(StreamInput in) throws IOException {
+            this.indexByIndexNameDiff = DiffableUtils.readJdkMapDiff(in, DiffableUtils.getStringKeySerializer(), INDEX_ID_VALUE_SERIALIZER);
+            this.updatedState = State.fromValue(in.readByte());
+            this.updatedRepositoryStateId = in.readLong();
+            this.updatedDataStreams = in.readOptionalStringList();
+            this.updatedFailure = in.readOptionalString();
+            this.shardsByShardIdDiff = DiffableUtils.readJdkMapDiff(
+                in,
+                SHARD_ID_KEY_SERIALIZER,
+                (DiffableUtils.ValueSerializer<ShardId, ShardSnapshotStatus>) SHARD_SNAPSHOT_STATUS_VALUE_SERIALIZER
+            );
+            shardsByRepoShardIdDiff = in.readOptionalWriteable(
+                i -> DiffableUtils.readJdkMapDiff(
+                    i,
+                    REPO_SHARD_ID_KEY_SERIALIZER,
+                    (DiffableUtils.ValueSerializer<RepositoryShardId, ShardSnapshotStatus>) SHARD_SNAPSHOT_STATUS_VALUE_SERIALIZER
+                )
+            );
+        }
+
+        @SuppressWarnings("unchecked")
+        EntryDiff(Entry before, Entry after) {
+            try {
+                verifyDiffable(before, after);
+            } catch (Exception e) {
+                final IllegalArgumentException ex = new IllegalArgumentException("Cannot diff [" + before + "] and [" + after + "]");
+                assert false : ex;
+                throw ex;
+            }
+            this.indexByIndexNameDiff = DiffableUtils.diff(
+                before.indices,
+                after.indices,
+                DiffableUtils.getStringKeySerializer(),
+                INDEX_ID_VALUE_SERIALIZER
+            );
+            this.updatedDataStreams = before.dataStreams.equals(after.dataStreams) ? null : after.dataStreams;
+            this.updatedState = after.state;
+            this.updatedRepositoryStateId = after.repositoryStateId;
+            this.updatedFailure = after.failure;
+            this.shardsByShardIdDiff = DiffableUtils.diff(
+                before.shards,
+                after.shards,
+                SHARD_ID_KEY_SERIALIZER,
+                (DiffableUtils.ValueSerializer<ShardId, ShardSnapshotStatus>) SHARD_SNAPSHOT_STATUS_VALUE_SERIALIZER
+            );
+            if (before.isClone()) {
+                this.shardsByRepoShardIdDiff = DiffableUtils.diff(
+                    before.shardStatusByRepoShardId,
+                    after.shardStatusByRepoShardId,
+                    REPO_SHARD_ID_KEY_SERIALIZER,
+                    (DiffableUtils.ValueSerializer<RepositoryShardId, ShardSnapshotStatus>) SHARD_SNAPSHOT_STATUS_VALUE_SERIALIZER
+                );
+            } else {
+                this.shardsByRepoShardIdDiff = null;
+            }
+        }
+
+        private static void verifyDiffable(Entry before, Entry after) {
+            if (before.snapshot().equals(after.snapshot()) == false) {
+                throw new IllegalArgumentException("snapshot changed from [" + before.snapshot() + "] to [" + after.snapshot() + "]");
+            }
+            if (before.startTime() != after.startTime()) {
+                throw new IllegalArgumentException("start time changed from [" + before.startTime() + "] to [" + after.startTime() + "]");
+            }
+            if (Objects.equals(before.source(), after.source()) == false) {
+                throw new IllegalArgumentException("source changed from [" + before.source() + "] to [" + after.source() + "]");
+            }
+            if (before.includeGlobalState() != after.includeGlobalState()) {
+                throw new IllegalArgumentException(
+                    "include global state changed from [" + before.includeGlobalState() + "] to [" + after.includeGlobalState() + "]"
+                );
+            }
+            if (before.partial() != after.partial()) {
+                throw new IllegalArgumentException("partial changed from [" + before.partial() + "] to [" + after.partial() + "]");
+            }
+            if (before.featureStates().equals(after.featureStates()) == false) {
+                throw new IllegalArgumentException(
+                    "feature states changed from " + before.featureStates() + " to " + after.featureStates()
+                );
+            }
+            if (Objects.equals(before.userMetadata(), after.userMetadata()) == false) {
+                throw new IllegalArgumentException("user metadata changed from " + before.userMetadata() + " to " + after.userMetadata());
+            }
+            if (before.version().equals(after.version()) == false) {
+                throw new IllegalArgumentException("version changed from " + before.version() + " to " + after.version());
+            }
+        }
+
+        @Override
+        public Entry apply(Entry part) {
+            final var updatedIndices = indexByIndexNameDiff.apply(part.indices);
+            final var updatedStateByShard = shardsByShardIdDiff.apply(part.shards);
+            if (part.isClone() == false && updatedIndices == part.indices && updatedStateByShard == part.shards) {
+                // fast path for normal snapshots that avoid rebuilding the by-repo-id map if nothing changed about shard status
+                return new Entry(
+                    part.snapshot,
+                    part.includeGlobalState,
+                    part.partial,
+                    updatedState,
+                    updatedIndices,
+                    updatedDataStreams == null ? part.dataStreams : updatedDataStreams,
+                    part.featureStates,
+                    part.startTime,
+                    updatedRepositoryStateId,
+                    updatedStateByShard,
+                    updatedFailure,
+                    part.userMetadata,
+                    part.version,
+                    null,
+                    part.shardStatusByRepoShardId,
+                    part.snapshotIndices
+                );
+            }
+            if (part.isClone()) {
+                return Entry.createClone(
+                    part.snapshot,
+                    updatedState,
+                    updatedIndices,
+                    part.startTime,
+                    updatedRepositoryStateId,
+                    updatedFailure,
+                    part.version,
+                    part.source,
+                    shardsByRepoShardIdDiff.apply(part.shardStatusByRepoShardId)
+                );
+            }
+            return Entry.snapshot(
+                part.snapshot,
+                part.includeGlobalState,
+                part.partial,
+                updatedState,
+                updatedIndices,
+                updatedDataStreams == null ? part.dataStreams : updatedDataStreams,
+                part.featureStates,
+                part.startTime,
+                updatedRepositoryStateId,
+                updatedStateByShard,
+                updatedFailure,
+                part.userMetadata,
+                part.version
+            );
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            this.indexByIndexNameDiff.writeTo(out);
+            out.writeByte(this.updatedState.value());
+            out.writeLong(this.updatedRepositoryStateId);
+            out.writeOptionalStringCollection(updatedDataStreams);
+            out.writeOptionalString(updatedFailure);
+            shardsByShardIdDiff.writeTo(out);
+            out.writeOptionalWriteable(shardsByRepoShardIdDiff);
+        }
+    }
+
+    private static final class SnapshotInProgressDiff implements NamedDiff<Custom> {
+
+        private final SnapshotsInProgress after;
+
+        private final DiffableUtils.MapDiff<String, ByRepo, Map<String, ByRepo>> mapDiff;
+
+        SnapshotInProgressDiff(SnapshotsInProgress before, SnapshotsInProgress after) {
+            this.mapDiff = DiffableUtils.diff(before.entries, after.entries, DiffableUtils.getStringKeySerializer());
+            this.after = after;
+        }
+
+        SnapshotInProgressDiff(StreamInput in) throws IOException {
+            this.mapDiff = DiffableUtils.readJdkMapDiff(
+                in,
+                DiffableUtils.getStringKeySerializer(),
+                i -> new ByRepo(i.readList(Entry::readFrom)),
+                i -> new ByRepo.ByRepoDiff(
+                    DiffableUtils.readJdkMapDiff(i, DiffableUtils.getStringKeySerializer(), Entry::readFrom, EntryDiff::new),
+                    DiffableUtils.readJdkMapDiff(i, DiffableUtils.getStringKeySerializer(), ByRepo.INT_DIFF_VALUE_SERIALIZER)
+                )
+            );
+            this.after = null;
+        }
+
+        @Override
+        public SnapshotsInProgress apply(Custom part) {
+            return new SnapshotsInProgress(mapDiff.apply(((SnapshotsInProgress) part).entries));
+        }
+
+        @Override
+        public Version getMinimalSupportedVersion() {
+            return Version.CURRENT.minimumCompatibilityVersion();
+        }
+
+        @Override
+        public String getWriteableName() {
+            return TYPE;
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            assert after != null : "should only write instances that were diffed from this node's state";
+            if (out.getVersion().onOrAfter(DIFFABLE_VERSION)) {
+                mapDiff.writeTo(out);
+            } else {
+                new SimpleDiffable.CompleteDiff<>(after).writeTo(out);
+            }
+        }
+    }
+
+    /**
+     * Wrapper for the list of snapshots per repository to allow for diffing changes in individual entries as well as position changes
+     * of entries in the list.
+     *
+     * @param entries all snapshots executing for a single repository
+     */
+    private record ByRepo(List<Entry> entries) implements Diffable<ByRepo> {
+
+        static final ByRepo EMPTY = new ByRepo(List.of());
+        private static final DiffableUtils.NonDiffableValueSerializer<String, Integer> INT_DIFF_VALUE_SERIALIZER =
+            new DiffableUtils.NonDiffableValueSerializer<>() {
+                @Override
+                public void write(Integer value, StreamOutput out) throws IOException {
+                    out.writeVInt(value);
+                }
+
+                @Override
+                public Integer read(StreamInput in, String key) throws IOException {
+                    return in.readVInt();
+                }
+            };
+
+        private ByRepo(List<Entry> entries) {
+            this.entries = List.copyOf(entries);
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            out.writeList(entries);
+        }
+
+        @Override
+        public Diff<ByRepo> diff(ByRepo previousState) {
+            return new ByRepoDiff(
+                DiffableUtils.diff(toMapByUUID(previousState), toMapByUUID(this), DiffableUtils.getStringKeySerializer()),
+                DiffableUtils.diff(
+                    toPositionMap(previousState),
+                    toPositionMap(this),
+                    DiffableUtils.getStringKeySerializer(),
+                    INT_DIFF_VALUE_SERIALIZER
+                )
+            );
+        }
+
+        public static Map<String, Integer> toPositionMap(ByRepo part) {
+            final Map<String, Integer> res = Maps.newMapWithExpectedSize(part.entries.size());
+            for (int i = 0; i < part.entries.size(); i++) {
+                final String snapshotUUID = part.entries.get(i).snapshot().getSnapshotId().getUUID();
+                assert res.containsKey(snapshotUUID) == false;
+                res.put(snapshotUUID, i);
+            }
+            return res;
+        }
+
+        public static Map<String, Entry> toMapByUUID(ByRepo part) {
+            final Map<String, Entry> res = Maps.newMapWithExpectedSize(part.entries.size());
+            for (Entry entry : part.entries) {
+                final String snapshotUUID = entry.snapshot().getSnapshotId().getUUID();
+                assert res.containsKey(snapshotUUID) == false;
+                res.put(snapshotUUID, entry);
+            }
+            return res;
+        }
+
+        /**
+         * @param diffBySnapshotUUID diff of a map of snapshot UUID to snapshot entry
+         * @param positionDiff diff of a map with snapshot UUID keys and positions in {@link ByRepo#entries} as values. Used to efficiently
+         *                     diff an entry moving to another index in the list
+         */
+        private record ByRepoDiff(
+            DiffableUtils.MapDiff<String, Entry, Map<String, Entry>> diffBySnapshotUUID,
+            DiffableUtils.MapDiff<String, Integer, Map<String, Integer>> positionDiff
+        ) implements Diff<ByRepo> {
+
+            @Override
+            public ByRepo apply(ByRepo part) {
+                final var updated = diffBySnapshotUUID.apply(toMapByUUID(part));
+                final var updatedPositions = positionDiff.apply(toPositionMap(part));
+                final Entry[] arr = new Entry[updated.size()];
+                updatedPositions.forEach((uuid, position) -> arr[position] = updated.get(uuid));
+                return new ByRepo(List.of(arr));
+            }
+
+            @Override
+            public void writeTo(StreamOutput out) throws IOException {
+                diffBySnapshotUUID.writeTo(out);
+                positionDiff.writeTo(out);
+            }
+        }
     }
 }

+ 64 - 51
server/src/test/java/org/elasticsearch/snapshots/SnapshotsInProgressSerializationTests.java

@@ -53,13 +53,13 @@ public class SnapshotsInProgressSerializationTests extends SimpleDiffableWireSer
         int numberOfSnapshots = randomInt(10);
         SnapshotsInProgress snapshotsInProgress = SnapshotsInProgress.EMPTY;
         for (int i = 0; i < numberOfSnapshots; i++) {
-            snapshotsInProgress.withAddedEntry(randomSnapshot());
+            snapshotsInProgress = snapshotsInProgress.withAddedEntry(randomSnapshot());
         }
         return snapshotsInProgress;
     }
 
     private Entry randomSnapshot() {
-        Snapshot snapshot = new Snapshot(randomAlphaOfLength(10), new SnapshotId(randomAlphaOfLength(10), randomAlphaOfLength(10)));
+        Snapshot snapshot = new Snapshot("repo-" + randomInt(5), new SnapshotId(randomAlphaOfLength(10), randomAlphaOfLength(10)));
         boolean includeGlobalState = randomBoolean();
         boolean partial = randomBoolean();
         int numberOfIndices = randomIntBetween(0, 10);
@@ -143,13 +143,16 @@ public class SnapshotsInProgressSerializationTests extends SimpleDiffableWireSer
         if (randomBoolean()) {
             // modify some elements
             for (List<Entry> perRepoEntries : updatedInstance.entriesByRepo()) {
-                final List<Entry> entries = new ArrayList<>(perRepoEntries);
+                List<Entry> entries = new ArrayList<>(perRepoEntries);
                 for (int i = 0; i < entries.size(); i++) {
                     if (randomBoolean()) {
                         final Entry entry = entries.get(i);
-                        entries.set(i, mutateEntry(entry));
+                        entries.set(i, mutateEntryWithLegalChange(entry));
                     }
                 }
+                if (randomBoolean()) {
+                    entries = shuffledList(entries);
+                }
                 updatedInstance = updatedInstance.withUpdatedEntriesForRepo(perRepoEntries.get(0).repository(), entries);
             }
         }
@@ -187,7 +190,7 @@ public class SnapshotsInProgressSerializationTests extends SimpleDiffableWireSer
     }
 
     private Entry mutateEntry(Entry entry) {
-        switch (randomInt(8)) {
+        switch (randomInt(5)) {
             case 0 -> {
                 boolean includeGlobalState = entry.includeGlobalState() == false;
                 return Entry.snapshot(
@@ -225,16 +228,16 @@ public class SnapshotsInProgressSerializationTests extends SimpleDiffableWireSer
                 );
             }
             case 2 -> {
-                List<String> dataStreams = Stream.concat(entry.dataStreams().stream(), Stream.of(randomAlphaOfLength(10))).toList();
+                long startTime = randomValueOtherThan(entry.startTime(), ESTestCase::randomLong);
                 return Entry.snapshot(
                     entry.snapshot(),
                     entry.includeGlobalState(),
                     entry.partial(),
                     entry.state(),
                     entry.indices(),
-                    dataStreams,
+                    entry.dataStreams(),
                     entry.featureStates(),
-                    entry.startTime(),
+                    startTime,
                     entry.repositoryStateId(),
                     entry.shards(),
                     entry.failure(),
@@ -243,7 +246,13 @@ public class SnapshotsInProgressSerializationTests extends SimpleDiffableWireSer
                 );
             }
             case 3 -> {
-                long startTime = randomValueOtherThan(entry.startTime(), ESTestCase::randomLong);
+                Map<String, Object> userMetadata = entry.userMetadata() != null ? new HashMap<>(entry.userMetadata()) : new HashMap<>();
+                String key = randomAlphaOfLengthBetween(2, 10);
+                if (userMetadata.containsKey(key)) {
+                    userMetadata.remove(key);
+                } else {
+                    userMetadata.put(key, randomAlphaOfLengthBetween(2, 10));
+                }
                 return Entry.snapshot(
                     entry.snapshot(),
                     entry.includeGlobalState(),
@@ -252,16 +261,20 @@ public class SnapshotsInProgressSerializationTests extends SimpleDiffableWireSer
                     entry.indices(),
                     entry.dataStreams(),
                     entry.featureStates(),
-                    startTime,
+                    entry.startTime(),
                     entry.repositoryStateId(),
                     entry.shards(),
                     entry.failure(),
-                    entry.userMetadata(),
+                    userMetadata,
                     entry.version()
                 );
             }
             case 4 -> {
-                long repositoryStateId = randomValueOtherThan(entry.startTime(), ESTestCase::randomLong);
+                List<SnapshotFeatureInfo> featureStates = randomList(
+                    1,
+                    5,
+                    () -> randomValueOtherThanMany(entry.featureStates()::contains, SnapshotFeatureInfoTests::randomSnapshotFeatureInfo)
+                );
                 return Entry.snapshot(
                     entry.snapshot(),
                     entry.includeGlobalState(),
@@ -269,9 +282,9 @@ public class SnapshotsInProgressSerializationTests extends SimpleDiffableWireSer
                     entry.state(),
                     entry.indices(),
                     entry.dataStreams(),
-                    entry.featureStates(),
+                    featureStates,
                     entry.startTime(),
-                    repositoryStateId,
+                    entry.repositoryStateId(),
                     entry.shards(),
                     entry.failure(),
                     entry.userMetadata(),
@@ -279,57 +292,53 @@ public class SnapshotsInProgressSerializationTests extends SimpleDiffableWireSer
                 );
             }
             case 5 -> {
-                String failure = randomValueOtherThan(entry.failure(), () -> randomAlphaOfLengthBetween(2, 10));
+                return mutateEntryWithLegalChange(entry);
+            }
+            default -> throw new IllegalArgumentException("invalid randomization case");
+        }
+    }
+
+    // mutates an entry with a change that could occur as part of a cluster state update and is thus diffable
+    private Entry mutateEntryWithLegalChange(Entry entry) {
+        switch (randomInt(3)) {
+            case 0 -> {
+                List<String> dataStreams = Stream.concat(entry.dataStreams().stream(), Stream.of(randomAlphaOfLength(10))).toList();
                 return Entry.snapshot(
                     entry.snapshot(),
                     entry.includeGlobalState(),
                     entry.partial(),
                     entry.state(),
                     entry.indices(),
-                    entry.dataStreams(),
+                    dataStreams,
                     entry.featureStates(),
                     entry.startTime(),
                     entry.repositoryStateId(),
                     entry.shards(),
-                    failure,
+                    entry.failure(),
                     entry.userMetadata(),
                     entry.version()
                 );
             }
-            case 6 -> {
-                Map<String, IndexId> indices = new HashMap<>(entry.indices());
-                IndexId indexId = new IndexId(randomAlphaOfLength(10), randomAlphaOfLength(10));
-                indices.put(indexId.getName(), indexId);
-                Map<ShardId, SnapshotsInProgress.ShardSnapshotStatus> shards = new HashMap<>(entry.shards());
-                Index index = new Index(indexId.getName(), randomAlphaOfLength(10));
-                int shardsCount = randomIntBetween(1, 10);
-                for (int j = 0; j < shardsCount; j++) {
-                    shards.put(new ShardId(index, j), randomShardSnapshotStatus(randomAlphaOfLength(10)));
-                }
+            case 1 -> {
+                long repositoryStateId = randomValueOtherThan(entry.repositoryStateId(), ESTestCase::randomLong);
                 return Entry.snapshot(
                     entry.snapshot(),
                     entry.includeGlobalState(),
                     entry.partial(),
-                    randomState(shards),
-                    indices,
+                    entry.state(),
+                    entry.indices(),
                     entry.dataStreams(),
                     entry.featureStates(),
                     entry.startTime(),
-                    entry.repositoryStateId(),
-                    shards,
+                    repositoryStateId,
+                    entry.shards(),
                     entry.failure(),
                     entry.userMetadata(),
                     entry.version()
                 );
             }
-            case 7 -> {
-                Map<String, Object> userMetadata = entry.userMetadata() != null ? new HashMap<>(entry.userMetadata()) : new HashMap<>();
-                String key = randomAlphaOfLengthBetween(2, 10);
-                if (userMetadata.containsKey(key)) {
-                    userMetadata.remove(key);
-                } else {
-                    userMetadata.put(key, randomAlphaOfLengthBetween(2, 10));
-                }
+            case 2 -> {
+                String failure = randomValueOtherThan(entry.failure(), () -> randomAlphaOfLengthBetween(2, 10));
                 return Entry.snapshot(
                     entry.snapshot(),
                     entry.includeGlobalState(),
@@ -341,28 +350,32 @@ public class SnapshotsInProgressSerializationTests extends SimpleDiffableWireSer
                     entry.startTime(),
                     entry.repositoryStateId(),
                     entry.shards(),
-                    entry.failure(),
-                    userMetadata,
+                    failure,
+                    entry.userMetadata(),
                     entry.version()
                 );
             }
-            case 8 -> {
-                List<SnapshotFeatureInfo> featureStates = randomList(
-                    1,
-                    5,
-                    () -> randomValueOtherThanMany(entry.featureStates()::contains, SnapshotFeatureInfoTests::randomSnapshotFeatureInfo)
-                );
+            case 3 -> {
+                Map<String, IndexId> indices = new HashMap<>(entry.indices());
+                IndexId indexId = new IndexId(randomAlphaOfLength(10), randomAlphaOfLength(10));
+                indices.put(indexId.getName(), indexId);
+                Map<ShardId, SnapshotsInProgress.ShardSnapshotStatus> shards = new HashMap<>(entry.shards());
+                Index index = new Index(indexId.getName(), randomAlphaOfLength(10));
+                int shardsCount = randomIntBetween(1, 10);
+                for (int j = 0; j < shardsCount; j++) {
+                    shards.put(new ShardId(index, j), randomShardSnapshotStatus(randomAlphaOfLength(10)));
+                }
                 return Entry.snapshot(
                     entry.snapshot(),
                     entry.includeGlobalState(),
                     entry.partial(),
-                    entry.state(),
-                    entry.indices(),
+                    randomState(shards),
+                    indices,
                     entry.dataStreams(),
-                    featureStates,
+                    entry.featureStates(),
                     entry.startTime(),
                     entry.repositoryStateId(),
-                    entry.shards(),
+                    shards,
                     entry.failure(),
                     entry.userMetadata(),
                     entry.version()