Browse Source

Fix disk computation when initializing new shards (#102879)

Currently elasticsearch checks if there is enough space to initialize each shard individually. This makes it possible to initialize 2 shards on the node that have enough space for only one of them. This change takes into account all shards initialized within a given round of `BalancedShardsAllocator.Balancer#allocateUnassigned` in order to prevent that.
Ievgen Degtiarenko 1 year ago
parent
commit
47dc5b67ce

+ 5 - 0
docs/changelog/102879.yaml

@@ -0,0 +1,5 @@
+pr: 102879
+summary: Fix disk computation when initializing new shards
+area: Allocation
+type: bug
+issues: []

+ 124 - 51
server/src/internalClusterTest/java/org/elasticsearch/cluster/routing/allocation/decider/DiskThresholdDeciderIT.java

@@ -11,7 +11,6 @@ package org.elasticsearch.cluster.routing.allocation.decider;
 import org.elasticsearch.action.admin.cluster.snapshots.create.CreateSnapshotResponse;
 import org.elasticsearch.action.admin.cluster.snapshots.restore.RestoreSnapshotResponse;
 import org.elasticsearch.action.admin.indices.stats.ShardStats;
-import org.elasticsearch.action.index.IndexRequestBuilder;
 import org.elasticsearch.cluster.ClusterInfoService;
 import org.elasticsearch.cluster.ClusterInfoServiceUtils;
 import org.elasticsearch.cluster.DiskUsageIntegTestCase;
@@ -39,13 +38,16 @@ import org.hamcrest.Matcher;
 import org.hamcrest.TypeSafeMatcher;
 
 import java.util.Arrays;
+import java.util.Comparator;
 import java.util.HashSet;
-import java.util.Locale;
+import java.util.List;
 import java.util.Set;
 import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
 
-import static java.util.stream.Collectors.toMap;
 import static java.util.stream.Collectors.toSet;
+import static org.elasticsearch.cluster.routing.RoutingNodesHelper.numberOfShardsWithState;
+import static org.elasticsearch.cluster.routing.allocation.decider.EnableAllocationDecider.CLUSTER_ROUTING_REBALANCE_ENABLE_SETTING;
 import static org.elasticsearch.index.store.Store.INDEX_STORE_STATS_REFRESH_INTERVAL_SETTING;
 import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
 import static org.hamcrest.Matchers.empty;
@@ -74,26 +76,25 @@ public class DiskThresholdDeciderIT extends DiskUsageIntegTestCase {
         final String dataNodeName = internalCluster().startDataOnlyNode();
         ensureStableCluster(3);
 
-        final InternalClusterInfoService clusterInfoService = (InternalClusterInfoService) internalCluster().getCurrentMasterNodeInstance(
-            ClusterInfoService.class
-        );
-        internalCluster().getCurrentMasterNodeInstance(ClusterService.class)
-            .addListener(event -> ClusterInfoServiceUtils.refresh(clusterInfoService));
+        final InternalClusterInfoService clusterInfoService = getInternalClusterInfoService();
+        internalCluster().getCurrentMasterNodeInstance(ClusterService.class).addListener(event -> {
+            ClusterInfoServiceUtils.refresh(clusterInfoService);
+        });
 
         final String dataNode0Id = internalCluster().getInstance(NodeEnvironment.class, dataNodeName).nodeId();
 
-        final String indexName = randomAlphaOfLength(10).toLowerCase(Locale.ROOT);
+        final String indexName = randomIdentifier();
         createIndex(indexName, indexSettings(6, 0).put(INDEX_STORE_STATS_REFRESH_INTERVAL_SETTING.getKey(), "0ms").build());
-        var smallestShard = createReasonableSizedShards(indexName);
+        var shardSizes = createReasonableSizedShards(indexName);
 
         // reduce disk size of node 0 so that no shards fit below the high watermark, forcing all shards onto the other data node
         // (subtract the translog size since the disk threshold decider ignores this and may therefore move the shard back again)
-        getTestFileStore(dataNodeName).setTotalSpace(smallestShard.size + WATERMARK_BYTES - 1L);
+        getTestFileStore(dataNodeName).setTotalSpace(shardSizes.getSmallestShardSize() + WATERMARK_BYTES - 1L);
         assertBusyWithDiskUsageRefresh(dataNode0Id, indexName, empty());
 
         // increase disk size of node 0 to allow just enough room for one shard, and check that it's rebalanced back
-        getTestFileStore(dataNodeName).setTotalSpace(smallestShard.size + WATERMARK_BYTES);
-        assertBusyWithDiskUsageRefresh(dataNode0Id, indexName, new ContainsExactlyOneOf<>(smallestShard.shardIds));
+        getTestFileStore(dataNodeName).setTotalSpace(shardSizes.getSmallestShardSize() + WATERMARK_BYTES);
+        assertBusyWithDiskUsageRefresh(dataNode0Id, indexName, new ContainsExactlyOneOf<>(shardSizes.getSmallestShardIds()));
     }
 
     public void testRestoreSnapshotAllocationDoesNotExceedWatermark() throws Exception {
@@ -108,17 +109,20 @@ public class DiskThresholdDeciderIT extends DiskUsageIntegTestCase {
                 .setSettings(Settings.builder().put("location", randomRepoPath()).put("compress", randomBoolean()))
         );
 
-        final InternalClusterInfoService clusterInfoService = (InternalClusterInfoService) internalCluster().getCurrentMasterNodeInstance(
-            ClusterInfoService.class
-        );
-        internalCluster().getCurrentMasterNodeInstance(ClusterService.class)
-            .addListener(event -> ClusterInfoServiceUtils.refresh(clusterInfoService));
+        final AtomicBoolean allowRelocations = new AtomicBoolean(true);
+        final InternalClusterInfoService clusterInfoService = getInternalClusterInfoService();
+        internalCluster().getCurrentMasterNodeInstance(ClusterService.class).addListener(event -> {
+            ClusterInfoServiceUtils.refresh(clusterInfoService);
+            if (allowRelocations.get() == false) {
+                assertThat(numberOfShardsWithState(event.state().getRoutingNodes(), ShardRoutingState.RELOCATING), equalTo(0));
+            }
+        });
 
         final String dataNode0Id = internalCluster().getInstance(NodeEnvironment.class, dataNodeName).nodeId();
 
-        final String indexName = randomAlphaOfLength(10).toLowerCase(Locale.ROOT);
+        final String indexName = randomIdentifier();
         createIndex(indexName, indexSettings(6, 0).put(INDEX_STORE_STATS_REFRESH_INTERVAL_SETTING.getKey(), "0ms").build());
-        var smallestShard = createReasonableSizedShards(indexName);
+        var shardSizes = createReasonableSizedShards(indexName);
 
         final CreateSnapshotResponse createSnapshotResponse = clusterAdmin().prepareCreateSnapshot("repo", "snap")
             .setWaitForCompletion(true)
@@ -128,15 +132,13 @@ public class DiskThresholdDeciderIT extends DiskUsageIntegTestCase {
         assertThat(snapshotInfo.state(), is(SnapshotState.SUCCESS));
 
         assertAcked(indicesAdmin().prepareDelete(indexName).get());
+        updateClusterSettings(Settings.builder().put(CLUSTER_ROUTING_REBALANCE_ENABLE_SETTING.getKey(), Rebalance.NONE.toString()));
+        allowRelocations.set(false);
 
         // reduce disk size of node 0 so that no shards fit below the low watermark, forcing shards to be assigned to the other data node
-        getTestFileStore(dataNodeName).setTotalSpace(smallestShard.size + WATERMARK_BYTES - 1L);
+        getTestFileStore(dataNodeName).setTotalSpace(shardSizes.getSmallestShardSize() + WATERMARK_BYTES - 1L);
         refreshDiskUsage();
 
-        updateClusterSettings(
-            Settings.builder().put(EnableAllocationDecider.CLUSTER_ROUTING_REBALANCE_ENABLE_SETTING.getKey(), Rebalance.NONE.toString())
-        );
-
         final RestoreSnapshotResponse restoreSnapshotResponse = clusterAdmin().prepareRestoreSnapshot("repo", "snap")
             .setWaitForCompletion(true)
             .get();
@@ -144,13 +146,71 @@ public class DiskThresholdDeciderIT extends DiskUsageIntegTestCase {
         assertThat(restoreInfo.successfulShards(), is(snapshotInfo.totalShards()));
         assertThat(restoreInfo.failedShards(), is(0));
 
-        assertBusy(() -> assertThat(getShardIds(dataNode0Id, indexName), empty()));
+        assertThat(getShardIds(dataNode0Id, indexName), empty());
 
-        updateClusterSettings(Settings.builder().putNull(EnableAllocationDecider.CLUSTER_ROUTING_REBALANCE_ENABLE_SETTING.getKey()));
+        allowRelocations.set(true);
+        updateClusterSettings(Settings.builder().putNull(CLUSTER_ROUTING_REBALANCE_ENABLE_SETTING.getKey()));
 
         // increase disk size of node 0 to allow just enough room for one shard, and check that it's rebalanced back
-        getTestFileStore(dataNodeName).setTotalSpace(smallestShard.size + WATERMARK_BYTES);
-        assertBusyWithDiskUsageRefresh(dataNode0Id, indexName, new ContainsExactlyOneOf<>(smallestShard.shardIds));
+        getTestFileStore(dataNodeName).setTotalSpace(shardSizes.getSmallestShardSize() + WATERMARK_BYTES);
+        assertBusyWithDiskUsageRefresh(dataNode0Id, indexName, new ContainsExactlyOneOf<>(shardSizes.getSmallestShardIds()));
+    }
+
+    public void testRestoreSnapshotAllocationDoesNotExceedWatermarkWithMultipleShards() throws Exception {
+        internalCluster().startMasterOnlyNode();
+        internalCluster().startDataOnlyNode();
+        final String dataNodeName = internalCluster().startDataOnlyNode();
+        ensureStableCluster(3);
+
+        assertAcked(
+            clusterAdmin().preparePutRepository("repo")
+                .setType(FsRepository.TYPE)
+                .setSettings(Settings.builder().put("location", randomRepoPath()).put("compress", randomBoolean()))
+        );
+
+        final AtomicBoolean allowRelocations = new AtomicBoolean(true);
+        final InternalClusterInfoService clusterInfoService = getInternalClusterInfoService();
+        internalCluster().getCurrentMasterNodeInstance(ClusterService.class).addListener(event -> {
+            ClusterInfoServiceUtils.refresh(clusterInfoService);
+            if (allowRelocations.get() == false) {
+                assertThat(numberOfShardsWithState(event.state().getRoutingNodes(), ShardRoutingState.RELOCATING), equalTo(0));
+            }
+        });
+
+        final String dataNode0Id = internalCluster().getInstance(NodeEnvironment.class, dataNodeName).nodeId();
+
+        final String indexName = randomIdentifier();
+        createIndex(indexName, indexSettings(6, 0).put(INDEX_STORE_STATS_REFRESH_INTERVAL_SETTING.getKey(), "0ms").build());
+        var shardSizes = createReasonableSizedShards(indexName);
+
+        final CreateSnapshotResponse createSnapshotResponse = clusterAdmin().prepareCreateSnapshot("repo", "snap")
+            .setWaitForCompletion(true)
+            .get();
+        final SnapshotInfo snapshotInfo = createSnapshotResponse.getSnapshotInfo();
+        assertThat(snapshotInfo.successfulShards(), is(snapshotInfo.totalShards()));
+        assertThat(snapshotInfo.state(), is(SnapshotState.SUCCESS));
+
+        assertAcked(indicesAdmin().prepareDelete(indexName).get());
+        updateClusterSettings(Settings.builder().put(CLUSTER_ROUTING_REBALANCE_ENABLE_SETTING.getKey(), Rebalance.NONE.toString()));
+        allowRelocations.set(false);
+
+        // reduce disk size of node 0 so that only 1 of 2 smallest shards can be allocated
+        var usableSpace = shardSizes.sizes().get(1).size();
+        getTestFileStore(dataNodeName).setTotalSpace(usableSpace + WATERMARK_BYTES + 1L);
+        refreshDiskUsage();
+
+        final RestoreSnapshotResponse restoreSnapshotResponse = clusterAdmin().prepareRestoreSnapshot("repo", "snap")
+            .setWaitForCompletion(true)
+            .get();
+        final RestoreInfo restoreInfo = restoreSnapshotResponse.getRestoreInfo();
+        assertThat(restoreInfo.successfulShards(), is(snapshotInfo.totalShards()));
+        assertThat(restoreInfo.failedShards(), is(0));
+
+        assertBusyWithDiskUsageRefresh(
+            dataNode0Id,
+            indexName,
+            new ContainsExactlyOneOf<>(shardSizes.getShardIdsWithSizeSmallerOrEqual(usableSpace))
+        );
     }
 
     private Set<ShardId> getShardIds(final String nodeId, final String indexName) {
@@ -178,13 +238,9 @@ public class DiskThresholdDeciderIT extends DiskUsageIntegTestCase {
     /**
      * Index documents until all the shards are at least WATERMARK_BYTES in size, and return the one with the smallest size
      */
-    private SmallestShards createReasonableSizedShards(final String indexName) throws InterruptedException {
+    private ShardSizes createReasonableSizedShards(final String indexName) throws InterruptedException {
         while (true) {
-            final IndexRequestBuilder[] indexRequestBuilders = new IndexRequestBuilder[scaledRandomIntBetween(100, 10000)];
-            for (int i = 0; i < indexRequestBuilders.length; i++) {
-                indexRequestBuilders[i] = prepareIndex(indexName).setSource("field", randomAlphaOfLength(10));
-            }
-            indexRandom(true, indexRequestBuilders);
+            indexRandom(true, indexName, scaledRandomIntBetween(100, 10000));
             forceMerge();
             refresh();
 
@@ -201,23 +257,36 @@ public class DiskThresholdDeciderIT extends DiskUsageIntegTestCase {
                 .orElseThrow(() -> new AssertionError("no shards"));
 
             if (smallestShardSize > WATERMARK_BYTES) {
-                var smallestShardIds = Arrays.stream(shardStates)
-                    .filter(it -> it.getStats().getStore().sizeInBytes() == smallestShardSize)
-                    .map(it -> removeIndexUUID(it.getShardRouting().shardId()))
-                    .collect(toSet());
-
-                logger.info(
-                    "Created shards with sizes {}",
-                    Arrays.stream(shardStates)
-                        .collect(toMap(it -> it.getShardRouting().shardId(), it -> it.getStats().getStore().sizeInBytes()))
-                );
-
-                return new SmallestShards(smallestShardSize, smallestShardIds);
+                var shardSizes = Arrays.stream(shardStates)
+                    .map(it -> new ShardSize(removeIndexUUID(it.getShardRouting().shardId()), it.getStats().getStore().sizeInBytes()))
+                    .sorted(Comparator.comparing(ShardSize::size))
+                    .toList();
+                logger.info("Created shards with sizes {}", shardSizes);
+                return new ShardSizes(shardSizes);
             }
         }
     }
 
-    private record SmallestShards(long size, Set<ShardId> shardIds) {}
+    private record ShardSizes(List<ShardSize> sizes) {
+
+        public long getSmallestShardSize() {
+            return sizes.get(0).size();
+        }
+
+        public Set<ShardId> getShardIdsWithSizeSmallerOrEqual(long size) {
+            return sizes.stream().filter(entry -> entry.size <= size).map(ShardSize::shardId).collect(toSet());
+        }
+
+        public Set<ShardId> getSmallestShardIds() {
+            return getShardIdsWithSizeSmallerOrEqual(getSmallestShardSize());
+        }
+
+        public Set<ShardId> getAllShardIds() {
+            return sizes.stream().map(ShardSize::shardId).collect(toSet());
+        }
+    }
+
+    private record ShardSize(ShardId shardId, long size) {}
 
     private static ShardId removeIndexUUID(ShardId shardId) {
         return ShardId.fromString(shardId.toString());
@@ -246,16 +315,20 @@ public class DiskThresholdDeciderIT extends DiskUsageIntegTestCase {
         );
     }
 
-    private void assertBusyWithDiskUsageRefresh(String nodeName, String indexName, Matcher<? super Set<ShardId>> matcher) throws Exception {
+    private void assertBusyWithDiskUsageRefresh(String nodeId, String indexName, Matcher<? super Set<ShardId>> matcher) throws Exception {
         assertBusy(() -> {
             // refresh the master's ClusterInfoService before checking the assigned shards because DiskThresholdMonitor might still
             // be processing a previous ClusterInfo update and will skip the new one (see DiskThresholdMonitor#onNewInfo(ClusterInfo)
             // and its internal checkInProgress flag)
             refreshDiskUsage();
 
-            final Set<ShardId> shardRoutings = getShardIds(nodeName, indexName);
+            final Set<ShardId> shardRoutings = getShardIds(nodeId, indexName);
             assertThat("Mismatching shard routings: " + shardRoutings, shardRoutings, matcher);
-        }, 30L, TimeUnit.SECONDS);
+        }, 5L, TimeUnit.SECONDS);
+    }
+
+    private InternalClusterInfoService getInternalClusterInfoService() {
+        return (InternalClusterInfoService) internalCluster().getCurrentMasterNodeInstance(ClusterInfoService.class);
     }
 
     private static final class ContainsExactlyOneOf<T> extends TypeSafeMatcher<Set<T>> {

+ 27 - 2
server/src/main/java/org/elasticsearch/cluster/routing/ExpectedShardSizeEstimator.java

@@ -20,9 +20,13 @@ import java.util.Set;
 
 public class ExpectedShardSizeEstimator {
 
-    public static long getExpectedShardSize(ShardRouting shardRouting, long defaultSize, RoutingAllocation allocation) {
+    public static boolean shouldReserveSpaceForInitializingShard(ShardRouting shard, RoutingAllocation allocation) {
+        return shouldReserveSpaceForInitializingShard(shard, allocation.metadata());
+    }
+
+    public static long getExpectedShardSize(ShardRouting shard, long defaultSize, RoutingAllocation allocation) {
         return getExpectedShardSize(
-            shardRouting,
+            shard,
             defaultSize,
             allocation.clusterInfo(),
             allocation.snapshotShardSizeInfo(),
@@ -31,6 +35,27 @@ public class ExpectedShardSizeEstimator {
         );
     }
 
+    public static boolean shouldReserveSpaceForInitializingShard(ShardRouting shard, Metadata metadata) {
+        assert shard.initializing() : "Expected initializing shard, got: " + shard;
+        return switch (shard.recoverySource().getType()) {
+            // No need to reserve disk space when initializing a new empty shard
+            case EMPTY_STORE -> false;
+
+            // No need to reserve disk space if the shard is already allocated on the disk. Starting it is not going to use more.
+            case EXISTING_STORE -> false;
+
+            // Peer recovery require downloading all segments locally to start the shard. Reserve disk space for this
+            case PEER -> true;
+
+            // Snapshot restore (unless it is partial) require downloading all segments locally from the blobstore to start the shard.
+            case SNAPSHOT -> metadata.getIndexSafe(shard.index()).isPartialSearchableSnapshot() == false;
+
+            // shrink/split/clone operation is going to clone existing locally placed shards using file system hard links
+            // so no additional space is going to be used until future merges
+            case LOCAL_SHARDS -> false;
+        };
+    }
+
     /**
      * Returns the expected shard size for the given shard or the default value provided if not enough information are available
      * to estimate the shards size.

+ 2 - 0
server/src/main/java/org/elasticsearch/cluster/routing/allocation/DiskThresholdMonitor.java

@@ -34,6 +34,7 @@ import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
 import org.elasticsearch.core.Releasable;
 import org.elasticsearch.gateway.GatewayService;
+import org.elasticsearch.snapshots.SnapshotShardSizeInfo;
 
 import java.util.ArrayList;
 import java.util.Collections;
@@ -425,6 +426,7 @@ public class DiskThresholdMonitor {
             true,
             diskUsage.getPath(),
             info,
+            SnapshotShardSizeInfo.EMPTY,
             reroutedClusterState.metadata(),
             reroutedClusterState.routingTable(),
             0L

+ 18 - 23
server/src/main/java/org/elasticsearch/cluster/routing/allocation/decider/DiskThresholdDecider.java

@@ -32,6 +32,7 @@ import org.elasticsearch.snapshots.SnapshotShardSizeInfo;
 import java.util.Map;
 
 import static org.elasticsearch.cluster.routing.ExpectedShardSizeEstimator.getExpectedShardSize;
+import static org.elasticsearch.cluster.routing.ExpectedShardSizeEstimator.shouldReserveSpaceForInitializingShard;
 
 /**
  * The {@link DiskThresholdDecider} checks that the node a shard is potentially
@@ -117,6 +118,7 @@ public class DiskThresholdDecider extends AllocationDecider {
         boolean subtractShardsMovingAway,
         String dataPath,
         ClusterInfo clusterInfo,
+        SnapshotShardSizeInfo snapshotShardSizeInfo,
         Metadata metadata,
         RoutingTable routingTable,
         long sizeOfUnaccountableSearchableSnapshotShards
@@ -129,28 +131,18 @@ public class DiskThresholdDecider extends AllocationDecider {
 
         // Where reserved space is unavailable (e.g. stats are out-of-sync) compute a conservative estimate for initialising shards
         for (ShardRouting routing : node.initializing()) {
-            if (routing.relocatingNodeId() == null && metadata.getIndexSafe(routing.index()).isSearchableSnapshot() == false) {
-                // in practice the only initializing-but-not-relocating non-searchable-snapshot shards with a nonzero expected shard size
-                // will be ones created
-                // by a resize (shrink/split/clone) operation which we expect to happen using hard links, so they shouldn't be taking
-                // any additional space and can be ignored here
-                continue;
-            }
-            if (reservedSpace.containsShardId(routing.shardId())) {
-                continue;
-            }
-            final String actualPath = clusterInfo.getDataPath(routing);
-            // if we don't yet know the actual path of the incoming shard then conservatively assume it's going to the path with the least
-            // free space
-            if (actualPath == null || actualPath.equals(dataPath)) {
-                totalSize += getExpectedShardSize(
-                    routing,
-                    Math.max(routing.getExpectedShardSize(), 0L),
-                    clusterInfo,
-                    SnapshotShardSizeInfo.EMPTY,
-                    metadata,
-                    routingTable
-                );
+            // Space needs to be reserved only when initializing shards that are going to use additional space
+            // that is not yet accounted for by `reservedSpace` in case of lengthy recoveries
+            if (shouldReserveSpaceForInitializingShard(routing, metadata) && reservedSpace.containsShardId(routing.shardId()) == false) {
+                final String actualPath = clusterInfo.getDataPath(routing);
+                // if we don't yet know the actual path of the incoming shard then conservatively assume
+                // it's going to the path with the least free space
+                if (actualPath == null || actualPath.equals(dataPath)) {
+                    totalSize += Math.max(
+                        routing.getExpectedShardSize(),
+                        getExpectedShardSize(routing, 0L, clusterInfo, snapshotShardSizeInfo, metadata, routingTable)
+                    );
+                }
             }
         }
 
@@ -159,7 +151,7 @@ public class DiskThresholdDecider extends AllocationDecider {
         if (subtractShardsMovingAway) {
             for (ShardRouting routing : node.relocating()) {
                 if (dataPath.equals(clusterInfo.getDataPath(routing))) {
-                    totalSize -= getExpectedShardSize(routing, 0L, clusterInfo, SnapshotShardSizeInfo.EMPTY, metadata, routingTable);
+                    totalSize -= getExpectedShardSize(routing, 0L, clusterInfo, snapshotShardSizeInfo, metadata, routingTable);
                 }
             }
         }
@@ -204,6 +196,7 @@ public class DiskThresholdDecider extends AllocationDecider {
                 false,
                 usage.getPath(),
                 allocation.clusterInfo(),
+                allocation.snapshotShardSizeInfo(),
                 allocation.metadata(),
                 allocation.routingTable(),
                 allocation.unaccountedSearchableSnapshotSize(node)
@@ -412,6 +405,7 @@ public class DiskThresholdDecider extends AllocationDecider {
                 true,
                 usage.getPath(),
                 allocation.clusterInfo(),
+                allocation.snapshotShardSizeInfo(),
                 allocation.metadata(),
                 allocation.routingTable(),
                 allocation.unaccountedSearchableSnapshotSize(node)
@@ -491,6 +485,7 @@ public class DiskThresholdDecider extends AllocationDecider {
                 subtractLeavingShards,
                 usage.getPath(),
                 allocation.clusterInfo(),
+                allocation.snapshotShardSizeInfo(),
                 allocation.metadata(),
                 allocation.routingTable(),
                 allocation.unaccountedSearchableSnapshotSize(node)

+ 57 - 6
server/src/test/java/org/elasticsearch/cluster/routing/ExpectedShardSizeEstimatorTests.java

@@ -30,33 +30,54 @@ import java.util.Map;
 import static org.elasticsearch.cluster.metadata.IndexMetadata.INDEX_RESIZE_SOURCE_NAME_KEY;
 import static org.elasticsearch.cluster.metadata.IndexMetadata.INDEX_RESIZE_SOURCE_UUID_KEY;
 import static org.elasticsearch.cluster.routing.ExpectedShardSizeEstimator.getExpectedShardSize;
+import static org.elasticsearch.cluster.routing.ExpectedShardSizeEstimator.shouldReserveSpaceForInitializingShard;
 import static org.elasticsearch.cluster.routing.TestShardRouting.newShardRouting;
+import static org.elasticsearch.index.IndexModule.INDEX_STORE_TYPE_SETTING;
+import static org.elasticsearch.snapshots.SearchableSnapshotsSettings.SEARCHABLE_SNAPSHOT_STORE_TYPE;
+import static org.elasticsearch.snapshots.SearchableSnapshotsSettings.SNAPSHOT_PARTIAL_SETTING;
 import static org.hamcrest.Matchers.equalTo;
 
 public class ExpectedShardSizeEstimatorTests extends ESAllocationTestCase {
 
     private final long defaultValue = randomLongBetween(-1, 0);
 
-    public void testShouldFallbackToDefaultValue() {
+    public void testShouldFallbackToDefaultExpectedShardSize() {
 
         var state = ClusterState.builder(ClusterName.DEFAULT).metadata(metadata(index("my-index"))).build();
-        var shard = newShardRouting("my-index", 0, randomIdentifier(), true, ShardRoutingState.INITIALIZING);
+        var shard = newShardRouting(
+            new ShardId("my-index", "_na_", 0),
+            randomIdentifier(),
+            true,
+            ShardRoutingState.INITIALIZING,
+            randomFrom(RecoverySource.EmptyStoreRecoverySource.INSTANCE, RecoverySource.ExistingStoreRecoverySource.INSTANCE)
+        );
 
         var allocation = createRoutingAllocation(state, ClusterInfo.EMPTY, SnapshotShardSizeInfo.EMPTY);
 
         assertThat(getExpectedShardSize(shard, defaultValue, allocation), equalTo(defaultValue));
+        assertFalse(
+            "Should NOT reserve space for locally initializing primaries",
+            shouldReserveSpaceForInitializingShard(shard, allocation)
+        );
     }
 
     public void testShouldReadExpectedSizeFromClusterInfo() {
 
         var shardSize = randomLongBetween(100, 1000);
         var state = ClusterState.builder(ClusterName.DEFAULT).metadata(metadata(index("my-index"))).build();
-        var shard = newShardRouting("my-index", 0, randomIdentifier(), true, ShardRoutingState.INITIALIZING);
+        var shard = newShardRouting(
+            new ShardId("my-index", "_na_", 0),
+            randomIdentifier(),
+            true,
+            ShardRoutingState.INITIALIZING,
+            RecoverySource.PeerRecoverySource.INSTANCE
+        );
 
         var clusterInfo = createClusterInfo(shard, shardSize);
         var allocation = createRoutingAllocation(state, clusterInfo, SnapshotShardSizeInfo.EMPTY);
 
         assertThat(getExpectedShardSize(shard, defaultValue, allocation), equalTo(shardSize));
+        assertTrue("Should reserve space for relocating shard", shouldReserveSpaceForInitializingShard(shard, allocation));
     }
 
     public void testShouldReadExpectedSizeFromPrimaryWhenAddingNewReplica() {
@@ -70,21 +91,39 @@ public class ExpectedShardSizeEstimatorTests extends ESAllocationTestCase {
         var allocation = createRoutingAllocation(state, clusterInfo, SnapshotShardSizeInfo.EMPTY);
 
         assertThat(getExpectedShardSize(replica, defaultValue, allocation), equalTo(shardSize));
+        assertTrue("Should reserve space for peer recovery", shouldReserveSpaceForInitializingShard(replica, allocation));
     }
 
     public void testShouldReadExpectedSizeWhenInitializingFromSnapshot() {
 
         var snapshotShardSize = randomLongBetween(100, 1000);
-        var state = ClusterState.builder(ClusterName.DEFAULT).metadata(metadata(index("my-index"))).build();
+
+        var index = switch (randomIntBetween(0, 2)) {
+            // regular snapshot
+            case 0 -> index("my-index");
+            // searchable snapshot
+            case 1 -> index("my-index").settings(
+                indexSettings(IndexVersion.current(), 1, 0) //
+                    .put(INDEX_STORE_TYPE_SETTING.getKey(), SEARCHABLE_SNAPSHOT_STORE_TYPE) //
+            );
+            // partial searchable snapshot
+            case 2 -> index("my-index").settings(
+                indexSettings(IndexVersion.current(), 1, 0) //
+                    .put(INDEX_STORE_TYPE_SETTING.getKey(), SEARCHABLE_SNAPSHOT_STORE_TYPE) //
+                    .put(SNAPSHOT_PARTIAL_SETTING.getKey(), true) //
+            );
+            default -> throw new AssertionError("unexpected index type");
+        };
+        var state = ClusterState.builder(ClusterName.DEFAULT).metadata(metadata(index)).build();
 
         var snapshot = new Snapshot("repository", new SnapshotId("snapshot-1", "na"));
         var indexId = new IndexId("my-index", "_na_");
 
         var shard = newShardRouting(
             new ShardId("my-index", "_na_", 0),
-            null,
+            randomIdentifier(),
             true,
-            ShardRoutingState.UNASSIGNED,
+            ShardRoutingState.INITIALIZING,
             new RecoverySource.SnapshotRecoverySource(randomUUID(), snapshot, IndexVersion.current(), indexId)
         );
 
@@ -94,6 +133,14 @@ public class ExpectedShardSizeEstimatorTests extends ESAllocationTestCase {
         var allocation = createRoutingAllocation(state, ClusterInfo.EMPTY, snapshotShardSizeInfo);
 
         assertThat(getExpectedShardSize(shard, defaultValue, allocation), equalTo(snapshotShardSize));
+        if (state.metadata().index("my-index").isPartialSearchableSnapshot() == false) {
+            assertTrue("Should reserve space for snapshot restore", shouldReserveSpaceForInitializingShard(shard, allocation));
+        } else {
+            assertFalse(
+                "Should NOT reserve space for partial searchable snapshot restore as they do not download all data during initialization",
+                shouldReserveSpaceForInitializingShard(shard, allocation)
+            );
+        }
     }
 
     public void testShouldReadSizeFromClonedShard() {
@@ -127,6 +174,10 @@ public class ExpectedShardSizeEstimatorTests extends ESAllocationTestCase {
         var allocation = createRoutingAllocation(state, clusterInfo, SnapshotShardSizeInfo.EMPTY);
 
         assertThat(getExpectedShardSize(target, defaultValue, allocation), equalTo(sourceShardSize));
+        assertFalse(
+            "Should NOT reserve space when using fs hardlink for clone/shrink/split",
+            shouldReserveSpaceForInitializingShard(target, state.metadata())
+        );
     }
 
     private static RoutingAllocation createRoutingAllocation(

+ 151 - 3
server/src/test/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceComputerTests.java

@@ -16,6 +16,7 @@ import org.elasticsearch.cluster.ClusterName;
 import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.cluster.DiskUsage;
 import org.elasticsearch.cluster.ESAllocationTestCase;
+import org.elasticsearch.cluster.RestoreInProgress;
 import org.elasticsearch.cluster.TestShardRoutingRoleStrategies;
 import org.elasticsearch.cluster.metadata.IndexMetadata;
 import org.elasticsearch.cluster.metadata.Metadata;
@@ -40,9 +41,14 @@ import org.elasticsearch.common.Randomness;
 import org.elasticsearch.common.UUIDs;
 import org.elasticsearch.common.settings.ClusterSettings;
 import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.common.util.Maps;
 import org.elasticsearch.index.IndexVersion;
 import org.elasticsearch.index.shard.ShardId;
+import org.elasticsearch.repositories.IndexId;
+import org.elasticsearch.snapshots.InternalSnapshotsInfoService.SnapshotShard;
+import org.elasticsearch.snapshots.Snapshot;
+import org.elasticsearch.snapshots.SnapshotId;
 import org.elasticsearch.snapshots.SnapshotShardSizeInfo;
 import org.elasticsearch.test.MockLogAppender;
 import org.elasticsearch.threadpool.ThreadPool;
@@ -61,6 +67,7 @@ import java.util.concurrent.atomic.AtomicLong;
 import java.util.function.Function;
 
 import static java.util.stream.Collectors.toMap;
+import static org.elasticsearch.cluster.ClusterInfo.shardIdentifierFromRouting;
 import static org.elasticsearch.cluster.routing.ShardRoutingState.INITIALIZING;
 import static org.elasticsearch.cluster.routing.ShardRoutingState.RELOCATING;
 import static org.elasticsearch.cluster.routing.ShardRoutingState.STARTED;
@@ -70,6 +77,7 @@ import static org.elasticsearch.common.settings.ClusterSettings.createBuiltInClu
 import static org.elasticsearch.test.MockLogAppender.assertThatLogger;
 import static org.hamcrest.Matchers.aMapWithSize;
 import static org.hamcrest.Matchers.allOf;
+import static org.hamcrest.Matchers.anyOf;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.everyItem;
 import static org.hamcrest.Matchers.hasEntry;
@@ -623,7 +631,7 @@ public class DesiredBalanceComputerTests extends ESAllocationTestCase {
                 var thisShardSize = smallShardSizeDeviation(shardSize);
 
                 var primaryNodeId = pickAndRemoveRandomValueFrom(remainingNodeIds);
-                shardSizes.put(ClusterInfo.shardIdentifierFromRouting(shardId, true), thisShardSize);
+                shardSizes.put(shardIdentifierFromRouting(shardId, true), thisShardSize);
                 totalShardsSize += thisShardSize;
                 if (primaryNodeId != null) {
                     dataPath.put(new NodeAndShard(primaryNodeId, shardId), "/data");
@@ -642,7 +650,7 @@ public class DesiredBalanceComputerTests extends ESAllocationTestCase {
                 );
                 for (int replica = 0; replica < replicas; replica++) {
                     var replicaNodeId = primaryNodeId == null ? null : pickAndRemoveRandomValueFrom(remainingNodeIds);
-                    shardSizes.put(ClusterInfo.shardIdentifierFromRouting(shardId, false), thisShardSize);
+                    shardSizes.put(shardIdentifierFromRouting(shardId, false), thisShardSize);
                     totalShardsSize += thisShardSize;
                     if (replicaNodeId != null) {
                         dataPath.put(new NodeAndShard(replicaNodeId, shardId), "/data");
@@ -862,6 +870,146 @@ public class DesiredBalanceComputerTests extends ESAllocationTestCase {
         return new ClusterInfo(diskUsage, diskUsage, shardSizes, Map.of(), Map.of(), Map.of());
     }
 
+    public void testAccountForSizeOfAllInitializingShardsDuringAllocation() {
+
+        var snapshot = new Snapshot("repository", new SnapshotId("snapshot", randomUUID()));
+
+        var shardSizeInfo = Maps.<String, Long>newHashMapWithExpectedSize(5);
+        var snapshotShardSizes = Maps.<SnapshotShard, Long>newHashMapWithExpectedSize(5);
+
+        var routingTableBuilder = RoutingTable.builder(TestShardRoutingRoleStrategies.DEFAULT_ROLE_ONLY);
+        // index-1 is allocated according to the desired balance
+        var indexMetadata1 = IndexMetadata.builder("index-1").settings(indexSettings(IndexVersion.current(), 2, 0)).build();
+        routingTableBuilder.add(
+            IndexRoutingTable.builder(indexMetadata1.getIndex())
+                .addShard(newShardRouting(shardIdFrom(indexMetadata1, 0), "node-1", true, STARTED))
+                .addShard(newShardRouting(shardIdFrom(indexMetadata1, 1), "node-2", true, STARTED))
+        );
+        shardSizeInfo.put(shardIdentifierFromRouting(shardIdFrom(indexMetadata1, 0), true), ByteSizeValue.ofGb(8).getBytes());
+        shardSizeInfo.put(shardIdentifierFromRouting(shardIdFrom(indexMetadata1, 1), true), ByteSizeValue.ofGb(8).getBytes());
+
+        // index-2 & index-3 are restored as new from snapshot
+        var indexMetadata2 = IndexMetadata.builder("index-2")
+            .settings(indexSettings(IndexVersion.current(), 1, 0).put(IndexMetadata.INDEX_PRIORITY_SETTING.getKey(), 2))
+            .build();
+        routingTableBuilder.addAsNewRestore(
+            indexMetadata2,
+            new RecoverySource.SnapshotRecoverySource("restore", snapshot, IndexVersion.current(), indexIdFrom(indexMetadata2)),
+            Set.of()
+        );
+        snapshotShardSizes.put(
+            new SnapshotShard(snapshot, indexIdFrom(indexMetadata2), shardIdFrom(indexMetadata2, 0)),
+            ByteSizeValue.ofGb(1).getBytes()
+        );
+
+        var indexMetadata3 = IndexMetadata.builder("index-3")
+            .settings(indexSettings(IndexVersion.current(), 2, 0).put(IndexMetadata.INDEX_PRIORITY_SETTING.getKey(), 1))
+            .build();
+        routingTableBuilder.addAsNewRestore(
+            indexMetadata3,
+            new RecoverySource.SnapshotRecoverySource("restore", snapshot, IndexVersion.current(), indexIdFrom(indexMetadata3)),
+            Set.of()
+        );
+        snapshotShardSizes.put(
+            new SnapshotShard(snapshot, indexIdFrom(indexMetadata3), shardIdFrom(indexMetadata3, 0)),
+            ByteSizeValue.ofMb(512).getBytes()
+        );
+        snapshotShardSizes.put(
+            new SnapshotShard(snapshot, indexIdFrom(indexMetadata3), shardIdFrom(indexMetadata3, 1)),
+            ByteSizeValue.ofMb(512).getBytes()
+        );
+
+        var clusterState = ClusterState.builder(ClusterName.DEFAULT)
+            .nodes(DiscoveryNodes.builder().add(newNode("node-1")).add(newNode("node-2")))
+            .metadata(Metadata.builder().put(indexMetadata1, false).put(indexMetadata2, false).put(indexMetadata3, false).build())
+            .routingTable(routingTableBuilder)
+            .customs(
+                Map.of(
+                    RestoreInProgress.TYPE,
+                    new RestoreInProgress.Builder().add(
+                        new RestoreInProgress.Entry(
+                            "restore",
+                            snapshot,
+                            RestoreInProgress.State.STARTED,
+                            randomBoolean(),
+                            List.of(indexMetadata2.getIndex().getName(), indexMetadata3.getIndex().getName()),
+                            Map.ofEntries(
+                                Map.entry(shardIdFrom(indexMetadata2, 0), new RestoreInProgress.ShardRestoreStatus(randomUUID())),
+                                Map.entry(shardIdFrom(indexMetadata3, 0), new RestoreInProgress.ShardRestoreStatus(randomUUID())),
+                                Map.entry(shardIdFrom(indexMetadata3, 1), new RestoreInProgress.ShardRestoreStatus(randomUUID()))
+                            )
+                        )
+                    ).build()
+                )
+            )
+            .build();
+
+        var clusterInfo = createClusterInfo(
+            List.of(
+                // node-1 has enough space to only allocate the only [index-2] shard
+                new DiskUsage("node-1", "data-1", "/data", ByteSizeValue.ofGb(10).getBytes(), ByteSizeValue.ofGb(2).getBytes()),
+                // node-2 has enough space to only allocate both shards of [index-3]
+                new DiskUsage("node-2", "data-2", "/data", ByteSizeValue.ofGb(10).getBytes(), ByteSizeValue.ofGb(2).getBytes())
+            ),
+            shardSizeInfo
+        );
+        var snapshotShardSizeInfo = new SnapshotShardSizeInfo(snapshotShardSizes);
+
+        var settings = Settings.EMPTY;
+        var allocation = new RoutingAllocation(
+            randomAllocationDeciders(settings, createBuiltInClusterSettings(settings)),
+            clusterState,
+            clusterInfo,
+            snapshotShardSizeInfo,
+            0L
+        );
+        var initialDesiredBalance = new DesiredBalance(
+            1,
+            Map.ofEntries(
+                Map.entry(shardIdFrom(indexMetadata1, 0), new ShardAssignment(Set.of("node-1"), 1, 0, 0)),
+                Map.entry(shardIdFrom(indexMetadata1, 1), new ShardAssignment(Set.of("node-2"), 1, 0, 0))
+            )
+        );
+        var nextDesiredBalance = createDesiredBalanceComputer(new BalancedShardsAllocator()).compute(
+            initialDesiredBalance,
+            new DesiredBalanceInput(2, allocation, List.of()),
+            queue(),
+            input -> true
+        );
+
+        assertThat(
+            nextDesiredBalance.assignments(),
+            anyOf(
+                equalTo(
+                    Map.ofEntries(
+                        Map.entry(shardIdFrom(indexMetadata1, 0), new ShardAssignment(Set.of("node-1"), 1, 0, 0)),
+                        Map.entry(shardIdFrom(indexMetadata1, 1), new ShardAssignment(Set.of("node-2"), 1, 0, 0)),
+                        Map.entry(shardIdFrom(indexMetadata2, 0), new ShardAssignment(Set.of("node-1"), 1, 0, 0)),
+                        Map.entry(shardIdFrom(indexMetadata3, 0), new ShardAssignment(Set.of("node-2"), 1, 0, 0)),
+                        Map.entry(shardIdFrom(indexMetadata3, 1), new ShardAssignment(Set.of("node-2"), 1, 0, 0))
+                    )
+                ),
+                equalTo(
+                    Map.ofEntries(
+                        Map.entry(shardIdFrom(indexMetadata1, 0), new ShardAssignment(Set.of("node-1"), 1, 0, 0)),
+                        Map.entry(shardIdFrom(indexMetadata1, 1), new ShardAssignment(Set.of("node-2"), 1, 0, 0)),
+                        Map.entry(shardIdFrom(indexMetadata2, 0), new ShardAssignment(Set.of("node-2"), 1, 0, 0)),
+                        Map.entry(shardIdFrom(indexMetadata3, 0), new ShardAssignment(Set.of("node-1"), 1, 0, 0)),
+                        Map.entry(shardIdFrom(indexMetadata3, 1), new ShardAssignment(Set.of("node-1"), 1, 0, 0))
+                    )
+                )
+            )
+        );
+    }
+
+    private static IndexId indexIdFrom(IndexMetadata indexMetadata) {
+        return new IndexId(indexMetadata.getIndex().getName(), indexMetadata.getIndex().getUUID());
+    }
+
+    private static ShardId shardIdFrom(IndexMetadata indexMetadata, int shardId) {
+        return new ShardId(indexMetadata.getIndex(), shardId);
+    }
+
     public void testShouldLogComputationIteration() {
         checkIterationLogging(
             999,
@@ -943,7 +1091,7 @@ public class DesiredBalanceComputerTests extends ESAllocationTestCase {
     }
 
     private static Map.Entry<String, Long> indexSize(ClusterState clusterState, String name, long size, boolean primary) {
-        return Map.entry(ClusterInfo.shardIdentifierFromRouting(findShardId(clusterState, name), primary), size);
+        return Map.entry(shardIdentifierFromRouting(findShardId(clusterState, name), primary), size);
     }
 
     private static ShardId findShardId(ClusterState clusterState, String name) {

+ 48 - 76
server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/DiskThresholdDeciderUnitTests.java

@@ -43,15 +43,15 @@ import org.elasticsearch.index.Index;
 import org.elasticsearch.index.IndexVersion;
 import org.elasticsearch.index.shard.ShardId;
 
-import java.util.ArrayList;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
-import java.util.List;
 import java.util.Map;
 
 import static java.util.Collections.emptySet;
+import static org.elasticsearch.cluster.ClusterInfo.shardIdentifierFromRouting;
 import static org.elasticsearch.cluster.routing.ExpectedShardSizeEstimator.getExpectedShardSize;
+import static org.elasticsearch.cluster.routing.TestShardRouting.newShardRouting;
 import static org.elasticsearch.index.IndexModule.INDEX_STORE_TYPE_SETTING;
 import static org.elasticsearch.snapshots.SearchableSnapshotsSettings.SEARCHABLE_SNAPSHOT_STORE_TYPE;
 import static org.hamcrest.Matchers.containsString;
@@ -512,93 +512,64 @@ public class DiskThresholdDeciderUnitTests extends ESAllocationTestCase {
     }
 
     public void testTakesIntoAccountExpectedSizeForInitializingSearchableSnapshots() {
-        String mainIndexName = "test";
-        Index index = new Index(mainIndexName, "1234");
-        String anotherIndexName = "another_index";
-        Index anotherIndex = new Index(anotherIndexName, "5678");
-        Metadata metadata = Metadata.builder()
-            .put(
-                IndexMetadata.builder(mainIndexName)
-                    .settings(
-                        settings(IndexVersion.current()).put("index.uuid", "1234")
-                            .put(INDEX_STORE_TYPE_SETTING.getKey(), SEARCHABLE_SNAPSHOT_STORE_TYPE)
-                    )
-                    .numberOfShards(3)
-                    .numberOfReplicas(1)
-            )
-            .put(
-                IndexMetadata.builder(anotherIndexName)
-                    .settings(settings(IndexVersion.current()).put("index.uuid", "5678"))
-                    .numberOfShards(1)
-                    .numberOfReplicas(1)
-            )
+
+        var searchableSnapshotIndex = IndexMetadata.builder("searchable_snapshot")
+            .settings(indexSettings(IndexVersion.current(), 3, 0).put(INDEX_STORE_TYPE_SETTING.getKey(), SEARCHABLE_SNAPSHOT_STORE_TYPE))
             .build();
+        var regularIndex = IndexMetadata.builder("regular_index").settings(indexSettings(IndexVersion.current(), 1, 0)).build();
+
         String nodeId = "node1";
-        String anotherNodeId = "another_node";
-
-        List<ShardRouting> shards = new ArrayList<>();
-        int anotherNodeShardCounter = 0;
-        int nodeShardCounter = 0;
-        Map<String, Long> initializingShardSizes = new HashMap<>();
-        for (int i = 1; i <= 3; i++) {
-            int expectedSize = 10 * i;
-            shards.add(createShard(index, nodeId, nodeShardCounter++, expectedSize));
-            if (randomBoolean()) {
-                ShardRouting initializingShard = ShardRoutingHelper.initialize(
-                    ShardRouting.newUnassigned(
-                        new ShardId(index, nodeShardCounter++),
-                        true,
-                        EmptyStoreRecoverySource.INSTANCE,
-                        new UnassignedInfo(UnassignedInfo.Reason.INDEX_CREATED, "foo"),
-                        ShardRouting.Role.DEFAULT
-                    ),
-                    nodeId
-                );
-                initializingShardSizes.put(ClusterInfo.shardIdentifierFromRouting(initializingShard), randomLongBetween(10, 50));
-                shards.add(initializingShard);
-            }
-            // randomly add shards for non-searchable snapshot index
-            if (randomBoolean()) {
-                for (int j = 0; j < randomIntBetween(1, 5); j++) {
-                    shards.add(createShard(anotherIndex, anotherNodeId, anotherNodeShardCounter++, expectedSize));
-                }
-            }
+        var knownShardSizes = new HashMap<String, Long>();
+        long unaccountedSearchableSnapshotSizes = 0;
+        long relocatingShardsSizes = 0;
+
+        var searchableSnapshotIndexRoutingTableBuilder = IndexRoutingTable.builder(searchableSnapshotIndex.getIndex());
+        for (int i = 0; i < searchableSnapshotIndex.getNumberOfShards(); i++) {
+            long expectedSize = randomLongBetween(10, 50);
+            // a searchable snapshot shard without corresponding entry in cluster info
+            ShardRouting startedShardWithExpectedSize = newShardRouting(
+                new ShardId(searchableSnapshotIndex.getIndex(), i),
+                nodeId,
+                true,
+                ShardRoutingState.STARTED,
+                expectedSize
+            );
+            searchableSnapshotIndexRoutingTableBuilder.addShard(startedShardWithExpectedSize);
+            unaccountedSearchableSnapshotSizes += expectedSize;
+        }
+        var regularIndexRoutingTableBuilder = IndexRoutingTable.builder(regularIndex.getIndex());
+        for (int i = 0; i < searchableSnapshotIndex.getNumberOfShards(); i++) {
+            var shardSize = randomLongBetween(10, 50);
+            // a shard relocating to this node
+            ShardRouting initializingShard = newShardRouting(
+                new ShardId(regularIndex.getIndex(), i),
+                nodeId,
+                true,
+                ShardRoutingState.INITIALIZING,
+                PeerRecoverySource.INSTANCE
+            );
+            regularIndexRoutingTableBuilder.addShard(initializingShard);
+            knownShardSizes.put(shardIdentifierFromRouting(initializingShard), shardSize);
+            relocatingShardsSizes += shardSize;
         }
 
-        DiscoveryNode node = DiscoveryNodeUtils.builder(nodeId).roles(emptySet()).build();
-        DiscoveryNode anotherNode = DiscoveryNodeUtils.builder(anotherNodeId).roles(emptySet()).build();
         ClusterState clusterState = ClusterState.builder(ClusterName.DEFAULT)
-            .metadata(metadata)
-            .routingTable(
-                RoutingTable.builder()
-                    .add(
-                        shards.stream()
-                            .filter(s -> s.getIndexName().equals(mainIndexName))
-                            .reduce(IndexRoutingTable.builder(index), IndexRoutingTable.Builder::addShard, (a, b) -> a)
-                    )
-                    .add(
-                        shards.stream()
-                            .filter(s -> s.getIndexName().equals(anotherIndexName))
-                            .reduce(IndexRoutingTable.builder(anotherIndex), IndexRoutingTable.Builder::addShard, (a, b) -> a)
-                    )
-                    .build()
-            )
-            .nodes(DiscoveryNodes.builder().add(node).add(anotherNode).build())
+            .metadata(Metadata.builder().put(searchableSnapshotIndex, false).put(regularIndex, false))
+            .routingTable(RoutingTable.builder().add(searchableSnapshotIndexRoutingTableBuilder).add(regularIndexRoutingTableBuilder))
+            .nodes(DiscoveryNodes.builder().add(newNode(nodeId)).build())
             .build();
+
         RoutingAllocation allocation = new RoutingAllocation(
             null,
             clusterState,
-            new DevNullClusterInfo(Map.of(), Map.of(), initializingShardSizes),
+            new DevNullClusterInfo(Map.of(), Map.of(), knownShardSizes),
             null,
             0
         );
-        long sizeOfUnaccountedShards = sizeOfUnaccountedShards(
-            allocation,
-            RoutingNodesHelper.routingNode(nodeId, node, shards.toArray(ShardRouting[]::new)),
-            false,
-            "/dev/null"
+        assertEquals(
+            unaccountedSearchableSnapshotSizes + relocatingShardsSizes,
+            sizeOfUnaccountedShards(allocation, clusterState.getRoutingNodes().node(nodeId), false, "/dev/null")
         );
-        assertEquals(60L + initializingShardSizes.values().stream().mapToLong(Long::longValue).sum(), sizeOfUnaccountedShards);
     }
 
     private ShardRouting createShard(Index index, String nodeId, int i, int expectedSize) {
@@ -620,6 +591,7 @@ public class DiskThresholdDeciderUnitTests extends ESAllocationTestCase {
             subtractShardsMovingAway,
             dataPath,
             allocation.clusterInfo(),
+            allocation.snapshotShardSizeInfo(),
             allocation.metadata(),
             allocation.routingTable(),
             allocation.unaccountedSearchableSnapshotSize(node)

+ 11 - 1
test/framework/src/main/java/org/elasticsearch/cluster/routing/TestShardRouting.java

@@ -36,6 +36,16 @@ public class TestShardRouting {
     }
 
     public static ShardRouting newShardRouting(ShardId shardId, String currentNodeId, boolean primary, ShardRoutingState state) {
+        return newShardRouting(shardId, currentNodeId, primary, state, -1);
+    }
+
+    public static ShardRouting newShardRouting(
+        ShardId shardId,
+        String currentNodeId,
+        boolean primary,
+        ShardRoutingState state,
+        long expectedShardSize
+    ) {
         assertNotEquals(ShardRoutingState.RELOCATING, state);
         return new ShardRouting(
             shardId,
@@ -47,7 +57,7 @@ public class TestShardRouting {
             buildUnassignedInfo(state),
             buildRelocationFailureInfo(state),
             buildAllocationId(state),
-            -1,
+            expectedShardSize,
             ShardRouting.Role.DEFAULT
         );
     }