Browse Source

Extract shard size estimation into a separate component (#101854)

This commit is a prerequisite for the future improvements of the initializing
shard size estimation. It moves the initializing shard size estimation into a
separate component and covers it with additional tests.
Ievgen Degtiarenko 1 year ago
parent
commit
cb6a570c49

+ 86 - 0
server/src/main/java/org/elasticsearch/cluster/routing/ExpectedShardSizeEstimator.java

@@ -0,0 +1,86 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.cluster.routing;
+
+import org.elasticsearch.cluster.ClusterInfo;
+import org.elasticsearch.cluster.metadata.IndexMetadata;
+import org.elasticsearch.cluster.metadata.Metadata;
+import org.elasticsearch.cluster.routing.allocation.RoutingAllocation;
+import org.elasticsearch.index.Index;
+import org.elasticsearch.index.shard.ShardId;
+import org.elasticsearch.snapshots.SnapshotShardSizeInfo;
+
+import java.util.Set;
+
+public class ExpectedShardSizeEstimator {
+
+    public static long getExpectedShardSize(ShardRouting shardRouting, long defaultSize, RoutingAllocation allocation) {
+        return getExpectedShardSize(
+            shardRouting,
+            defaultSize,
+            allocation.clusterInfo(),
+            allocation.snapshotShardSizeInfo(),
+            allocation.metadata(),
+            allocation.routingTable()
+        );
+    }
+
+    /**
+     * Returns the expected shard size for the given shard or the default value provided if not enough information are available
+     * to estimate the shards size.
+     */
+    public static long getExpectedShardSize(
+        ShardRouting shard,
+        long defaultValue,
+        ClusterInfo clusterInfo,
+        SnapshotShardSizeInfo snapshotShardSizeInfo,
+        Metadata metadata,
+        RoutingTable routingTable
+    ) {
+        final IndexMetadata indexMetadata = metadata.getIndexSafe(shard.index());
+        if (indexMetadata.getResizeSourceIndex() != null
+            && shard.active() == false
+            && shard.recoverySource().getType() == RecoverySource.Type.LOCAL_SHARDS) {
+            return getExpectedSizeOfResizedShard(shard, defaultValue, indexMetadata, clusterInfo, metadata, routingTable);
+        } else if (shard.unassigned() && shard.recoverySource().getType() == RecoverySource.Type.SNAPSHOT) {
+            return snapshotShardSizeInfo.getShardSize(shard, defaultValue);
+        } else {
+            return clusterInfo.getShardSize(shard, defaultValue);
+        }
+    }
+
+    private static long getExpectedSizeOfResizedShard(
+        ShardRouting shard,
+        long defaultValue,
+        IndexMetadata indexMetadata,
+        ClusterInfo clusterInfo,
+        Metadata metadata,
+        RoutingTable routingTable
+    ) {
+        // in the shrink index case we sum up the source index shards since we basically make a copy of the shard in the worst case
+        long targetShardSize = 0;
+        final Index mergeSourceIndex = indexMetadata.getResizeSourceIndex();
+        final IndexMetadata sourceIndexMetadata = metadata.index(mergeSourceIndex);
+        if (sourceIndexMetadata != null) {
+            final Set<ShardId> shardIds = IndexMetadata.selectRecoverFromShards(
+                shard.id(),
+                sourceIndexMetadata,
+                indexMetadata.getNumberOfShards()
+            );
+            final IndexRoutingTable indexRoutingTable = routingTable.index(mergeSourceIndex.getName());
+            for (int i = 0; i < indexRoutingTable.size(); i++) {
+                IndexShardRoutingTable shardRoutingTable = indexRoutingTable.shard(i);
+                if (shardIds.contains(shardRoutingTable.shardId())) {
+                    targetShardSize += clusterInfo.getShardSize(shardRoutingTable.primaryShard(), 0);
+                }
+            }
+        }
+        return targetShardSize == 0 ? defaultValue : targetShardSize;
+    }
+}

+ 3 - 11
server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/BalancedShardsAllocator.java

@@ -31,7 +31,6 @@ import org.elasticsearch.cluster.routing.allocation.WriteLoadForecaster;
 import org.elasticsearch.cluster.routing.allocation.decider.AllocationDeciders;
 import org.elasticsearch.cluster.routing.allocation.decider.Decision;
 import org.elasticsearch.cluster.routing.allocation.decider.Decision.Type;
-import org.elasticsearch.cluster.routing.allocation.decider.DiskThresholdDecider;
 import org.elasticsearch.common.inject.Inject;
 import org.elasticsearch.common.logging.DeprecationCategory;
 import org.elasticsearch.common.logging.DeprecationLogger;
@@ -57,6 +56,7 @@ import java.util.function.BiFunction;
 import java.util.stream.StreamSupport;
 
 import static org.elasticsearch.cluster.metadata.SingleNodeShutdownMetadata.Type.REPLACE;
+import static org.elasticsearch.cluster.routing.ExpectedShardSizeEstimator.getExpectedShardSize;
 import static org.elasticsearch.cluster.routing.ShardRoutingState.RELOCATING;
 import static org.elasticsearch.common.settings.ClusterSettings.createBuiltInClusterSettings;
 
@@ -1037,11 +1037,7 @@ public class BalancedShardsAllocator implements ShardsAllocator {
                             logger.trace("Assigned shard [{}] to [{}]", shard, minNode.getNodeId());
                         }
 
-                        final long shardSize = DiskThresholdDecider.getExpectedShardSize(
-                            shard,
-                            ShardRouting.UNAVAILABLE_EXPECTED_SHARD_SIZE,
-                            allocation
-                        );
+                        final long shardSize = getExpectedShardSize(shard, ShardRouting.UNAVAILABLE_EXPECTED_SHARD_SIZE, allocation);
                         shard = routingNodes.initializeShard(shard, minNode.getNodeId(), null, shardSize, allocation.changes());
                         minNode.addShard(shard);
                         if (shard.primary() == false) {
@@ -1064,11 +1060,7 @@ public class BalancedShardsAllocator implements ShardsAllocator {
                         if (minNode != null) {
                             // throttle decision scenario
                             assert allocationDecision.getAllocationStatus() == AllocationStatus.DECIDERS_THROTTLED;
-                            final long shardSize = DiskThresholdDecider.getExpectedShardSize(
-                                shard,
-                                ShardRouting.UNAVAILABLE_EXPECTED_SHARD_SIZE,
-                                allocation
-                            );
+                            final long shardSize = getExpectedShardSize(shard, ShardRouting.UNAVAILABLE_EXPECTED_SHARD_SIZE, allocation);
                             minNode.addShard(shard.initialize(minNode.getNodeId(), null, shardSize));
                         } else {
                             if (logger.isTraceEnabled()) {

+ 2 - 9
server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/DesiredBalanceReconciler.java

@@ -21,7 +21,6 @@ import org.elasticsearch.cluster.routing.UnassignedInfo;
 import org.elasticsearch.cluster.routing.UnassignedInfo.AllocationStatus;
 import org.elasticsearch.cluster.routing.allocation.RoutingAllocation;
 import org.elasticsearch.cluster.routing.allocation.decider.Decision;
-import org.elasticsearch.cluster.routing.allocation.decider.DiskThresholdDecider;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.settings.ClusterSettings;
 import org.elasticsearch.common.settings.Setting;
@@ -40,6 +39,7 @@ import java.util.stream.Collectors;
 import java.util.stream.IntStream;
 
 import static org.elasticsearch.cluster.metadata.SingleNodeShutdownMetadata.Type.REPLACE;
+import static org.elasticsearch.cluster.routing.ExpectedShardSizeEstimator.getExpectedShardSize;
 
 /**
  * Given the current allocation of shards and the desired balance, performs the next (legal) shard movements towards the goal.
@@ -271,14 +271,7 @@ public class DesiredBalanceReconciler {
                             switch (decision.type()) {
                                 case YES -> {
                                     logger.debug("Assigning shard [{}] to {} [{}]", shard, nodeIdsIterator.source, nodeId);
-                                    final long shardSize = DiskThresholdDecider.getExpectedShardSize(
-                                        shard,
-                                        ShardRouting.UNAVAILABLE_EXPECTED_SHARD_SIZE,
-                                        allocation.clusterInfo(),
-                                        allocation.snapshotShardSizeInfo(),
-                                        allocation.metadata(),
-                                        allocation.routingTable()
-                                    );
+                                    long shardSize = getExpectedShardSize(shard, ShardRouting.UNAVAILABLE_EXPECTED_SHARD_SIZE, allocation);
                                     routingNodes.initializeShard(shard, nodeId, null, shardSize, allocation.changes());
                                     allocationOrdering.recordAllocation(nodeId);
                                     if (shard.primary() == false) {

+ 2 - 61
server/src/main/java/org/elasticsearch/cluster/routing/allocation/decider/DiskThresholdDecider.java

@@ -15,8 +15,6 @@ import org.elasticsearch.cluster.ClusterInfo;
 import org.elasticsearch.cluster.DiskUsage;
 import org.elasticsearch.cluster.metadata.IndexMetadata;
 import org.elasticsearch.cluster.metadata.Metadata;
-import org.elasticsearch.cluster.routing.IndexRoutingTable;
-import org.elasticsearch.cluster.routing.IndexShardRoutingTable;
 import org.elasticsearch.cluster.routing.RecoverySource;
 import org.elasticsearch.cluster.routing.RoutingNode;
 import org.elasticsearch.cluster.routing.RoutingTable;
@@ -29,12 +27,10 @@ import org.elasticsearch.common.settings.Setting;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.settings.SettingsException;
 import org.elasticsearch.common.unit.ByteSizeValue;
-import org.elasticsearch.index.Index;
-import org.elasticsearch.index.shard.ShardId;
-import org.elasticsearch.snapshots.SnapshotShardSizeInfo;
 
 import java.util.Map;
-import java.util.Set;
+
+import static org.elasticsearch.cluster.routing.ExpectedShardSizeEstimator.getExpectedShardSize;
 
 /**
  * The {@link DiskThresholdDecider} checks that the node a shard is potentially
@@ -541,61 +537,6 @@ public class DiskThresholdDecider extends AllocationDecider {
         return null;
     }
 
-    public static long getExpectedShardSize(ShardRouting shardRouting, long defaultSize, RoutingAllocation allocation) {
-        return DiskThresholdDecider.getExpectedShardSize(
-            shardRouting,
-            defaultSize,
-            allocation.clusterInfo(),
-            allocation.snapshotShardSizeInfo(),
-            allocation.metadata(),
-            allocation.routingTable()
-        );
-    }
-
-    /**
-     * Returns the expected shard size for the given shard or the default value provided if not enough information are available
-     * to estimate the shards size.
-     */
-    public static long getExpectedShardSize(
-        ShardRouting shard,
-        long defaultValue,
-        ClusterInfo clusterInfo,
-        SnapshotShardSizeInfo snapshotShardSizeInfo,
-        Metadata metadata,
-        RoutingTable routingTable
-    ) {
-        final IndexMetadata indexMetadata = metadata.getIndexSafe(shard.index());
-        if (indexMetadata.getResizeSourceIndex() != null
-            && shard.active() == false
-            && shard.recoverySource().getType() == RecoverySource.Type.LOCAL_SHARDS) {
-            // in the shrink index case we sum up the source index shards since we basically make a copy of the shard in
-            // the worst case
-            long targetShardSize = 0;
-            final Index mergeSourceIndex = indexMetadata.getResizeSourceIndex();
-            final IndexMetadata sourceIndexMeta = metadata.index(mergeSourceIndex);
-            if (sourceIndexMeta != null) {
-                final Set<ShardId> shardIds = IndexMetadata.selectRecoverFromShards(
-                    shard.id(),
-                    sourceIndexMeta,
-                    indexMetadata.getNumberOfShards()
-                );
-                final IndexRoutingTable indexRoutingTable = routingTable.index(mergeSourceIndex.getName());
-                for (int i = 0; i < indexRoutingTable.size(); i++) {
-                    IndexShardRoutingTable shardRoutingTable = indexRoutingTable.shard(i);
-                    if (shardIds.contains(shardRoutingTable.shardId())) {
-                        targetShardSize += clusterInfo.getShardSize(shardRoutingTable.primaryShard(), 0);
-                    }
-                }
-            }
-            return targetShardSize == 0 ? defaultValue : targetShardSize;
-        } else {
-            if (shard.unassigned() && shard.recoverySource().getType() == RecoverySource.Type.SNAPSHOT) {
-                return snapshotShardSizeInfo.getShardSize(shard, defaultValue);
-            }
-            return clusterInfo.getShardSize(shard, defaultValue);
-        }
-    }
-
     record DiskUsageWithRelocations(DiskUsage diskUsage, long relocatingShardSize) {
 
         double getFreeDiskAsPercentage() {

+ 149 - 0
server/src/test/java/org/elasticsearch/cluster/routing/ExpectedShardSizeEstimatorTests.java

@@ -0,0 +1,149 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.cluster.routing;
+
+import org.elasticsearch.cluster.ClusterInfo;
+import org.elasticsearch.cluster.ClusterName;
+import org.elasticsearch.cluster.ClusterState;
+import org.elasticsearch.cluster.ESAllocationTestCase;
+import org.elasticsearch.cluster.metadata.IndexMetadata;
+import org.elasticsearch.cluster.metadata.Metadata;
+import org.elasticsearch.cluster.routing.allocation.RoutingAllocation;
+import org.elasticsearch.cluster.routing.allocation.decider.AllocationDeciders;
+import org.elasticsearch.index.IndexVersion;
+import org.elasticsearch.index.shard.ShardId;
+import org.elasticsearch.repositories.IndexId;
+import org.elasticsearch.snapshots.InternalSnapshotsInfoService;
+import org.elasticsearch.snapshots.Snapshot;
+import org.elasticsearch.snapshots.SnapshotId;
+import org.elasticsearch.snapshots.SnapshotShardSizeInfo;
+
+import java.util.List;
+import java.util.Map;
+
+import static org.elasticsearch.cluster.metadata.IndexMetadata.INDEX_RESIZE_SOURCE_NAME_KEY;
+import static org.elasticsearch.cluster.metadata.IndexMetadata.INDEX_RESIZE_SOURCE_UUID_KEY;
+import static org.elasticsearch.cluster.routing.ExpectedShardSizeEstimator.getExpectedShardSize;
+import static org.elasticsearch.cluster.routing.TestShardRouting.newShardRouting;
+import static org.hamcrest.Matchers.equalTo;
+
+public class ExpectedShardSizeEstimatorTests extends ESAllocationTestCase {
+
+    private final long defaultValue = randomLongBetween(-1, 0);
+
+    public void testShouldFallbackToDefaultValue() {
+
+        var state = ClusterState.builder(ClusterName.DEFAULT).metadata(metadata(index("my-index"))).build();
+        var shard = newShardRouting("my-index", 0, randomIdentifier(), true, ShardRoutingState.INITIALIZING);
+
+        var allocation = createRoutingAllocation(state, ClusterInfo.EMPTY, SnapshotShardSizeInfo.EMPTY);
+
+        assertThat(getExpectedShardSize(shard, defaultValue, allocation), equalTo(defaultValue));
+    }
+
+    public void testShouldReadExpectedSizeFromClusterInfo() {
+
+        var shardSize = randomLongBetween(100, 1000);
+        var state = ClusterState.builder(ClusterName.DEFAULT).metadata(metadata(index("my-index"))).build();
+        var shard = newShardRouting("my-index", 0, randomIdentifier(), true, ShardRoutingState.INITIALIZING);
+
+        var clusterInfo = createClusterInfo(shard, shardSize);
+        var allocation = createRoutingAllocation(state, clusterInfo, SnapshotShardSizeInfo.EMPTY);
+
+        assertThat(getExpectedShardSize(shard, defaultValue, allocation), equalTo(shardSize));
+    }
+
+    public void testShouldReadExpectedSizeWhenInitializingFromSnapshot() {
+
+        var snapshotShardSize = randomLongBetween(100, 1000);
+        var state = ClusterState.builder(ClusterName.DEFAULT).metadata(metadata(index("my-index"))).build();
+
+        var snapshot = new Snapshot("repository", new SnapshotId("snapshot-1", "na"));
+        var indexId = new IndexId("my-index", "_na_");
+
+        var shard = newShardRouting(
+            new ShardId("my-index", "_na_", 0),
+            null,
+            true,
+            ShardRoutingState.UNASSIGNED,
+            new RecoverySource.SnapshotRecoverySource(randomUUID(), snapshot, IndexVersion.current(), indexId)
+        );
+
+        var snapshotShardSizeInfo = new SnapshotShardSizeInfo(
+            Map.of(new InternalSnapshotsInfoService.SnapshotShard(snapshot, indexId, shard.shardId()), snapshotShardSize)
+        );
+        var allocation = createRoutingAllocation(state, ClusterInfo.EMPTY, snapshotShardSizeInfo);
+
+        assertThat(getExpectedShardSize(shard, defaultValue, allocation), equalTo(snapshotShardSize));
+    }
+
+    public void testShouldReadSizeFromClonedShard() {
+
+        var sourceShardSize = randomLongBetween(100, 1000);
+        var source = newShardRouting(new ShardId("source", "_na_", 0), randomIdentifier(), true, ShardRoutingState.STARTED);
+        var target = newShardRouting(
+            new ShardId("target", "_na_", 0),
+            randomIdentifier(),
+            true,
+            ShardRoutingState.INITIALIZING,
+            RecoverySource.LocalShardsRecoverySource.INSTANCE
+        );
+
+        var state = ClusterState.builder(ClusterName.DEFAULT)
+            .metadata(
+                metadata(
+                    IndexMetadata.builder("source").settings(indexSettings(IndexVersion.current(), 2, 0)),
+                    IndexMetadata.builder("target")
+                        .settings(
+                            indexSettings(IndexVersion.current(), 1, 0) //
+                                .put(INDEX_RESIZE_SOURCE_NAME_KEY, "source") //
+                                .put(INDEX_RESIZE_SOURCE_UUID_KEY, "_na_")
+                        )
+                )
+            )
+            .routingTable(RoutingTable.builder().add(IndexRoutingTable.builder(source.index()).addShard(source)))
+            .build();
+
+        var clusterInfo = createClusterInfo(source, sourceShardSize);
+        var allocation = createRoutingAllocation(state, clusterInfo, SnapshotShardSizeInfo.EMPTY);
+
+        assertThat(getExpectedShardSize(target, defaultValue, allocation), equalTo(sourceShardSize));
+    }
+
+    private static RoutingAllocation createRoutingAllocation(
+        ClusterState state,
+        ClusterInfo clusterInfo,
+        SnapshotShardSizeInfo snapshotShardSizeInfo
+    ) {
+        return new RoutingAllocation(new AllocationDeciders(List.of()), state, clusterInfo, snapshotShardSizeInfo, 0);
+    }
+
+    private static IndexMetadata.Builder index(String name) {
+        return IndexMetadata.builder(name).settings(indexSettings(IndexVersion.current(), 1, 0));
+    }
+
+    private static Metadata metadata(IndexMetadata.Builder... indices) {
+        var builder = Metadata.builder();
+        for (IndexMetadata.Builder index : indices) {
+            builder.put(index.build(), false);
+        }
+        return builder.build();
+    }
+
+    private static ClusterInfo createClusterInfo(ShardRouting shard, Long size) {
+        return new ClusterInfo(
+            Map.of(),
+            Map.of(),
+            Map.of(ClusterInfo.shardIdentifierFromRouting(shard), size),
+            Map.of(),
+            Map.of(),
+            Map.of()
+        );
+    }
+}

+ 148 - 21
server/src/test/java/org/elasticsearch/cluster/routing/allocation/ExpectedShardSizeAllocationTests.java

@@ -12,37 +12,170 @@ import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.cluster.ClusterInfo;
 import org.elasticsearch.cluster.ClusterName;
 import org.elasticsearch.cluster.ClusterState;
+import org.elasticsearch.cluster.DiskUsage;
 import org.elasticsearch.cluster.ESAllocationTestCase;
+import org.elasticsearch.cluster.RestoreInProgress;
 import org.elasticsearch.cluster.TestShardRoutingRoleStrategies;
 import org.elasticsearch.cluster.metadata.IndexMetadata;
 import org.elasticsearch.cluster.metadata.Metadata;
 import org.elasticsearch.cluster.node.DiscoveryNodes;
+import org.elasticsearch.cluster.routing.RecoverySource;
 import org.elasticsearch.cluster.routing.RoutingTable;
+import org.elasticsearch.cluster.routing.ShardRouting;
 import org.elasticsearch.cluster.routing.ShardRoutingState;
 import org.elasticsearch.cluster.routing.allocation.command.AllocationCommands;
 import org.elasticsearch.cluster.routing.allocation.command.MoveAllocationCommand;
 import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.index.IndexVersion;
 import org.elasticsearch.index.shard.ShardId;
+import org.elasticsearch.repositories.IndexId;
+import org.elasticsearch.snapshots.InternalSnapshotsInfoService;
+import org.elasticsearch.snapshots.Snapshot;
+import org.elasticsearch.snapshots.SnapshotId;
+import org.elasticsearch.snapshots.SnapshotShardSizeInfo;
+import org.elasticsearch.test.gateway.TestGatewayAllocator;
 
+import java.util.Collection;
+import java.util.List;
 import java.util.Map;
+import java.util.Objects;
+import java.util.Set;
+import java.util.function.Function;
 
+import static java.util.stream.Collectors.toMap;
 import static org.elasticsearch.cluster.routing.RoutingNodesHelper.shardsWithState;
+import static org.elasticsearch.cluster.routing.ShardRoutingState.INITIALIZING;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.not;
 
 public class ExpectedShardSizeAllocationTests extends ESAllocationTestCase {
+
+    public void testAllocateToCorrectNodeUsingShardSizeFromClusterInfo() {
+
+        var indexMetadata = IndexMetadata.builder("test").settings(indexSettings(IndexVersion.current(), 1, 0)).build();
+
+        var clusterState = ClusterState.builder(ClusterName.DEFAULT)
+            .nodes(DiscoveryNodes.builder().add(newNode("node-1")).add(newNode("node-2")).add(newNode("node-3")))
+            .metadata(Metadata.builder().put(indexMetadata, false))
+            .routingTable(RoutingTable.builder(TestShardRoutingRoleStrategies.DEFAULT_ROLE_ONLY).addAsNew(indexMetadata))
+            .build();
+        var dataNodeIds = clusterState.nodes().getDataNodes().keySet();
+
+        long shardSize = ByteSizeValue.ofGb(1).getBytes();
+        long diskSize = ByteSizeValue.ofGb(5).getBytes();
+        long headRoom = diskSize / 10;
+        var expectedNodeId = randomFrom(dataNodeIds);
+        var clusterInfo = createClusterInfo(
+            createDiskUsage(
+                dataNodeIds,
+                nodeId -> createDiskUsage(nodeId, diskSize, headRoom + shardSize + (Objects.equals(nodeId, expectedNodeId) ? +1 : -1))
+            ),
+            Map.of(ClusterInfo.shardIdentifierFromRouting(new ShardId(indexMetadata.getIndex(), 0), true), shardSize)
+        );
+
+        AllocationService service = createAllocationService(Settings.EMPTY, () -> clusterInfo);
+        clusterState = service.reroute(clusterState, "reroute", ActionListener.noop());
+
+        assertThatShard(
+            clusterState.routingTable().index(indexMetadata.getIndex()).shard(0).primaryShard(),
+            INITIALIZING,
+            expectedNodeId,
+            shardSize
+        );
+    }
+
+    public void testAllocateToCorrectNodeAccordingToSnapshotShardInfo() {
+
+        var snapshot = new Snapshot("repository", new SnapshotId("snapshot-1", "na"));
+        var indexId = new IndexId("my-index", "_na_");
+        var restoreId = "restore-id";
+
+        var indexMetadata = IndexMetadata.builder("test")
+            .settings(indexSettings(IndexVersion.current(), 1, 0))
+            .putInSyncAllocationIds(0, Set.of(randomUUID()))
+            .build();
+
+        var clusterState = ClusterState.builder(ClusterName.DEFAULT)
+            .nodes(DiscoveryNodes.builder().add(newNode("node-1")).add(newNode("node-2")).add(newNode("node-3")))
+            .metadata(Metadata.builder().put(indexMetadata, false))
+            .routingTable(
+                RoutingTable.builder(TestShardRoutingRoleStrategies.DEFAULT_ROLE_ONLY)
+                    .addAsRestore(
+                        indexMetadata,
+                        new RecoverySource.SnapshotRecoverySource(restoreId, snapshot, IndexVersion.current(), indexId)
+                    )
+            )
+            .customs(
+                Map.of(
+                    RestoreInProgress.TYPE,
+                    new RestoreInProgress.Builder().add(
+                        new RestoreInProgress.Entry(
+                            restoreId,
+                            snapshot,
+                            RestoreInProgress.State.STARTED,
+                            false,
+                            List.of(indexMetadata.getIndex().getName()),
+                            Map.of(new ShardId(indexMetadata.getIndex(), 0), new RestoreInProgress.ShardRestoreStatus(randomIdentifier()))
+                        )
+                    ).build()
+                )
+            )
+            .build();
+        var dataNodeIds = clusterState.nodes().getDataNodes().keySet();
+
+        long shardSize = ByteSizeValue.ofGb(1).getBytes();
+        long diskSize = ByteSizeValue.ofGb(5).getBytes();
+        long headRoom = diskSize / 10;
+        var expectedNodeId = randomFrom(dataNodeIds);
+        var clusterInfo = createClusterInfo(
+            createDiskUsage(
+                dataNodeIds,
+                nodeId -> createDiskUsage(nodeId, diskSize, headRoom + shardSize + (Objects.equals(nodeId, expectedNodeId) ? +1 : -1))
+            ),
+            Map.of()
+        );
+        var snapshotShardSizeInfo = new SnapshotShardSizeInfo(
+            Map.of(new InternalSnapshotsInfoService.SnapshotShard(snapshot, indexId, new ShardId(indexMetadata.getIndex(), 0)), shardSize)
+        );
+
+        AllocationService service = createAllocationService(
+            Settings.EMPTY,
+            new TestGatewayAllocator(),
+            () -> clusterInfo,
+            () -> snapshotShardSizeInfo
+        );
+        clusterState = service.reroute(clusterState, "reroute", ActionListener.noop());
+
+        assertThatShard(
+            clusterState.routingTable().index(indexMetadata.getIndex()).shard(0).primaryShard(),
+            INITIALIZING,
+            expectedNodeId,
+            shardSize
+        );
+    }
+
+    private static void assertThatShard(ShardRouting shard, ShardRoutingState state, String nodeId, long expectedShardSize) {
+        assertThat(shard.state(), equalTo(state));
+        assertThat(shard.currentNodeId(), equalTo(nodeId));
+        assertThat(shard.getExpectedShardSize(), equalTo(expectedShardSize));
+    }
+
+    private static Map<String, DiskUsage> createDiskUsage(Collection<String> nodeIds, Function<String, DiskUsage> diskUsageCreator) {
+        return nodeIds.stream().collect(toMap(Function.identity(), diskUsageCreator));
+    }
+
+    private static DiskUsage createDiskUsage(String nodeId, long totalBytes, long freeBytes) {
+        return new DiskUsage(nodeId, nodeId, "/data", totalBytes, freeBytes);
+    }
+
     public void testInitializingHasExpectedSize() {
         final long byteSize = randomIntBetween(0, Integer.MAX_VALUE);
         final ClusterInfo clusterInfo = createClusterInfoWith(new ShardId("test", "_na_", 0), byteSize);
         AllocationService strategy = createAllocationService(Settings.EMPTY, () -> clusterInfo);
 
         logger.info("Building initial routing table");
-        var indexMetadata = IndexMetadata.builder("test")
-            .settings(settings(IndexVersion.current()))
-            .numberOfShards(1)
-            .numberOfReplicas(1)
-            .build();
+        var indexMetadata = IndexMetadata.builder("test").settings(indexSettings(IndexVersion.current(), 1, 1)).build();
 
         ClusterState clusterState = ClusterState.builder(ClusterName.DEFAULT)
             .metadata(Metadata.builder().put(indexMetadata, false))
@@ -52,11 +185,8 @@ public class ExpectedShardSizeAllocationTests extends ESAllocationTestCase {
         logger.info("Adding one node and performing rerouting");
         clusterState = strategy.reroute(clusterState, "reroute", ActionListener.noop());
 
-        assertEquals(1, clusterState.getRoutingNodes().node("node1").numberOfShardsWithState(ShardRoutingState.INITIALIZING));
-        assertEquals(
-            byteSize,
-            shardsWithState(clusterState.getRoutingNodes(), ShardRoutingState.INITIALIZING).get(0).getExpectedShardSize()
-        );
+        assertEquals(1, clusterState.getRoutingNodes().node("node1").numberOfShardsWithState(INITIALIZING));
+        assertEquals(byteSize, shardsWithState(clusterState.getRoutingNodes(), INITIALIZING).get(0).getExpectedShardSize());
         logger.info("Start the primary shard");
         clusterState = startInitializingShardsAndReroute(strategy, clusterState);
 
@@ -67,11 +197,8 @@ public class ExpectedShardSizeAllocationTests extends ESAllocationTestCase {
         clusterState = ClusterState.builder(clusterState).nodes(DiscoveryNodes.builder(clusterState.nodes()).add(newNode("node2"))).build();
         clusterState = strategy.reroute(clusterState, "reroute", ActionListener.noop());
 
-        assertEquals(1, clusterState.getRoutingNodes().node("node2").numberOfShardsWithState(ShardRoutingState.INITIALIZING));
-        assertEquals(
-            byteSize,
-            shardsWithState(clusterState.getRoutingNodes(), ShardRoutingState.INITIALIZING).get(0).getExpectedShardSize()
-        );
+        assertEquals(1, clusterState.getRoutingNodes().node("node2").numberOfShardsWithState(INITIALIZING));
+        assertEquals(byteSize, shardsWithState(clusterState.getRoutingNodes(), INITIALIZING).get(0).getExpectedShardSize());
     }
 
     public void testExpectedSizeOnMove() {
@@ -79,11 +206,7 @@ public class ExpectedShardSizeAllocationTests extends ESAllocationTestCase {
         final ClusterInfo clusterInfo = createClusterInfoWith(new ShardId("test", "_na_", 0), byteSize);
         final AllocationService allocation = createAllocationService(Settings.EMPTY, () -> clusterInfo);
         logger.info("creating an index with 1 shard, no replica");
-        var indexMetadata = IndexMetadata.builder("test")
-            .settings(settings(IndexVersion.current()))
-            .numberOfShards(1)
-            .numberOfReplicas(0)
-            .build();
+        var indexMetadata = IndexMetadata.builder("test").settings(indexSettings(IndexVersion.current(), 1, 0)).build();
         ClusterState clusterState = ClusterState.builder(ClusterName.DEFAULT)
             .metadata(Metadata.builder().put(indexMetadata, false))
             .routingTable(RoutingTable.builder(TestShardRoutingRoleStrategies.DEFAULT_ROLE_ONLY).addAsNew(indexMetadata))
@@ -111,7 +234,7 @@ public class ExpectedShardSizeAllocationTests extends ESAllocationTestCase {
         assertThat(commandsResult.clusterState(), not(equalTo(clusterState)));
         clusterState = commandsResult.clusterState();
         assertEquals(clusterState.getRoutingNodes().node(existingNodeId).iterator().next().state(), ShardRoutingState.RELOCATING);
-        assertEquals(clusterState.getRoutingNodes().node(toNodeId).iterator().next().state(), ShardRoutingState.INITIALIZING);
+        assertEquals(clusterState.getRoutingNodes().node(toNodeId).iterator().next().state(), INITIALIZING);
 
         assertEquals(clusterState.getRoutingNodes().node(existingNodeId).iterator().next().getExpectedShardSize(), byteSize);
         assertEquals(clusterState.getRoutingNodes().node(toNodeId).iterator().next().getExpectedShardSize(), byteSize);
@@ -137,4 +260,8 @@ public class ExpectedShardSizeAllocationTests extends ESAllocationTestCase {
             Map.of()
         );
     }
+
+    private static ClusterInfo createClusterInfo(Map<String, DiskUsage> diskUsage, Map<String, Long> shardSizes) {
+        return new ClusterInfo(diskUsage, diskUsage, shardSizes, Map.of(), Map.of(), Map.of());
+    }
 }

+ 14 - 13
server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/DiskThresholdDeciderUnitTests.java

@@ -51,6 +51,7 @@ import java.util.List;
 import java.util.Map;
 
 import static java.util.Collections.emptySet;
+import static org.elasticsearch.cluster.routing.ExpectedShardSizeEstimator.getExpectedShardSize;
 import static org.elasticsearch.index.IndexModule.INDEX_STORE_TYPE_SETTING;
 import static org.elasticsearch.snapshots.SearchableSnapshotsSettings.SEARCHABLE_SNAPSHOT_STORE_TYPE;
 import static org.hamcrest.Matchers.containsString;
@@ -459,9 +460,9 @@ public class DiskThresholdDeciderUnitTests extends ESAllocationTestCase {
         test_2 = ShardRoutingHelper.initialize(test_2, "node1");
         test_2 = ShardRoutingHelper.moveToStarted(test_2);
 
-        assertEquals(1000L, DiskThresholdDecider.getExpectedShardSize(test_2, 0L, allocation));
-        assertEquals(100L, DiskThresholdDecider.getExpectedShardSize(test_1, 0L, allocation));
-        assertEquals(10L, DiskThresholdDecider.getExpectedShardSize(test_0, 0L, allocation));
+        assertEquals(1000L, getExpectedShardSize(test_2, 0L, allocation));
+        assertEquals(100L, getExpectedShardSize(test_1, 0L, allocation));
+        assertEquals(10L, getExpectedShardSize(test_0, 0L, allocation));
 
         RoutingNode node = RoutingNodesHelper.routingNode(
             "node1",
@@ -484,7 +485,7 @@ public class DiskThresholdDeciderUnitTests extends ESAllocationTestCase {
         );
         test_3 = ShardRoutingHelper.initialize(test_3, "node1");
         test_3 = ShardRoutingHelper.moveToStarted(test_3);
-        assertEquals(0L, DiskThresholdDecider.getExpectedShardSize(test_3, 0L, allocation));
+        assertEquals(0L, getExpectedShardSize(test_3, 0L, allocation));
 
         boolean primary = randomBoolean();
         ShardRouting other_0 = ShardRouting.newUnassigned(
@@ -725,10 +726,10 @@ public class DiskThresholdDeciderUnitTests extends ESAllocationTestCase {
             ShardRouting.Role.DEFAULT
         );
         test_3 = ShardRoutingHelper.initialize(test_3, "node1");
-        assertEquals(500L, DiskThresholdDecider.getExpectedShardSize(test_3, 0L, allocation));
-        assertEquals(500L, DiskThresholdDecider.getExpectedShardSize(test_2, 0L, allocation));
-        assertEquals(100L, DiskThresholdDecider.getExpectedShardSize(test_1, 0L, allocation));
-        assertEquals(10L, DiskThresholdDecider.getExpectedShardSize(test_0, 0L, allocation));
+        assertEquals(500L, getExpectedShardSize(test_3, 0L, allocation));
+        assertEquals(500L, getExpectedShardSize(test_2, 0L, allocation));
+        assertEquals(100L, getExpectedShardSize(test_1, 0L, allocation));
+        assertEquals(10L, getExpectedShardSize(test_0, 0L, allocation));
 
         ShardRouting target = ShardRouting.newUnassigned(
             new ShardId(new Index("target", "5678"), 0),
@@ -737,7 +738,7 @@ public class DiskThresholdDeciderUnitTests extends ESAllocationTestCase {
             new UnassignedInfo(UnassignedInfo.Reason.INDEX_CREATED, "foo"),
             ShardRouting.Role.DEFAULT
         );
-        assertEquals(1110L, DiskThresholdDecider.getExpectedShardSize(target, 0L, allocation));
+        assertEquals(1110L, getExpectedShardSize(target, 0L, allocation));
 
         ShardRouting target2 = ShardRouting.newUnassigned(
             new ShardId(new Index("target2", "9101112"), 0),
@@ -746,7 +747,7 @@ public class DiskThresholdDeciderUnitTests extends ESAllocationTestCase {
             new UnassignedInfo(UnassignedInfo.Reason.INDEX_CREATED, "foo"),
             ShardRouting.Role.DEFAULT
         );
-        assertEquals(110L, DiskThresholdDecider.getExpectedShardSize(target2, 0L, allocation));
+        assertEquals(110L, getExpectedShardSize(target2, 0L, allocation));
 
         target2 = ShardRouting.newUnassigned(
             new ShardId(new Index("target2", "9101112"), 1),
@@ -755,7 +756,7 @@ public class DiskThresholdDeciderUnitTests extends ESAllocationTestCase {
             new UnassignedInfo(UnassignedInfo.Reason.INDEX_CREATED, "foo"),
             ShardRouting.Role.DEFAULT
         );
-        assertEquals(1000L, DiskThresholdDecider.getExpectedShardSize(target2, 0L, allocation));
+        assertEquals(1000L, getExpectedShardSize(target2, 0L, allocation));
 
         // check that the DiskThresholdDecider still works even if the source index has been deleted
         ClusterState clusterStateWithMissingSourceIndex = ClusterState.builder(clusterState)
@@ -765,8 +766,8 @@ public class DiskThresholdDeciderUnitTests extends ESAllocationTestCase {
 
         allocationService.reroute(clusterState, "foo", ActionListener.noop());
         RoutingAllocation allocationWithMissingSourceIndex = new RoutingAllocation(null, clusterStateWithMissingSourceIndex, info, null, 0);
-        assertEquals(42L, DiskThresholdDecider.getExpectedShardSize(target, 42L, allocationWithMissingSourceIndex));
-        assertEquals(42L, DiskThresholdDecider.getExpectedShardSize(target2, 42L, allocationWithMissingSourceIndex));
+        assertEquals(42L, getExpectedShardSize(target, 42L, allocationWithMissingSourceIndex));
+        assertEquals(42L, getExpectedShardSize(target2, 42L, allocationWithMissingSourceIndex));
     }
 
     public void testDiskUsageWithRelocations() {

+ 2 - 1
x-pack/plugin/autoscaling/src/main/java/org/elasticsearch/xpack/autoscaling/storage/ReactiveStorageDeciderService.java

@@ -22,6 +22,7 @@ import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.node.DiscoveryNodeFilters;
 import org.elasticsearch.cluster.node.DiscoveryNodeRole;
 import org.elasticsearch.cluster.node.DiscoveryNodes;
+import org.elasticsearch.cluster.routing.ExpectedShardSizeEstimator;
 import org.elasticsearch.cluster.routing.IndexRoutingTable;
 import org.elasticsearch.cluster.routing.RecoverySource;
 import org.elasticsearch.cluster.routing.RoutingNode;
@@ -670,7 +671,7 @@ public class ReactiveStorageDeciderService implements AutoscalingDeciderService
         }
 
         private long getExpectedShardSize(ShardRouting shard) {
-            return DiskThresholdDecider.getExpectedShardSize(shard, 0L, info, shardSizeInfo, state.metadata(), state.routingTable());
+            return ExpectedShardSizeEstimator.getExpectedShardSize(shard, 0L, info, shardSizeInfo, state.metadata(), state.routingTable());
         }
 
         long unmovableSize(String nodeId, Collection<ShardRouting> shards) {

+ 2 - 2
x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/allocation/SearchableSnapshotAllocator.java

@@ -30,7 +30,6 @@ import org.elasticsearch.cluster.routing.allocation.FailedShard;
 import org.elasticsearch.cluster.routing.allocation.NodeAllocationResult;
 import org.elasticsearch.cluster.routing.allocation.RoutingAllocation;
 import org.elasticsearch.cluster.routing.allocation.decider.Decision;
-import org.elasticsearch.cluster.routing.allocation.decider.DiskThresholdDecider;
 import org.elasticsearch.common.Priority;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.settings.Settings;
@@ -61,6 +60,7 @@ import java.util.Set;
 import java.util.concurrent.ConcurrentMap;
 
 import static java.util.stream.Collectors.toSet;
+import static org.elasticsearch.cluster.routing.ExpectedShardSizeEstimator.getExpectedShardSize;
 import static org.elasticsearch.gateway.ReplicaShardAllocator.augmentExplanationsWithStoreInfo;
 import static org.elasticsearch.snapshots.SearchableSnapshotsSettings.SNAPSHOT_PARTIAL_SETTING;
 import static org.elasticsearch.xpack.searchablesnapshots.SearchableSnapshots.SNAPSHOT_INDEX_ID_SETTING;
@@ -198,7 +198,7 @@ public class SearchableSnapshotAllocator implements ExistingShardsAllocator {
                 unassignedAllocationHandler.initialize(
                     allocateUnassignedDecision.getTargetNode().getId(),
                     allocateUnassignedDecision.getAllocationId(),
-                    DiskThresholdDecider.getExpectedShardSize(shardRouting, ShardRouting.UNAVAILABLE_EXPECTED_SHARD_SIZE, allocation),
+                    getExpectedShardSize(shardRouting, ShardRouting.UNAVAILABLE_EXPECTED_SHARD_SIZE, allocation),
                     allocation.changes()
                 );
             } else {