浏览代码

Determine shard size before allocating shards recovering from snapshots (#61906)

Determines the shard size of shards before allocating shards that are 
recovering from snapshots. It ensures during shard allocation that the 
target node that is selected as recovery target will have enough free 
disk space for the recovery event. This applies to regular restores, 
CCR bootstrap from remote, as well as mounting searchable snapshots.

The InternalSnapshotInfoService is responsible for fetching snapshot 
shard sizes from repositories. It provides a getShardSize() method 
to other components of the system that can be used to retrieve the 
latest known shard size. If the latest snapshot shard size retrieval 
failed, the getShardSize() returns 
ShardRouting.UNAVAILABLE_EXPECTED_SHARD_SIZE. While 
we'd like a better way to handle such failures, returning this value 
allows to keep the existing behavior for now.

Note that this PR does not address an issues (we already have today) 
where a replica is being allocated without knowing how much disk 
space is being used by the primary.
Yannick Welsch 5 年之前
父节点
当前提交
2afec0d916
共有 54 个文件被更改,包括 1328 次插入180 次删除
  1. 3 1
      benchmarks/src/main/java/org/elasticsearch/benchmark/routing/allocation/Allocators.java
  2. 83 9
      server/src/internalClusterTest/java/org/elasticsearch/cluster/routing/allocation/decider/DiskThresholdDeciderIT.java
  3. 6 2
      server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/TransportClusterAllocationExplainAction.java
  4. 0 2
      server/src/main/java/org/elasticsearch/cluster/ClusterInfo.java
  5. 3 2
      server/src/main/java/org/elasticsearch/cluster/ClusterModule.java
  6. 13 9
      server/src/main/java/org/elasticsearch/cluster/routing/allocation/AllocationService.java
  7. 9 1
      server/src/main/java/org/elasticsearch/cluster/routing/allocation/RoutingAllocation.java
  8. 3 2
      server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/BalancedShardsAllocator.java
  9. 13 6
      server/src/main/java/org/elasticsearch/cluster/routing/allocation/decider/DiskThresholdDecider.java
  10. 2 0
      server/src/main/java/org/elasticsearch/common/settings/ClusterSettings.java
  11. 14 2
      server/src/main/java/org/elasticsearch/gateway/BaseGatewayShardAllocator.java
  12. 11 0
      server/src/main/java/org/elasticsearch/gateway/PrimaryShardAllocator.java
  13. 12 3
      server/src/main/java/org/elasticsearch/node/Node.java
  14. 31 0
      server/src/main/java/org/elasticsearch/snapshots/EmptySnapshotsInfoService.java
  15. 386 0
      server/src/main/java/org/elasticsearch/snapshots/InternalSnapshotsInfoService.java
  16. 53 0
      server/src/main/java/org/elasticsearch/snapshots/SnapshotShardSizeInfo.java
  17. 25 0
      server/src/main/java/org/elasticsearch/snapshots/SnapshotsInfoService.java
  18. 3 3
      server/src/test/java/org/elasticsearch/action/admin/cluster/allocation/ClusterAllocationExplainActionTests.java
  19. 3 1
      server/src/test/java/org/elasticsearch/action/admin/cluster/reroute/ClusterRerouteTests.java
  20. 9 4
      server/src/test/java/org/elasticsearch/action/admin/indices/shrink/TransportResizeActionTests.java
  21. 6 6
      server/src/test/java/org/elasticsearch/cluster/ClusterModuleTests.java
  22. 1 1
      server/src/test/java/org/elasticsearch/cluster/health/ClusterStateHealthTests.java
  23. 7 3
      server/src/test/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexServiceTests.java
  24. 3 2
      server/src/test/java/org/elasticsearch/cluster/routing/allocation/AllocationCommandsTests.java
  25. 4 3
      server/src/test/java/org/elasticsearch/cluster/routing/allocation/AllocationServiceTests.java
  26. 2 1
      server/src/test/java/org/elasticsearch/cluster/routing/allocation/BalanceConfigurationTests.java
  27. 2 1
      server/src/test/java/org/elasticsearch/cluster/routing/allocation/BalancedSingleShardTests.java
  28. 3 1
      server/src/test/java/org/elasticsearch/cluster/routing/allocation/DecisionsImpactOnClusterHealthTests.java
  29. 7 5
      server/src/test/java/org/elasticsearch/cluster/routing/allocation/MaxRetryAllocationDeciderTests.java
  30. 25 10
      server/src/test/java/org/elasticsearch/cluster/routing/allocation/NodeVersionAllocationDeciderTests.java
  31. 3 1
      server/src/test/java/org/elasticsearch/cluster/routing/allocation/RandomAllocationDeciderTests.java
  32. 7 5
      server/src/test/java/org/elasticsearch/cluster/routing/allocation/ResizeAllocationDeciderTests.java
  33. 2 1
      server/src/test/java/org/elasticsearch/cluster/routing/allocation/SameShardRoutingTests.java
  34. 45 12
      server/src/test/java/org/elasticsearch/cluster/routing/allocation/ThrottlingAllocationTests.java
  35. 1 1
      server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/AllocationDecidersTests.java
  36. 21 18
      server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/DiskThresholdDeciderTests.java
  37. 7 7
      server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/DiskThresholdDeciderUnitTests.java
  38. 3 1
      server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/EnableAllocationShortCircuitTests.java
  39. 5 3
      server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/FilterAllocationDeciderTests.java
  40. 1 1
      server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/RestoreInProgressAllocationDeciderTests.java
  41. 3 1
      server/src/test/java/org/elasticsearch/gateway/GatewayServiceTests.java
  42. 35 9
      server/src/test/java/org/elasticsearch/gateway/PrimaryShardAllocatorTests.java
  43. 5 2
      server/src/test/java/org/elasticsearch/gateway/ReplicaShardAllocatorTests.java
  44. 2 1
      server/src/test/java/org/elasticsearch/indices/cluster/ClusterStateChanges.java
  45. 350 0
      server/src/test/java/org/elasticsearch/snapshots/InternalSnapshotsInfoServiceTests.java
  46. 13 8
      server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java
  47. 35 6
      test/framework/src/main/java/org/elasticsearch/cluster/ESAllocationTestCase.java
  48. 2 1
      test/framework/src/main/java/org/elasticsearch/cluster/MockInternalClusterInfoService.java
  49. 21 2
      x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/repository/CcrRepository.java
  50. 2 1
      x-pack/plugin/ccr/src/test/java/org/elasticsearch/xpack/ccr/allocation/CcrPrimaryFollowerAllocationDeciderTests.java
  51. 1 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/AllocationRoutedStep.java
  52. 1 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/SetSingleNodeAllocateStep.java
  53. 15 17
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/cluster/routing/allocation/DataTierAllocationDeciderTests.java
  54. 6 0
      x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/SearchableSnapshotAllocator.java

+ 3 - 1
benchmarks/src/main/java/org/elasticsearch/benchmark/routing/allocation/Allocators.java

@@ -35,6 +35,7 @@ import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.transport.TransportAddress;
 import org.elasticsearch.common.transport.TransportAddress;
 import org.elasticsearch.common.util.set.Sets;
 import org.elasticsearch.common.util.set.Sets;
 import org.elasticsearch.gateway.GatewayAllocator;
 import org.elasticsearch.gateway.GatewayAllocator;
+import org.elasticsearch.snapshots.EmptySnapshotsInfoService;
 
 
 import java.util.Collection;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.Collections;
@@ -79,7 +80,8 @@ public final class Allocators {
             defaultAllocationDeciders(settings, clusterSettings),
             defaultAllocationDeciders(settings, clusterSettings),
             NoopGatewayAllocator.INSTANCE,
             NoopGatewayAllocator.INSTANCE,
             new BalancedShardsAllocator(settings),
             new BalancedShardsAllocator(settings),
-            EmptyClusterInfoService.INSTANCE
+            EmptyClusterInfoService.INSTANCE,
+            EmptySnapshotsInfoService.INSTANCE
         );
         );
     }
     }
 
 

+ 83 - 9
server/src/internalClusterTest/java/org/elasticsearch/cluster/routing/allocation/decider/DiskThresholdDeciderIT.java

@@ -23,6 +23,8 @@ import org.apache.lucene.mockfile.FilterFileStore;
 import org.apache.lucene.mockfile.FilterFileSystemProvider;
 import org.apache.lucene.mockfile.FilterFileSystemProvider;
 import org.apache.lucene.mockfile.FilterPath;
 import org.apache.lucene.mockfile.FilterPath;
 import org.apache.lucene.util.Constants;
 import org.apache.lucene.util.Constants;
+import org.elasticsearch.action.admin.cluster.snapshots.create.CreateSnapshotResponse;
+import org.elasticsearch.action.admin.cluster.snapshots.restore.RestoreSnapshotResponse;
 import org.elasticsearch.action.admin.indices.stats.ShardStats;
 import org.elasticsearch.action.admin.indices.stats.ShardStats;
 import org.elasticsearch.action.index.IndexRequestBuilder;
 import org.elasticsearch.action.index.IndexRequestBuilder;
 import org.elasticsearch.cluster.ClusterInfoService;
 import org.elasticsearch.cluster.ClusterInfoService;
@@ -32,6 +34,7 @@ import org.elasticsearch.cluster.routing.IndexShardRoutingTable;
 import org.elasticsearch.cluster.routing.ShardRouting;
 import org.elasticsearch.cluster.routing.ShardRouting;
 import org.elasticsearch.cluster.routing.ShardRoutingState;
 import org.elasticsearch.cluster.routing.ShardRoutingState;
 import org.elasticsearch.cluster.routing.allocation.DiskThresholdSettings;
 import org.elasticsearch.cluster.routing.allocation.DiskThresholdSettings;
+import org.elasticsearch.cluster.routing.allocation.decider.EnableAllocationDecider.Rebalance;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.Priority;
 import org.elasticsearch.common.Priority;
 import org.elasticsearch.common.io.PathUtils;
 import org.elasticsearch.common.io.PathUtils;
@@ -44,6 +47,10 @@ import org.elasticsearch.env.Environment;
 import org.elasticsearch.env.NodeEnvironment;
 import org.elasticsearch.env.NodeEnvironment;
 import org.elasticsearch.monitor.fs.FsService;
 import org.elasticsearch.monitor.fs.FsService;
 import org.elasticsearch.plugins.Plugin;
 import org.elasticsearch.plugins.Plugin;
+import org.elasticsearch.repositories.fs.FsRepository;
+import org.elasticsearch.snapshots.RestoreInfo;
+import org.elasticsearch.snapshots.SnapshotInfo;
+import org.elasticsearch.snapshots.SnapshotState;
 import org.elasticsearch.test.ESIntegTestCase;
 import org.elasticsearch.test.ESIntegTestCase;
 import org.elasticsearch.test.InternalSettingsPlugin;
 import org.elasticsearch.test.InternalSettingsPlugin;
 import org.junit.After;
 import org.junit.After;
@@ -62,6 +69,7 @@ import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collection;
 import java.util.HashSet;
 import java.util.HashSet;
 import java.util.List;
 import java.util.List;
+import java.util.Locale;
 import java.util.Map;
 import java.util.Map;
 import java.util.Set;
 import java.util.Set;
 import java.util.stream.Collectors;
 import java.util.stream.Collectors;
@@ -141,29 +149,95 @@ public class DiskThresholdDeciderIT extends ESIntegTestCase {
         final String dataNode0Id = internalCluster().getInstance(NodeEnvironment.class, dataNodeName).nodeId();
         final String dataNode0Id = internalCluster().getInstance(NodeEnvironment.class, dataNodeName).nodeId();
         final Path dataNode0Path = internalCluster().getInstance(Environment.class, dataNodeName).dataFiles()[0];
         final Path dataNode0Path = internalCluster().getInstance(Environment.class, dataNodeName).dataFiles()[0];
 
 
-        createIndex("test", Settings.builder()
+        final String indexName = randomAlphaOfLength(10).toLowerCase(Locale.ROOT);
+        createIndex(indexName, Settings.builder()
                 .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)
                 .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)
                 .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 6)
                 .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 6)
                 .put(INDEX_STORE_STATS_REFRESH_INTERVAL_SETTING.getKey(), "0ms")
                 .put(INDEX_STORE_STATS_REFRESH_INTERVAL_SETTING.getKey(), "0ms")
                 .build());
                 .build());
-        final long minShardSize = createReasonableSizedShards();
+        final long minShardSize = createReasonableSizedShards(indexName);
 
 
         // reduce disk size of node 0 so that no shards fit below the high watermark, forcing all shards onto the other data node
         // reduce disk size of node 0 so that no shards fit below the high watermark, forcing all shards onto the other data node
         // (subtract the translog size since the disk threshold decider ignores this and may therefore move the shard back again)
         // (subtract the translog size since the disk threshold decider ignores this and may therefore move the shard back again)
         fileSystemProvider.getTestFileStore(dataNode0Path).setTotalSpace(minShardSize + WATERMARK_BYTES - 1L);
         fileSystemProvider.getTestFileStore(dataNode0Path).setTotalSpace(minShardSize + WATERMARK_BYTES - 1L);
         refreshDiskUsage();
         refreshDiskUsage();
-        assertBusy(() -> assertThat(getShardRoutings(dataNode0Id), empty()));
+        assertBusy(() -> assertThat(getShardRoutings(dataNode0Id, indexName), empty()));
 
 
         // increase disk size of node 0 to allow just enough room for one shard, and check that it's rebalanced back
         // increase disk size of node 0 to allow just enough room for one shard, and check that it's rebalanced back
         fileSystemProvider.getTestFileStore(dataNode0Path).setTotalSpace(minShardSize + WATERMARK_BYTES + 1L);
         fileSystemProvider.getTestFileStore(dataNode0Path).setTotalSpace(minShardSize + WATERMARK_BYTES + 1L);
         refreshDiskUsage();
         refreshDiskUsage();
-        assertBusy(() -> assertThat(getShardRoutings(dataNode0Id), hasSize(1)));
+        assertBusy(() -> assertThat(getShardRoutings(dataNode0Id, indexName), hasSize(1)));
     }
     }
 
 
-    private Set<ShardRouting> getShardRoutings(String nodeId) {
+    public void testRestoreSnapshotAllocationDoesNotExceedWatermark() throws Exception {
+        internalCluster().startMasterOnlyNode();
+        internalCluster().startDataOnlyNode();
+        final String dataNodeName = internalCluster().startDataOnlyNode();
+        ensureStableCluster(3);
+
+        assertAcked(client().admin().cluster().preparePutRepository("repo")
+            .setType(FsRepository.TYPE)
+            .setSettings(Settings.builder()
+                .put("location", randomRepoPath())
+                .put("compress", randomBoolean())));
+
+        final InternalClusterInfoService clusterInfoService
+            = (InternalClusterInfoService) internalCluster().getMasterNodeInstance(ClusterInfoService.class);
+        internalCluster().getMasterNodeInstance(ClusterService.class).addListener(event -> clusterInfoService.refresh());
+
+        final String dataNode0Id = internalCluster().getInstance(NodeEnvironment.class, dataNodeName).nodeId();
+        final Path dataNode0Path = internalCluster().getInstance(Environment.class, dataNodeName).dataFiles()[0];
+
+        final String indexName = randomAlphaOfLength(10).toLowerCase(Locale.ROOT);
+        createIndex(indexName, Settings.builder()
+            .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)
+            .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 6)
+            .put(INDEX_STORE_STATS_REFRESH_INTERVAL_SETTING.getKey(), "0ms")
+            .build());
+        final long minShardSize = createReasonableSizedShards(indexName);
+
+        final CreateSnapshotResponse createSnapshotResponse = client().admin().cluster().prepareCreateSnapshot("repo", "snap")
+            .setWaitForCompletion(true).get();
+        final SnapshotInfo snapshotInfo = createSnapshotResponse.getSnapshotInfo();
+        assertThat(snapshotInfo.successfulShards(), is(snapshotInfo.totalShards()));
+        assertThat(snapshotInfo.state(), is(SnapshotState.SUCCESS));
+
+        assertAcked(client().admin().indices().prepareDelete(indexName).get());
+
+        // reduce disk size of node 0 so that no shards fit below the low watermark, forcing shards to be assigned to the other data node
+        fileSystemProvider.getTestFileStore(dataNode0Path).setTotalSpace(minShardSize + WATERMARK_BYTES - 1L);
+        refreshDiskUsage();
+
+        assertAcked(client().admin().cluster().prepareUpdateSettings()
+            .setTransientSettings(Settings.builder()
+                .put(EnableAllocationDecider.CLUSTER_ROUTING_REBALANCE_ENABLE_SETTING.getKey(), Rebalance.NONE.toString())
+                .build())
+            .get());
+
+        final RestoreSnapshotResponse restoreSnapshotResponse = client().admin().cluster().prepareRestoreSnapshot("repo", "snap")
+            .setWaitForCompletion(true).get();
+        final RestoreInfo restoreInfo = restoreSnapshotResponse.getRestoreInfo();
+        assertThat(restoreInfo.successfulShards(), is(snapshotInfo.totalShards()));
+        assertThat(restoreInfo.failedShards(), is(0));
+
+        assertBusy(() -> assertThat(getShardRoutings(dataNode0Id, indexName), empty()));
+
+        assertAcked(client().admin().cluster().prepareUpdateSettings()
+            .setTransientSettings(Settings.builder()
+                .putNull(EnableAllocationDecider.CLUSTER_ROUTING_REBALANCE_ENABLE_SETTING.getKey())
+                .build())
+            .get());
+
+        // increase disk size of node 0 to allow just enough room for one shard, and check that it's rebalanced back
+        fileSystemProvider.getTestFileStore(dataNode0Path).setTotalSpace(minShardSize + WATERMARK_BYTES + 1L);
+        refreshDiskUsage();
+        assertBusy(() -> assertThat(getShardRoutings(dataNode0Id, indexName), hasSize(1)));
+    }
+
+    private Set<ShardRouting> getShardRoutings(final String nodeId, final String indexName) {
         final Set<ShardRouting> shardRoutings = new HashSet<>();
         final Set<ShardRouting> shardRoutings = new HashSet<>();
         for (IndexShardRoutingTable indexShardRoutingTable : client().admin().cluster().prepareState().clear().setRoutingTable(true)
         for (IndexShardRoutingTable indexShardRoutingTable : client().admin().cluster().prepareState().clear().setRoutingTable(true)
-                .get().getState().getRoutingTable().index("test")) {
+                .get().getState().getRoutingTable().index(indexName)) {
             for (ShardRouting shard : indexShardRoutingTable.shards()) {
             for (ShardRouting shard : indexShardRoutingTable.shards()) {
                 assertThat(shard.state(), equalTo(ShardRoutingState.STARTED));
                 assertThat(shard.state(), equalTo(ShardRoutingState.STARTED));
                 if (shard.currentNodeId().equals(nodeId)) {
                 if (shard.currentNodeId().equals(nodeId)) {
@@ -177,17 +251,17 @@ public class DiskThresholdDeciderIT extends ESIntegTestCase {
     /**
     /**
      * Index documents until all the shards are at least WATERMARK_BYTES in size, and return the size of the smallest shard
      * Index documents until all the shards are at least WATERMARK_BYTES in size, and return the size of the smallest shard
      */
      */
-    private long createReasonableSizedShards() throws InterruptedException {
+    private long createReasonableSizedShards(final String indexName) throws InterruptedException {
         while (true) {
         while (true) {
             final IndexRequestBuilder[] indexRequestBuilders = new IndexRequestBuilder[scaledRandomIntBetween(100, 10000)];
             final IndexRequestBuilder[] indexRequestBuilders = new IndexRequestBuilder[scaledRandomIntBetween(100, 10000)];
             for (int i = 0; i < indexRequestBuilders.length; i++) {
             for (int i = 0; i < indexRequestBuilders.length; i++) {
-                indexRequestBuilders[i] = client().prepareIndex("test").setSource("field", randomAlphaOfLength(10));
+                indexRequestBuilders[i] = client().prepareIndex(indexName).setSource("field", randomAlphaOfLength(10));
             }
             }
             indexRandom(true, indexRequestBuilders);
             indexRandom(true, indexRequestBuilders);
             forceMerge();
             forceMerge();
             refresh();
             refresh();
 
 
-            final ShardStats[] shardStatses = client().admin().indices().prepareStats("test")
+            final ShardStats[] shardStatses = client().admin().indices().prepareStats(indexName)
                     .clear().setStore(true).setTranslog(true).get().getShards();
                     .clear().setStore(true).setTranslog(true).get().getShards();
             final long[] shardSizes = new long[shardStatses.length];
             final long[] shardSizes = new long[shardStatses.length];
             for (ShardStats shardStats : shardStatses) {
             for (ShardStats shardStats : shardStatses) {

+ 6 - 2
server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/TransportClusterAllocationExplainAction.java

@@ -42,6 +42,7 @@ import org.elasticsearch.cluster.routing.allocation.decider.AllocationDeciders;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.inject.Inject;
 import org.elasticsearch.common.inject.Inject;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.snapshots.SnapshotsInfoService;
 import org.elasticsearch.tasks.Task;
 import org.elasticsearch.tasks.Task;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.TransportService;
 import org.elasticsearch.transport.TransportService;
@@ -59,6 +60,7 @@ public class TransportClusterAllocationExplainAction
     private static final Logger logger = LogManager.getLogger(TransportClusterAllocationExplainAction.class);
     private static final Logger logger = LogManager.getLogger(TransportClusterAllocationExplainAction.class);
 
 
     private final ClusterInfoService clusterInfoService;
     private final ClusterInfoService clusterInfoService;
+    private final SnapshotsInfoService snapshotsInfoService;
     private final AllocationDeciders allocationDeciders;
     private final AllocationDeciders allocationDeciders;
     private final ShardsAllocator shardAllocator;
     private final ShardsAllocator shardAllocator;
     private final AllocationService allocationService;
     private final AllocationService allocationService;
@@ -67,11 +69,13 @@ public class TransportClusterAllocationExplainAction
     public TransportClusterAllocationExplainAction(TransportService transportService, ClusterService clusterService,
     public TransportClusterAllocationExplainAction(TransportService transportService, ClusterService clusterService,
                                                    ThreadPool threadPool, ActionFilters actionFilters,
                                                    ThreadPool threadPool, ActionFilters actionFilters,
                                                    IndexNameExpressionResolver indexNameExpressionResolver,
                                                    IndexNameExpressionResolver indexNameExpressionResolver,
-                                                   ClusterInfoService clusterInfoService, AllocationDeciders allocationDeciders,
+                                                   ClusterInfoService clusterInfoService, SnapshotsInfoService snapshotsInfoService,
+                                                   AllocationDeciders allocationDeciders,
                                                    ShardsAllocator shardAllocator, AllocationService allocationService) {
                                                    ShardsAllocator shardAllocator, AllocationService allocationService) {
         super(ClusterAllocationExplainAction.NAME, transportService, clusterService, threadPool, actionFilters,
         super(ClusterAllocationExplainAction.NAME, transportService, clusterService, threadPool, actionFilters,
             ClusterAllocationExplainRequest::new, indexNameExpressionResolver);
             ClusterAllocationExplainRequest::new, indexNameExpressionResolver);
         this.clusterInfoService = clusterInfoService;
         this.clusterInfoService = clusterInfoService;
+        this.snapshotsInfoService = snapshotsInfoService;
         this.allocationDeciders = allocationDeciders;
         this.allocationDeciders = allocationDeciders;
         this.shardAllocator = shardAllocator;
         this.shardAllocator = shardAllocator;
         this.allocationService = allocationService;
         this.allocationService = allocationService;
@@ -98,7 +102,7 @@ public class TransportClusterAllocationExplainAction
         final RoutingNodes routingNodes = state.getRoutingNodes();
         final RoutingNodes routingNodes = state.getRoutingNodes();
         final ClusterInfo clusterInfo = clusterInfoService.getClusterInfo();
         final ClusterInfo clusterInfo = clusterInfoService.getClusterInfo();
         final RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, routingNodes, state,
         final RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, routingNodes, state,
-                clusterInfo, System.nanoTime());
+                clusterInfo, snapshotsInfoService.snapshotShardSizes(), System.nanoTime());
 
 
         ShardRouting shardRouting = findShardToExplain(request, allocation);
         ShardRouting shardRouting = findShardToExplain(request, allocation);
         logger.debug("explaining the allocation for [{}], found shard [{}]", request, shardRouting);
         logger.debug("explaining the allocation for [{}], found shard [{}]", request, shardRouting);

+ 0 - 2
server/src/main/java/org/elasticsearch/cluster/ClusterInfo.java

@@ -22,7 +22,6 @@ package org.elasticsearch.cluster;
 import com.carrotsearch.hppc.ObjectHashSet;
 import com.carrotsearch.hppc.ObjectHashSet;
 import com.carrotsearch.hppc.cursors.ObjectCursor;
 import com.carrotsearch.hppc.cursors.ObjectCursor;
 import com.carrotsearch.hppc.cursors.ObjectObjectCursor;
 import com.carrotsearch.hppc.cursors.ObjectObjectCursor;
-
 import org.elasticsearch.cluster.routing.ShardRouting;
 import org.elasticsearch.cluster.routing.ShardRouting;
 import org.elasticsearch.common.collect.ImmutableOpenMap;
 import org.elasticsearch.common.collect.ImmutableOpenMap;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamInput;
@@ -33,7 +32,6 @@ import org.elasticsearch.common.xcontent.ToXContentFragment;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.index.store.StoreStats;
 import org.elasticsearch.index.store.StoreStats;
-
 import java.io.IOException;
 import java.io.IOException;
 import java.util.Map;
 import java.util.Map;
 import java.util.Objects;
 import java.util.Objects;

+ 3 - 2
server/src/main/java/org/elasticsearch/cluster/ClusterModule.java

@@ -75,6 +75,7 @@ import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
 import org.elasticsearch.persistent.PersistentTasksNodeService;
 import org.elasticsearch.persistent.PersistentTasksNodeService;
 import org.elasticsearch.plugins.ClusterPlugin;
 import org.elasticsearch.plugins.ClusterPlugin;
 import org.elasticsearch.script.ScriptMetadata;
 import org.elasticsearch.script.ScriptMetadata;
+import org.elasticsearch.snapshots.SnapshotsInfoService;
 import org.elasticsearch.tasks.Task;
 import org.elasticsearch.tasks.Task;
 import org.elasticsearch.tasks.TaskResultsService;
 import org.elasticsearch.tasks.TaskResultsService;
 
 
@@ -107,14 +108,14 @@ public class ClusterModule extends AbstractModule {
     final ShardsAllocator shardsAllocator;
     final ShardsAllocator shardsAllocator;
 
 
     public ClusterModule(Settings settings, ClusterService clusterService, List<ClusterPlugin> clusterPlugins,
     public ClusterModule(Settings settings, ClusterService clusterService, List<ClusterPlugin> clusterPlugins,
-                         ClusterInfoService clusterInfoService) {
+                         ClusterInfoService clusterInfoService, SnapshotsInfoService snapshotsInfoService) {
         this.clusterPlugins = clusterPlugins;
         this.clusterPlugins = clusterPlugins;
         this.deciderList = createAllocationDeciders(settings, clusterService.getClusterSettings(), clusterPlugins);
         this.deciderList = createAllocationDeciders(settings, clusterService.getClusterSettings(), clusterPlugins);
         this.allocationDeciders = new AllocationDeciders(deciderList);
         this.allocationDeciders = new AllocationDeciders(deciderList);
         this.shardsAllocator = createShardsAllocator(settings, clusterService.getClusterSettings(), clusterPlugins);
         this.shardsAllocator = createShardsAllocator(settings, clusterService.getClusterSettings(), clusterPlugins);
         this.clusterService = clusterService;
         this.clusterService = clusterService;
         this.indexNameExpressionResolver = new IndexNameExpressionResolver();
         this.indexNameExpressionResolver = new IndexNameExpressionResolver();
-        this.allocationService = new AllocationService(allocationDeciders, shardsAllocator, clusterInfoService);
+        this.allocationService = new AllocationService(allocationDeciders, shardsAllocator, clusterInfoService, snapshotsInfoService);
     }
     }
 
 
     public static List<Entry> getNamedWriteables() {
     public static List<Entry> getNamedWriteables() {

+ 13 - 9
server/src/main/java/org/elasticsearch/cluster/routing/allocation/AllocationService.java

@@ -45,6 +45,7 @@ import org.elasticsearch.common.collect.ImmutableOpenMap;
 import org.elasticsearch.common.logging.ESLogMessage;
 import org.elasticsearch.common.logging.ESLogMessage;
 import org.elasticsearch.gateway.GatewayAllocator;
 import org.elasticsearch.gateway.GatewayAllocator;
 import org.elasticsearch.gateway.PriorityComparator;
 import org.elasticsearch.gateway.PriorityComparator;
+import org.elasticsearch.snapshots.SnapshotsInfoService;
 
 
 import java.util.ArrayList;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.Collections;
@@ -75,19 +76,22 @@ public class AllocationService {
     private Map<String, ExistingShardsAllocator> existingShardsAllocators;
     private Map<String, ExistingShardsAllocator> existingShardsAllocators;
     private final ShardsAllocator shardsAllocator;
     private final ShardsAllocator shardsAllocator;
     private final ClusterInfoService clusterInfoService;
     private final ClusterInfoService clusterInfoService;
+    private SnapshotsInfoService snapshotsInfoService;
 
 
     // only for tests that use the GatewayAllocator as the unique ExistingShardsAllocator
     // only for tests that use the GatewayAllocator as the unique ExistingShardsAllocator
     public AllocationService(AllocationDeciders allocationDeciders, GatewayAllocator gatewayAllocator,
     public AllocationService(AllocationDeciders allocationDeciders, GatewayAllocator gatewayAllocator,
-                             ShardsAllocator shardsAllocator, ClusterInfoService clusterInfoService) {
-        this(allocationDeciders, shardsAllocator, clusterInfoService);
+                             ShardsAllocator shardsAllocator, ClusterInfoService clusterInfoService,
+                             SnapshotsInfoService snapshotsInfoService) {
+        this(allocationDeciders, shardsAllocator, clusterInfoService, snapshotsInfoService);
         setExistingShardsAllocators(Collections.singletonMap(GatewayAllocator.ALLOCATOR_NAME, gatewayAllocator));
         setExistingShardsAllocators(Collections.singletonMap(GatewayAllocator.ALLOCATOR_NAME, gatewayAllocator));
     }
     }
 
 
     public AllocationService(AllocationDeciders allocationDeciders, ShardsAllocator shardsAllocator,
     public AllocationService(AllocationDeciders allocationDeciders, ShardsAllocator shardsAllocator,
-                             ClusterInfoService clusterInfoService) {
+                             ClusterInfoService clusterInfoService, SnapshotsInfoService snapshotsInfoService) {
         this.allocationDeciders = allocationDeciders;
         this.allocationDeciders = allocationDeciders;
         this.shardsAllocator = shardsAllocator;
         this.shardsAllocator = shardsAllocator;
         this.clusterInfoService = clusterInfoService;
         this.clusterInfoService = clusterInfoService;
+        this.snapshotsInfoService = snapshotsInfoService;
     }
     }
 
 
     /**
     /**
@@ -114,7 +118,7 @@ public class AllocationService {
         // shuffle the unassigned nodes, just so we won't have things like poison failed shards
         // shuffle the unassigned nodes, just so we won't have things like poison failed shards
         routingNodes.unassigned().shuffle();
         routingNodes.unassigned().shuffle();
         RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, routingNodes, clusterState,
         RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, routingNodes, clusterState,
-            clusterInfoService.getClusterInfo(), currentNanoTime());
+            clusterInfoService.getClusterInfo(), snapshotsInfoService.snapshotShardSizes(), currentNanoTime());
         // as starting a primary relocation target can reinitialize replica shards, start replicas first
         // as starting a primary relocation target can reinitialize replica shards, start replicas first
         startedShards = new ArrayList<>(startedShards);
         startedShards = new ArrayList<>(startedShards);
         startedShards.sort(Comparator.comparing(ShardRouting::primary));
         startedShards.sort(Comparator.comparing(ShardRouting::primary));
@@ -193,7 +197,7 @@ public class AllocationService {
         routingNodes.unassigned().shuffle();
         routingNodes.unassigned().shuffle();
         long currentNanoTime = currentNanoTime();
         long currentNanoTime = currentNanoTime();
         RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, routingNodes, tmpState,
         RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, routingNodes, tmpState,
-            clusterInfoService.getClusterInfo(), currentNanoTime);
+            clusterInfoService.getClusterInfo(), snapshotsInfoService.snapshotShardSizes(), currentNanoTime);
 
 
         for (FailedShard failedShardEntry : failedShards) {
         for (FailedShard failedShardEntry : failedShards) {
             ShardRouting shardToFail = failedShardEntry.getRoutingEntry();
             ShardRouting shardToFail = failedShardEntry.getRoutingEntry();
@@ -247,7 +251,7 @@ public class AllocationService {
         // shuffle the unassigned nodes, just so we won't have things like poison failed shards
         // shuffle the unassigned nodes, just so we won't have things like poison failed shards
         routingNodes.unassigned().shuffle();
         routingNodes.unassigned().shuffle();
         RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, routingNodes, clusterState,
         RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, routingNodes, clusterState,
-            clusterInfoService.getClusterInfo(), currentNanoTime());
+            clusterInfoService.getClusterInfo(), snapshotsInfoService.snapshotShardSizes(), currentNanoTime());
 
 
         // first, clear from the shards any node id they used to belong to that is now dead
         // first, clear from the shards any node id they used to belong to that is now dead
         disassociateDeadNodes(allocation);
         disassociateDeadNodes(allocation);
@@ -268,7 +272,7 @@ public class AllocationService {
      */
      */
     public ClusterState adaptAutoExpandReplicas(ClusterState clusterState) {
     public ClusterState adaptAutoExpandReplicas(ClusterState clusterState) {
         RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, clusterState.getRoutingNodes(), clusterState,
         RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, clusterState.getRoutingNodes(), clusterState,
-            clusterInfoService.getClusterInfo(), currentNanoTime());
+            clusterInfoService.getClusterInfo(), snapshotsInfoService.snapshotShardSizes(), currentNanoTime());
         final Map<Integer, List<String>> autoExpandReplicaChanges =
         final Map<Integer, List<String>> autoExpandReplicaChanges =
             AutoExpandReplicas.getAutoExpandReplicaChanges(clusterState.metadata(), allocation);
             AutoExpandReplicas.getAutoExpandReplicaChanges(clusterState.metadata(), allocation);
         if (autoExpandReplicaChanges.isEmpty()) {
         if (autoExpandReplicaChanges.isEmpty()) {
@@ -362,7 +366,7 @@ public class AllocationService {
         // a consistent result of the effect the commands have on the routing
         // a consistent result of the effect the commands have on the routing
         // this allows systems to dry run the commands, see the resulting cluster state, and act on it
         // this allows systems to dry run the commands, see the resulting cluster state, and act on it
         RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, routingNodes, clusterState,
         RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, routingNodes, clusterState,
-            clusterInfoService.getClusterInfo(), currentNanoTime());
+            clusterInfoService.getClusterInfo(), snapshotsInfoService.snapshotShardSizes(), currentNanoTime());
         // don't short circuit deciders, we want a full explanation
         // don't short circuit deciders, we want a full explanation
         allocation.debugDecision(true);
         allocation.debugDecision(true);
         // we ignore disable allocation, because commands are explicit
         // we ignore disable allocation, because commands are explicit
@@ -393,7 +397,7 @@ public class AllocationService {
         // shuffle the unassigned nodes, just so we won't have things like poison failed shards
         // shuffle the unassigned nodes, just so we won't have things like poison failed shards
         routingNodes.unassigned().shuffle();
         routingNodes.unassigned().shuffle();
         RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, routingNodes, fixedClusterState,
         RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, routingNodes, fixedClusterState,
-            clusterInfoService.getClusterInfo(), currentNanoTime());
+            clusterInfoService.getClusterInfo(), snapshotsInfoService.snapshotShardSizes(), currentNanoTime());
         reroute(allocation);
         reroute(allocation);
         if (fixedClusterState == clusterState && allocation.routingNodesChanged() == false) {
         if (fixedClusterState == clusterState && allocation.routingNodesChanged() == false) {
             return clusterState;
             return clusterState;

+ 9 - 1
server/src/main/java/org/elasticsearch/cluster/routing/allocation/RoutingAllocation.java

@@ -33,6 +33,7 @@ import org.elasticsearch.cluster.routing.allocation.decider.Decision;
 import org.elasticsearch.common.collect.ImmutableOpenMap;
 import org.elasticsearch.common.collect.ImmutableOpenMap;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.snapshots.RestoreService.RestoreInProgressUpdater;
 import org.elasticsearch.snapshots.RestoreService.RestoreInProgressUpdater;
+import org.elasticsearch.snapshots.SnapshotShardSizeInfo;
 
 
 import java.util.HashMap;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.HashSet;
@@ -62,6 +63,8 @@ public class RoutingAllocation {
 
 
     private final ClusterInfo clusterInfo;
     private final ClusterInfo clusterInfo;
 
 
+    private final SnapshotShardSizeInfo shardSizeInfo;
+
     private Map<ShardId, Set<String>> ignoredShardToNodes = null;
     private Map<ShardId, Set<String>> ignoredShardToNodes = null;
 
 
     private boolean ignoreDisable = false;
     private boolean ignoreDisable = false;
@@ -88,7 +91,7 @@ public class RoutingAllocation {
      * @param currentNanoTime the nano time to use for all delay allocation calculation (typically {@link System#nanoTime()})
      * @param currentNanoTime the nano time to use for all delay allocation calculation (typically {@link System#nanoTime()})
      */
      */
     public RoutingAllocation(AllocationDeciders deciders, RoutingNodes routingNodes, ClusterState clusterState, ClusterInfo clusterInfo,
     public RoutingAllocation(AllocationDeciders deciders, RoutingNodes routingNodes, ClusterState clusterState, ClusterInfo clusterInfo,
-                             long currentNanoTime) {
+                             SnapshotShardSizeInfo shardSizeInfo, long currentNanoTime) {
         this.deciders = deciders;
         this.deciders = deciders;
         this.routingNodes = routingNodes;
         this.routingNodes = routingNodes;
         this.metadata = clusterState.metadata();
         this.metadata = clusterState.metadata();
@@ -96,6 +99,7 @@ public class RoutingAllocation {
         this.nodes = clusterState.nodes();
         this.nodes = clusterState.nodes();
         this.customs = clusterState.customs();
         this.customs = clusterState.customs();
         this.clusterInfo = clusterInfo;
         this.clusterInfo = clusterInfo;
+        this.shardSizeInfo = shardSizeInfo;
         this.currentNanoTime = currentNanoTime;
         this.currentNanoTime = currentNanoTime;
     }
     }
 
 
@@ -148,6 +152,10 @@ public class RoutingAllocation {
         return clusterInfo;
         return clusterInfo;
     }
     }
 
 
+    public SnapshotShardSizeInfo snapshotShardSizeInfo() {
+        return shardSizeInfo;
+    }
+
     public <T extends ClusterState.Custom> T custom(String key) {
     public <T extends ClusterState.Custom> T custom(String key) {
         return (T)customs.get(key);
         return (T)customs.get(key);
     }
     }

+ 3 - 2
server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/BalancedShardsAllocator.java

@@ -810,7 +810,7 @@ public class BalancedShardsAllocator implements ShardsAllocator {
 
 
                         final long shardSize = DiskThresholdDecider.getExpectedShardSize(shard,
                         final long shardSize = DiskThresholdDecider.getExpectedShardSize(shard,
                             ShardRouting.UNAVAILABLE_EXPECTED_SHARD_SIZE,
                             ShardRouting.UNAVAILABLE_EXPECTED_SHARD_SIZE,
-                            allocation.clusterInfo(), allocation.metadata(), allocation.routingTable());
+                            allocation.clusterInfo(), allocation.snapshotShardSizeInfo(), allocation.metadata(), allocation.routingTable());
                         shard = routingNodes.initializeShard(shard, minNode.getNodeId(), null, shardSize, allocation.changes());
                         shard = routingNodes.initializeShard(shard, minNode.getNodeId(), null, shardSize, allocation.changes());
                         minNode.addShard(shard);
                         minNode.addShard(shard);
                         if (!shard.primary()) {
                         if (!shard.primary()) {
@@ -832,7 +832,8 @@ public class BalancedShardsAllocator implements ShardsAllocator {
                             assert allocationDecision.getAllocationStatus() == AllocationStatus.DECIDERS_THROTTLED;
                             assert allocationDecision.getAllocationStatus() == AllocationStatus.DECIDERS_THROTTLED;
                             final long shardSize = DiskThresholdDecider.getExpectedShardSize(shard,
                             final long shardSize = DiskThresholdDecider.getExpectedShardSize(shard,
                                 ShardRouting.UNAVAILABLE_EXPECTED_SHARD_SIZE,
                                 ShardRouting.UNAVAILABLE_EXPECTED_SHARD_SIZE,
-                                allocation.clusterInfo(), allocation.metadata(), allocation.routingTable());
+                                allocation.clusterInfo(), allocation.snapshotShardSizeInfo(), allocation.metadata(),
+                                allocation.routingTable());
                             minNode.addShard(shard.initialize(minNode.getNodeId(), null, shardSize));
                             minNode.addShard(shard.initialize(minNode.getNodeId(), null, shardSize));
                         } else {
                         } else {
                             if (logger.isTraceEnabled()) {
                             if (logger.isTraceEnabled()) {

+ 13 - 6
server/src/main/java/org/elasticsearch/cluster/routing/allocation/decider/DiskThresholdDecider.java

@@ -43,6 +43,7 @@ import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.index.Index;
 import org.elasticsearch.index.Index;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.index.shard.ShardId;
+import org.elasticsearch.snapshots.SnapshotShardSizeInfo;
 
 
 import java.util.List;
 import java.util.List;
 import java.util.Set;
 import java.util.Set;
@@ -120,7 +121,7 @@ public class DiskThresholdDecider extends AllocationDecider {
             // if we don't yet know the actual path of the incoming shard then conservatively assume it's going to the path with the least
             // if we don't yet know the actual path of the incoming shard then conservatively assume it's going to the path with the least
             // free space
             // free space
             if (actualPath == null || actualPath.equals(dataPath)) {
             if (actualPath == null || actualPath.equals(dataPath)) {
-                totalSize += getExpectedShardSize(routing, 0L, clusterInfo, metadata, routingTable);
+                totalSize += getExpectedShardSize(routing, 0L, clusterInfo, null, metadata, routingTable);
             }
             }
         }
         }
 
 
@@ -132,7 +133,7 @@ public class DiskThresholdDecider extends AllocationDecider {
                     actualPath = clusterInfo.getDataPath(routing.cancelRelocation());
                     actualPath = clusterInfo.getDataPath(routing.cancelRelocation());
                 }
                 }
                 if (dataPath.equals(actualPath)) {
                 if (dataPath.equals(actualPath)) {
-                    totalSize -= getExpectedShardSize(routing, 0L, clusterInfo, metadata, routingTable);
+                    totalSize -= getExpectedShardSize(routing, 0L, clusterInfo, null, metadata, routingTable);
                 }
                 }
             }
             }
         }
         }
@@ -153,7 +154,7 @@ public class DiskThresholdDecider extends AllocationDecider {
         final double usedDiskThresholdLow = 100.0 - diskThresholdSettings.getFreeDiskThresholdLow();
         final double usedDiskThresholdLow = 100.0 - diskThresholdSettings.getFreeDiskThresholdLow();
         final double usedDiskThresholdHigh = 100.0 - diskThresholdSettings.getFreeDiskThresholdHigh();
         final double usedDiskThresholdHigh = 100.0 - diskThresholdSettings.getFreeDiskThresholdHigh();
 
 
-        // subtractLeavingShards is passed as false here, because they still use disk space, and therefore should we should be extra careful
+        // subtractLeavingShards is passed as false here, because they still use disk space, and therefore we should be extra careful
         // and take the size into account
         // and take the size into account
         final DiskUsageWithRelocations usage = getDiskUsage(node, allocation, usages, false);
         final DiskUsageWithRelocations usage = getDiskUsage(node, allocation, usages, false);
         // First, check that the node currently over the low watermark
         // First, check that the node currently over the low watermark
@@ -270,7 +271,7 @@ public class DiskThresholdDecider extends AllocationDecider {
 
 
         // Secondly, check that allocating the shard to this node doesn't put it above the high watermark
         // Secondly, check that allocating the shard to this node doesn't put it above the high watermark
         final long shardSize = getExpectedShardSize(shardRouting, 0L,
         final long shardSize = getExpectedShardSize(shardRouting, 0L,
-            allocation.clusterInfo(), allocation.metadata(), allocation.routingTable());
+            allocation.clusterInfo(), allocation.snapshotShardSizeInfo(), allocation.metadata(), allocation.routingTable());
         assert shardSize >= 0 : shardSize;
         assert shardSize >= 0 : shardSize;
         double freeSpaceAfterShard = freeDiskPercentageAfterShardAssigned(usage, shardSize);
         double freeSpaceAfterShard = freeDiskPercentageAfterShardAssigned(usage, shardSize);
         long freeBytesAfterShard = freeBytes - shardSize;
         long freeBytesAfterShard = freeBytes - shardSize;
@@ -461,8 +462,9 @@ public class DiskThresholdDecider extends AllocationDecider {
      * Returns the expected shard size for the given shard or the default value provided if not enough information are available
      * Returns the expected shard size for the given shard or the default value provided if not enough information are available
      * to estimate the shards size.
      * to estimate the shards size.
      */
      */
-    public static long getExpectedShardSize(ShardRouting shard, long defaultValue, ClusterInfo clusterInfo, Metadata metadata,
-                                            RoutingTable routingTable) {
+    public static long getExpectedShardSize(ShardRouting shard, long defaultValue, ClusterInfo clusterInfo,
+                                            SnapshotShardSizeInfo snapshotShardSizeInfo,
+                                            Metadata metadata, RoutingTable routingTable) {
         final IndexMetadata indexMetadata = metadata.getIndexSafe(shard.index());
         final IndexMetadata indexMetadata = metadata.getIndexSafe(shard.index());
         if (indexMetadata.getResizeSourceIndex() != null && shard.active() == false &&
         if (indexMetadata.getResizeSourceIndex() != null && shard.active() == false &&
             shard.recoverySource().getType() == RecoverySource.Type.LOCAL_SHARDS) {
             shard.recoverySource().getType() == RecoverySource.Type.LOCAL_SHARDS) {
@@ -482,6 +484,11 @@ public class DiskThresholdDecider extends AllocationDecider {
             }
             }
             return targetShardSize == 0 ? defaultValue : targetShardSize;
             return targetShardSize == 0 ? defaultValue : targetShardSize;
         } else {
         } else {
+            if (shard.unassigned() && shard.recoverySource().getType() == RecoverySource.Type.SNAPSHOT) {
+                final Long shardSize = snapshotShardSizeInfo.getShardSize(shard);
+                assert shardSize != null : "no shard size provided for " + shard;
+                return shardSize;
+            }
             return clusterInfo.getShardSize(shard, defaultValue);
             return clusterInfo.getShardSize(shard, defaultValue);
         }
         }
     }
     }

+ 2 - 0
server/src/main/java/org/elasticsearch/common/settings/ClusterSettings.java

@@ -107,6 +107,7 @@ import org.elasticsearch.search.SearchModule;
 import org.elasticsearch.search.SearchService;
 import org.elasticsearch.search.SearchService;
 import org.elasticsearch.search.aggregations.MultiBucketConsumerService;
 import org.elasticsearch.search.aggregations.MultiBucketConsumerService;
 import org.elasticsearch.search.fetch.subphase.highlight.FastVectorHighlighter;
 import org.elasticsearch.search.fetch.subphase.highlight.FastVectorHighlighter;
+import org.elasticsearch.snapshots.InternalSnapshotsInfoService;
 import org.elasticsearch.snapshots.SnapshotsService;
 import org.elasticsearch.snapshots.SnapshotsService;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.ProxyConnectionStrategy;
 import org.elasticsearch.transport.ProxyConnectionStrategy;
@@ -235,6 +236,7 @@ public final class ClusterSettings extends AbstractScopedSettings {
             SameShardAllocationDecider.CLUSTER_ROUTING_ALLOCATION_SAME_HOST_SETTING,
             SameShardAllocationDecider.CLUSTER_ROUTING_ALLOCATION_SAME_HOST_SETTING,
             InternalClusterInfoService.INTERNAL_CLUSTER_INFO_UPDATE_INTERVAL_SETTING,
             InternalClusterInfoService.INTERNAL_CLUSTER_INFO_UPDATE_INTERVAL_SETTING,
             InternalClusterInfoService.INTERNAL_CLUSTER_INFO_TIMEOUT_SETTING,
             InternalClusterInfoService.INTERNAL_CLUSTER_INFO_TIMEOUT_SETTING,
+            InternalSnapshotsInfoService.INTERNAL_SNAPSHOT_INFO_MAX_CONCURRENT_FETCHES_SETTING,
             DestructiveOperations.REQUIRES_NAME_SETTING,
             DestructiveOperations.REQUIRES_NAME_SETTING,
             NoMasterBlockService.NO_MASTER_BLOCK_SETTING,
             NoMasterBlockService.NO_MASTER_BLOCK_SETTING,
             GatewayService.EXPECTED_DATA_NODES_SETTING,
             GatewayService.EXPECTED_DATA_NODES_SETTING,

+ 14 - 2
server/src/main/java/org/elasticsearch/gateway/BaseGatewayShardAllocator.java

@@ -21,6 +21,7 @@ package org.elasticsearch.gateway;
 
 
 import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 import org.apache.logging.log4j.Logger;
+import org.elasticsearch.cluster.routing.RecoverySource;
 import org.elasticsearch.cluster.routing.RoutingNode;
 import org.elasticsearch.cluster.routing.RoutingNode;
 import org.elasticsearch.cluster.routing.ShardRouting;
 import org.elasticsearch.cluster.routing.ShardRouting;
 import org.elasticsearch.cluster.routing.allocation.AllocateUnassignedDecision;
 import org.elasticsearch.cluster.routing.allocation.AllocateUnassignedDecision;
@@ -64,14 +65,25 @@ public abstract class BaseGatewayShardAllocator {
         if (allocateUnassignedDecision.getAllocationDecision() == AllocationDecision.YES) {
         if (allocateUnassignedDecision.getAllocationDecision() == AllocationDecision.YES) {
             unassignedAllocationHandler.initialize(allocateUnassignedDecision.getTargetNode().getId(),
             unassignedAllocationHandler.initialize(allocateUnassignedDecision.getTargetNode().getId(),
                 allocateUnassignedDecision.getAllocationId(),
                 allocateUnassignedDecision.getAllocationId(),
-                shardRouting.primary() ? ShardRouting.UNAVAILABLE_EXPECTED_SHARD_SIZE :
-                                         allocation.clusterInfo().getShardSize(shardRouting, ShardRouting.UNAVAILABLE_EXPECTED_SHARD_SIZE),
+                getExpectedShardSize(shardRouting, allocation),
                 allocation.changes());
                 allocation.changes());
         } else {
         } else {
             unassignedAllocationHandler.removeAndIgnore(allocateUnassignedDecision.getAllocationStatus(), allocation.changes());
             unassignedAllocationHandler.removeAndIgnore(allocateUnassignedDecision.getAllocationStatus(), allocation.changes());
         }
         }
     }
     }
 
 
+    protected long getExpectedShardSize(ShardRouting shardRouting, RoutingAllocation allocation) {
+        if (shardRouting.primary()) {
+            if (shardRouting.recoverySource().getType() == RecoverySource.Type.SNAPSHOT) {
+                return allocation.snapshotShardSizeInfo().getShardSize(shardRouting, ShardRouting.UNAVAILABLE_EXPECTED_SHARD_SIZE);
+            } else {
+                return ShardRouting.UNAVAILABLE_EXPECTED_SHARD_SIZE;
+            }
+        } else {
+            return allocation.clusterInfo().getShardSize(shardRouting, ShardRouting.UNAVAILABLE_EXPECTED_SHARD_SIZE);
+        }
+    }
+
     /**
     /**
      * Make a decision on the allocation of an unassigned shard.  This method is used by
      * Make a decision on the allocation of an unassigned shard.  This method is used by
      * {@link #allocateUnassigned(ShardRouting, RoutingAllocation, ExistingShardsAllocator.UnassignedAllocationHandler)} to make decisions
      * {@link #allocateUnassigned(ShardRouting, RoutingAllocation, ExistingShardsAllocator.UnassignedAllocationHandler)} to make decisions

+ 11 - 0
server/src/main/java/org/elasticsearch/gateway/PrimaryShardAllocator.java

@@ -27,6 +27,7 @@ import org.elasticsearch.cluster.routing.RecoverySource;
 import org.elasticsearch.cluster.routing.RoutingNode;
 import org.elasticsearch.cluster.routing.RoutingNode;
 import org.elasticsearch.cluster.routing.RoutingNodes;
 import org.elasticsearch.cluster.routing.RoutingNodes;
 import org.elasticsearch.cluster.routing.ShardRouting;
 import org.elasticsearch.cluster.routing.ShardRouting;
+import org.elasticsearch.cluster.routing.UnassignedInfo;
 import org.elasticsearch.cluster.routing.UnassignedInfo.AllocationStatus;
 import org.elasticsearch.cluster.routing.UnassignedInfo.AllocationStatus;
 import org.elasticsearch.cluster.routing.allocation.AllocateUnassignedDecision;
 import org.elasticsearch.cluster.routing.allocation.AllocateUnassignedDecision;
 import org.elasticsearch.cluster.routing.allocation.NodeAllocationResult;
 import org.elasticsearch.cluster.routing.allocation.NodeAllocationResult;
@@ -83,6 +84,16 @@ public abstract class PrimaryShardAllocator extends BaseGatewayShardAllocator {
         }
         }
 
 
         final boolean explain = allocation.debugDecision();
         final boolean explain = allocation.debugDecision();
+
+        if (unassignedShard.recoverySource().getType() == RecoverySource.Type.SNAPSHOT &&
+            allocation.snapshotShardSizeInfo().getShardSize(unassignedShard) == null) {
+            List<NodeAllocationResult> nodeDecisions = null;
+            if (explain) {
+                nodeDecisions = buildDecisionsForAllNodes(unassignedShard, allocation);
+            }
+            return AllocateUnassignedDecision.no(UnassignedInfo.AllocationStatus.FETCHING_SHARD_DATA, nodeDecisions);
+        }
+
         final FetchResult<NodeGatewayStartedShards> shardState = fetchData(unassignedShard, allocation);
         final FetchResult<NodeGatewayStartedShards> shardState = fetchData(unassignedShard, allocation);
         if (shardState.hasData() == false) {
         if (shardState.hasData() == false) {
             allocation.setHasPendingAsyncFetch();
             allocation.setHasPendingAsyncFetch();

+ 12 - 3
server/src/main/java/org/elasticsearch/node/Node.java

@@ -157,8 +157,10 @@ import org.elasticsearch.search.SearchModule;
 import org.elasticsearch.search.SearchService;
 import org.elasticsearch.search.SearchService;
 import org.elasticsearch.search.aggregations.support.AggregationUsageService;
 import org.elasticsearch.search.aggregations.support.AggregationUsageService;
 import org.elasticsearch.search.fetch.FetchPhase;
 import org.elasticsearch.search.fetch.FetchPhase;
+import org.elasticsearch.snapshots.InternalSnapshotsInfoService;
 import org.elasticsearch.snapshots.RestoreService;
 import org.elasticsearch.snapshots.RestoreService;
 import org.elasticsearch.snapshots.SnapshotShardsService;
 import org.elasticsearch.snapshots.SnapshotShardsService;
+import org.elasticsearch.snapshots.SnapshotsInfoService;
 import org.elasticsearch.snapshots.SnapshotsService;
 import org.elasticsearch.snapshots.SnapshotsService;
 import org.elasticsearch.tasks.Task;
 import org.elasticsearch.tasks.Task;
 import org.elasticsearch.tasks.TaskCancellationService;
 import org.elasticsearch.tasks.TaskCancellationService;
@@ -394,6 +396,7 @@ public class Node implements Closeable {
             final IngestService ingestService = new IngestService(clusterService, threadPool, this.environment,
             final IngestService ingestService = new IngestService(clusterService, threadPool, this.environment,
                 scriptService, analysisModule.getAnalysisRegistry(),
                 scriptService, analysisModule.getAnalysisRegistry(),
                 pluginsService.filterPlugins(IngestPlugin.class), client);
                 pluginsService.filterPlugins(IngestPlugin.class), client);
+            final SetOnce<RepositoriesService> repositoriesServiceReference = new SetOnce<>();
             final ClusterInfoService clusterInfoService = newClusterInfoService(settings, clusterService, threadPool, client);
             final ClusterInfoService clusterInfoService = newClusterInfoService(settings, clusterService, threadPool, client);
             final UsageService usageService = new UsageService();
             final UsageService usageService = new UsageService();
 
 
@@ -401,7 +404,11 @@ public class Node implements Closeable {
             final MonitorService monitorService = new MonitorService(settings, nodeEnvironment, threadPool);
             final MonitorService monitorService = new MonitorService(settings, nodeEnvironment, threadPool);
             final FsHealthService fsHealthService = new FsHealthService(settings, clusterService.getClusterSettings(), threadPool,
             final FsHealthService fsHealthService = new FsHealthService(settings, clusterService.getClusterSettings(), threadPool,
                 nodeEnvironment);
                 nodeEnvironment);
-            ClusterModule clusterModule = new ClusterModule(settings, clusterService, clusterPlugins, clusterInfoService);
+            final SetOnce<RerouteService> rerouteServiceReference = new SetOnce<>();
+            final InternalSnapshotsInfoService snapshotsInfoService = new InternalSnapshotsInfoService(settings, clusterService,
+                repositoriesServiceReference::get, rerouteServiceReference::get);
+            final ClusterModule clusterModule = new ClusterModule(settings, clusterService, clusterPlugins, clusterInfoService,
+                snapshotsInfoService);
             modules.add(clusterModule);
             modules.add(clusterModule);
             IndicesModule indicesModule = new IndicesModule(pluginsService.filterPlugins(MapperPlugin.class));
             IndicesModule indicesModule = new IndicesModule(pluginsService.filterPlugins(MapperPlugin.class));
             modules.add(indicesModule);
             modules.add(indicesModule);
@@ -479,6 +486,7 @@ public class Node implements Closeable {
 
 
             final RerouteService rerouteService
             final RerouteService rerouteService
                 = new BatchedRerouteService(clusterService, clusterModule.getAllocationService()::reroute);
                 = new BatchedRerouteService(clusterService, clusterModule.getAllocationService()::reroute);
+            rerouteServiceReference.set(rerouteService);
             clusterService.setRerouteService(rerouteService);
             clusterService.setRerouteService(rerouteService);
 
 
             final IndicesService indicesService =
             final IndicesService indicesService =
@@ -512,7 +520,6 @@ public class Node implements Closeable {
             final MetadataCreateDataStreamService metadataCreateDataStreamService =
             final MetadataCreateDataStreamService metadataCreateDataStreamService =
                 new MetadataCreateDataStreamService(threadPool, clusterService, metadataCreateIndexService);
                 new MetadataCreateDataStreamService(threadPool, clusterService, metadataCreateIndexService);
 
 
-            final SetOnce<RepositoriesService> repositoriesServiceReference = new SetOnce<>();
             Collection<Object> pluginComponents = pluginsService.filterPlugins(Plugin.class).stream()
             Collection<Object> pluginComponents = pluginsService.filterPlugins(Plugin.class).stream()
                 .flatMap(p -> p.createComponents(client, clusterService, threadPool, resourceWatcherService,
                 .flatMap(p -> p.createComponents(client, clusterService, threadPool, resourceWatcherService,
                                                  scriptService, xContentRegistry, environment, nodeEnvironment,
                                                  scriptService, xContentRegistry, environment, nodeEnvironment,
@@ -634,6 +641,7 @@ public class Node implements Closeable {
                     b.bind(UpdateHelper.class).toInstance(new UpdateHelper(scriptService));
                     b.bind(UpdateHelper.class).toInstance(new UpdateHelper(scriptService));
                     b.bind(MetadataIndexUpgradeService.class).toInstance(metadataIndexUpgradeService);
                     b.bind(MetadataIndexUpgradeService.class).toInstance(metadataIndexUpgradeService);
                     b.bind(ClusterInfoService.class).toInstance(clusterInfoService);
                     b.bind(ClusterInfoService.class).toInstance(clusterInfoService);
+                    b.bind(SnapshotsInfoService.class).toInstance(snapshotsInfoService);
                     b.bind(GatewayMetaState.class).toInstance(gatewayMetaState);
                     b.bind(GatewayMetaState.class).toInstance(gatewayMetaState);
                     b.bind(Discovery.class).toInstance(discoveryModule.getDiscovery());
                     b.bind(Discovery.class).toInstance(discoveryModule.getDiscovery());
                     {
                     {
@@ -1129,7 +1137,8 @@ public class Node implements Closeable {
     /** Constructs a ClusterInfoService which may be mocked for tests. */
     /** Constructs a ClusterInfoService which may be mocked for tests. */
     protected ClusterInfoService newClusterInfoService(Settings settings, ClusterService clusterService,
     protected ClusterInfoService newClusterInfoService(Settings settings, ClusterService clusterService,
                                                        ThreadPool threadPool, NodeClient client) {
                                                        ThreadPool threadPool, NodeClient client) {
-        final InternalClusterInfoService service = new InternalClusterInfoService(settings, clusterService, threadPool, client);
+        final InternalClusterInfoService service =
+            new InternalClusterInfoService(settings, clusterService, threadPool, client);
         // listen for state changes (this node starts/stops being the elected master, or new nodes are added)
         // listen for state changes (this node starts/stops being the elected master, or new nodes are added)
         clusterService.addListener(service);
         clusterService.addListener(service);
         return service;
         return service;

+ 31 - 0
server/src/main/java/org/elasticsearch/snapshots/EmptySnapshotsInfoService.java

@@ -0,0 +1,31 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.snapshots;
+
+import org.elasticsearch.common.collect.ImmutableOpenMap;
+
+public class EmptySnapshotsInfoService implements SnapshotsInfoService {
+    public static final EmptySnapshotsInfoService INSTANCE = new EmptySnapshotsInfoService();
+
+    @Override
+    public SnapshotShardSizeInfo snapshotShardSizes() {
+        return new SnapshotShardSizeInfo(ImmutableOpenMap.of());
+    }
+}

+ 386 - 0
server/src/main/java/org/elasticsearch/snapshots/InternalSnapshotsInfoService.java

@@ -0,0 +1,386 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.snapshots;
+
+import com.carrotsearch.hppc.cursors.ObjectCursor;
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.apache.logging.log4j.message.ParameterizedMessage;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.cluster.ClusterChangedEvent;
+import org.elasticsearch.cluster.ClusterState;
+import org.elasticsearch.cluster.ClusterStateListener;
+import org.elasticsearch.cluster.node.DiscoveryNode;
+import org.elasticsearch.cluster.routing.RecoverySource;
+import org.elasticsearch.cluster.routing.RerouteService;
+import org.elasticsearch.cluster.routing.ShardRouting;
+import org.elasticsearch.cluster.routing.ShardRoutingState;
+import org.elasticsearch.cluster.service.ClusterService;
+import org.elasticsearch.common.Priority;
+import org.elasticsearch.common.collect.ImmutableOpenMap;
+import org.elasticsearch.common.settings.ClusterSettings;
+import org.elasticsearch.common.settings.Setting;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.util.concurrent.AbstractRunnable;
+import org.elasticsearch.index.shard.ShardId;
+import org.elasticsearch.repositories.IndexId;
+import org.elasticsearch.repositories.RepositoriesService;
+import org.elasticsearch.repositories.Repository;
+import org.elasticsearch.threadpool.ThreadPool;
+
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.LinkedHashSet;
+import java.util.LinkedList;
+import java.util.Objects;
+import java.util.Queue;
+import java.util.Set;
+import java.util.function.Supplier;
+
+public class InternalSnapshotsInfoService implements ClusterStateListener, SnapshotsInfoService {
+
+    public static final Setting<Integer> INTERNAL_SNAPSHOT_INFO_MAX_CONCURRENT_FETCHES_SETTING =
+        Setting.intSetting("cluster.snapshot.info.max_concurrent_fetches", 5, 1,
+            Setting.Property.Dynamic, Setting.Property.NodeScope);
+
+    private static final Logger logger = LogManager.getLogger(InternalSnapshotsInfoService.class);
+
+    private static final ActionListener<ClusterState> REROUTE_LISTENER = ActionListener.wrap(
+        r -> logger.trace("reroute after snapshot shard size update completed"),
+        e -> logger.debug("reroute after snapshot shard size update failed", e)
+    );
+
+    private final ThreadPool threadPool;
+    private final Supplier<RepositoriesService> repositoriesService;
+    private final Supplier<RerouteService> rerouteService;
+
+    /** contains the snapshot shards for which the size is known **/
+    private volatile ImmutableOpenMap<SnapshotShard, Long> knownSnapshotShardSizes;
+
+    private volatile boolean isMaster;
+
+    /** contains the snapshot shards for which the size is unknown and must be fetched (or is being fetched) **/
+    private final Set<SnapshotShard> unknownSnapshotShards;
+
+    /** a blocking queue used for concurrent fetching **/
+    private final Queue<SnapshotShard> queue;
+
+    /** contains the snapshot shards for which the snapshot shard size retrieval failed **/
+    private final Set<SnapshotShard> failedSnapshotShards;
+
+    private volatile int maxConcurrentFetches;
+    private int activeFetches;
+
+    private final Object mutex;
+
+    public InternalSnapshotsInfoService(
+        final Settings settings,
+        final ClusterService clusterService,
+        final Supplier<RepositoriesService> repositoriesServiceSupplier,
+        final Supplier<RerouteService> rerouteServiceSupplier
+    ) {
+        this.threadPool = clusterService.getClusterApplierService().threadPool();
+        this.repositoriesService = repositoriesServiceSupplier;
+        this.rerouteService = rerouteServiceSupplier;
+        this.knownSnapshotShardSizes = ImmutableOpenMap.of();
+        this.unknownSnapshotShards  = new LinkedHashSet<>();
+        this.failedSnapshotShards  = new LinkedHashSet<>();
+        this.queue = new LinkedList<>();
+        this.mutex = new Object();
+        this.activeFetches = 0;
+        this.maxConcurrentFetches = INTERNAL_SNAPSHOT_INFO_MAX_CONCURRENT_FETCHES_SETTING.get(settings);
+        final ClusterSettings clusterSettings = clusterService.getClusterSettings();
+        clusterSettings.addSettingsUpdateConsumer(INTERNAL_SNAPSHOT_INFO_MAX_CONCURRENT_FETCHES_SETTING, this::setMaxConcurrentFetches);
+        if (DiscoveryNode.isMasterNode(settings)) {
+            clusterService.addListener(this);
+        }
+    }
+
+    private void setMaxConcurrentFetches(Integer maxConcurrentFetches) {
+        this.maxConcurrentFetches = maxConcurrentFetches;
+    }
+
+    @Override
+    public SnapshotShardSizeInfo snapshotShardSizes() {
+        synchronized (mutex){
+            final ImmutableOpenMap.Builder<SnapshotShard, Long> snapshotShardSizes = ImmutableOpenMap.builder(knownSnapshotShardSizes);
+            for (SnapshotShard snapshotShard : failedSnapshotShards) {
+                Long previous = snapshotShardSizes.put(snapshotShard, ShardRouting.UNAVAILABLE_EXPECTED_SHARD_SIZE);
+                assert previous == null : "snapshot shard size already known for " + snapshotShard;
+            }
+            return new SnapshotShardSizeInfo(snapshotShardSizes.build());
+        }
+    }
+
+    @Override
+    public void clusterChanged(ClusterChangedEvent event) {
+        if (event.localNodeMaster()) {
+            final Set<SnapshotShard> onGoingSnapshotRecoveries = listOfSnapshotShards(event.state());
+
+            int unknownShards = 0;
+            synchronized (mutex) {
+                isMaster = true;
+                for (SnapshotShard snapshotShard : onGoingSnapshotRecoveries) {
+                    // check if already populated entry
+                    if (knownSnapshotShardSizes.containsKey(snapshotShard) == false) {
+                        // check if already fetching snapshot info in progress
+                        if (unknownSnapshotShards.add(snapshotShard)) {
+                            failedSnapshotShards.remove(snapshotShard); // retry the failed shard
+                            queue.add(snapshotShard);
+                            unknownShards += 1;
+                        }
+                    }
+                }
+                // Clean up keys from knownSnapshotShardSizes that are no longer needed for recoveries
+                cleanUpKnownSnapshotShardSizes(onGoingSnapshotRecoveries);
+            }
+
+            final int nbFetchers = Math.min(unknownShards, maxConcurrentFetches);
+            for (int i = 0; i < nbFetchers; i++) {
+                fetchNextSnapshotShard();
+            }
+
+        } else if (event.previousState().nodes().isLocalNodeElectedMaster()) {
+            // TODO Maybe just clear out non-ongoing snapshot recoveries is the node is master eligible, so that we don't
+            // have to repopulate the data over and over in an unstable master situation?
+            synchronized (mutex) {
+                // information only needed on current master
+                knownSnapshotShardSizes = ImmutableOpenMap.of();
+                failedSnapshotShards.clear();
+                isMaster = false;
+                SnapshotShard snapshotShard;
+                while ((snapshotShard = queue.poll()) != null) {
+                    final boolean removed = unknownSnapshotShards.remove(snapshotShard);
+                    assert removed : "snapshot shard to remove does not exist " + snapshotShard;
+                }
+                assert invariant();
+            }
+        } else {
+            synchronized (mutex) {
+                assert unknownSnapshotShards.isEmpty() || unknownSnapshotShards.size() == activeFetches;
+                assert knownSnapshotShardSizes.isEmpty();
+                assert failedSnapshotShards.isEmpty();
+                assert isMaster == false;
+                assert queue.isEmpty();
+            }
+        }
+    }
+
+    private void fetchNextSnapshotShard() {
+        synchronized (mutex) {
+            if (activeFetches < maxConcurrentFetches) {
+                final SnapshotShard snapshotShard = queue.poll();
+                if (snapshotShard != null) {
+                    activeFetches += 1;
+                    threadPool.generic().execute(new FetchingSnapshotShardSizeRunnable(snapshotShard));
+                }
+            }
+            assert invariant();
+        }
+    }
+
+    private class FetchingSnapshotShardSizeRunnable extends AbstractRunnable {
+
+        private final SnapshotShard snapshotShard;
+        private boolean removed;
+
+        FetchingSnapshotShardSizeRunnable(SnapshotShard snapshotShard) {
+            super();
+            this.snapshotShard = snapshotShard;
+            this.removed = false;
+        }
+
+        @Override
+        protected void doRun() throws Exception {
+            final RepositoriesService repositories = repositoriesService.get();
+            assert repositories != null;
+            final Repository repository = repositories.repository(snapshotShard.snapshot.getRepository());
+
+            logger.debug("fetching snapshot shard size for {}", snapshotShard);
+            final long snapshotShardSize = repository.getShardSnapshotStatus(
+                snapshotShard.snapshot().getSnapshotId(),
+                snapshotShard.index(),
+                snapshotShard.shardId()
+            ).asCopy().getTotalSize();
+
+            logger.debug("snapshot shard size for {}: {} bytes", snapshotShard, snapshotShardSize);
+
+            boolean updated = false;
+            synchronized (mutex) {
+                removed = unknownSnapshotShards.remove(snapshotShard);
+                assert removed : "snapshot shard to remove does not exist " + snapshotShardSize;
+                if (isMaster) {
+                    final ImmutableOpenMap.Builder<SnapshotShard, Long> newSnapshotShardSizes =
+                        ImmutableOpenMap.builder(knownSnapshotShardSizes);
+                    updated = newSnapshotShardSizes.put(snapshotShard, snapshotShardSize) == null;
+                    assert updated : "snapshot shard size already exists for " + snapshotShard;
+                    knownSnapshotShardSizes = newSnapshotShardSizes.build();
+                }
+                activeFetches -= 1;
+                assert invariant();
+            }
+            if (updated) {
+                rerouteService.get().reroute("snapshot shard size updated", Priority.HIGH, REROUTE_LISTENER);
+            }
+        }
+
+        @Override
+        public void onFailure(Exception e) {
+            logger.warn(() -> new ParameterizedMessage("failed to retrieve shard size for {}", snapshotShard), e);
+            synchronized (mutex) {
+                if (isMaster) {
+                    final boolean added = failedSnapshotShards.add(snapshotShard);
+                    assert added : "snapshot shard size already failed for " + snapshotShard;
+                }
+                if (removed == false) {
+                    unknownSnapshotShards.remove(snapshotShard);
+                }
+                activeFetches -= 1;
+                assert invariant();
+            }
+        }
+
+        @Override
+        public void onAfter() {
+            fetchNextSnapshotShard();
+        }
+    }
+
+    private void cleanUpKnownSnapshotShardSizes(Set<SnapshotShard> requiredSnapshotShards) {
+        assert Thread.holdsLock(mutex);
+        ImmutableOpenMap.Builder<SnapshotShard, Long> newSnapshotShardSizes = null;
+        for (ObjectCursor<SnapshotShard> shard : knownSnapshotShardSizes.keys()) {
+            if (requiredSnapshotShards.contains(shard.value) == false) {
+                if (newSnapshotShardSizes == null) {
+                    newSnapshotShardSizes = ImmutableOpenMap.builder(knownSnapshotShardSizes);
+                }
+                newSnapshotShardSizes.remove(shard.value);
+            }
+        }
+        if (newSnapshotShardSizes != null) {
+            knownSnapshotShardSizes = newSnapshotShardSizes.build();
+        }
+    }
+
+    private boolean invariant() {
+        assert Thread.holdsLock(mutex);
+        assert activeFetches >= 0 : "active fetches should be greater than or equal to zero but got: " + activeFetches;
+        assert activeFetches <= maxConcurrentFetches : activeFetches + " <= " + maxConcurrentFetches;
+        for (ObjectCursor<SnapshotShard> cursor : knownSnapshotShardSizes.keys()) {
+            assert unknownSnapshotShards.contains(cursor.value) == false : "cannot be known and unknown at same time: " + cursor.value;
+            assert failedSnapshotShards.contains(cursor.value) == false : "cannot be known and failed at same time: " + cursor.value;
+        }
+        for (SnapshotShard shard : unknownSnapshotShards) {
+            assert knownSnapshotShardSizes.keys().contains(shard) == false : "cannot be unknown and known at same time: " + shard;
+            assert failedSnapshotShards.contains(shard) == false : "cannot be unknown and failed at same time: " + shard;
+        }
+        for (SnapshotShard shard : failedSnapshotShards) {
+            assert knownSnapshotShardSizes.keys().contains(shard) == false : "cannot be failed and known at same time: " + shard;
+            assert unknownSnapshotShards.contains(shard) == false : "cannot be failed and unknown at same time: " + shard;
+        }
+        return true;
+    }
+
+    // used in tests
+    int numberOfUnknownSnapshotShardSizes() {
+        synchronized (mutex) {
+            return unknownSnapshotShards.size();
+        }
+    }
+
+    // used in tests
+    int numberOfFailedSnapshotShardSizes() {
+        synchronized (mutex) {
+            return failedSnapshotShards.size();
+        }
+    }
+
+    // used in tests
+    int numberOfKnownSnapshotShardSizes() {
+        return knownSnapshotShardSizes.size();
+    }
+
+    private static Set<SnapshotShard> listOfSnapshotShards(final ClusterState state) {
+        final Set<SnapshotShard> snapshotShards = new HashSet<>();
+        for (ShardRouting shardRouting : state.routingTable().shardsWithState(ShardRoutingState.UNASSIGNED)) {
+            if (shardRouting.primary() && shardRouting.recoverySource().getType() == RecoverySource.Type.SNAPSHOT) {
+                final RecoverySource.SnapshotRecoverySource snapshotRecoverySource =
+                    (RecoverySource.SnapshotRecoverySource) shardRouting.recoverySource();
+                final SnapshotShard snapshotShard = new SnapshotShard(snapshotRecoverySource.snapshot(),
+                    snapshotRecoverySource.index(), shardRouting.shardId());
+                snapshotShards.add(snapshotShard);
+            }
+        }
+        return Collections.unmodifiableSet(snapshotShards);
+    }
+
+    public static class SnapshotShard {
+
+        private final Snapshot snapshot;
+        private final IndexId index;
+        private final ShardId shardId;
+
+        public SnapshotShard(Snapshot snapshot, IndexId index, ShardId shardId) {
+            this.snapshot = snapshot;
+            this.index = index;
+            this.shardId = shardId;
+        }
+
+        public Snapshot snapshot() {
+            return snapshot;
+        }
+
+        public IndexId index() {
+            return index;
+        }
+
+        public ShardId shardId() {
+            return shardId;
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) {
+                return true;
+            }
+            if (o == null || getClass() != o.getClass()) {
+                return false;
+            }
+            final SnapshotShard that = (SnapshotShard) o;
+            return shardId.equals(that.shardId)
+                && snapshot.equals(that.snapshot)
+                && index.equals(that.index);
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(snapshot, index, shardId);
+        }
+
+        @Override
+        public String toString() {
+            return "[" +
+                "snapshot=" + snapshot +
+                ", index=" + index +
+                ", shard=" + shardId +
+                ']';
+        }
+    }
+}

+ 53 - 0
server/src/main/java/org/elasticsearch/snapshots/SnapshotShardSizeInfo.java

@@ -0,0 +1,53 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.snapshots;
+
+import org.elasticsearch.cluster.routing.RecoverySource;
+import org.elasticsearch.cluster.routing.ShardRouting;
+import org.elasticsearch.common.collect.ImmutableOpenMap;
+
+public class SnapshotShardSizeInfo {
+
+    public static final SnapshotShardSizeInfo EMPTY = new SnapshotShardSizeInfo(ImmutableOpenMap.of());
+
+    private final ImmutableOpenMap<InternalSnapshotsInfoService.SnapshotShard, Long> snapshotShardSizes;
+
+    public SnapshotShardSizeInfo(ImmutableOpenMap<InternalSnapshotsInfoService.SnapshotShard, Long> snapshotShardSizes) {
+        this.snapshotShardSizes = snapshotShardSizes;
+    }
+
+    public Long getShardSize(ShardRouting shardRouting) {
+        if (shardRouting.primary()
+            && shardRouting.active() == false
+            && shardRouting.recoverySource().getType() == RecoverySource.Type.SNAPSHOT) {
+            final RecoverySource.SnapshotRecoverySource snapshotRecoverySource =
+                (RecoverySource.SnapshotRecoverySource) shardRouting.recoverySource();
+            return snapshotShardSizes.get(new InternalSnapshotsInfoService.SnapshotShard(
+                snapshotRecoverySource.snapshot(), snapshotRecoverySource.index(), shardRouting.shardId()));
+        }
+        assert false : "Expected shard with snapshot recovery source but was " + shardRouting;
+        return null;
+    }
+
+    public long getShardSize(ShardRouting shardRouting, long fallback) {
+        final Long shardSize = getShardSize(shardRouting);
+        return shardSize == null ? fallback : shardSize;
+    }
+}

+ 25 - 0
server/src/main/java/org/elasticsearch/snapshots/SnapshotsInfoService.java

@@ -0,0 +1,25 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.snapshots;
+
+@FunctionalInterface
+public interface SnapshotsInfoService {
+    SnapshotShardSizeInfo snapshotShardSizes();
+}

+ 3 - 3
server/src/test/java/org/elasticsearch/action/admin/cluster/allocation/ClusterAllocationExplainActionTests.java

@@ -56,7 +56,7 @@ public class ClusterAllocationExplainActionTests extends ESTestCase {
         ClusterState clusterState = ClusterStateCreationUtils.state("idx", randomBoolean(), shardRoutingState);
         ClusterState clusterState = ClusterStateCreationUtils.state("idx", randomBoolean(), shardRoutingState);
         ShardRouting shard = clusterState.getRoutingTable().index("idx").shard(0).primaryShard();
         ShardRouting shard = clusterState.getRoutingTable().index("idx").shard(0).primaryShard();
         RoutingAllocation allocation = new RoutingAllocation(new AllocationDeciders(Collections.emptyList()),
         RoutingAllocation allocation = new RoutingAllocation(new AllocationDeciders(Collections.emptyList()),
-            clusterState.getRoutingNodes(), clusterState, null, System.nanoTime());
+            clusterState.getRoutingNodes(), clusterState, null, null, System.nanoTime());
         ClusterAllocationExplanation cae = TransportClusterAllocationExplainAction.explainShard(shard, allocation, null, randomBoolean(),
         ClusterAllocationExplanation cae = TransportClusterAllocationExplainAction.explainShard(shard, allocation, null, randomBoolean(),
             new AllocationService(null, new TestGatewayAllocator(), new ShardsAllocator() {
             new AllocationService(null, new TestGatewayAllocator(), new ShardsAllocator() {
                 @Override
                 @Override
@@ -72,7 +72,7 @@ public class ClusterAllocationExplainActionTests extends ESTestCase {
                         throw new UnsupportedOperationException("cannot explain");
                         throw new UnsupportedOperationException("cannot explain");
                     }
                     }
                 }
                 }
-            }, null));
+            }, null, null));
 
 
         assertEquals(shard.currentNodeId(), cae.getCurrentNode().getId());
         assertEquals(shard.currentNodeId(), cae.getCurrentNode().getId());
         assertFalse(cae.getShardAllocationDecision().isDecisionTaken());
         assertFalse(cae.getShardAllocationDecision().isDecisionTaken());
@@ -178,6 +178,6 @@ public class ClusterAllocationExplainActionTests extends ESTestCase {
     }
     }
 
 
     private static RoutingAllocation routingAllocation(ClusterState clusterState) {
     private static RoutingAllocation routingAllocation(ClusterState clusterState) {
-        return new RoutingAllocation(NOOP_DECIDERS, clusterState.getRoutingNodes(), clusterState, null, System.nanoTime());
+        return new RoutingAllocation(NOOP_DECIDERS, clusterState.getRoutingNodes(), clusterState, null, null, System.nanoTime());
     }
     }
 }
 }

+ 3 - 1
server/src/test/java/org/elasticsearch/action/admin/cluster/reroute/ClusterRerouteTests.java

@@ -41,6 +41,7 @@ import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.network.NetworkModule;
 import org.elasticsearch.common.network.NetworkModule;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.unit.TimeValue;
 import org.elasticsearch.common.unit.TimeValue;
+import org.elasticsearch.snapshots.EmptySnapshotsInfoService;
 import org.elasticsearch.test.gateway.TestGatewayAllocator;
 import org.elasticsearch.test.gateway.TestGatewayAllocator;
 
 
 import java.io.IOException;
 import java.io.IOException;
@@ -81,7 +82,8 @@ public class ClusterRerouteTests extends ESAllocationTestCase {
     public void testClusterStateUpdateTask() {
     public void testClusterStateUpdateTask() {
         AllocationService allocationService = new AllocationService(
         AllocationService allocationService = new AllocationService(
             new AllocationDeciders(Collections.singleton(new MaxRetryAllocationDecider())),
             new AllocationDeciders(Collections.singleton(new MaxRetryAllocationDecider())),
-            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE);
+            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE,
+            EmptySnapshotsInfoService.INSTANCE);
         ClusterState clusterState = createInitialClusterState(allocationService);
         ClusterState clusterState = createInitialClusterState(allocationService);
         ClusterRerouteRequest req = new ClusterRerouteRequest();
         ClusterRerouteRequest req = new ClusterRerouteRequest();
         req.dryRun(true);
         req.dryRun(true);

+ 9 - 4
server/src/test/java/org/elasticsearch/action/admin/indices/shrink/TransportResizeActionTests.java

@@ -40,6 +40,7 @@ import org.elasticsearch.cluster.routing.allocation.decider.AllocationDeciders;
 import org.elasticsearch.cluster.routing.allocation.decider.MaxRetryAllocationDecider;
 import org.elasticsearch.cluster.routing.allocation.decider.MaxRetryAllocationDecider;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.index.shard.DocsStats;
 import org.elasticsearch.index.shard.DocsStats;
+import org.elasticsearch.snapshots.EmptySnapshotsInfoService;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.gateway.TestGatewayAllocator;
 import org.elasticsearch.test.gateway.TestGatewayAllocator;
 
 
@@ -109,7 +110,8 @@ public class TransportResizeActionTests extends ESTestCase {
             .build();
             .build();
         AllocationService service = new AllocationService(new AllocationDeciders(
         AllocationService service = new AllocationService(new AllocationDeciders(
             Collections.singleton(new MaxRetryAllocationDecider())),
             Collections.singleton(new MaxRetryAllocationDecider())),
-            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE);
+            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE,
+            EmptySnapshotsInfoService.INSTANCE);
 
 
         RoutingTable routingTable = service.reroute(clusterState, "reroute").routingTable();
         RoutingTable routingTable = service.reroute(clusterState, "reroute").routingTable();
         clusterState = ClusterState.builder(clusterState).routingTable(routingTable).build();
         clusterState = ClusterState.builder(clusterState).routingTable(routingTable).build();
@@ -127,7 +129,8 @@ public class TransportResizeActionTests extends ESTestCase {
             .build();
             .build();
         AllocationService service = new AllocationService(new AllocationDeciders(
         AllocationService service = new AllocationService(new AllocationDeciders(
             Collections.singleton(new MaxRetryAllocationDecider())),
             Collections.singleton(new MaxRetryAllocationDecider())),
-            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE);
+            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE,
+            EmptySnapshotsInfoService.INSTANCE);
 
 
         RoutingTable routingTable = service.reroute(clusterState, "reroute").routingTable();
         RoutingTable routingTable = service.reroute(clusterState, "reroute").routingTable();
         clusterState = ClusterState.builder(clusterState).routingTable(routingTable).build();
         clusterState = ClusterState.builder(clusterState).routingTable(routingTable).build();
@@ -156,7 +159,8 @@ public class TransportResizeActionTests extends ESTestCase {
             .build();
             .build();
         AllocationService service = new AllocationService(new AllocationDeciders(
         AllocationService service = new AllocationService(new AllocationDeciders(
             Collections.singleton(new MaxRetryAllocationDecider())),
             Collections.singleton(new MaxRetryAllocationDecider())),
-            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE);
+            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE,
+            EmptySnapshotsInfoService.INSTANCE);
 
 
         RoutingTable routingTable = service.reroute(clusterState, "reroute").routingTable();
         RoutingTable routingTable = service.reroute(clusterState, "reroute").routingTable();
         clusterState = ClusterState.builder(clusterState).routingTable(routingTable).build();
         clusterState = ClusterState.builder(clusterState).routingTable(routingTable).build();
@@ -190,7 +194,8 @@ public class TransportResizeActionTests extends ESTestCase {
             .build();
             .build();
         AllocationService service = new AllocationService(new AllocationDeciders(
         AllocationService service = new AllocationService(new AllocationDeciders(
             Collections.singleton(new MaxRetryAllocationDecider())),
             Collections.singleton(new MaxRetryAllocationDecider())),
-            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE);
+            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE,
+            EmptySnapshotsInfoService.INSTANCE);
 
 
         RoutingTable routingTable = service.reroute(clusterState, "reroute").routingTable();
         RoutingTable routingTable = service.reroute(clusterState, "reroute").routingTable();
         clusterState = ClusterState.builder(clusterState).routingTable(routingTable).build();
         clusterState = ClusterState.builder(clusterState).routingTable(routingTable).build();

+ 6 - 6
server/src/test/java/org/elasticsearch/cluster/ClusterModuleTests.java

@@ -119,7 +119,7 @@ public class ClusterModuleTests extends ModuleTestCase {
                     public Collection<AllocationDecider> createAllocationDeciders(Settings settings, ClusterSettings clusterSettings) {
                     public Collection<AllocationDecider> createAllocationDeciders(Settings settings, ClusterSettings clusterSettings) {
                         return Collections.singletonList(new EnableAllocationDecider(settings, clusterSettings));
                         return Collections.singletonList(new EnableAllocationDecider(settings, clusterSettings));
                     }
                     }
-                }), clusterInfoService));
+                }), clusterInfoService, null));
         assertEquals(e.getMessage(),
         assertEquals(e.getMessage(),
             "Cannot specify allocation decider [" + EnableAllocationDecider.class.getName() + "] twice");
             "Cannot specify allocation decider [" + EnableAllocationDecider.class.getName() + "] twice");
     }
     }
@@ -131,7 +131,7 @@ public class ClusterModuleTests extends ModuleTestCase {
                 public Collection<AllocationDecider> createAllocationDeciders(Settings settings, ClusterSettings clusterSettings) {
                 public Collection<AllocationDecider> createAllocationDeciders(Settings settings, ClusterSettings clusterSettings) {
                     return Collections.singletonList(new FakeAllocationDecider());
                     return Collections.singletonList(new FakeAllocationDecider());
                 }
                 }
-            }), clusterInfoService);
+            }), clusterInfoService, null);
         assertTrue(module.deciderList.stream().anyMatch(d -> d.getClass().equals(FakeAllocationDecider.class)));
         assertTrue(module.deciderList.stream().anyMatch(d -> d.getClass().equals(FakeAllocationDecider.class)));
     }
     }
 
 
@@ -143,7 +143,7 @@ public class ClusterModuleTests extends ModuleTestCase {
                     return Collections.singletonMap(name, supplier);
                     return Collections.singletonMap(name, supplier);
                 }
                 }
             }
             }
-        ), clusterInfoService);
+        ), clusterInfoService, null);
     }
     }
 
 
     public void testRegisterShardsAllocator() {
     public void testRegisterShardsAllocator() {
@@ -161,7 +161,7 @@ public class ClusterModuleTests extends ModuleTestCase {
     public void testUnknownShardsAllocator() {
     public void testUnknownShardsAllocator() {
         Settings settings = Settings.builder().put(ClusterModule.SHARDS_ALLOCATOR_TYPE_SETTING.getKey(), "dne").build();
         Settings settings = Settings.builder().put(ClusterModule.SHARDS_ALLOCATOR_TYPE_SETTING.getKey(), "dne").build();
         IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () ->
         IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () ->
-            new ClusterModule(settings, clusterService, Collections.emptyList(), clusterInfoService));
+            new ClusterModule(settings, clusterService, Collections.emptyList(), clusterInfoService, null));
         assertEquals("Unknown ShardsAllocator [dne]", e.getMessage());
         assertEquals("Unknown ShardsAllocator [dne]", e.getMessage());
     }
     }
 
 
@@ -204,13 +204,13 @@ public class ClusterModuleTests extends ModuleTestCase {
 
 
     public void testRejectsReservedExistingShardsAllocatorName() {
     public void testRejectsReservedExistingShardsAllocatorName() {
         final ClusterModule clusterModule = new ClusterModule(Settings.EMPTY, clusterService,
         final ClusterModule clusterModule = new ClusterModule(Settings.EMPTY, clusterService,
-            List.of(existingShardsAllocatorPlugin(GatewayAllocator.ALLOCATOR_NAME)), clusterInfoService);
+            List.of(existingShardsAllocatorPlugin(GatewayAllocator.ALLOCATOR_NAME)), clusterInfoService, null);
         expectThrows(IllegalArgumentException.class, () -> clusterModule.setExistingShardsAllocators(new TestGatewayAllocator()));
         expectThrows(IllegalArgumentException.class, () -> clusterModule.setExistingShardsAllocators(new TestGatewayAllocator()));
     }
     }
 
 
     public void testRejectsDuplicateExistingShardsAllocatorName() {
     public void testRejectsDuplicateExistingShardsAllocatorName() {
         final ClusterModule clusterModule = new ClusterModule(Settings.EMPTY, clusterService,
         final ClusterModule clusterModule = new ClusterModule(Settings.EMPTY, clusterService,
-            List.of(existingShardsAllocatorPlugin("duplicate"), existingShardsAllocatorPlugin("duplicate")), clusterInfoService);
+            List.of(existingShardsAllocatorPlugin("duplicate"), existingShardsAllocatorPlugin("duplicate")), clusterInfoService, null);
         expectThrows(IllegalArgumentException.class, () -> clusterModule.setExistingShardsAllocators(new TestGatewayAllocator()));
         expectThrows(IllegalArgumentException.class, () -> clusterModule.setExistingShardsAllocators(new TestGatewayAllocator()));
     }
     }
 
 

+ 1 - 1
server/src/test/java/org/elasticsearch/cluster/health/ClusterStateHealthTests.java

@@ -144,7 +144,7 @@ public class ClusterStateHealthTests extends ESTestCase {
 
 
         TransportClusterHealthAction action = new TransportClusterHealthAction(transportService,
         TransportClusterHealthAction action = new TransportClusterHealthAction(transportService,
             clusterService, threadPool, new ActionFilters(new HashSet<>()), indexNameExpressionResolver,
             clusterService, threadPool, new ActionFilters(new HashSet<>()), indexNameExpressionResolver,
-            new AllocationService(null, new TestGatewayAllocator(), null, null));
+            new AllocationService(null, new TestGatewayAllocator(), null, null, null));
         PlainActionFuture<ClusterHealthResponse> listener = new PlainActionFuture<>();
         PlainActionFuture<ClusterHealthResponse> listener = new PlainActionFuture<>();
         ActionTestUtils.execute(action, null, new ClusterHealthRequest().waitForGreenStatus(), listener);
         ActionTestUtils.execute(action, null, new ClusterHealthRequest().waitForGreenStatus(), listener);
 
 

+ 7 - 3
server/src/test/java/org/elasticsearch/cluster/metadata/MetadataCreateIndexServiceTests.java

@@ -59,6 +59,7 @@ import org.elasticsearch.indices.InvalidIndexNameException;
 import org.elasticsearch.indices.ShardLimitValidator;
 import org.elasticsearch.indices.ShardLimitValidator;
 import org.elasticsearch.indices.SystemIndexDescriptor;
 import org.elasticsearch.indices.SystemIndexDescriptor;
 import org.elasticsearch.indices.SystemIndices;
 import org.elasticsearch.indices.SystemIndices;
+import org.elasticsearch.snapshots.EmptySnapshotsInfoService;
 import org.elasticsearch.test.ClusterServiceUtils;
 import org.elasticsearch.test.ClusterServiceUtils;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.VersionUtils;
 import org.elasticsearch.test.VersionUtils;
@@ -210,7 +211,8 @@ public class MetadataCreateIndexServiceTests extends ESTestCase {
             .build();
             .build();
         AllocationService service = new AllocationService(new AllocationDeciders(
         AllocationService service = new AllocationService(new AllocationDeciders(
             Collections.singleton(new MaxRetryAllocationDecider())),
             Collections.singleton(new MaxRetryAllocationDecider())),
-            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE);
+            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE,
+            EmptySnapshotsInfoService.INSTANCE);
 
 
         RoutingTable routingTable = service.reroute(clusterState, "reroute").routingTable();
         RoutingTable routingTable = service.reroute(clusterState, "reroute").routingTable();
         clusterState = ClusterState.builder(clusterState).routingTable(routingTable).build();
         clusterState = ClusterState.builder(clusterState).routingTable(routingTable).build();
@@ -272,7 +274,8 @@ public class MetadataCreateIndexServiceTests extends ESTestCase {
             .nodes(DiscoveryNodes.builder().add(newNode("node1"))).build();
             .nodes(DiscoveryNodes.builder().add(newNode("node1"))).build();
         AllocationService service = new AllocationService(new AllocationDeciders(
         AllocationService service = new AllocationService(new AllocationDeciders(
             Collections.singleton(new MaxRetryAllocationDecider())),
             Collections.singleton(new MaxRetryAllocationDecider())),
-            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE);
+            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE,
+            EmptySnapshotsInfoService.INSTANCE);
 
 
         RoutingTable routingTable = service.reroute(clusterState, "reroute").routingTable();
         RoutingTable routingTable = service.reroute(clusterState, "reroute").routingTable();
         clusterState = ClusterState.builder(clusterState).routingTable(routingTable).build();
         clusterState = ClusterState.builder(clusterState).routingTable(routingTable).build();
@@ -408,7 +411,8 @@ public class MetadataCreateIndexServiceTests extends ESTestCase {
                 new AllocationDeciders(Collections.singleton(new MaxRetryAllocationDecider())),
                 new AllocationDeciders(Collections.singleton(new MaxRetryAllocationDecider())),
                 new TestGatewayAllocator(),
                 new TestGatewayAllocator(),
                 new BalancedShardsAllocator(Settings.EMPTY),
                 new BalancedShardsAllocator(Settings.EMPTY),
-                EmptyClusterInfoService.INSTANCE);
+                EmptyClusterInfoService.INSTANCE,
+                EmptySnapshotsInfoService.INSTANCE);
 
 
         final RoutingTable initialRoutingTable = service.reroute(initialClusterState, "reroute").routingTable();
         final RoutingTable initialRoutingTable = service.reroute(initialClusterState, "reroute").routingTable();
         final ClusterState routingTableClusterState = ClusterState.builder(initialClusterState).routingTable(initialRoutingTable).build();
         final ClusterState routingTableClusterState = ClusterState.builder(initialClusterState).routingTable(initialRoutingTable).build();

+ 3 - 2
server/src/test/java/org/elasticsearch/cluster/routing/allocation/AllocationCommandsTests.java

@@ -59,6 +59,7 @@ import org.elasticsearch.index.Index;
 import org.elasticsearch.index.IndexNotFoundException;
 import org.elasticsearch.index.IndexNotFoundException;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.index.shard.ShardNotFoundException;
 import org.elasticsearch.index.shard.ShardNotFoundException;
+import org.elasticsearch.snapshots.SnapshotShardSizeInfo;
 
 
 import java.util.Collections;
 import java.util.Collections;
 import java.util.HashSet;
 import java.util.HashSet;
@@ -633,7 +634,7 @@ public class AllocationCommandsTests extends ESAllocationTestCase {
         Index index = clusterState.getMetadata().index("test").getIndex();
         Index index = clusterState.getMetadata().index("test").getIndex();
         MoveAllocationCommand command = new MoveAllocationCommand(index.getName(), 0, "node1", "node2");
         MoveAllocationCommand command = new MoveAllocationCommand(index.getName(), 0, "node1", "node2");
         RoutingAllocation routingAllocation = new RoutingAllocation(new AllocationDeciders(Collections.emptyList()),
         RoutingAllocation routingAllocation = new RoutingAllocation(new AllocationDeciders(Collections.emptyList()),
-            new RoutingNodes(clusterState, false), clusterState, ClusterInfo.EMPTY, System.nanoTime());
+            new RoutingNodes(clusterState, false), clusterState, ClusterInfo.EMPTY, SnapshotShardSizeInfo.EMPTY, System.nanoTime());
         logger.info("--> executing move allocation command to non-data node");
         logger.info("--> executing move allocation command to non-data node");
         IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> command.execute(routingAllocation, false));
         IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> command.execute(routingAllocation, false));
         assertEquals("[move_allocation] can't move [test][0] from " + node1 + " to " +
         assertEquals("[move_allocation] can't move [test][0] from " + node1 + " to " +
@@ -671,7 +672,7 @@ public class AllocationCommandsTests extends ESAllocationTestCase {
         Index index = clusterState.getMetadata().index("test").getIndex();
         Index index = clusterState.getMetadata().index("test").getIndex();
         MoveAllocationCommand command = new MoveAllocationCommand(index.getName(), 0, "node2", "node1");
         MoveAllocationCommand command = new MoveAllocationCommand(index.getName(), 0, "node2", "node1");
         RoutingAllocation routingAllocation = new RoutingAllocation(new AllocationDeciders(Collections.emptyList()),
         RoutingAllocation routingAllocation = new RoutingAllocation(new AllocationDeciders(Collections.emptyList()),
-            new RoutingNodes(clusterState, false), clusterState, ClusterInfo.EMPTY, System.nanoTime());
+            new RoutingNodes(clusterState, false), clusterState, ClusterInfo.EMPTY, SnapshotShardSizeInfo.EMPTY, System.nanoTime());
         logger.info("--> executing move allocation command from non-data node");
         logger.info("--> executing move allocation command from non-data node");
         IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> command.execute(routingAllocation, false));
         IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> command.execute(routingAllocation, false));
         assertEquals("[move_allocation] can't move [test][0] from " + node2 + " to " + node1 +
         assertEquals("[move_allocation] can't move [test][0] from " + node2 + " to " + node1 +

+ 4 - 3
server/src/test/java/org/elasticsearch/cluster/routing/allocation/AllocationServiceTests.java

@@ -42,6 +42,7 @@ import org.elasticsearch.cluster.routing.allocation.decider.ThrottlingAllocation
 import org.elasticsearch.common.settings.ClusterSettings;
 import org.elasticsearch.common.settings.ClusterSettings;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.gateway.GatewayAllocator;
 import org.elasticsearch.gateway.GatewayAllocator;
+import org.elasticsearch.snapshots.EmptySnapshotsInfoService;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.gateway.TestGatewayAllocator;
 import org.elasticsearch.test.gateway.TestGatewayAllocator;
 
 
@@ -135,7 +136,7 @@ public class AllocationServiceTests extends ESTestCase {
                 public ShardAllocationDecision decideShardAllocation(ShardRouting shard, RoutingAllocation allocation) {
                 public ShardAllocationDecision decideShardAllocation(ShardRouting shard, RoutingAllocation allocation) {
                     return ShardAllocationDecision.NOT_TAKEN;
                     return ShardAllocationDecision.NOT_TAKEN;
                 }
                 }
-            }, new EmptyClusterInfoService());
+            }, new EmptyClusterInfoService(), EmptySnapshotsInfoService.INSTANCE);
 
 
         final String unrealisticAllocatorName = "unrealistic";
         final String unrealisticAllocatorName = "unrealistic";
         final Map<String, ExistingShardsAllocator> allocatorMap = new HashMap<>();
         final Map<String, ExistingShardsAllocator> allocatorMap = new HashMap<>();
@@ -222,7 +223,7 @@ public class AllocationServiceTests extends ESTestCase {
     }
     }
 
 
     public void testExplainsNonAllocationOfShardWithUnknownAllocator() {
     public void testExplainsNonAllocationOfShardWithUnknownAllocator() {
-        final AllocationService allocationService = new AllocationService(null, null, null);
+        final AllocationService allocationService = new AllocationService(null, null, null, null);
         allocationService.setExistingShardsAllocators(
         allocationService.setExistingShardsAllocators(
             Collections.singletonMap(GatewayAllocator.ALLOCATOR_NAME, new TestGatewayAllocator()));
             Collections.singletonMap(GatewayAllocator.ALLOCATOR_NAME, new TestGatewayAllocator()));
 
 
@@ -242,7 +243,7 @@ public class AllocationServiceTests extends ESTestCase {
             .build();
             .build();
 
 
         final RoutingAllocation allocation = new RoutingAllocation(new AllocationDeciders(Collections.emptyList()),
         final RoutingAllocation allocation = new RoutingAllocation(new AllocationDeciders(Collections.emptyList()),
-            clusterState.getRoutingNodes(), clusterState, ClusterInfo.EMPTY, 0L);
+            clusterState.getRoutingNodes(), clusterState, ClusterInfo.EMPTY, null,0L);
         allocation.setDebugMode(randomBoolean() ? RoutingAllocation.DebugMode.ON : RoutingAllocation.DebugMode.EXCLUDE_YES_DECISIONS);
         allocation.setDebugMode(randomBoolean() ? RoutingAllocation.DebugMode.ON : RoutingAllocation.DebugMode.EXCLUDE_YES_DECISIONS);
 
 
         final ShardAllocationDecision shardAllocationDecision
         final ShardAllocationDecision shardAllocationDecision

+ 2 - 1
server/src/test/java/org/elasticsearch/cluster/routing/allocation/BalanceConfigurationTests.java

@@ -41,6 +41,7 @@ import org.elasticsearch.cluster.routing.allocation.allocator.ShardsAllocator;
 import org.elasticsearch.cluster.routing.allocation.decider.ClusterRebalanceAllocationDecider;
 import org.elasticsearch.cluster.routing.allocation.decider.ClusterRebalanceAllocationDecider;
 import org.elasticsearch.common.settings.ClusterSettings;
 import org.elasticsearch.common.settings.ClusterSettings;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.snapshots.EmptySnapshotsInfoService;
 import org.elasticsearch.test.gateway.TestGatewayAllocator;
 import org.elasticsearch.test.gateway.TestGatewayAllocator;
 import org.hamcrest.Matchers;
 import org.hamcrest.Matchers;
 
 
@@ -347,7 +348,7 @@ public class BalanceConfigurationTests extends ESAllocationTestCase {
             public ShardAllocationDecision decideShardAllocation(ShardRouting shard, RoutingAllocation allocation) {
             public ShardAllocationDecision decideShardAllocation(ShardRouting shard, RoutingAllocation allocation) {
                 throw new UnsupportedOperationException("explain not supported");
                 throw new UnsupportedOperationException("explain not supported");
             }
             }
-        }, EmptyClusterInfoService.INSTANCE);
+        }, EmptyClusterInfoService.INSTANCE, EmptySnapshotsInfoService.INSTANCE);
         Metadata.Builder metadataBuilder = Metadata.builder();
         Metadata.Builder metadataBuilder = Metadata.builder();
         RoutingTable.Builder routingTableBuilder = RoutingTable.builder();
         RoutingTable.Builder routingTableBuilder = RoutingTable.builder();
         IndexMetadata.Builder indexMeta = IndexMetadata.builder("test")
         IndexMetadata.Builder indexMeta = IndexMetadata.builder("test")

+ 2 - 1
server/src/test/java/org/elasticsearch/cluster/routing/allocation/BalancedSingleShardTests.java

@@ -37,6 +37,7 @@ import org.elasticsearch.cluster.routing.allocation.decider.Decision;
 import org.elasticsearch.cluster.routing.allocation.decider.Decision.Type;
 import org.elasticsearch.cluster.routing.allocation.decider.Decision.Type;
 import org.elasticsearch.common.collect.Tuple;
 import org.elasticsearch.common.collect.Tuple;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.snapshots.SnapshotShardSizeInfo;
 
 
 import java.util.Arrays;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.Collections;
@@ -368,7 +369,7 @@ public class BalancedSingleShardTests extends ESAllocationTestCase {
 
 
     private RoutingAllocation newRoutingAllocation(AllocationDeciders deciders, ClusterState state) {
     private RoutingAllocation newRoutingAllocation(AllocationDeciders deciders, ClusterState state) {
         RoutingAllocation allocation = new RoutingAllocation(
         RoutingAllocation allocation = new RoutingAllocation(
-            deciders, new RoutingNodes(state, false), state, ClusterInfo.EMPTY, System.nanoTime());
+            deciders, new RoutingNodes(state, false), state, ClusterInfo.EMPTY, SnapshotShardSizeInfo.EMPTY, System.nanoTime());
         allocation.debugDecision(true);
         allocation.debugDecision(true);
         return allocation;
         return allocation;
     }
     }

+ 3 - 1
server/src/test/java/org/elasticsearch/cluster/routing/allocation/DecisionsImpactOnClusterHealthTests.java

@@ -39,6 +39,7 @@ import org.elasticsearch.cluster.routing.allocation.decider.AllocationDeciders;
 import org.elasticsearch.cluster.routing.allocation.decider.Decision;
 import org.elasticsearch.cluster.routing.allocation.decider.Decision;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.env.Environment;
 import org.elasticsearch.env.Environment;
+import org.elasticsearch.snapshots.EmptySnapshotsInfoService;
 import org.elasticsearch.test.gateway.TestGatewayAllocator;
 import org.elasticsearch.test.gateway.TestGatewayAllocator;
 
 
 import java.io.IOException;
 import java.io.IOException;
@@ -155,7 +156,8 @@ public class DecisionsImpactOnClusterHealthTests extends ESAllocationTestCase {
         return new AllocationService(new AllocationDeciders(deciders),
         return new AllocationService(new AllocationDeciders(deciders),
                                      new TestGatewayAllocator(),
                                      new TestGatewayAllocator(),
                                      new BalancedShardsAllocator(settings),
                                      new BalancedShardsAllocator(settings),
-                                     EmptyClusterInfoService.INSTANCE);
+                                     EmptyClusterInfoService.INSTANCE,
+                                     EmptySnapshotsInfoService.INSTANCE);
     }
     }
 
 
 }
 }

+ 7 - 5
server/src/test/java/org/elasticsearch/cluster/routing/allocation/MaxRetryAllocationDeciderTests.java

@@ -35,6 +35,7 @@ import org.elasticsearch.cluster.routing.allocation.decider.AllocationDeciders;
 import org.elasticsearch.cluster.routing.allocation.decider.Decision;
 import org.elasticsearch.cluster.routing.allocation.decider.Decision;
 import org.elasticsearch.cluster.routing.allocation.decider.MaxRetryAllocationDecider;
 import org.elasticsearch.cluster.routing.allocation.decider.MaxRetryAllocationDecider;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.snapshots.EmptySnapshotsInfoService;
 import org.elasticsearch.test.gateway.TestGatewayAllocator;
 import org.elasticsearch.test.gateway.TestGatewayAllocator;
 
 
 import java.util.Collections;
 import java.util.Collections;
@@ -56,7 +57,8 @@ public class MaxRetryAllocationDeciderTests extends ESAllocationTestCase {
         super.setUp();
         super.setUp();
         strategy = new AllocationService(new AllocationDeciders(
         strategy = new AllocationService(new AllocationDeciders(
             Collections.singleton(new MaxRetryAllocationDecider())),
             Collections.singleton(new MaxRetryAllocationDecider())),
-            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE);
+            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE,
+            EmptySnapshotsInfoService.INSTANCE);
     }
     }
 
 
     private ClusterState createInitialClusterState() {
     private ClusterState createInitialClusterState() {
@@ -176,7 +178,7 @@ public class MaxRetryAllocationDeciderTests extends ESAllocationTestCase {
             assertThat(unassignedPrimary.unassignedInfo().getMessage(), containsString("boom" + i));
             assertThat(unassignedPrimary.unassignedInfo().getMessage(), containsString("boom" + i));
             // MaxRetryAllocationDecider#canForceAllocatePrimary should return YES decisions because canAllocate returns YES here
             // MaxRetryAllocationDecider#canForceAllocatePrimary should return YES decisions because canAllocate returns YES here
             assertEquals(Decision.YES, new MaxRetryAllocationDecider().canForceAllocatePrimary(
             assertEquals(Decision.YES, new MaxRetryAllocationDecider().canForceAllocatePrimary(
-                unassignedPrimary, null, new RoutingAllocation(null, null, clusterState, null, 0)));
+                unassignedPrimary, null, new RoutingAllocation(null, null, clusterState, null, null,0)));
         }
         }
         // now we go and check that we are actually stick to unassigned on the next failure
         // now we go and check that we are actually stick to unassigned on the next failure
         {
         {
@@ -194,7 +196,7 @@ public class MaxRetryAllocationDeciderTests extends ESAllocationTestCase {
             assertThat(unassignedPrimary.unassignedInfo().getMessage(), containsString("boom"));
             assertThat(unassignedPrimary.unassignedInfo().getMessage(), containsString("boom"));
             // MaxRetryAllocationDecider#canForceAllocatePrimary should return a NO decision because canAllocate returns NO here
             // MaxRetryAllocationDecider#canForceAllocatePrimary should return a NO decision because canAllocate returns NO here
             assertEquals(Decision.NO, new MaxRetryAllocationDecider().canForceAllocatePrimary(
             assertEquals(Decision.NO, new MaxRetryAllocationDecider().canForceAllocatePrimary(
-                unassignedPrimary, null, new RoutingAllocation(null, null, clusterState, null, 0)));
+                unassignedPrimary, null, new RoutingAllocation(null, null, clusterState, null, null,0)));
         }
         }
 
 
         // change the settings and ensure we can do another round of allocation for that index.
         // change the settings and ensure we can do another round of allocation for that index.
@@ -216,7 +218,7 @@ public class MaxRetryAllocationDeciderTests extends ESAllocationTestCase {
         assertThat(unassignedPrimary.unassignedInfo().getMessage(), containsString("boom"));
         assertThat(unassignedPrimary.unassignedInfo().getMessage(), containsString("boom"));
         // bumped up the max retry count, so canForceAllocatePrimary should return a YES decision
         // bumped up the max retry count, so canForceAllocatePrimary should return a YES decision
         assertEquals(Decision.YES, new MaxRetryAllocationDecider().canForceAllocatePrimary(
         assertEquals(Decision.YES, new MaxRetryAllocationDecider().canForceAllocatePrimary(
-            routingTable.index("idx").shard(0).shards().get(0), null, new RoutingAllocation(null, null, clusterState, null, 0)));
+            routingTable.index("idx").shard(0).shards().get(0), null, new RoutingAllocation(null, null, clusterState, null, null,0)));
 
 
         // now we start the shard
         // now we start the shard
         clusterState = startShardsAndReroute(strategy, clusterState, routingTable.index("idx").shard(0).shards().get(0));
         clusterState = startShardsAndReroute(strategy, clusterState, routingTable.index("idx").shard(0).shards().get(0));
@@ -242,7 +244,7 @@ public class MaxRetryAllocationDeciderTests extends ESAllocationTestCase {
         assertThat(unassignedPrimary.unassignedInfo().getMessage(), containsString("ZOOOMG"));
         assertThat(unassignedPrimary.unassignedInfo().getMessage(), containsString("ZOOOMG"));
         // Counter reset, so MaxRetryAllocationDecider#canForceAllocatePrimary should return a YES decision
         // Counter reset, so MaxRetryAllocationDecider#canForceAllocatePrimary should return a YES decision
         assertEquals(Decision.YES, new MaxRetryAllocationDecider().canForceAllocatePrimary(
         assertEquals(Decision.YES, new MaxRetryAllocationDecider().canForceAllocatePrimary(
-            unassignedPrimary, null, new RoutingAllocation(null, null, clusterState, null, 0)));
+            unassignedPrimary, null, new RoutingAllocation(null, null, clusterState, null, null,0)));
     }
     }
 
 
 }
 }

+ 25 - 10
server/src/test/java/org/elasticsearch/cluster/routing/allocation/NodeVersionAllocationDeciderTests.java

@@ -50,12 +50,17 @@ import org.elasticsearch.cluster.routing.allocation.decider.Decision;
 import org.elasticsearch.cluster.routing.allocation.decider.NodeVersionAllocationDecider;
 import org.elasticsearch.cluster.routing.allocation.decider.NodeVersionAllocationDecider;
 import org.elasticsearch.cluster.routing.allocation.decider.ReplicaAfterPrimaryActiveAllocationDecider;
 import org.elasticsearch.cluster.routing.allocation.decider.ReplicaAfterPrimaryActiveAllocationDecider;
 import org.elasticsearch.common.UUIDs;
 import org.elasticsearch.common.UUIDs;
+import org.elasticsearch.common.collect.ImmutableOpenMap;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.util.set.Sets;
 import org.elasticsearch.common.util.set.Sets;
+import org.elasticsearch.index.Index;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.repositories.IndexId;
 import org.elasticsearch.repositories.IndexId;
+import org.elasticsearch.snapshots.EmptySnapshotsInfoService;
+import org.elasticsearch.snapshots.InternalSnapshotsInfoService;
 import org.elasticsearch.snapshots.Snapshot;
 import org.elasticsearch.snapshots.Snapshot;
 import org.elasticsearch.snapshots.SnapshotId;
 import org.elasticsearch.snapshots.SnapshotId;
+import org.elasticsearch.snapshots.SnapshotShardSizeInfo;
 import org.elasticsearch.test.VersionUtils;
 import org.elasticsearch.test.VersionUtils;
 import org.elasticsearch.test.gateway.TestGatewayAllocator;
 import org.elasticsearch.test.gateway.TestGatewayAllocator;
 
 
@@ -338,7 +343,8 @@ public class NodeVersionAllocationDeciderTests extends ESAllocationTestCase {
             Collections.singleton(new NodeVersionAllocationDecider()));
             Collections.singleton(new NodeVersionAllocationDecider()));
         AllocationService strategy = new MockAllocationService(
         AllocationService strategy = new MockAllocationService(
             allocationDeciders,
             allocationDeciders,
-            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE);
+            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE,
+            EmptySnapshotsInfoService.INSTANCE);
         state = strategy.reroute(state, new AllocationCommands(), true, false).getClusterState();
         state = strategy.reroute(state, new AllocationCommands(), true, false).getClusterState();
         // the two indices must stay as is, the replicas cannot move to oldNode2 because versions don't match
         // the two indices must stay as is, the replicas cannot move to oldNode2 because versions don't match
         assertThat(state.routingTable().index(shard2.getIndex()).shardsWithState(ShardRoutingState.RELOCATING).size(), equalTo(0));
         assertThat(state.routingTable().index(shard2.getIndex()).shardsWithState(ShardRoutingState.RELOCATING).size(), equalTo(0));
@@ -353,7 +359,10 @@ public class NodeVersionAllocationDeciderTests extends ESAllocationTestCase {
         final DiscoveryNode oldNode2 = new DiscoveryNode("oldNode2", buildNewFakeTransportAddress(), emptyMap(),
         final DiscoveryNode oldNode2 = new DiscoveryNode("oldNode2", buildNewFakeTransportAddress(), emptyMap(),
                 MASTER_DATA_ROLES, VersionUtils.getPreviousVersion());
                 MASTER_DATA_ROLES, VersionUtils.getPreviousVersion());
 
 
-        int numberOfShards = randomIntBetween(1, 3);
+        final Snapshot snapshot = new Snapshot("rep1", new SnapshotId("snp1", UUIDs.randomBase64UUID()));
+        final IndexId indexId = new IndexId("test", UUIDs.randomBase64UUID(random()));
+
+        final int numberOfShards = randomIntBetween(1, 3);
         final IndexMetadata.Builder indexMetadata = IndexMetadata.builder("test").settings(settings(Version.CURRENT))
         final IndexMetadata.Builder indexMetadata = IndexMetadata.builder("test").settings(settings(Version.CURRENT))
             .numberOfShards(numberOfShards).numberOfReplicas(randomIntBetween(0, 3));
             .numberOfShards(numberOfShards).numberOfReplicas(randomIntBetween(0, 3));
         for (int i = 0; i < numberOfShards; i++) {
         for (int i = 0; i < numberOfShards; i++) {
@@ -361,20 +370,26 @@ public class NodeVersionAllocationDeciderTests extends ESAllocationTestCase {
         }
         }
         Metadata metadata = Metadata.builder().put(indexMetadata).build();
         Metadata metadata = Metadata.builder().put(indexMetadata).build();
 
 
+        final ImmutableOpenMap.Builder<InternalSnapshotsInfoService.SnapshotShard, Long> snapshotShardSizes =
+            ImmutableOpenMap.builder(numberOfShards);
+        final Index index = metadata.index("test").getIndex();
+        for (int i = 0; i < numberOfShards; i++) {
+            final ShardId shardId = new ShardId(index, i);
+            snapshotShardSizes.put(new InternalSnapshotsInfoService.SnapshotShard(snapshot, indexId, shardId), randomNonNegativeLong());
+        }
+
         ClusterState state = ClusterState.builder(ClusterName.CLUSTER_NAME_SETTING.getDefault(Settings.EMPTY))
         ClusterState state = ClusterState.builder(ClusterName.CLUSTER_NAME_SETTING.getDefault(Settings.EMPTY))
             .metadata(metadata)
             .metadata(metadata)
             .routingTable(RoutingTable.builder().addAsRestore(metadata.index("test"),
             .routingTable(RoutingTable.builder().addAsRestore(metadata.index("test"),
-                new SnapshotRecoverySource(
-                    UUIDs.randomBase64UUID(),
-                    new Snapshot("rep1", new SnapshotId("snp1", UUIDs.randomBase64UUID())),
-                Version.CURRENT, new IndexId("test", UUIDs.randomBase64UUID(random())))).build())
+                new SnapshotRecoverySource(UUIDs.randomBase64UUID(), snapshot, Version.CURRENT, indexId)).build())
             .nodes(DiscoveryNodes.builder().add(newNode).add(oldNode1).add(oldNode2)).build();
             .nodes(DiscoveryNodes.builder().add(newNode).add(oldNode1).add(oldNode2)).build();
         AllocationDeciders allocationDeciders = new AllocationDeciders(Arrays.asList(
         AllocationDeciders allocationDeciders = new AllocationDeciders(Arrays.asList(
             new ReplicaAfterPrimaryActiveAllocationDecider(),
             new ReplicaAfterPrimaryActiveAllocationDecider(),
             new NodeVersionAllocationDecider()));
             new NodeVersionAllocationDecider()));
         AllocationService strategy = new MockAllocationService(
         AllocationService strategy = new MockAllocationService(
             allocationDeciders,
             allocationDeciders,
-            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE);
+            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE,
+            () -> new SnapshotShardSizeInfo(snapshotShardSizes.build()));
         state = strategy.reroute(state, new AllocationCommands(), true, false).getClusterState();
         state = strategy.reroute(state, new AllocationCommands(), true, false).getClusterState();
 
 
         // Make sure that primary shards are only allocated on the new node
         // Make sure that primary shards are only allocated on the new node
@@ -463,7 +478,7 @@ public class NodeVersionAllocationDeciderTests extends ESAllocationTestCase {
         final ShardRouting replicaShard = clusterState.routingTable().shardRoutingTable(shardId).replicaShards().get(0);
         final ShardRouting replicaShard = clusterState.routingTable().shardRoutingTable(shardId).replicaShards().get(0);
 
 
         RoutingAllocation routingAllocation = new RoutingAllocation(null, clusterState.getRoutingNodes(), clusterState,
         RoutingAllocation routingAllocation = new RoutingAllocation(null, clusterState.getRoutingNodes(), clusterState,
-            null, 0);
+            null, null, 0);
         routingAllocation.debugDecision(true);
         routingAllocation.debugDecision(true);
 
 
         final NodeVersionAllocationDecider allocationDecider = new NodeVersionAllocationDecider();
         final NodeVersionAllocationDecider allocationDecider = new NodeVersionAllocationDecider();
@@ -508,7 +523,7 @@ public class NodeVersionAllocationDeciderTests extends ESAllocationTestCase {
         final ShardRouting startedPrimary = routingNodes.startShard(logger,
         final ShardRouting startedPrimary = routingNodes.startShard(logger,
             routingNodes.initializeShard(primaryShard, "newNode", null, 0,
             routingNodes.initializeShard(primaryShard, "newNode", null, 0,
             routingChangesObserver), routingChangesObserver);
             routingChangesObserver), routingChangesObserver);
-        routingAllocation = new RoutingAllocation(null, routingNodes, clusterState, null, 0);
+        routingAllocation = new RoutingAllocation(null, routingNodes, clusterState, null, null,0);
         routingAllocation.debugDecision(true);
         routingAllocation.debugDecision(true);
 
 
         decision = allocationDecider.canAllocate(replicaShard, oldNode, routingAllocation);
         decision = allocationDecider.canAllocate(replicaShard, oldNode, routingAllocation);
@@ -518,7 +533,7 @@ public class NodeVersionAllocationDeciderTests extends ESAllocationTestCase {
 
 
         routingNodes.startShard(logger, routingNodes.relocateShard(startedPrimary,
         routingNodes.startShard(logger, routingNodes.relocateShard(startedPrimary,
             "oldNode", 0, routingChangesObserver).v2(), routingChangesObserver);
             "oldNode", 0, routingChangesObserver).v2(), routingChangesObserver);
-        routingAllocation = new RoutingAllocation(null, routingNodes, clusterState, null, 0);
+        routingAllocation = new RoutingAllocation(null, routingNodes, clusterState, null, null,0);
         routingAllocation.debugDecision(true);
         routingAllocation.debugDecision(true);
 
 
         decision = allocationDecider.canAllocate(replicaShard, newNode, routingAllocation);
         decision = allocationDecider.canAllocate(replicaShard, newNode, routingAllocation);

+ 3 - 1
server/src/test/java/org/elasticsearch/cluster/routing/allocation/RandomAllocationDeciderTests.java

@@ -39,6 +39,7 @@ import org.elasticsearch.cluster.routing.allocation.decider.ReplicaAfterPrimaryA
 import org.elasticsearch.cluster.routing.allocation.decider.SameShardAllocationDecider;
 import org.elasticsearch.cluster.routing.allocation.decider.SameShardAllocationDecider;
 import org.elasticsearch.common.settings.ClusterSettings;
 import org.elasticsearch.common.settings.ClusterSettings;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.snapshots.EmptySnapshotsInfoService;
 import org.elasticsearch.test.gateway.TestGatewayAllocator;
 import org.elasticsearch.test.gateway.TestGatewayAllocator;
 import org.hamcrest.Matchers;
 import org.hamcrest.Matchers;
 
 
@@ -62,7 +63,8 @@ public class RandomAllocationDeciderTests extends ESAllocationTestCase {
                 new HashSet<>(Arrays.asList(new SameShardAllocationDecider(Settings.EMPTY,
                 new HashSet<>(Arrays.asList(new SameShardAllocationDecider(Settings.EMPTY,
                         new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS)),
                         new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS)),
                     new ReplicaAfterPrimaryActiveAllocationDecider(), randomAllocationDecider))),
                     new ReplicaAfterPrimaryActiveAllocationDecider(), randomAllocationDecider))),
-            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE);
+            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE,
+            EmptySnapshotsInfoService.INSTANCE);
         int indices = scaledRandomIntBetween(1, 20);
         int indices = scaledRandomIntBetween(1, 20);
         Builder metaBuilder = Metadata.builder();
         Builder metaBuilder = Metadata.builder();
         int maxNumReplicas = 1;
         int maxNumReplicas = 1;

+ 7 - 5
server/src/test/java/org/elasticsearch/cluster/routing/allocation/ResizeAllocationDeciderTests.java

@@ -38,6 +38,7 @@ import org.elasticsearch.cluster.routing.allocation.decider.ResizeAllocationDeci
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.index.Index;
 import org.elasticsearch.index.Index;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.index.shard.ShardId;
+import org.elasticsearch.snapshots.EmptySnapshotsInfoService;
 import org.elasticsearch.test.gateway.TestGatewayAllocator;
 import org.elasticsearch.test.gateway.TestGatewayAllocator;
 
 
 import java.util.Collections;
 import java.util.Collections;
@@ -56,7 +57,8 @@ public class ResizeAllocationDeciderTests extends ESAllocationTestCase {
         super.setUp();
         super.setUp();
         strategy = new AllocationService(new AllocationDeciders(
         strategy = new AllocationService(new AllocationDeciders(
             Collections.singleton(new ResizeAllocationDecider())),
             Collections.singleton(new ResizeAllocationDecider())),
-            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE);
+            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE,
+            EmptySnapshotsInfoService.INSTANCE);
     }
     }
 
 
     private ClusterState createInitialClusterState(boolean startShards) {
     private ClusterState createInitialClusterState(boolean startShards) {
@@ -104,7 +106,7 @@ public class ResizeAllocationDeciderTests extends ESAllocationTestCase {
     public void testNonResizeRouting() {
     public void testNonResizeRouting() {
         ClusterState clusterState = createInitialClusterState(true);
         ClusterState clusterState = createInitialClusterState(true);
         ResizeAllocationDecider resizeAllocationDecider = new ResizeAllocationDecider();
         ResizeAllocationDecider resizeAllocationDecider = new ResizeAllocationDecider();
-        RoutingAllocation routingAllocation = new RoutingAllocation(null, null, clusterState, null, 0);
+        RoutingAllocation routingAllocation = new RoutingAllocation(null, null, clusterState, null, null, 0);
         ShardRouting shardRouting = TestShardRouting.newShardRouting("non-resize", 0, null, true, ShardRoutingState.UNASSIGNED);
         ShardRouting shardRouting = TestShardRouting.newShardRouting("non-resize", 0, null, true, ShardRoutingState.UNASSIGNED);
         assertEquals(Decision.ALWAYS, resizeAllocationDecider.canAllocate(shardRouting, routingAllocation));
         assertEquals(Decision.ALWAYS, resizeAllocationDecider.canAllocate(shardRouting, routingAllocation));
         assertEquals(Decision.ALWAYS, resizeAllocationDecider.canAllocate(shardRouting, clusterState.getRoutingNodes().node("node1"),
         assertEquals(Decision.ALWAYS, resizeAllocationDecider.canAllocate(shardRouting, clusterState.getRoutingNodes().node("node1"),
@@ -128,7 +130,7 @@ public class ResizeAllocationDeciderTests extends ESAllocationTestCase {
         Index idx = clusterState.metadata().index("target").getIndex();
         Index idx = clusterState.metadata().index("target").getIndex();
 
 
         ResizeAllocationDecider resizeAllocationDecider = new ResizeAllocationDecider();
         ResizeAllocationDecider resizeAllocationDecider = new ResizeAllocationDecider();
-        RoutingAllocation routingAllocation = new RoutingAllocation(null, null, clusterState, null, 0);
+        RoutingAllocation routingAllocation = new RoutingAllocation(null, null, clusterState, null, null, 0);
         ShardRouting shardRouting = TestShardRouting.newShardRouting(new ShardId(idx, 0), null, true, ShardRoutingState.UNASSIGNED,
         ShardRouting shardRouting = TestShardRouting.newShardRouting(new ShardId(idx, 0), null, true, ShardRoutingState.UNASSIGNED,
             RecoverySource.LocalShardsRecoverySource.INSTANCE);
             RecoverySource.LocalShardsRecoverySource.INSTANCE);
         assertEquals(Decision.ALWAYS, resizeAllocationDecider.canAllocate(shardRouting, routingAllocation));
         assertEquals(Decision.ALWAYS, resizeAllocationDecider.canAllocate(shardRouting, routingAllocation));
@@ -156,7 +158,7 @@ public class ResizeAllocationDeciderTests extends ESAllocationTestCase {
 
 
 
 
         ResizeAllocationDecider resizeAllocationDecider = new ResizeAllocationDecider();
         ResizeAllocationDecider resizeAllocationDecider = new ResizeAllocationDecider();
-        RoutingAllocation routingAllocation = new RoutingAllocation(null, clusterState.getRoutingNodes(), clusterState, null, 0);
+        RoutingAllocation routingAllocation = new RoutingAllocation(null, clusterState.getRoutingNodes(), clusterState, null, null, 0);
         int shardId = randomIntBetween(0, 3);
         int shardId = randomIntBetween(0, 3);
         int sourceShardId = IndexMetadata.selectSplitShard(shardId, clusterState.metadata().index("source"), 4).id();
         int sourceShardId = IndexMetadata.selectSplitShard(shardId, clusterState.metadata().index("source"), 4).id();
         ShardRouting shardRouting = TestShardRouting.newShardRouting(new ShardId(idx, shardId), null, true, ShardRoutingState.UNASSIGNED,
         ShardRouting shardRouting = TestShardRouting.newShardRouting(new ShardId(idx, shardId), null, true, ShardRoutingState.UNASSIGNED,
@@ -196,7 +198,7 @@ public class ResizeAllocationDeciderTests extends ESAllocationTestCase {
 
 
 
 
         ResizeAllocationDecider resizeAllocationDecider = new ResizeAllocationDecider();
         ResizeAllocationDecider resizeAllocationDecider = new ResizeAllocationDecider();
-        RoutingAllocation routingAllocation = new RoutingAllocation(null, clusterState.getRoutingNodes(), clusterState, null, 0);
+        RoutingAllocation routingAllocation = new RoutingAllocation(null, clusterState.getRoutingNodes(), clusterState, null, null, 0);
         int shardId = randomIntBetween(0, 3);
         int shardId = randomIntBetween(0, 3);
         int sourceShardId = IndexMetadata.selectSplitShard(shardId, clusterState.metadata().index("source"), 4).id();
         int sourceShardId = IndexMetadata.selectSplitShard(shardId, clusterState.metadata().index("source"), 4).id();
         ShardRouting shardRouting = TestShardRouting.newShardRouting(new ShardId(idx, shardId), null, true, ShardRoutingState.UNASSIGNED,
         ShardRouting shardRouting = TestShardRouting.newShardRouting(new ShardId(idx, shardId), null, true, ShardRoutingState.UNASSIGNED,

+ 2 - 1
server/src/test/java/org/elasticsearch/cluster/routing/allocation/SameShardRoutingTests.java

@@ -43,6 +43,7 @@ import org.elasticsearch.cluster.routing.allocation.decider.SameShardAllocationD
 import org.elasticsearch.common.settings.ClusterSettings;
 import org.elasticsearch.common.settings.ClusterSettings;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.index.Index;
 import org.elasticsearch.index.Index;
+import org.elasticsearch.snapshots.SnapshotShardSizeInfo;
 
 
 import java.util.Collections;
 import java.util.Collections;
 
 
@@ -106,7 +107,7 @@ public class SameShardRoutingTests extends ESAllocationTestCase {
         ShardRouting primaryShard = clusterState.routingTable().index(index).shard(0).primaryShard();
         ShardRouting primaryShard = clusterState.routingTable().index(index).shard(0).primaryShard();
         RoutingNode routingNode = clusterState.getRoutingNodes().node(primaryShard.currentNodeId());
         RoutingNode routingNode = clusterState.getRoutingNodes().node(primaryShard.currentNodeId());
         RoutingAllocation routingAllocation = new RoutingAllocation(new AllocationDeciders(Collections.emptyList()),
         RoutingAllocation routingAllocation = new RoutingAllocation(new AllocationDeciders(Collections.emptyList()),
-            new RoutingNodes(clusterState, false), clusterState, ClusterInfo.EMPTY, System.nanoTime());
+            new RoutingNodes(clusterState, false), clusterState, ClusterInfo.EMPTY, SnapshotShardSizeInfo.EMPTY, System.nanoTime());
 
 
         // can't force allocate same shard copy to the same node
         // can't force allocate same shard copy to the same node
         ShardRouting newPrimary = TestShardRouting.newShardRouting(primaryShard.shardId(), null, true, ShardRoutingState.UNASSIGNED);
         ShardRouting newPrimary = TestShardRouting.newShardRouting(primaryShard.shardId(), null, true, ShardRoutingState.UNASSIGNED);

+ 45 - 12
server/src/test/java/org/elasticsearch/cluster/routing/allocation/ThrottlingAllocationTests.java

@@ -21,7 +21,6 @@ package org.elasticsearch.cluster.routing.allocation;
 
 
 import com.carrotsearch.hppc.IntHashSet;
 import com.carrotsearch.hppc.IntHashSet;
 import com.carrotsearch.hppc.cursors.ObjectCursor;
 import com.carrotsearch.hppc.cursors.ObjectCursor;
-
 import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 import org.apache.logging.log4j.Logger;
 import org.elasticsearch.Version;
 import org.elasticsearch.Version;
@@ -48,8 +47,11 @@ import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.index.Index;
 import org.elasticsearch.index.Index;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.repositories.IndexId;
 import org.elasticsearch.repositories.IndexId;
+import org.elasticsearch.snapshots.InternalSnapshotsInfoService;
 import org.elasticsearch.snapshots.Snapshot;
 import org.elasticsearch.snapshots.Snapshot;
 import org.elasticsearch.snapshots.SnapshotId;
 import org.elasticsearch.snapshots.SnapshotId;
+import org.elasticsearch.snapshots.SnapshotShardSizeInfo;
+import org.elasticsearch.snapshots.SnapshotsInfoService;
 import org.elasticsearch.test.gateway.TestGatewayAllocator;
 import org.elasticsearch.test.gateway.TestGatewayAllocator;
 
 
 import java.util.ArrayList;
 import java.util.ArrayList;
@@ -70,10 +72,11 @@ public class ThrottlingAllocationTests extends ESAllocationTestCase {
     public void testPrimaryRecoveryThrottling() {
     public void testPrimaryRecoveryThrottling() {
 
 
         TestGatewayAllocator gatewayAllocator = new TestGatewayAllocator();
         TestGatewayAllocator gatewayAllocator = new TestGatewayAllocator();
+        TestSnapshotsInfoService snapshotsInfoService = new TestSnapshotsInfoService();
         AllocationService strategy = createAllocationService(Settings.builder()
         AllocationService strategy = createAllocationService(Settings.builder()
                 .put("cluster.routing.allocation.node_concurrent_recoveries", 3)
                 .put("cluster.routing.allocation.node_concurrent_recoveries", 3)
                 .put("cluster.routing.allocation.node_initial_primaries_recoveries", 3)
                 .put("cluster.routing.allocation.node_initial_primaries_recoveries", 3)
-                .build(), gatewayAllocator);
+                .build(), gatewayAllocator, snapshotsInfoService);
 
 
         logger.info("Building initial routing table");
         logger.info("Building initial routing table");
 
 
@@ -81,7 +84,7 @@ public class ThrottlingAllocationTests extends ESAllocationTestCase {
                 .put(IndexMetadata.builder("test").settings(settings(Version.CURRENT)).numberOfShards(10).numberOfReplicas(1))
                 .put(IndexMetadata.builder("test").settings(settings(Version.CURRENT)).numberOfShards(10).numberOfReplicas(1))
                 .build();
                 .build();
 
 
-        ClusterState clusterState = createRecoveryStateAndInitalizeAllocations(metadata, gatewayAllocator);
+        ClusterState clusterState = createRecoveryStateAndInitializeAllocations(metadata, gatewayAllocator, snapshotsInfoService);
 
 
         logger.info("start one node, do reroute, only 3 should initialize");
         logger.info("start one node, do reroute, only 3 should initialize");
         clusterState = ClusterState.builder(clusterState).nodes(DiscoveryNodes.builder().add(newNode("node1"))).build();
         clusterState = ClusterState.builder(clusterState).nodes(DiscoveryNodes.builder().add(newNode("node1"))).build();
@@ -122,12 +125,14 @@ public class ThrottlingAllocationTests extends ESAllocationTestCase {
 
 
     public void testReplicaAndPrimaryRecoveryThrottling() {
     public void testReplicaAndPrimaryRecoveryThrottling() {
         TestGatewayAllocator gatewayAllocator = new TestGatewayAllocator();
         TestGatewayAllocator gatewayAllocator = new TestGatewayAllocator();
+        TestSnapshotsInfoService snapshotsInfoService = new TestSnapshotsInfoService();
         AllocationService strategy = createAllocationService(Settings.builder()
         AllocationService strategy = createAllocationService(Settings.builder()
                 .put("cluster.routing.allocation.node_concurrent_recoveries", 3)
                 .put("cluster.routing.allocation.node_concurrent_recoveries", 3)
                 .put("cluster.routing.allocation.concurrent_source_recoveries", 3)
                 .put("cluster.routing.allocation.concurrent_source_recoveries", 3)
                 .put("cluster.routing.allocation.node_initial_primaries_recoveries", 3)
                 .put("cluster.routing.allocation.node_initial_primaries_recoveries", 3)
                 .build(),
                 .build(),
-            gatewayAllocator);
+            gatewayAllocator,
+            snapshotsInfoService);
 
 
         logger.info("Building initial routing table");
         logger.info("Building initial routing table");
 
 
@@ -135,7 +140,7 @@ public class ThrottlingAllocationTests extends ESAllocationTestCase {
                 .put(IndexMetadata.builder("test").settings(settings(Version.CURRENT)).numberOfShards(5).numberOfReplicas(1))
                 .put(IndexMetadata.builder("test").settings(settings(Version.CURRENT)).numberOfShards(5).numberOfReplicas(1))
                 .build();
                 .build();
 
 
-        ClusterState clusterState = createRecoveryStateAndInitalizeAllocations(metadata, gatewayAllocator);
+        ClusterState clusterState = createRecoveryStateAndInitializeAllocations(metadata, gatewayAllocator, snapshotsInfoService);
 
 
         logger.info("with one node, do reroute, only 3 should initialize");
         logger.info("with one node, do reroute, only 3 should initialize");
         clusterState = strategy.reroute(clusterState, "reroute");
         clusterState = strategy.reroute(clusterState, "reroute");
@@ -184,19 +189,20 @@ public class ThrottlingAllocationTests extends ESAllocationTestCase {
 
 
     public void testThrottleIncomingAndOutgoing() {
     public void testThrottleIncomingAndOutgoing() {
         TestGatewayAllocator gatewayAllocator = new TestGatewayAllocator();
         TestGatewayAllocator gatewayAllocator = new TestGatewayAllocator();
+        TestSnapshotsInfoService snapshotsInfoService = new TestSnapshotsInfoService();
         Settings settings = Settings.builder()
         Settings settings = Settings.builder()
             .put("cluster.routing.allocation.node_concurrent_recoveries", 5)
             .put("cluster.routing.allocation.node_concurrent_recoveries", 5)
             .put("cluster.routing.allocation.node_initial_primaries_recoveries", 5)
             .put("cluster.routing.allocation.node_initial_primaries_recoveries", 5)
             .put("cluster.routing.allocation.cluster_concurrent_rebalance", 5)
             .put("cluster.routing.allocation.cluster_concurrent_rebalance", 5)
             .build();
             .build();
-        AllocationService strategy = createAllocationService(settings, gatewayAllocator);
+        AllocationService strategy = createAllocationService(settings, gatewayAllocator, snapshotsInfoService);
         logger.info("Building initial routing table");
         logger.info("Building initial routing table");
 
 
         Metadata metadata = Metadata.builder()
         Metadata metadata = Metadata.builder()
             .put(IndexMetadata.builder("test").settings(settings(Version.CURRENT)).numberOfShards(9).numberOfReplicas(0))
             .put(IndexMetadata.builder("test").settings(settings(Version.CURRENT)).numberOfShards(9).numberOfReplicas(0))
             .build();
             .build();
 
 
-        ClusterState clusterState = createRecoveryStateAndInitalizeAllocations(metadata, gatewayAllocator);
+        ClusterState clusterState = createRecoveryStateAndInitializeAllocations(metadata, gatewayAllocator, snapshotsInfoService);
 
 
         logger.info("with one node, do reroute, only 5 should initialize");
         logger.info("with one node, do reroute, only 5 should initialize");
         clusterState = strategy.reroute(clusterState, "reroute");
         clusterState = strategy.reroute(clusterState, "reroute");
@@ -243,9 +249,10 @@ public class ThrottlingAllocationTests extends ESAllocationTestCase {
 
 
     public void testOutgoingThrottlesAllocation() {
     public void testOutgoingThrottlesAllocation() {
         TestGatewayAllocator gatewayAllocator = new TestGatewayAllocator();
         TestGatewayAllocator gatewayAllocator = new TestGatewayAllocator();
+        TestSnapshotsInfoService snapshotsInfoService = new TestSnapshotsInfoService();
         AllocationService strategy = createAllocationService(Settings.builder()
         AllocationService strategy = createAllocationService(Settings.builder()
             .put("cluster.routing.allocation.node_concurrent_outgoing_recoveries", 1)
             .put("cluster.routing.allocation.node_concurrent_outgoing_recoveries", 1)
-            .build(), gatewayAllocator);
+            .build(), gatewayAllocator, snapshotsInfoService);
 
 
         logger.info("Building initial routing table");
         logger.info("Building initial routing table");
 
 
@@ -253,7 +260,7 @@ public class ThrottlingAllocationTests extends ESAllocationTestCase {
             .put(IndexMetadata.builder("test").settings(settings(Version.CURRENT)).numberOfShards(1).numberOfReplicas(2))
             .put(IndexMetadata.builder("test").settings(settings(Version.CURRENT)).numberOfShards(1).numberOfReplicas(2))
             .build();
             .build();
 
 
-        ClusterState clusterState = createRecoveryStateAndInitalizeAllocations(metadata, gatewayAllocator);
+        ClusterState clusterState = createRecoveryStateAndInitializeAllocations(metadata, gatewayAllocator, snapshotsInfoService);
 
 
         logger.info("with one node, do reroute, only 1 should initialize");
         logger.info("with one node, do reroute, only 1 should initialize");
         clusterState = strategy.reroute(clusterState, "reroute");
         clusterState = strategy.reroute(clusterState, "reroute");
@@ -314,7 +321,7 @@ public class ThrottlingAllocationTests extends ESAllocationTestCase {
                 assertEquals("reached the limit of outgoing shard recoveries [1] on the node [node1] which holds the primary, "
                 assertEquals("reached the limit of outgoing shard recoveries [1] on the node [node1] which holds the primary, "
                         + "cluster setting [cluster.routing.allocation.node_concurrent_outgoing_recoveries=1] "
                         + "cluster setting [cluster.routing.allocation.node_concurrent_outgoing_recoveries=1] "
                         + "(can also be set via [cluster.routing.allocation.node_concurrent_recoveries])",
                         + "(can also be set via [cluster.routing.allocation.node_concurrent_recoveries])",
-                        decision.getExplanation());
+                    decision.getExplanation());
                 assertEquals(Decision.Type.THROTTLE, decision.type());
                 assertEquals(Decision.Type.THROTTLE, decision.type());
                 foundThrottledMessage = true;
                 foundThrottledMessage = true;
             }
             }
@@ -331,7 +338,11 @@ public class ThrottlingAllocationTests extends ESAllocationTestCase {
         assertEquals(clusterState.getRoutingNodes().getOutgoingRecoveries("node2"), 0);
         assertEquals(clusterState.getRoutingNodes().getOutgoingRecoveries("node2"), 0);
     }
     }
 
 
-    private ClusterState createRecoveryStateAndInitalizeAllocations(Metadata metadata, TestGatewayAllocator gatewayAllocator) {
+    private ClusterState createRecoveryStateAndInitializeAllocations(
+        final Metadata metadata,
+        final TestGatewayAllocator gatewayAllocator,
+        final TestSnapshotsInfoService snapshotsInfoService
+        ) {
         DiscoveryNode node1 = newNode("node1");
         DiscoveryNode node1 = newNode("node1");
         Metadata.Builder metadataBuilder = new Metadata.Builder(metadata);
         Metadata.Builder metadataBuilder = new Metadata.Builder(metadata);
         RoutingTable.Builder routingTableBuilder = RoutingTable.builder();
         RoutingTable.Builder routingTableBuilder = RoutingTable.builder();
@@ -387,8 +398,12 @@ public class ThrottlingAllocationTests extends ESAllocationTestCase {
             ImmutableOpenMap.Builder<ShardId, RestoreInProgress.ShardRestoreStatus> restoreShards = ImmutableOpenMap.builder();
             ImmutableOpenMap.Builder<ShardId, RestoreInProgress.ShardRestoreStatus> restoreShards = ImmutableOpenMap.builder();
             for (ShardRouting shard : routingTable.allShards()) {
             for (ShardRouting shard : routingTable.allShards()) {
                 if (shard.primary() && shard.recoverySource().getType() == RecoverySource.Type.SNAPSHOT) {
                 if (shard.primary() && shard.recoverySource().getType() == RecoverySource.Type.SNAPSHOT) {
-                    ShardId shardId = shard.shardId();
+                    final ShardId shardId = shard.shardId();
                     restoreShards.put(shardId, new RestoreInProgress.ShardRestoreStatus(node1.getId(), RestoreInProgress.State.INIT));
                     restoreShards.put(shardId, new RestoreInProgress.ShardRestoreStatus(node1.getId(), RestoreInProgress.State.INIT));
+                    // Also set the snapshot shard size
+                    final SnapshotRecoverySource recoverySource = (SnapshotRecoverySource) shard.recoverySource();
+                    final long shardSize = randomNonNegativeLong();
+                    snapshotsInfoService.addSnapshotShardSize(recoverySource.snapshot(), recoverySource.index(), shardId, shardSize);
                 }
                 }
             }
             }
 
 
@@ -421,4 +436,22 @@ public class ThrottlingAllocationTests extends ESAllocationTestCase {
             gatewayAllocator.addKnownAllocation(started);
             gatewayAllocator.addKnownAllocation(started);
         }
         }
     }
     }
+
+    private static class TestSnapshotsInfoService implements SnapshotsInfoService {
+
+        private volatile ImmutableOpenMap<InternalSnapshotsInfoService.SnapshotShard, Long> snapshotShardSizes = ImmutableOpenMap.of();
+
+        synchronized void addSnapshotShardSize(Snapshot snapshot, IndexId index, ShardId shard, Long size) {
+            final ImmutableOpenMap.Builder<InternalSnapshotsInfoService.SnapshotShard, Long> newSnapshotShardSizes =
+                ImmutableOpenMap.builder(snapshotShardSizes);
+            boolean added = newSnapshotShardSizes.put(new InternalSnapshotsInfoService.SnapshotShard(snapshot, index, shard), size) == null;
+            assert added : "cannot add snapshot shard size twice";
+            this.snapshotShardSizes = newSnapshotShardSizes.build();
+        }
+
+        @Override
+        public SnapshotShardSizeInfo snapshotShardSizes() {
+            return new SnapshotShardSizeInfo(snapshotShardSizes);
+        }
+    }
 }
 }

+ 1 - 1
server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/AllocationDecidersTests.java

@@ -91,7 +91,7 @@ public class AllocationDecidersTests extends ESTestCase {
 
 
         ClusterState clusterState = ClusterState.builder(new ClusterName("test")).build();
         ClusterState clusterState = ClusterState.builder(new ClusterName("test")).build();
         final RoutingAllocation allocation = new RoutingAllocation(deciders,
         final RoutingAllocation allocation = new RoutingAllocation(deciders,
-            clusterState.getRoutingNodes(), clusterState, null, 0L);
+            clusterState.getRoutingNodes(), clusterState, null, null,0L);
 
 
         allocation.setDebugMode(mode);
         allocation.setDebugMode(mode);
         final UnassignedInfo unassignedInfo = new UnassignedInfo(UnassignedInfo.Reason.INDEX_CREATED, "_message");
         final UnassignedInfo unassignedInfo = new UnassignedInfo(UnassignedInfo.Reason.INDEX_CREATED, "_message");

+ 21 - 18
server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/DiskThresholdDeciderTests.java

@@ -51,6 +51,7 @@ import org.elasticsearch.common.collect.ImmutableOpenMap;
 import org.elasticsearch.common.settings.ClusterSettings;
 import org.elasticsearch.common.settings.ClusterSettings;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.index.shard.ShardId;
+import org.elasticsearch.snapshots.EmptySnapshotsInfoService;
 import org.elasticsearch.test.gateway.TestGatewayAllocator;
 import org.elasticsearch.test.gateway.TestGatewayAllocator;
 
 
 import java.util.Arrays;
 import java.util.Arrays;
@@ -109,7 +110,7 @@ public class DiskThresholdDeciderTests extends ESAllocationTestCase {
             return clusterInfo;
             return clusterInfo;
         };
         };
         AllocationService strategy = new AllocationService(deciders,
         AllocationService strategy = new AllocationService(deciders,
-                new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), cis);
+                new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), cis, EmptySnapshotsInfoService.INSTANCE);
 
 
         Metadata metadata = Metadata.builder()
         Metadata metadata = Metadata.builder()
                 .put(IndexMetadata.builder("test").settings(settings(Version.CURRENT)).numberOfShards(1).numberOfReplicas(1))
                 .put(IndexMetadata.builder("test").settings(settings(Version.CURRENT)).numberOfShards(1).numberOfReplicas(1))
@@ -188,7 +189,7 @@ public class DiskThresholdDeciderTests extends ESAllocationTestCase {
                         makeDecider(diskSettings))));
                         makeDecider(diskSettings))));
 
 
         strategy = new AllocationService(deciders, new TestGatewayAllocator(),
         strategy = new AllocationService(deciders, new TestGatewayAllocator(),
-                new BalancedShardsAllocator(Settings.EMPTY), cis);
+                new BalancedShardsAllocator(Settings.EMPTY), cis, EmptySnapshotsInfoService.INSTANCE);
 
 
         clusterState = strategy.reroute(clusterState, "reroute");
         clusterState = strategy.reroute(clusterState, "reroute");
         logShardStates(clusterState);
         logShardStates(clusterState);
@@ -216,7 +217,7 @@ public class DiskThresholdDeciderTests extends ESAllocationTestCase {
                         makeDecider(diskSettings))));
                         makeDecider(diskSettings))));
 
 
         strategy = new AllocationService(deciders, new TestGatewayAllocator(),
         strategy = new AllocationService(deciders, new TestGatewayAllocator(),
-                new BalancedShardsAllocator(Settings.EMPTY), cis);
+                new BalancedShardsAllocator(Settings.EMPTY), cis, EmptySnapshotsInfoService.INSTANCE);
 
 
         clusterState = strategy.reroute(clusterState, "reroute");
         clusterState = strategy.reroute(clusterState, "reroute");
 
 
@@ -285,7 +286,7 @@ public class DiskThresholdDeciderTests extends ESAllocationTestCase {
         };
         };
 
 
         AllocationService strategy = new AllocationService(deciders, new TestGatewayAllocator(),
         AllocationService strategy = new AllocationService(deciders, new TestGatewayAllocator(),
-                new BalancedShardsAllocator(Settings.EMPTY), cis);
+                new BalancedShardsAllocator(Settings.EMPTY), cis, EmptySnapshotsInfoService.INSTANCE);
 
 
         Metadata metadata = Metadata.builder()
         Metadata metadata = Metadata.builder()
                 .put(IndexMetadata.builder("test").settings(settings(Version.CURRENT)).numberOfShards(1).numberOfReplicas(2))
                 .put(IndexMetadata.builder("test").settings(settings(Version.CURRENT)).numberOfShards(1).numberOfReplicas(2))
@@ -331,7 +332,7 @@ public class DiskThresholdDeciderTests extends ESAllocationTestCase {
             return clusterInfo2;
             return clusterInfo2;
         };
         };
         strategy = new AllocationService(deciders, new TestGatewayAllocator(),
         strategy = new AllocationService(deciders, new TestGatewayAllocator(),
-                new BalancedShardsAllocator(Settings.EMPTY), cis);
+                new BalancedShardsAllocator(Settings.EMPTY), cis, EmptySnapshotsInfoService.INSTANCE);
 
 
         clusterState = strategy.reroute(clusterState, "reroute");
         clusterState = strategy.reroute(clusterState, "reroute");
         logShardStates(clusterState);
         logShardStates(clusterState);
@@ -393,7 +394,7 @@ public class DiskThresholdDeciderTests extends ESAllocationTestCase {
                         makeDecider(diskSettings))));
                         makeDecider(diskSettings))));
 
 
         strategy = new AllocationService(deciders, new TestGatewayAllocator(),
         strategy = new AllocationService(deciders, new TestGatewayAllocator(),
-                new BalancedShardsAllocator(Settings.EMPTY), cis);
+                new BalancedShardsAllocator(Settings.EMPTY), cis, EmptySnapshotsInfoService.INSTANCE);
 
 
         clusterState = strategy.reroute(clusterState, "reroute");
         clusterState = strategy.reroute(clusterState, "reroute");
         logShardStates(clusterState);
         logShardStates(clusterState);
@@ -422,7 +423,7 @@ public class DiskThresholdDeciderTests extends ESAllocationTestCase {
                         makeDecider(diskSettings))));
                         makeDecider(diskSettings))));
 
 
         strategy = new AllocationService(deciders, new TestGatewayAllocator(),
         strategy = new AllocationService(deciders, new TestGatewayAllocator(),
-                new BalancedShardsAllocator(Settings.EMPTY), cis);
+                new BalancedShardsAllocator(Settings.EMPTY), cis, EmptySnapshotsInfoService.INSTANCE);
 
 
         clusterState = strategy.reroute(clusterState, "reroute");
         clusterState = strategy.reroute(clusterState, "reroute");
 
 
@@ -517,7 +518,7 @@ public class DiskThresholdDeciderTests extends ESAllocationTestCase {
         };
         };
 
 
         AllocationService strategy = new AllocationService(deciders, new TestGatewayAllocator(),
         AllocationService strategy = new AllocationService(deciders, new TestGatewayAllocator(),
-                new BalancedShardsAllocator(Settings.EMPTY), cis);
+                new BalancedShardsAllocator(Settings.EMPTY), cis, EmptySnapshotsInfoService.INSTANCE);
 
 
         Metadata metadata = Metadata.builder()
         Metadata metadata = Metadata.builder()
                 .put(IndexMetadata.builder("test").settings(settings(Version.CURRENT)).numberOfShards(1).numberOfReplicas(0))
                 .put(IndexMetadata.builder("test").settings(settings(Version.CURRENT)).numberOfShards(1).numberOfReplicas(0))
@@ -577,7 +578,7 @@ public class DiskThresholdDeciderTests extends ESAllocationTestCase {
         };
         };
 
 
         AllocationService strategy = new AllocationService(deciders, new TestGatewayAllocator(),
         AllocationService strategy = new AllocationService(deciders, new TestGatewayAllocator(),
-                new BalancedShardsAllocator(Settings.EMPTY), cis);
+                new BalancedShardsAllocator(Settings.EMPTY), cis, EmptySnapshotsInfoService.INSTANCE);
 
 
         Metadata metadata = Metadata.builder()
         Metadata metadata = Metadata.builder()
                 .put(IndexMetadata.builder("test").settings(settings(Version.CURRENT)).numberOfShards(1).numberOfReplicas(0))
                 .put(IndexMetadata.builder("test").settings(settings(Version.CURRENT)).numberOfShards(1).numberOfReplicas(0))
@@ -671,7 +672,7 @@ public class DiskThresholdDeciderTests extends ESAllocationTestCase {
         final ClusterInfoService cis = clusterInfoReference::get;
         final ClusterInfoService cis = clusterInfoReference::get;
 
 
         AllocationService strategy = new AllocationService(deciders, new TestGatewayAllocator(),
         AllocationService strategy = new AllocationService(deciders, new TestGatewayAllocator(),
-            new BalancedShardsAllocator(Settings.EMPTY), cis);
+            new BalancedShardsAllocator(Settings.EMPTY), cis, EmptySnapshotsInfoService.INSTANCE);
 
 
         Metadata metadata = Metadata.builder()
         Metadata metadata = Metadata.builder()
                 .put(IndexMetadata.builder("test").settings(settings(Version.CURRENT)).numberOfShards(1).numberOfReplicas(1))
                 .put(IndexMetadata.builder("test").settings(settings(Version.CURRENT)).numberOfShards(1).numberOfReplicas(1))
@@ -851,7 +852,7 @@ public class DiskThresholdDeciderTests extends ESAllocationTestCase {
         );
         );
         ClusterState clusterState = ClusterState.builder(baseClusterState).routingTable(builder.build()).build();
         ClusterState clusterState = ClusterState.builder(baseClusterState).routingTable(builder.build()).build();
         RoutingAllocation routingAllocation = new RoutingAllocation(null, new RoutingNodes(clusterState), clusterState, clusterInfo,
         RoutingAllocation routingAllocation = new RoutingAllocation(null, new RoutingNodes(clusterState), clusterState, clusterInfo,
-                System.nanoTime());
+                null, System.nanoTime());
         routingAllocation.debugDecision(true);
         routingAllocation.debugDecision(true);
         Decision decision = diskThresholdDecider.canRemain(firstRouting, firstRoutingNode, routingAllocation);
         Decision decision = diskThresholdDecider.canRemain(firstRouting, firstRoutingNode, routingAllocation);
         assertThat(decision.type(), equalTo(Decision.Type.NO));
         assertThat(decision.type(), equalTo(Decision.Type.NO));
@@ -877,7 +878,8 @@ public class DiskThresholdDeciderTests extends ESAllocationTestCase {
                         )
                         )
         );
         );
         clusterState = ClusterState.builder(baseClusterState).routingTable(builder.build()).build();
         clusterState = ClusterState.builder(baseClusterState).routingTable(builder.build()).build();
-        routingAllocation = new RoutingAllocation(null, new RoutingNodes(clusterState), clusterState, clusterInfo, System.nanoTime());
+        routingAllocation = new RoutingAllocation(null, new RoutingNodes(clusterState), clusterState, clusterInfo, null,
+            System.nanoTime());
         routingAllocation.debugDecision(true);
         routingAllocation.debugDecision(true);
         decision = diskThresholdDecider.canRemain(firstRouting, firstRoutingNode, routingAllocation);
         decision = diskThresholdDecider.canRemain(firstRouting, firstRoutingNode, routingAllocation);
         assertThat(decision.type(), equalTo(Decision.Type.YES));
         assertThat(decision.type(), equalTo(Decision.Type.YES));
@@ -907,7 +909,7 @@ public class DiskThresholdDeciderTests extends ESAllocationTestCase {
                 diskThresholdDecider
                 diskThresholdDecider
         )));
         )));
         AllocationService strategy = new AllocationService(deciders, new TestGatewayAllocator(),
         AllocationService strategy = new AllocationService(deciders, new TestGatewayAllocator(),
-                new BalancedShardsAllocator(Settings.EMPTY), cis);
+                new BalancedShardsAllocator(Settings.EMPTY), cis, EmptySnapshotsInfoService.INSTANCE);
         // Ensure that the reroute call doesn't alter the routing table, since the first primary is relocating away
         // Ensure that the reroute call doesn't alter the routing table, since the first primary is relocating away
         // and therefor we will have sufficient disk space on node1.
         // and therefor we will have sufficient disk space on node1.
         ClusterState result = strategy.reroute(clusterState, "reroute");
         ClusterState result = strategy.reroute(clusterState, "reroute");
@@ -979,7 +981,7 @@ public class DiskThresholdDeciderTests extends ESAllocationTestCase {
         );
         );
         ClusterState clusterState = ClusterState.builder(baseClusterState).routingTable(builder.build()).build();
         ClusterState clusterState = ClusterState.builder(baseClusterState).routingTable(builder.build()).build();
         RoutingAllocation routingAllocation = new RoutingAllocation(null, new RoutingNodes(clusterState), clusterState, clusterInfo,
         RoutingAllocation routingAllocation = new RoutingAllocation(null, new RoutingNodes(clusterState), clusterState, clusterInfo,
-                System.nanoTime());
+                null, System.nanoTime());
         routingAllocation.debugDecision(true);
         routingAllocation.debugDecision(true);
         Decision decision = diskThresholdDecider.canRemain(firstRouting, firstRoutingNode, routingAllocation);
         Decision decision = diskThresholdDecider.canRemain(firstRouting, firstRoutingNode, routingAllocation);
 
 
@@ -999,7 +1001,7 @@ public class DiskThresholdDeciderTests extends ESAllocationTestCase {
         )));
         )));
 
 
         AllocationService strategy = new AllocationService(deciders, new TestGatewayAllocator(),
         AllocationService strategy = new AllocationService(deciders, new TestGatewayAllocator(),
-                new BalancedShardsAllocator(Settings.EMPTY), cis);
+                new BalancedShardsAllocator(Settings.EMPTY), cis, EmptySnapshotsInfoService.INSTANCE);
         ClusterState result = strategy.reroute(clusterState, "reroute");
         ClusterState result = strategy.reroute(clusterState, "reroute");
 
 
         assertThat(result.routingTable().index("test").getShards().get(0).primaryShard().state(), equalTo(STARTED));
         assertThat(result.routingTable().index("test").getShards().get(0).primaryShard().state(), equalTo(STARTED));
@@ -1032,7 +1034,8 @@ public class DiskThresholdDeciderTests extends ESAllocationTestCase {
         );
         );
 
 
         clusterState = ClusterState.builder(updateClusterState).routingTable(builder.build()).build();
         clusterState = ClusterState.builder(updateClusterState).routingTable(builder.build()).build();
-        routingAllocation = new RoutingAllocation(null, new RoutingNodes(clusterState), clusterState, clusterInfo, System.nanoTime());
+        routingAllocation = new RoutingAllocation(null, new RoutingNodes(clusterState), clusterState, clusterInfo, null,
+            System.nanoTime());
         routingAllocation.debugDecision(true);
         routingAllocation.debugDecision(true);
         decision = diskThresholdDecider.canRemain(firstRouting, firstRoutingNode, routingAllocation);
         decision = diskThresholdDecider.canRemain(firstRouting, firstRoutingNode, routingAllocation);
         assertThat(decision.type(), equalTo(Decision.Type.YES));
         assertThat(decision.type(), equalTo(Decision.Type.YES));
@@ -1096,7 +1099,7 @@ public class DiskThresholdDeciderTests extends ESAllocationTestCase {
             diskThresholdDecider
             diskThresholdDecider
         )));
         )));
         AllocationService strategy = new AllocationService(deciders,
         AllocationService strategy = new AllocationService(deciders,
-            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), cis);
+            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), cis, EmptySnapshotsInfoService.INSTANCE);
         ClusterState result = strategy.reroute(clusterState, "reroute");
         ClusterState result = strategy.reroute(clusterState, "reroute");
 
 
         ShardRouting shardRouting = result.routingTable().index("test").getShards().get(0).primaryShard();
         ShardRouting shardRouting = result.routingTable().index("test").getShards().get(0).primaryShard();
@@ -1117,7 +1120,7 @@ public class DiskThresholdDeciderTests extends ESAllocationTestCase {
         clusterState = ClusterState.builder(clusterState).routingTable(forceAssignedRoutingTable).build();
         clusterState = ClusterState.builder(clusterState).routingTable(forceAssignedRoutingTable).build();
 
 
         RoutingAllocation routingAllocation = new RoutingAllocation(null, new RoutingNodes(clusterState), clusterState, clusterInfo,
         RoutingAllocation routingAllocation = new RoutingAllocation(null, new RoutingNodes(clusterState), clusterState, clusterInfo,
-            System.nanoTime());
+            null, System.nanoTime());
         routingAllocation.debugDecision(true);
         routingAllocation.debugDecision(true);
         Decision decision = diskThresholdDecider.canRemain(startedShard, clusterState.getRoutingNodes().node("data"), routingAllocation);
         Decision decision = diskThresholdDecider.canRemain(startedShard, clusterState.getRoutingNodes().node("data"), routingAllocation);
         assertThat(decision.type(), equalTo(Decision.Type.NO));
         assertThat(decision.type(), equalTo(Decision.Type.NO));

+ 7 - 7
server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/DiskThresholdDeciderUnitTests.java

@@ -106,7 +106,7 @@ public class DiskThresholdDeciderUnitTests extends ESAllocationTestCase {
         final ClusterInfo clusterInfo = new ClusterInfo(leastAvailableUsages.build(),
         final ClusterInfo clusterInfo = new ClusterInfo(leastAvailableUsages.build(),
             mostAvailableUsage.build(), shardSizes.build(), ImmutableOpenMap.of(),  ImmutableOpenMap.of());
             mostAvailableUsage.build(), shardSizes.build(), ImmutableOpenMap.of(),  ImmutableOpenMap.of());
         RoutingAllocation allocation = new RoutingAllocation(new AllocationDeciders(Collections.singleton(decider)),
         RoutingAllocation allocation = new RoutingAllocation(new AllocationDeciders(Collections.singleton(decider)),
-            clusterState.getRoutingNodes(), clusterState, clusterInfo, System.nanoTime());
+            clusterState.getRoutingNodes(), clusterState, clusterInfo, null, System.nanoTime());
         allocation.debugDecision(true);
         allocation.debugDecision(true);
         Decision decision = decider.canAllocate(test_0, new RoutingNode("node_0", node_0), allocation);
         Decision decision = decider.canAllocate(test_0, new RoutingNode("node_0", node_0), allocation);
         assertEquals(mostAvailableUsage.toString(), Decision.Type.YES, decision.type());
         assertEquals(mostAvailableUsage.toString(), Decision.Type.YES, decision.type());
@@ -161,7 +161,7 @@ public class DiskThresholdDeciderUnitTests extends ESAllocationTestCase {
         ClusterInfo clusterInfo = new ClusterInfo(leastAvailableUsages.build(), mostAvailableUsage.build(),
         ClusterInfo clusterInfo = new ClusterInfo(leastAvailableUsages.build(), mostAvailableUsage.build(),
             shardSizes.build(), ImmutableOpenMap.of(),  ImmutableOpenMap.of());
             shardSizes.build(), ImmutableOpenMap.of(),  ImmutableOpenMap.of());
         RoutingAllocation allocation = new RoutingAllocation(new AllocationDeciders(Collections.singleton(decider)),
         RoutingAllocation allocation = new RoutingAllocation(new AllocationDeciders(Collections.singleton(decider)),
-            clusterState.getRoutingNodes(), clusterState, clusterInfo, System.nanoTime());
+            clusterState.getRoutingNodes(), clusterState, clusterInfo, null, System.nanoTime());
         allocation.debugDecision(true);
         allocation.debugDecision(true);
         Decision decision = decider.canAllocate(test_0, new RoutingNode("node_0", node_0), allocation);
         Decision decision = decider.canAllocate(test_0, new RoutingNode("node_0", node_0), allocation);
         assertEquals(Decision.Type.NO, decision.type());
         assertEquals(Decision.Type.NO, decision.type());
@@ -242,7 +242,7 @@ public class DiskThresholdDeciderUnitTests extends ESAllocationTestCase {
         final ClusterInfo clusterInfo = new ClusterInfo(leastAvailableUsages.build(), mostAvailableUsage.build(),
         final ClusterInfo clusterInfo = new ClusterInfo(leastAvailableUsages.build(), mostAvailableUsage.build(),
             shardSizes.build(), shardRoutingMap.build(), ImmutableOpenMap.of());
             shardSizes.build(), shardRoutingMap.build(), ImmutableOpenMap.of());
         RoutingAllocation allocation = new RoutingAllocation(new AllocationDeciders(Collections.singleton(decider)),
         RoutingAllocation allocation = new RoutingAllocation(new AllocationDeciders(Collections.singleton(decider)),
-            clusterState.getRoutingNodes(), clusterState, clusterInfo, System.nanoTime());
+            clusterState.getRoutingNodes(), clusterState, clusterInfo, null, System.nanoTime());
         allocation.debugDecision(true);
         allocation.debugDecision(true);
         Decision decision = decider.canRemain(test_0, new RoutingNode("node_0", node_0), allocation);
         Decision decision = decider.canRemain(test_0, new RoutingNode("node_0", node_0), allocation);
         assertEquals(Decision.Type.YES, decision.type());
         assertEquals(Decision.Type.YES, decision.type());
@@ -296,7 +296,7 @@ public class DiskThresholdDeciderUnitTests extends ESAllocationTestCase {
         routingTableBuilder.addAsNew(metadata.index("other"));
         routingTableBuilder.addAsNew(metadata.index("other"));
         ClusterState clusterState = ClusterState.builder(org.elasticsearch.cluster.ClusterName.CLUSTER_NAME_SETTING
         ClusterState clusterState = ClusterState.builder(org.elasticsearch.cluster.ClusterName.CLUSTER_NAME_SETTING
             .getDefault(Settings.EMPTY)).metadata(metadata).routingTable(routingTableBuilder.build()).build();
             .getDefault(Settings.EMPTY)).metadata(metadata).routingTable(routingTableBuilder.build()).build();
-        RoutingAllocation allocation = new RoutingAllocation(null, null, clusterState, info, 0);
+        RoutingAllocation allocation = new RoutingAllocation(null, null, clusterState, info, null, 0);
 
 
         final Index index = new Index("test", "1234");
         final Index index = new Index("test", "1234");
         ShardRouting test_0 = ShardRouting.newUnassigned(new ShardId(index, 0), false, PeerRecoverySource.INSTANCE,
         ShardRouting test_0 = ShardRouting.newUnassigned(new ShardId(index, 0), false, PeerRecoverySource.INSTANCE,
@@ -390,7 +390,7 @@ public class DiskThresholdDeciderUnitTests extends ESAllocationTestCase {
         clusterState = startShardsAndReroute(allocationService, clusterState,
         clusterState = startShardsAndReroute(allocationService, clusterState,
             clusterState.getRoutingTable().index("test").shardsWithState(ShardRoutingState.UNASSIGNED));
             clusterState.getRoutingTable().index("test").shardsWithState(ShardRoutingState.UNASSIGNED));
 
 
-        RoutingAllocation allocation = new RoutingAllocation(null, clusterState.getRoutingNodes(), clusterState, info, 0);
+        RoutingAllocation allocation = new RoutingAllocation(null, clusterState.getRoutingNodes(), clusterState, info, null,0);
 
 
         final Index index = new Index("test", "1234");
         final Index index = new Index("test", "1234");
         ShardRouting test_0 = ShardRouting.newUnassigned(new ShardId(index, 0), true,
         ShardRouting test_0 = ShardRouting.newUnassigned(new ShardId(index, 0), true,
@@ -435,14 +435,14 @@ public class DiskThresholdDeciderUnitTests extends ESAllocationTestCase {
 
 
         allocationService.reroute(clusterState, "foo");
         allocationService.reroute(clusterState, "foo");
         RoutingAllocation allocationWithMissingSourceIndex = new RoutingAllocation(null,
         RoutingAllocation allocationWithMissingSourceIndex = new RoutingAllocation(null,
-            clusterStateWithMissingSourceIndex.getRoutingNodes(), clusterStateWithMissingSourceIndex, info, 0);
+            clusterStateWithMissingSourceIndex.getRoutingNodes(), clusterStateWithMissingSourceIndex, info, null,0);
         assertEquals(42L, getExpectedShardSize(target, 42L, allocationWithMissingSourceIndex));
         assertEquals(42L, getExpectedShardSize(target, 42L, allocationWithMissingSourceIndex));
         assertEquals(42L, getExpectedShardSize(target2, 42L, allocationWithMissingSourceIndex));
         assertEquals(42L, getExpectedShardSize(target2, 42L, allocationWithMissingSourceIndex));
     }
     }
 
 
     private static long getExpectedShardSize(ShardRouting shardRouting, long defaultSize, RoutingAllocation allocation) {
     private static long getExpectedShardSize(ShardRouting shardRouting, long defaultSize, RoutingAllocation allocation) {
         return DiskThresholdDecider.getExpectedShardSize(shardRouting, defaultSize,
         return DiskThresholdDecider.getExpectedShardSize(shardRouting, defaultSize,
-            allocation.clusterInfo(), allocation.metadata(), allocation.routingTable());
+            allocation.clusterInfo(), allocation.snapshotShardSizeInfo(), allocation.metadata(), allocation.routingTable());
     }
     }
 
 
     public void testDiskUsageWithRelocations() {
     public void testDiskUsageWithRelocations() {

+ 3 - 1
server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/EnableAllocationShortCircuitTests.java

@@ -37,6 +37,7 @@ import org.elasticsearch.cluster.routing.allocation.allocator.BalancedShardsAllo
 import org.elasticsearch.common.settings.ClusterSettings;
 import org.elasticsearch.common.settings.ClusterSettings;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.plugins.ClusterPlugin;
 import org.elasticsearch.plugins.ClusterPlugin;
+import org.elasticsearch.snapshots.EmptySnapshotsInfoService;
 import org.elasticsearch.test.gateway.TestGatewayAllocator;
 import org.elasticsearch.test.gateway.TestGatewayAllocator;
 
 
 import java.util.ArrayList;
 import java.util.ArrayList;
@@ -166,7 +167,8 @@ public class EnableAllocationShortCircuitTests extends ESAllocationTestCase {
                 Collections.singletonList(plugin)));
                 Collections.singletonList(plugin)));
         return new MockAllocationService(
         return new MockAllocationService(
             new AllocationDeciders(deciders),
             new AllocationDeciders(deciders),
-            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE);
+            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE,
+            EmptySnapshotsInfoService.INSTANCE);
     }
     }
 
 
     private static class RebalanceShortCircuitPlugin implements ClusterPlugin {
     private static class RebalanceShortCircuitPlugin implements ClusterPlugin {

+ 5 - 3
server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/FilterAllocationDeciderTests.java

@@ -37,6 +37,7 @@ import org.elasticsearch.common.settings.ClusterSettings;
 import org.elasticsearch.common.settings.IndexScopedSettings;
 import org.elasticsearch.common.settings.IndexScopedSettings;
 import org.elasticsearch.common.settings.Setting;
 import org.elasticsearch.common.settings.Setting;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.snapshots.EmptySnapshotsInfoService;
 import org.elasticsearch.test.gateway.TestGatewayAllocator;
 import org.elasticsearch.test.gateway.TestGatewayAllocator;
 
 
 import java.util.Arrays;
 import java.util.Arrays;
@@ -58,7 +59,8 @@ public class FilterAllocationDeciderTests extends ESAllocationTestCase {
                 new SameShardAllocationDecider(Settings.EMPTY, clusterSettings),
                 new SameShardAllocationDecider(Settings.EMPTY, clusterSettings),
                 new ReplicaAfterPrimaryActiveAllocationDecider()));
                 new ReplicaAfterPrimaryActiveAllocationDecider()));
         AllocationService service = new AllocationService(allocationDeciders,
         AllocationService service = new AllocationService(allocationDeciders,
-            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE);
+            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE,
+            EmptySnapshotsInfoService.INSTANCE);
         ClusterState state = createInitialClusterState(service, Settings.builder().put("index.routing.allocation.initial_recovery._id",
         ClusterState state = createInitialClusterState(service, Settings.builder().put("index.routing.allocation.initial_recovery._id",
             "node2").build());
             "node2").build());
         RoutingTable routingTable = state.routingTable();
         RoutingTable routingTable = state.routingTable();
@@ -73,7 +75,7 @@ public class FilterAllocationDeciderTests extends ESAllocationTestCase {
 
 
         // after failing the shard we are unassigned since the node is blacklisted and we can't initialize on the other node
         // after failing the shard we are unassigned since the node is blacklisted and we can't initialize on the other node
         RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, state.getRoutingNodes(), state,
         RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, state.getRoutingNodes(), state,
-            null, 0);
+            null, null, 0);
         allocation.debugDecision(true);
         allocation.debugDecision(true);
         Decision.Single decision = (Decision.Single) filterAllocationDecider.canAllocate(
         Decision.Single decision = (Decision.Single) filterAllocationDecider.canAllocate(
             routingTable.index("idx").shard(0).primaryShard(),
             routingTable.index("idx").shard(0).primaryShard(),
@@ -124,7 +126,7 @@ public class FilterAllocationDeciderTests extends ESAllocationTestCase {
         assertEquals(routingTable.index("idx").shard(0).primaryShard().currentNodeId(), "node1");
         assertEquals(routingTable.index("idx").shard(0).primaryShard().currentNodeId(), "node1");
 
 
         allocation = new RoutingAllocation(allocationDeciders, state.getRoutingNodes(), state,
         allocation = new RoutingAllocation(allocationDeciders, state.getRoutingNodes(), state,
-            null, 0);
+            null, null, 0);
         allocation.debugDecision(true);
         allocation.debugDecision(true);
         decision = (Decision.Single) filterAllocationDecider.canAllocate(
         decision = (Decision.Single) filterAllocationDecider.canAllocate(
             routingTable.index("idx").shard(0).shards().get(0),
             routingTable.index("idx").shard(0).shards().get(0),

+ 1 - 1
server/src/test/java/org/elasticsearch/cluster/routing/allocation/decider/RestoreInProgressAllocationDeciderTests.java

@@ -193,7 +193,7 @@ public class RestoreInProgressAllocationDeciderTests extends ESAllocationTestCas
     private Decision executeAllocation(final ClusterState clusterState, final ShardRouting shardRouting) {
     private Decision executeAllocation(final ClusterState clusterState, final ShardRouting shardRouting) {
         final AllocationDecider decider = new RestoreInProgressAllocationDecider();
         final AllocationDecider decider = new RestoreInProgressAllocationDecider();
         final RoutingAllocation allocation = new RoutingAllocation(new AllocationDeciders(Collections.singleton(decider)),
         final RoutingAllocation allocation = new RoutingAllocation(new AllocationDeciders(Collections.singleton(decider)),
-            clusterState.getRoutingNodes(), clusterState, null, 0L);
+            clusterState.getRoutingNodes(), clusterState, null, null, 0L);
         allocation.debugDecision(true);
         allocation.debugDecision(true);
 
 
         final Decision decision;
         final Decision decision;

+ 3 - 1
server/src/test/java/org/elasticsearch/gateway/GatewayServiceTests.java

@@ -39,6 +39,7 @@ import org.elasticsearch.common.settings.Setting;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.transport.TransportAddress;
 import org.elasticsearch.common.transport.TransportAddress;
 import org.elasticsearch.common.unit.TimeValue;
 import org.elasticsearch.common.unit.TimeValue;
+import org.elasticsearch.snapshots.EmptySnapshotsInfoService;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.gateway.TestGatewayAllocator;
 import org.elasticsearch.test.gateway.TestGatewayAllocator;
 import org.hamcrest.Matchers;
 import org.hamcrest.Matchers;
@@ -60,7 +61,8 @@ public class GatewayServiceTests extends ESTestCase {
         final AllocationService allocationService = new AllocationService(new AllocationDeciders(new HashSet<>(
         final AllocationService allocationService = new AllocationService(new AllocationDeciders(new HashSet<>(
             Arrays.asList(new SameShardAllocationDecider(Settings.EMPTY, new ClusterSettings(Settings.EMPTY,
             Arrays.asList(new SameShardAllocationDecider(Settings.EMPTY, new ClusterSettings(Settings.EMPTY,
                 ClusterSettings.BUILT_IN_CLUSTER_SETTINGS)), new ReplicaAfterPrimaryActiveAllocationDecider()))),
                 ClusterSettings.BUILT_IN_CLUSTER_SETTINGS)), new ReplicaAfterPrimaryActiveAllocationDecider()))),
-            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE);
+            new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE,
+            EmptySnapshotsInfoService.INSTANCE);
         return new GatewayService(settings.build(), allocationService, clusterService, null, null, null);
         return new GatewayService(settings.build(), allocationService, clusterService, null, null, null);
     }
     }
 
 

+ 35 - 9
server/src/test/java/org/elasticsearch/gateway/PrimaryShardAllocatorTests.java

@@ -44,6 +44,7 @@ import org.elasticsearch.cluster.routing.allocation.decider.AllocationDeciders;
 import org.elasticsearch.cluster.routing.allocation.decider.Decision;
 import org.elasticsearch.cluster.routing.allocation.decider.Decision;
 import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.UUIDs;
 import org.elasticsearch.common.UUIDs;
+import org.elasticsearch.common.collect.ImmutableOpenMap;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.util.set.Sets;
 import org.elasticsearch.common.util.set.Sets;
 import org.elasticsearch.env.ShardLockObtainFailedException;
 import org.elasticsearch.env.ShardLockObtainFailedException;
@@ -51,6 +52,7 @@ import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.repositories.IndexId;
 import org.elasticsearch.repositories.IndexId;
 import org.elasticsearch.snapshots.Snapshot;
 import org.elasticsearch.snapshots.Snapshot;
 import org.elasticsearch.snapshots.SnapshotId;
 import org.elasticsearch.snapshots.SnapshotId;
+import org.elasticsearch.snapshots.SnapshotShardSizeInfo;
 import org.junit.Before;
 import org.junit.Before;
 
 
 import java.util.Arrays;
 import java.util.Arrays;
@@ -341,7 +343,7 @@ public class PrimaryShardAllocatorTests extends ESAllocationTestCase {
      * deciders say yes, we allocate to that node.
      * deciders say yes, we allocate to that node.
      */
      */
     public void testRestore() {
     public void testRestore() {
-        RoutingAllocation allocation = getRestoreRoutingAllocation(yesAllocationDeciders(), "allocId");
+        RoutingAllocation allocation = getRestoreRoutingAllocation(yesAllocationDeciders(), randomLong(), "allocId");
         testAllocator.addData(node1, "some allocId", randomBoolean());
         testAllocator.addData(node1, "some allocId", randomBoolean());
         allocateAllUnassigned(allocation);
         allocateAllUnassigned(allocation);
         assertThat(allocation.routingNodesChanged(), equalTo(true));
         assertThat(allocation.routingNodesChanged(), equalTo(true));
@@ -355,7 +357,7 @@ public class PrimaryShardAllocatorTests extends ESAllocationTestCase {
      * deciders say throttle, we add it to ignored shards.
      * deciders say throttle, we add it to ignored shards.
      */
      */
     public void testRestoreThrottle() {
     public void testRestoreThrottle() {
-        RoutingAllocation allocation = getRestoreRoutingAllocation(throttleAllocationDeciders(), "allocId");
+        RoutingAllocation allocation = getRestoreRoutingAllocation(throttleAllocationDeciders(), randomLong(), "allocId");
         testAllocator.addData(node1, "some allocId", randomBoolean());
         testAllocator.addData(node1, "some allocId", randomBoolean());
         allocateAllUnassigned(allocation);
         allocateAllUnassigned(allocation);
         assertThat(allocation.routingNodesChanged(), equalTo(true));
         assertThat(allocation.routingNodesChanged(), equalTo(true));
@@ -368,12 +370,15 @@ public class PrimaryShardAllocatorTests extends ESAllocationTestCase {
      * deciders say no, we still allocate to that node.
      * deciders say no, we still allocate to that node.
      */
      */
     public void testRestoreForcesAllocateIfShardAvailable() {
     public void testRestoreForcesAllocateIfShardAvailable() {
-        RoutingAllocation allocation = getRestoreRoutingAllocation(noAllocationDeciders(), "allocId");
+        final long shardSize = randomNonNegativeLong();
+        RoutingAllocation allocation = getRestoreRoutingAllocation(noAllocationDeciders(), shardSize, "allocId");
         testAllocator.addData(node1, "some allocId", randomBoolean());
         testAllocator.addData(node1, "some allocId", randomBoolean());
         allocateAllUnassigned(allocation);
         allocateAllUnassigned(allocation);
         assertThat(allocation.routingNodesChanged(), equalTo(true));
         assertThat(allocation.routingNodesChanged(), equalTo(true));
         assertThat(allocation.routingNodes().unassigned().ignored().isEmpty(), equalTo(true));
         assertThat(allocation.routingNodes().unassigned().ignored().isEmpty(), equalTo(true));
-        assertThat(allocation.routingNodes().shardsWithState(ShardRoutingState.INITIALIZING).size(), equalTo(1));
+        final List<ShardRouting> initializingShards = allocation.routingNodes().shardsWithState(ShardRoutingState.INITIALIZING);
+        assertThat(initializingShards.size(), equalTo(1));
+        assertThat(initializingShards.get(0).getExpectedShardSize(), equalTo(shardSize));
         assertClusterHealthStatus(allocation, ClusterHealthStatus.YELLOW);
         assertClusterHealthStatus(allocation, ClusterHealthStatus.YELLOW);
     }
     }
 
 
@@ -382,8 +387,8 @@ public class PrimaryShardAllocatorTests extends ESAllocationTestCase {
      * the unassigned list to be allocated later.
      * the unassigned list to be allocated later.
      */
      */
     public void testRestoreDoesNotAssignIfNoShardAvailable() {
     public void testRestoreDoesNotAssignIfNoShardAvailable() {
-        RoutingAllocation allocation = getRestoreRoutingAllocation(yesAllocationDeciders(), "allocId");
-        testAllocator.addData(node1, null, false);
+        RoutingAllocation allocation = getRestoreRoutingAllocation(yesAllocationDeciders(), randomNonNegativeLong(), "allocId");
+        testAllocator.addData(node1, null, randomBoolean());
         allocateAllUnassigned(allocation);
         allocateAllUnassigned(allocation);
         assertThat(allocation.routingNodesChanged(), equalTo(false));
         assertThat(allocation.routingNodesChanged(), equalTo(false));
         assertThat(allocation.routingNodes().unassigned().ignored().isEmpty(), equalTo(true));
         assertThat(allocation.routingNodes().unassigned().ignored().isEmpty(), equalTo(true));
@@ -391,7 +396,22 @@ public class PrimaryShardAllocatorTests extends ESAllocationTestCase {
         assertClusterHealthStatus(allocation, ClusterHealthStatus.YELLOW);
         assertClusterHealthStatus(allocation, ClusterHealthStatus.YELLOW);
     }
     }
 
 
-    private RoutingAllocation getRestoreRoutingAllocation(AllocationDeciders allocationDeciders, String... allocIds) {
+    /**
+     * Tests that when restoring from a snapshot and we don't know the shard size yet, the shard will remain in
+     * the unassigned list to be allocated later.
+     */
+    public void testRestoreDoesNotAssignIfShardSizeNotAvailable() {
+        RoutingAllocation allocation = getRestoreRoutingAllocation(yesAllocationDeciders(), null, "allocId");
+        testAllocator.addData(node1, null, false);
+        allocateAllUnassigned(allocation);
+        assertThat(allocation.routingNodesChanged(), equalTo(true));
+        assertThat(allocation.routingNodes().unassigned().ignored().isEmpty(), equalTo(false));
+        ShardRouting ignoredRouting = allocation.routingNodes().unassigned().ignored().get(0);
+        assertThat(ignoredRouting.unassignedInfo().getLastAllocationStatus(), equalTo(AllocationStatus.FETCHING_SHARD_DATA));
+        assertClusterHealthStatus(allocation, ClusterHealthStatus.YELLOW);
+    }
+
+    private RoutingAllocation getRestoreRoutingAllocation(AllocationDeciders allocationDeciders, Long shardSize, String... allocIds) {
         Metadata metadata = Metadata.builder()
         Metadata metadata = Metadata.builder()
             .put(IndexMetadata.builder(shardId.getIndexName()).settings(settings(Version.CURRENT)).numberOfShards(1).numberOfReplicas(0)
             .put(IndexMetadata.builder(shardId.getIndexName()).settings(settings(Version.CURRENT)).numberOfShards(1).numberOfReplicas(0)
                 .putInSyncAllocationIds(0, Sets.newHashSet(allocIds)))
                 .putInSyncAllocationIds(0, Sets.newHashSet(allocIds)))
@@ -407,7 +427,13 @@ public class PrimaryShardAllocatorTests extends ESAllocationTestCase {
             .metadata(metadata)
             .metadata(metadata)
             .routingTable(routingTable)
             .routingTable(routingTable)
             .nodes(DiscoveryNodes.builder().add(node1).add(node2).add(node3)).build();
             .nodes(DiscoveryNodes.builder().add(node1).add(node2).add(node3)).build();
-        return new RoutingAllocation(allocationDeciders, new RoutingNodes(state, false), state, null, System.nanoTime());
+        return new RoutingAllocation(allocationDeciders, new RoutingNodes(state, false), state, null,
+            new SnapshotShardSizeInfo(ImmutableOpenMap.of()) {
+                @Override
+                public Long getShardSize(ShardRouting shardRouting) {
+                    return shardSize;
+                }
+            }, System.nanoTime());
     }
     }
 
 
     private RoutingAllocation routingAllocationWithOnePrimaryNoReplicas(AllocationDeciders deciders, UnassignedInfo.Reason reason,
     private RoutingAllocation routingAllocationWithOnePrimaryNoReplicas(AllocationDeciders deciders, UnassignedInfo.Reason reason,
@@ -435,7 +461,7 @@ public class PrimaryShardAllocatorTests extends ESAllocationTestCase {
                 .metadata(metadata)
                 .metadata(metadata)
                 .routingTable(routingTableBuilder.build())
                 .routingTable(routingTableBuilder.build())
                 .nodes(DiscoveryNodes.builder().add(node1).add(node2).add(node3)).build();
                 .nodes(DiscoveryNodes.builder().add(node1).add(node2).add(node3)).build();
-        return new RoutingAllocation(deciders, new RoutingNodes(state, false), state, null, System.nanoTime());
+        return new RoutingAllocation(deciders, new RoutingNodes(state, false), state, null, null, System.nanoTime());
     }
     }
 
 
     private void assertClusterHealthStatus(RoutingAllocation allocation, ClusterHealthStatus expectedStatus) {
     private void assertClusterHealthStatus(RoutingAllocation allocation, ClusterHealthStatus expectedStatus) {

+ 5 - 2
server/src/test/java/org/elasticsearch/gateway/ReplicaShardAllocatorTests.java

@@ -54,6 +54,7 @@ import org.elasticsearch.index.store.Store;
 import org.elasticsearch.index.store.StoreFileMetadata;
 import org.elasticsearch.index.store.StoreFileMetadata;
 import org.elasticsearch.indices.store.TransportNodesListShardStoreMetadata;
 import org.elasticsearch.indices.store.TransportNodesListShardStoreMetadata;
 import org.elasticsearch.cluster.ESAllocationTestCase;
 import org.elasticsearch.cluster.ESAllocationTestCase;
+import org.elasticsearch.snapshots.SnapshotShardSizeInfo;
 import org.junit.Before;
 import org.junit.Before;
 
 
 import java.util.ArrayList;
 import java.util.ArrayList;
@@ -472,7 +473,8 @@ public class ReplicaShardAllocatorTests extends ESAllocationTestCase {
                 .metadata(metadata)
                 .metadata(metadata)
                 .routingTable(routingTable)
                 .routingTable(routingTable)
                 .nodes(DiscoveryNodes.builder().add(node1).add(node2).add(node3)).build();
                 .nodes(DiscoveryNodes.builder().add(node1).add(node2).add(node3)).build();
-        return new RoutingAllocation(deciders, new RoutingNodes(state, false), state, ClusterInfo.EMPTY, System.nanoTime());
+        return new RoutingAllocation(deciders, new RoutingNodes(state, false), state, ClusterInfo.EMPTY, SnapshotShardSizeInfo.EMPTY,
+            System.nanoTime());
     }
     }
 
 
     private RoutingAllocation onePrimaryOnNode1And1ReplicaRecovering(AllocationDeciders deciders, UnassignedInfo unassignedInfo) {
     private RoutingAllocation onePrimaryOnNode1And1ReplicaRecovering(AllocationDeciders deciders, UnassignedInfo unassignedInfo) {
@@ -495,7 +497,8 @@ public class ReplicaShardAllocatorTests extends ESAllocationTestCase {
                 .metadata(metadata)
                 .metadata(metadata)
                 .routingTable(routingTable)
                 .routingTable(routingTable)
                 .nodes(DiscoveryNodes.builder().add(node1).add(node2).add(node3)).build();
                 .nodes(DiscoveryNodes.builder().add(node1).add(node2).add(node3)).build();
-        return new RoutingAllocation(deciders, new RoutingNodes(state, false), state, ClusterInfo.EMPTY, System.nanoTime());
+        return new RoutingAllocation(deciders, new RoutingNodes(state, false), state, ClusterInfo.EMPTY, SnapshotShardSizeInfo.EMPTY,
+            System.nanoTime());
     }
     }
 
 
     private RoutingAllocation onePrimaryOnNode1And1ReplicaRecovering(AllocationDeciders deciders) {
     private RoutingAllocation onePrimaryOnNode1And1ReplicaRecovering(AllocationDeciders deciders) {

+ 2 - 1
server/src/test/java/org/elasticsearch/indices/cluster/ClusterStateChanges.java

@@ -93,6 +93,7 @@ import org.elasticsearch.index.shard.IndexEventListener;
 import org.elasticsearch.indices.IndicesService;
 import org.elasticsearch.indices.IndicesService;
 import org.elasticsearch.indices.ShardLimitValidator;
 import org.elasticsearch.indices.ShardLimitValidator;
 import org.elasticsearch.indices.SystemIndices;
 import org.elasticsearch.indices.SystemIndices;
+import org.elasticsearch.snapshots.EmptySnapshotsInfoService;
 import org.elasticsearch.test.gateway.TestGatewayAllocator;
 import org.elasticsearch.test.gateway.TestGatewayAllocator;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.Transport;
 import org.elasticsearch.transport.Transport;
@@ -148,7 +149,7 @@ public class ClusterStateChanges {
                 new ReplicaAfterPrimaryActiveAllocationDecider(),
                 new ReplicaAfterPrimaryActiveAllocationDecider(),
                 new RandomAllocationDeciderTests.RandomAllocationDecider(getRandom())))),
                 new RandomAllocationDeciderTests.RandomAllocationDecider(getRandom())))),
             new TestGatewayAllocator(), new BalancedShardsAllocator(SETTINGS),
             new TestGatewayAllocator(), new BalancedShardsAllocator(SETTINGS),
-            EmptyClusterInfoService.INSTANCE);
+            EmptyClusterInfoService.INSTANCE, EmptySnapshotsInfoService.INSTANCE);
         shardFailedClusterStateTaskExecutor
         shardFailedClusterStateTaskExecutor
             = new ShardStateAction.ShardFailedClusterStateTaskExecutor(allocationService, null, logger);
             = new ShardStateAction.ShardFailedClusterStateTaskExecutor(allocationService, null, logger);
         shardStartedClusterStateTaskExecutor
         shardStartedClusterStateTaskExecutor

+ 350 - 0
server/src/test/java/org/elasticsearch/snapshots/InternalSnapshotsInfoServiceTests.java

@@ -0,0 +1,350 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.snapshots;
+
+import org.elasticsearch.Version;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.support.PlainActionFuture;
+import org.elasticsearch.cluster.ClusterState;
+import org.elasticsearch.cluster.metadata.IndexMetadata;
+import org.elasticsearch.cluster.metadata.Metadata;
+import org.elasticsearch.cluster.node.DiscoveryNode;
+import org.elasticsearch.cluster.node.DiscoveryNodeRole;
+import org.elasticsearch.cluster.node.DiscoveryNodes;
+import org.elasticsearch.cluster.routing.IndexRoutingTable;
+import org.elasticsearch.cluster.routing.IndexShardRoutingTable;
+import org.elasticsearch.cluster.routing.RecoverySource;
+import org.elasticsearch.cluster.routing.RerouteService;
+import org.elasticsearch.cluster.routing.RoutingTable;
+import org.elasticsearch.cluster.routing.ShardRouting;
+import org.elasticsearch.cluster.routing.ShardRoutingState;
+import org.elasticsearch.cluster.routing.TestShardRouting;
+import org.elasticsearch.cluster.service.ClusterApplier;
+import org.elasticsearch.cluster.service.ClusterService;
+import org.elasticsearch.common.Priority;
+import org.elasticsearch.common.UUIDs;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.index.Index;
+import org.elasticsearch.index.shard.ShardId;
+import org.elasticsearch.index.snapshots.IndexShardSnapshotStatus;
+import org.elasticsearch.repositories.FilterRepository;
+import org.elasticsearch.repositories.IndexId;
+import org.elasticsearch.repositories.RepositoriesService;
+import org.elasticsearch.repositories.Repository;
+import org.elasticsearch.test.ClusterServiceUtils;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.threadpool.TestThreadPool;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.threadpool.ThreadPoolStats;
+import org.junit.After;
+import org.junit.Before;
+
+import java.util.Collections;
+import java.util.Locale;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.Function;
+
+import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_CREATION_DATE;
+import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_REPLICAS;
+import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_SHARDS;
+import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_VERSION_CREATED;
+import static org.elasticsearch.snapshots.InternalSnapshotsInfoService.INTERNAL_SNAPSHOT_INFO_MAX_CONCURRENT_FETCHES_SETTING;
+import static org.hamcrest.Matchers.allOf;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.greaterThanOrEqualTo;
+import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.lessThan;
+import static org.hamcrest.Matchers.notNullValue;
+import static org.hamcrest.Matchers.nullValue;
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.anyString;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+public class InternalSnapshotsInfoServiceTests extends ESTestCase {
+
+    private TestThreadPool threadPool;
+    private ClusterService clusterService;
+    private RepositoriesService repositoriesService;
+    private RerouteService rerouteService;
+
+    @Before
+    @Override
+    public void setUp() throws Exception {
+        super.setUp();
+        threadPool = new TestThreadPool(getTestName());
+        clusterService = ClusterServiceUtils.createClusterService(threadPool);
+        repositoriesService = mock(RepositoriesService.class);
+        rerouteService = mock(RerouteService.class);
+        doAnswer(invocation -> {
+            @SuppressWarnings("unchecked")
+            final ActionListener<ClusterState> listener = (ActionListener<ClusterState>) invocation.getArguments()[2];
+            listener.onResponse(clusterService.state());
+            return null;
+        }).when(rerouteService).reroute(anyString(), any(Priority.class), any());
+    }
+
+    @After
+    @Override
+    public void tearDown() throws Exception {
+        super.tearDown();
+        final boolean terminated = terminate(threadPool);
+        assert terminated;
+        clusterService.close();
+    }
+
+    public void testSnapshotShardSizes() throws Exception {
+        final int maxConcurrentFetches = randomIntBetween(1, 10);
+        final InternalSnapshotsInfoService snapshotsInfoService =
+            new InternalSnapshotsInfoService(Settings.builder()
+                .put(INTERNAL_SNAPSHOT_INFO_MAX_CONCURRENT_FETCHES_SETTING.getKey(), maxConcurrentFetches)
+                .build(), clusterService, () -> repositoriesService, () -> rerouteService);
+
+        final int numberOfShards = randomIntBetween(1, 50);
+        final String indexName = randomAlphaOfLength(10).toLowerCase(Locale.ROOT);
+        final long[] expectedShardSizes = new long[numberOfShards];
+        for (int i = 0; i < expectedShardSizes.length; i++) {
+            expectedShardSizes[i] = randomNonNegativeLong();
+        }
+
+        final AtomicInteger getShardSnapshotStatusCount = new AtomicInteger(0);
+        final CountDownLatch latch = new CountDownLatch(1);
+        final Repository mockRepository = new FilterRepository(mock(Repository.class)) {
+            @Override
+            public IndexShardSnapshotStatus getShardSnapshotStatus(SnapshotId snapshotId, IndexId indexId, ShardId shardId) {
+                try {
+                    assertThat(indexId.getName(), equalTo(indexName));
+                    assertThat(shardId.id(), allOf(greaterThanOrEqualTo(0), lessThan(numberOfShards)));
+                    latch.await();
+                    getShardSnapshotStatusCount.incrementAndGet();
+                    return IndexShardSnapshotStatus.newDone(0L, 0L, 0, 0, 0L, expectedShardSizes[shardId.id()], null);
+                } catch (InterruptedException e) {
+                    throw new AssertionError(e);
+                }
+            }
+        };
+        when(repositoriesService.repository("_repo")).thenReturn(mockRepository);
+
+        applyClusterState("add-unassigned-shards", clusterState -> addUnassignedShards(clusterState, indexName, numberOfShards));
+        waitForMaxActiveGenericThreads(Math.min(numberOfShards, maxConcurrentFetches));
+
+        if (randomBoolean()) {
+            applyClusterState("reapply-last-cluster-state-to-check-deduplication-works",
+                state -> ClusterState.builder(state).incrementVersion().build());
+        }
+
+        assertThat(snapshotsInfoService.numberOfUnknownSnapshotShardSizes(), equalTo(numberOfShards));
+        assertThat(snapshotsInfoService.numberOfKnownSnapshotShardSizes(), equalTo(0));
+
+        latch.countDown();
+
+        assertBusy(() -> {
+            assertThat(snapshotsInfoService.numberOfKnownSnapshotShardSizes(), equalTo(numberOfShards));
+            assertThat(snapshotsInfoService.numberOfUnknownSnapshotShardSizes(), equalTo(0));
+            assertThat(snapshotsInfoService.numberOfFailedSnapshotShardSizes(), equalTo(0));
+        });
+        verify(rerouteService, times(numberOfShards)).reroute(anyString(), any(Priority.class), any());
+        assertThat(getShardSnapshotStatusCount.get(), equalTo(numberOfShards));
+
+        for (int i = 0; i < numberOfShards; i++) {
+            final ShardRouting shardRouting = clusterService.state().routingTable().index(indexName).shard(i).primaryShard();
+            assertThat(snapshotsInfoService.snapshotShardSizes().getShardSize(shardRouting), equalTo(expectedShardSizes[i]));
+        }
+    }
+
+    public void testErroneousSnapshotShardSizes() throws Exception {
+        final InternalSnapshotsInfoService snapshotsInfoService =
+            new InternalSnapshotsInfoService(Settings.builder()
+                .put(INTERNAL_SNAPSHOT_INFO_MAX_CONCURRENT_FETCHES_SETTING.getKey(), randomIntBetween(1, 10))
+                .build(), clusterService, () -> repositoriesService, () -> rerouteService);
+
+        final Map<InternalSnapshotsInfoService.SnapshotShard, Boolean> results = new ConcurrentHashMap<>();
+        final Repository mockRepository = new FilterRepository(mock(Repository.class)) {
+            @Override
+            public IndexShardSnapshotStatus getShardSnapshotStatus(SnapshotId snapshotId, IndexId indexId, ShardId shardId) {
+                final InternalSnapshotsInfoService.SnapshotShard snapshotShard =
+                    new InternalSnapshotsInfoService.SnapshotShard(new Snapshot("_repo", snapshotId), indexId, shardId);
+                if (randomBoolean()) {
+                    results.put(snapshotShard, Boolean.FALSE);
+                    throw new SnapshotException(snapshotShard.snapshot(), "simulated");
+                } else {
+                    results.put(snapshotShard, Boolean.TRUE);
+                    return IndexShardSnapshotStatus.newDone(0L, 0L, 0, 0, 0L, randomNonNegativeLong(), null);
+                }
+            }
+        };
+        when(repositoriesService.repository("_repo")).thenReturn(mockRepository);
+
+        final int maxShardsToCreate = scaledRandomIntBetween(10, 500);
+        final Thread addSnapshotRestoreIndicesThread = new Thread(() -> {
+            int remainingShards = maxShardsToCreate;
+            while (remainingShards > 0) {
+                final String indexName = randomAlphaOfLength(10).toLowerCase(Locale.ROOT);
+                final int numberOfShards = randomIntBetween(1, remainingShards);
+                try {
+                    applyClusterState("add-more-unassigned-shards-for-" + indexName,
+                        clusterState -> addUnassignedShards(clusterState, indexName, numberOfShards));
+                } catch (Exception e) {
+                    throw new AssertionError(e);
+                } finally {
+                    remainingShards -= numberOfShards;
+                }
+            }
+        });
+        addSnapshotRestoreIndicesThread.start();
+        addSnapshotRestoreIndicesThread.join();
+
+        assertBusy(() -> {
+            assertThat(snapshotsInfoService.numberOfKnownSnapshotShardSizes(),
+                equalTo((int) results.values().stream().filter(result -> result.equals(Boolean.TRUE)).count()));
+            assertThat(snapshotsInfoService.numberOfFailedSnapshotShardSizes(),
+                equalTo((int) results.values().stream().filter(result -> result.equals(Boolean.FALSE)).count()));
+            assertThat(snapshotsInfoService.numberOfUnknownSnapshotShardSizes(), equalTo(0));
+        });
+    }
+
+    public void testNoLongerMaster() throws Exception {
+        final InternalSnapshotsInfoService snapshotsInfoService =
+            new InternalSnapshotsInfoService(Settings.EMPTY, clusterService, () -> repositoriesService, () -> rerouteService);
+
+        final Repository mockRepository = new FilterRepository(mock(Repository.class)) {
+            @Override
+            public IndexShardSnapshotStatus getShardSnapshotStatus(SnapshotId snapshotId, IndexId indexId, ShardId shardId) {
+                return IndexShardSnapshotStatus.newDone(0L, 0L, 0, 0, 0L, randomNonNegativeLong(), null);
+            }
+        };
+        when(repositoriesService.repository("_repo")).thenReturn(mockRepository);
+
+        for (int i = 0; i < randomIntBetween(1, 10); i++) {
+            final String indexName = randomAlphaOfLength(10).toLowerCase(Locale.ROOT);
+            final int nbShards =  randomIntBetween(1, 5);
+            applyClusterState("restore-indices-when-master-" + indexName,
+                clusterState -> addUnassignedShards(clusterState, indexName, nbShards));
+        }
+
+        applyClusterState("demote-current-master", this::demoteMasterNode);
+
+        for (int i = 0; i < randomIntBetween(1, 10); i++) {
+            final String indexName = randomAlphaOfLength(10).toLowerCase(Locale.ROOT);
+            final int nbShards =  randomIntBetween(1, 5);
+            applyClusterState("restore-indices-when-no-longer-master-" + indexName,
+                clusterState -> addUnassignedShards(clusterState, indexName, nbShards));
+        }
+
+        assertBusy(() -> {
+            assertThat(snapshotsInfoService.numberOfKnownSnapshotShardSizes(), equalTo(0));
+            assertThat(snapshotsInfoService.numberOfUnknownSnapshotShardSizes(), equalTo(0));
+            assertThat(snapshotsInfoService.numberOfFailedSnapshotShardSizes(), equalTo(0));
+        });
+    }
+
+    private void applyClusterState(final String reason, final Function<ClusterState, ClusterState> applier) {
+        PlainActionFuture.get(future -> clusterService.getClusterApplierService().onNewClusterState(reason,
+            () -> applier.apply(clusterService.state()),
+            new ClusterApplier.ClusterApplyListener() {
+                @Override
+                public void onSuccess(String source) {
+                    future.onResponse(source);
+                }
+
+                @Override
+                public void onFailure(String source, Exception e) {
+                    future.onFailure(e);
+                }
+            })
+        );
+    }
+
+    private void waitForMaxActiveGenericThreads(final int nbActive) throws Exception {
+        assertBusy(() -> {
+            final ThreadPoolStats threadPoolStats = clusterService.getClusterApplierService().threadPool().stats();
+            ThreadPoolStats.Stats generic = null;
+            for (ThreadPoolStats.Stats threadPoolStat : threadPoolStats) {
+                if (ThreadPool.Names.GENERIC.equals(threadPoolStat.getName())) {
+                    generic = threadPoolStat;
+                }
+            }
+            assertThat(generic, notNullValue());
+            assertThat(generic.getActive(), equalTo(nbActive));
+        }, 30L, TimeUnit.SECONDS);
+    }
+
+    private ClusterState addUnassignedShards(final ClusterState currentState, String indexName, int numberOfShards) {
+        assertThat(currentState.metadata().hasIndex(indexName), is(false));
+
+        final Metadata.Builder metadata = Metadata.builder(currentState.metadata())
+            .put(IndexMetadata.builder(indexName)
+                .settings(Settings.builder()
+                    .put(SETTING_VERSION_CREATED, Version.CURRENT)
+                    .put(SETTING_NUMBER_OF_SHARDS, numberOfShards)
+                    .put(SETTING_NUMBER_OF_REPLICAS, randomIntBetween(0, 1))
+                    .put(SETTING_CREATION_DATE, System.currentTimeMillis()))
+                .build(), true)
+            .generateClusterUuidIfNeeded();
+
+        final RecoverySource.SnapshotRecoverySource recoverySource = new RecoverySource.SnapshotRecoverySource(
+            UUIDs.randomBase64UUID(random()),
+            new Snapshot("_repo", new SnapshotId(randomAlphaOfLength(5), UUIDs.randomBase64UUID(random()))),
+            Version.CURRENT,
+            new IndexId(indexName, UUIDs.randomBase64UUID(random()))
+        );
+
+        final Index index = metadata.get(indexName).getIndex();
+        final IndexRoutingTable.Builder indexRoutingTable = IndexRoutingTable.builder(index);
+        for (int primary = 0; primary < numberOfShards; primary++) {
+            final ShardId shardId = new ShardId(index, primary);
+
+            final IndexShardRoutingTable.Builder indexShards = new IndexShardRoutingTable.Builder(shardId);
+            indexShards.addShard(TestShardRouting.newShardRouting(shardId, null, true, ShardRoutingState.UNASSIGNED, recoverySource));
+            for (int replica = 0; replica < metadata.get(indexName).getNumberOfReplicas(); replica++) {
+                indexShards.addShard(TestShardRouting.newShardRouting(shardId, null, false, ShardRoutingState.UNASSIGNED,
+                    RecoverySource.PeerRecoverySource.INSTANCE));
+            }
+            indexRoutingTable.addIndexShard(indexShards.build());
+        }
+
+        final RoutingTable.Builder routingTable = RoutingTable.builder(currentState.routingTable());
+        routingTable.add(indexRoutingTable.build());
+
+        return ClusterState.builder(currentState)
+            .routingTable(routingTable.build())
+            .metadata(metadata)
+            .build();
+    }
+
+    private ClusterState demoteMasterNode(final ClusterState currentState) {
+        final DiscoveryNode node = new DiscoveryNode("other", ESTestCase.buildNewFakeTransportAddress(), Collections.emptyMap(),
+            DiscoveryNodeRole.BUILT_IN_ROLES, Version.CURRENT);
+        assertThat(currentState.nodes().get(node.getId()), nullValue());
+
+        return ClusterState.builder(currentState)
+            .nodes(DiscoveryNodes.builder(currentState.nodes())
+                .add(node)
+                .masterNodeId(node.getId()))
+            .build();
+    }
+}

+ 13 - 8
server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java

@@ -65,7 +65,6 @@ import org.elasticsearch.action.admin.indices.mapping.put.TransportPutMappingAct
 import org.elasticsearch.action.admin.indices.shards.IndicesShardStoresAction;
 import org.elasticsearch.action.admin.indices.shards.IndicesShardStoresAction;
 import org.elasticsearch.action.admin.indices.shards.TransportIndicesShardStoresAction;
 import org.elasticsearch.action.admin.indices.shards.TransportIndicesShardStoresAction;
 import org.elasticsearch.action.bulk.BulkAction;
 import org.elasticsearch.action.bulk.BulkAction;
-import org.elasticsearch.index.IndexingPressure;
 import org.elasticsearch.action.bulk.BulkRequest;
 import org.elasticsearch.action.bulk.BulkRequest;
 import org.elasticsearch.action.bulk.BulkResponse;
 import org.elasticsearch.action.bulk.BulkResponse;
 import org.elasticsearch.action.bulk.TransportBulkAction;
 import org.elasticsearch.action.bulk.TransportBulkAction;
@@ -125,6 +124,7 @@ import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.node.DiscoveryNodeRole;
 import org.elasticsearch.cluster.node.DiscoveryNodeRole;
 import org.elasticsearch.cluster.node.DiscoveryNodes;
 import org.elasticsearch.cluster.node.DiscoveryNodes;
 import org.elasticsearch.cluster.routing.BatchedRerouteService;
 import org.elasticsearch.cluster.routing.BatchedRerouteService;
+import org.elasticsearch.cluster.routing.RerouteService;
 import org.elasticsearch.cluster.routing.ShardRouting;
 import org.elasticsearch.cluster.routing.ShardRouting;
 import org.elasticsearch.cluster.routing.UnassignedInfo;
 import org.elasticsearch.cluster.routing.UnassignedInfo;
 import org.elasticsearch.cluster.routing.allocation.AllocationService;
 import org.elasticsearch.cluster.routing.allocation.AllocationService;
@@ -152,6 +152,7 @@ import org.elasticsearch.env.TestEnvironment;
 import org.elasticsearch.gateway.MetaStateService;
 import org.elasticsearch.gateway.MetaStateService;
 import org.elasticsearch.gateway.TransportNodesListGatewayStartedShards;
 import org.elasticsearch.gateway.TransportNodesListGatewayStartedShards;
 import org.elasticsearch.index.Index;
 import org.elasticsearch.index.Index;
+import org.elasticsearch.index.IndexingPressure;
 import org.elasticsearch.index.analysis.AnalysisRegistry;
 import org.elasticsearch.index.analysis.AnalysisRegistry;
 import org.elasticsearch.index.seqno.GlobalCheckpointSyncAction;
 import org.elasticsearch.index.seqno.GlobalCheckpointSyncAction;
 import org.elasticsearch.index.seqno.RetentionLeaseSyncer;
 import org.elasticsearch.index.seqno.RetentionLeaseSyncer;
@@ -1387,6 +1388,8 @@ public class SnapshotResiliencyTests extends ESTestCase {
 
 
             private final AllocationService allocationService;
             private final AllocationService allocationService;
 
 
+            private final RerouteService rerouteService;
+
             private final NodeClient client;
             private final NodeClient client;
 
 
             private final NodeEnvironment nodeEnv;
             private final NodeEnvironment nodeEnv;
@@ -1487,7 +1490,12 @@ public class SnapshotResiliencyTests extends ESTestCase {
                 final NamedXContentRegistry namedXContentRegistry = new NamedXContentRegistry(Collections.emptyList());
                 final NamedXContentRegistry namedXContentRegistry = new NamedXContentRegistry(Collections.emptyList());
                 final ScriptService scriptService = new ScriptService(settings, emptyMap(), emptyMap());
                 final ScriptService scriptService = new ScriptService(settings, emptyMap(), emptyMap());
                 client = new NodeClient(settings, threadPool);
                 client = new NodeClient(settings, threadPool);
-                allocationService = ESAllocationTestCase.createAllocationService(settings);
+                final SetOnce<RerouteService> rerouteServiceSetOnce = new SetOnce<>();
+                final SnapshotsInfoService snapshotsInfoService = new InternalSnapshotsInfoService(settings, clusterService,
+                    () -> repositoriesService, rerouteServiceSetOnce::get);
+                allocationService = ESAllocationTestCase.createAllocationService(settings, snapshotsInfoService);
+                rerouteService = new BatchedRerouteService(clusterService, allocationService::reroute);
+                rerouteServiceSetOnce.set(rerouteService);
                 final IndexScopedSettings indexScopedSettings =
                 final IndexScopedSettings indexScopedSettings =
                     new IndexScopedSettings(settings, IndexScopedSettings.BUILT_IN_INDEX_SETTINGS);
                     new IndexScopedSettings(settings, IndexScopedSettings.BUILT_IN_INDEX_SETTINGS);
                 final BigArrays bigArrays = new BigArrays(new PageCacheRecycler(settings), null, "test");
                 final BigArrays bigArrays = new BigArrays(new PageCacheRecycler(settings), null, "test");
@@ -1518,11 +1526,8 @@ public class SnapshotResiliencyTests extends ESTestCase {
                 final RecoverySettings recoverySettings = new RecoverySettings(settings, clusterSettings);
                 final RecoverySettings recoverySettings = new RecoverySettings(settings, clusterSettings);
                 snapshotShardsService =
                 snapshotShardsService =
                         new SnapshotShardsService(settings, clusterService, repositoriesService, transportService, indicesService);
                         new SnapshotShardsService(settings, clusterService, repositoriesService, transportService, indicesService);
-                final ShardStateAction shardStateAction = new ShardStateAction(
-                    clusterService, transportService, allocationService,
-                    new BatchedRerouteService(clusterService, allocationService::reroute),
-                    threadPool
-                );
+                final ShardStateAction shardStateAction =
+                    new ShardStateAction(clusterService, transportService, allocationService, rerouteService, threadPool);
                 nodeConnectionsService =
                 nodeConnectionsService =
                     new NodeConnectionsService(clusterService.getSettings(), threadPool, transportService);
                     new NodeConnectionsService(clusterService.getSettings(), threadPool, transportService);
                 @SuppressWarnings("rawtypes")
                 @SuppressWarnings("rawtypes")
@@ -1717,7 +1722,7 @@ public class SnapshotResiliencyTests extends ESTestCase {
                     hostsResolver -> nodes.values().stream().filter(n -> n.node.isMasterNode())
                     hostsResolver -> nodes.values().stream().filter(n -> n.node.isMasterNode())
                         .map(n -> n.node.getAddress()).collect(Collectors.toList()),
                         .map(n -> n.node.getAddress()).collect(Collectors.toList()),
                     clusterService.getClusterApplierService(), Collections.emptyList(), random(),
                     clusterService.getClusterApplierService(), Collections.emptyList(), random(),
-                    new BatchedRerouteService(clusterService, allocationService::reroute), ElectionStrategy.DEFAULT_INSTANCE,
+                    rerouteService, ElectionStrategy.DEFAULT_INSTANCE,
                     () -> new StatusInfo(HEALTHY, "healthy-info"));
                     () -> new StatusInfo(HEALTHY, "healthy-info"));
                 masterService.setClusterStatePublisher(coordinator);
                 masterService.setClusterStatePublisher(coordinator);
                 coordinator.start();
                 coordinator.start();

+ 35 - 6
test/framework/src/main/java/org/elasticsearch/cluster/ESAllocationTestCase.java

@@ -22,6 +22,7 @@ package org.elasticsearch.cluster;
 import org.elasticsearch.Version;
 import org.elasticsearch.Version;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.node.DiscoveryNodeRole;
 import org.elasticsearch.cluster.node.DiscoveryNodeRole;
+import org.elasticsearch.cluster.routing.RecoverySource;
 import org.elasticsearch.cluster.routing.RoutingNode;
 import org.elasticsearch.cluster.routing.RoutingNode;
 import org.elasticsearch.cluster.routing.ShardRouting;
 import org.elasticsearch.cluster.routing.ShardRouting;
 import org.elasticsearch.cluster.routing.UnassignedInfo;
 import org.elasticsearch.cluster.routing.UnassignedInfo;
@@ -34,9 +35,12 @@ import org.elasticsearch.cluster.routing.allocation.decider.AllocationDecider;
 import org.elasticsearch.cluster.routing.allocation.decider.AllocationDeciders;
 import org.elasticsearch.cluster.routing.allocation.decider.AllocationDeciders;
 import org.elasticsearch.cluster.routing.allocation.decider.Decision;
 import org.elasticsearch.cluster.routing.allocation.decider.Decision;
 import org.elasticsearch.cluster.routing.allocation.decider.SameShardAllocationDecider;
 import org.elasticsearch.cluster.routing.allocation.decider.SameShardAllocationDecider;
+import org.elasticsearch.common.collect.ImmutableOpenMap;
 import org.elasticsearch.common.settings.ClusterSettings;
 import org.elasticsearch.common.settings.ClusterSettings;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.gateway.GatewayAllocator;
 import org.elasticsearch.gateway.GatewayAllocator;
+import org.elasticsearch.snapshots.SnapshotShardSizeInfo;
+import org.elasticsearch.snapshots.SnapshotsInfoService;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.gateway.TestGatewayAllocator;
 import org.elasticsearch.test.gateway.TestGatewayAllocator;
 
 
@@ -55,6 +59,16 @@ public abstract class ESAllocationTestCase extends ESTestCase {
     private static final ClusterSettings EMPTY_CLUSTER_SETTINGS =
     private static final ClusterSettings EMPTY_CLUSTER_SETTINGS =
         new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
         new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
 
 
+    public static final SnapshotsInfoService SNAPSHOT_INFO_SERVICE_WITH_NO_SHARD_SIZES = () ->
+        new SnapshotShardSizeInfo(ImmutableOpenMap.of()) {
+            @Override
+            public Long getShardSize(ShardRouting shardRouting) {
+                assert shardRouting.recoverySource().getType() == RecoverySource.Type.SNAPSHOT :
+                    "Expecting a recovery source of type [SNAPSHOT] but got [" + shardRouting.recoverySource().getType() + ']';
+                throw new UnsupportedOperationException();
+            }
+    };
+
     public static MockAllocationService createAllocationService() {
     public static MockAllocationService createAllocationService() {
         return createAllocationService(Settings.Builder.EMPTY_SETTINGS);
         return createAllocationService(Settings.Builder.EMPTY_SETTINGS);
     }
     }
@@ -70,19 +84,33 @@ public abstract class ESAllocationTestCase extends ESTestCase {
     public static MockAllocationService createAllocationService(Settings settings, ClusterSettings clusterSettings, Random random) {
     public static MockAllocationService createAllocationService(Settings settings, ClusterSettings clusterSettings, Random random) {
         return new MockAllocationService(
         return new MockAllocationService(
                 randomAllocationDeciders(settings, clusterSettings, random),
                 randomAllocationDeciders(settings, clusterSettings, random),
-                new TestGatewayAllocator(), new BalancedShardsAllocator(settings), EmptyClusterInfoService.INSTANCE);
+                new TestGatewayAllocator(), new BalancedShardsAllocator(settings), EmptyClusterInfoService.INSTANCE,
+            SNAPSHOT_INFO_SERVICE_WITH_NO_SHARD_SIZES);
     }
     }
 
 
     public static MockAllocationService createAllocationService(Settings settings, ClusterInfoService clusterInfoService) {
     public static MockAllocationService createAllocationService(Settings settings, ClusterInfoService clusterInfoService) {
         return new MockAllocationService(
         return new MockAllocationService(
                 randomAllocationDeciders(settings, EMPTY_CLUSTER_SETTINGS, random()),
                 randomAllocationDeciders(settings, EMPTY_CLUSTER_SETTINGS, random()),
-            new TestGatewayAllocator(), new BalancedShardsAllocator(settings), clusterInfoService);
+            new TestGatewayAllocator(), new BalancedShardsAllocator(settings), clusterInfoService,
+            SNAPSHOT_INFO_SERVICE_WITH_NO_SHARD_SIZES);
     }
     }
 
 
     public static MockAllocationService createAllocationService(Settings settings, GatewayAllocator gatewayAllocator) {
     public static MockAllocationService createAllocationService(Settings settings, GatewayAllocator gatewayAllocator) {
+        return createAllocationService(settings, gatewayAllocator, SNAPSHOT_INFO_SERVICE_WITH_NO_SHARD_SIZES);
+    }
+
+    public static MockAllocationService createAllocationService(Settings settings, SnapshotsInfoService snapshotsInfoService) {
+        return createAllocationService(settings, new TestGatewayAllocator(), snapshotsInfoService);
+    }
+
+    public static MockAllocationService createAllocationService(
+        Settings settings,
+        GatewayAllocator gatewayAllocator,
+        SnapshotsInfoService snapshotsInfoService
+    ) {
         return new MockAllocationService(
         return new MockAllocationService(
-                randomAllocationDeciders(settings, EMPTY_CLUSTER_SETTINGS, random()),
-                gatewayAllocator, new BalancedShardsAllocator(settings), EmptyClusterInfoService.INSTANCE);
+            randomAllocationDeciders(settings, EMPTY_CLUSTER_SETTINGS, random()),
+            gatewayAllocator, new BalancedShardsAllocator(settings), EmptyClusterInfoService.INSTANCE, snapshotsInfoService);
     }
     }
 
 
     public static AllocationDeciders randomAllocationDeciders(Settings settings, ClusterSettings clusterSettings, Random random) {
     public static AllocationDeciders randomAllocationDeciders(Settings settings, ClusterSettings clusterSettings, Random random) {
@@ -230,8 +258,9 @@ public abstract class ESAllocationTestCase extends ESTestCase {
         private volatile long nanoTimeOverride = -1L;
         private volatile long nanoTimeOverride = -1L;
 
 
         public MockAllocationService(AllocationDeciders allocationDeciders, GatewayAllocator gatewayAllocator,
         public MockAllocationService(AllocationDeciders allocationDeciders, GatewayAllocator gatewayAllocator,
-                                     ShardsAllocator shardsAllocator, ClusterInfoService clusterInfoService) {
-            super(allocationDeciders, gatewayAllocator, shardsAllocator, clusterInfoService);
+                                     ShardsAllocator shardsAllocator, ClusterInfoService clusterInfoService,
+                                     SnapshotsInfoService snapshotsInfoService) {
+            super(allocationDeciders, gatewayAllocator, shardsAllocator, clusterInfoService, snapshotsInfoService);
         }
         }
 
 
         public void setNanoTimeOverride(long nanoTime) {
         public void setNanoTimeOverride(long nanoTime) {

+ 2 - 1
test/framework/src/main/java/org/elasticsearch/cluster/MockInternalClusterInfoService.java

@@ -47,7 +47,8 @@ public class MockInternalClusterInfoService extends InternalClusterInfoService {
     @Nullable // if no fakery should take place
     @Nullable // if no fakery should take place
     private volatile BiFunction<DiscoveryNode, FsInfo.Path, FsInfo.Path> diskUsageFunction;
     private volatile BiFunction<DiscoveryNode, FsInfo.Path, FsInfo.Path> diskUsageFunction;
 
 
-    public MockInternalClusterInfoService(Settings settings, ClusterService clusterService, ThreadPool threadPool, NodeClient client) {
+    public MockInternalClusterInfoService(Settings settings, ClusterService clusterService,
+                                          ThreadPool threadPool, NodeClient client) {
         super(settings, clusterService, threadPool, client);
         super(settings, clusterService, threadPool, client);
     }
     }
 
 

+ 21 - 2
x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/repository/CcrRepository.java

@@ -11,6 +11,7 @@ import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 import org.apache.logging.log4j.Logger;
 import org.apache.logging.log4j.message.ParameterizedMessage;
 import org.apache.logging.log4j.message.ParameterizedMessage;
 import org.apache.lucene.index.IndexCommit;
 import org.apache.lucene.index.IndexCommit;
+import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.ElasticsearchSecurityException;
 import org.elasticsearch.ElasticsearchSecurityException;
 import org.elasticsearch.ExceptionsHelper;
 import org.elasticsearch.ExceptionsHelper;
 import org.elasticsearch.Version;
 import org.elasticsearch.Version;
@@ -18,6 +19,8 @@ import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.admin.cluster.state.ClusterStateRequest;
 import org.elasticsearch.action.admin.cluster.state.ClusterStateRequest;
 import org.elasticsearch.action.admin.cluster.state.ClusterStateResponse;
 import org.elasticsearch.action.admin.cluster.state.ClusterStateResponse;
 import org.elasticsearch.action.admin.indices.mapping.put.PutMappingRequest;
 import org.elasticsearch.action.admin.indices.mapping.put.PutMappingRequest;
+import org.elasticsearch.action.admin.indices.stats.IndicesStatsResponse;
+import org.elasticsearch.action.admin.indices.stats.ShardStats;
 import org.elasticsearch.action.support.ListenerTimeouts;
 import org.elasticsearch.action.support.ListenerTimeouts;
 import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.action.support.ThreadedActionListener;
 import org.elasticsearch.action.support.ThreadedActionListener;
@@ -30,6 +33,7 @@ import org.elasticsearch.cluster.metadata.MappingMetadata;
 import org.elasticsearch.cluster.metadata.Metadata;
 import org.elasticsearch.cluster.metadata.Metadata;
 import org.elasticsearch.cluster.metadata.RepositoryMetadata;
 import org.elasticsearch.cluster.metadata.RepositoryMetadata;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.node.DiscoveryNode;
+import org.elasticsearch.cluster.routing.ShardRouting;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.UUIDs;
 import org.elasticsearch.common.UUIDs;
 import org.elasticsearch.common.collect.ImmutableOpenMap;
 import org.elasticsearch.common.collect.ImmutableOpenMap;
@@ -429,8 +433,23 @@ public class CcrRepository extends AbstractLifecycleComponent implements Reposit
     }
     }
 
 
     @Override
     @Override
-    public IndexShardSnapshotStatus getShardSnapshotStatus(SnapshotId snapshotId, IndexId indexId, ShardId leaderShardId) {
-        throw new UnsupportedOperationException("Unsupported for repository of type: " + TYPE);
+    public IndexShardSnapshotStatus getShardSnapshotStatus(SnapshotId snapshotId, IndexId index, ShardId shardId) {
+        assert SNAPSHOT_ID.equals(snapshotId) : "RemoteClusterRepository only supports " + SNAPSHOT_ID + " as the SnapshotId";
+        final String leaderIndex = index.getName();
+        final IndicesStatsResponse response = getRemoteClusterClient().admin().indices().prepareStats(leaderIndex)
+            .clear().setStore(true)
+            .get(ccrSettings.getRecoveryActionTimeout());
+        for (ShardStats shardStats : response.getIndex(leaderIndex).getShards()) {
+            final ShardRouting shardRouting = shardStats.getShardRouting();
+            if (shardRouting.shardId().id() == shardId.getId()
+                && shardRouting.primary()
+                && shardRouting.active()) {
+                // we only care about the shard size here for shard allocation, populate the rest with dummy values
+                final long totalSize = shardStats.getStats().getStore().getSizeInBytes();
+                return IndexShardSnapshotStatus.newDone(0L, 0L, 1, 1, totalSize, totalSize, "");
+            }
+        }
+        throw new ElasticsearchException("Could not get shard stats for primary of index " + leaderIndex + " on leader cluster");
     }
     }
 
 
     @Override
     @Override

+ 2 - 1
x-pack/plugin/ccr/src/test/java/org/elasticsearch/xpack/ccr/allocation/CcrPrimaryFollowerAllocationDeciderTests.java

@@ -51,6 +51,7 @@ import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.repositories.IndexId;
 import org.elasticsearch.repositories.IndexId;
 import org.elasticsearch.snapshots.Snapshot;
 import org.elasticsearch.snapshots.Snapshot;
 import org.elasticsearch.snapshots.SnapshotId;
 import org.elasticsearch.snapshots.SnapshotId;
+import org.elasticsearch.snapshots.SnapshotShardSizeInfo;
 import org.elasticsearch.xpack.ccr.CcrSettings;
 import org.elasticsearch.xpack.ccr.CcrSettings;
 
 
 import java.util.ArrayList;
 import java.util.ArrayList;
@@ -181,7 +182,7 @@ public class CcrPrimaryFollowerAllocationDeciderTests extends ESAllocationTestCa
     static Decision executeAllocation(ClusterState clusterState, ShardRouting shardRouting, DiscoveryNode node) {
     static Decision executeAllocation(ClusterState clusterState, ShardRouting shardRouting, DiscoveryNode node) {
         final AllocationDecider decider = new CcrPrimaryFollowerAllocationDecider();
         final AllocationDecider decider = new CcrPrimaryFollowerAllocationDecider();
         final RoutingAllocation routingAllocation = new RoutingAllocation(new AllocationDeciders(List.of(decider)),
         final RoutingAllocation routingAllocation = new RoutingAllocation(new AllocationDeciders(List.of(decider)),
-            new RoutingNodes(clusterState), clusterState, ClusterInfo.EMPTY, System.nanoTime());
+            new RoutingNodes(clusterState), clusterState, ClusterInfo.EMPTY, SnapshotShardSizeInfo.EMPTY, System.nanoTime());
         routingAllocation.debugDecision(true);
         routingAllocation.debugDecision(true);
         return decider.canAllocate(shardRouting, new RoutingNode(node.getId(), node), routingAllocation);
         return decider.canAllocate(shardRouting, new RoutingNode(node.getId(), node), routingAllocation);
     }
     }

+ 1 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/AllocationRoutedStep.java

@@ -71,7 +71,7 @@ public class AllocationRoutedStep extends ClusterStateWaitStep {
         // All the allocation attributes are already set so just need to check
         // All the allocation attributes are already set so just need to check
         // if the allocation has happened
         // if the allocation has happened
         RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, clusterState.getRoutingNodes(), clusterState, null,
         RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, clusterState.getRoutingNodes(), clusterState, null,
-            System.nanoTime());
+                null, System.nanoTime());
 
 
         int allocationPendingAllShards = 0;
         int allocationPendingAllShards = 0;
 
 

+ 1 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/SetSingleNodeAllocateStep.java

@@ -64,7 +64,7 @@ public class SetSingleNodeAllocateStep extends AsyncActionStep {
     public void performAction(IndexMetadata indexMetadata, ClusterState clusterState, ClusterStateObserver observer, Listener listener) {
     public void performAction(IndexMetadata indexMetadata, ClusterState clusterState, ClusterStateObserver observer, Listener listener) {
         final RoutingNodes routingNodes = clusterState.getRoutingNodes();
         final RoutingNodes routingNodes = clusterState.getRoutingNodes();
         RoutingAllocation allocation = new RoutingAllocation(ALLOCATION_DECIDERS, routingNodes, clusterState, null,
         RoutingAllocation allocation = new RoutingAllocation(ALLOCATION_DECIDERS, routingNodes, clusterState, null,
-                System.nanoTime());
+                null, System.nanoTime());
         List<String> validNodeIds = new ArrayList<>();
         List<String> validNodeIds = new ArrayList<>();
         String indexName = indexMetadata.getIndex().getName();
         String indexName = indexMetadata.getIndex().getName();
         final Map<ShardId, List<ShardRouting>> routingsByShardId = clusterState.getRoutingTable()
         final Map<ShardId, List<ShardRouting>> routingsByShardId = clusterState.getRoutingTable()

+ 15 - 17
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/cluster/routing/allocation/DataTierAllocationDeciderTests.java

@@ -30,6 +30,7 @@ import org.elasticsearch.common.settings.ClusterSettings;
 import org.elasticsearch.common.settings.Setting;
 import org.elasticsearch.common.settings.Setting;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.index.shard.ShardId;
+import org.elasticsearch.snapshots.EmptySnapshotsInfoService;
 import org.elasticsearch.test.gateway.TestGatewayAllocator;
 import org.elasticsearch.test.gateway.TestGatewayAllocator;
 import org.elasticsearch.xpack.core.DataTier;
 import org.elasticsearch.xpack.core.DataTier;
 
 
@@ -58,7 +59,8 @@ public class DataTierAllocationDeciderTests extends ESAllocationTestCase {
             new SameShardAllocationDecider(Settings.EMPTY, clusterSettings),
             new SameShardAllocationDecider(Settings.EMPTY, clusterSettings),
             new ReplicaAfterPrimaryActiveAllocationDecider()));
             new ReplicaAfterPrimaryActiveAllocationDecider()));
     private final AllocationService service = new AllocationService(allocationDeciders,
     private final AllocationService service = new AllocationService(allocationDeciders,
-        new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE);
+        new TestGatewayAllocator(), new BalancedShardsAllocator(Settings.EMPTY), EmptyClusterInfoService.INSTANCE,
+        EmptySnapshotsInfoService.INSTANCE);
 
 
     private final ShardRouting shard = ShardRouting.newUnassigned(new ShardId("myindex", "myindex", 0), true,
     private final ShardRouting shard = ShardRouting.newUnassigned(new ShardId("myindex", "myindex", 0), true,
         RecoverySource.EmptyStoreRecoverySource.INSTANCE, new UnassignedInfo(UnassignedInfo.Reason.INDEX_CREATED, "index created"));
         RecoverySource.EmptyStoreRecoverySource.INSTANCE, new UnassignedInfo(UnassignedInfo.Reason.INDEX_CREATED, "index created"));
@@ -74,7 +76,7 @@ public class DataTierAllocationDeciderTests extends ESAllocationTestCase {
     public void testClusterRequires() {
     public void testClusterRequires() {
         ClusterState state = prepareState(service.reroute(ClusterState.EMPTY_STATE, "initial state"));
         ClusterState state = prepareState(service.reroute(ClusterState.EMPTY_STATE, "initial state"));
         RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, state.getRoutingNodes(), state,
         RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, state.getRoutingNodes(), state,
-            null, 0);
+            null, null, 0);
         allocation.debugDecision(true);
         allocation.debugDecision(true);
         clusterSettings.applySettings(Settings.builder()
         clusterSettings.applySettings(Settings.builder()
             .put(DataTierAllocationDecider.CLUSTER_ROUTING_REQUIRE, "data_hot")
             .put(DataTierAllocationDecider.CLUSTER_ROUTING_REQUIRE, "data_hot")
@@ -108,7 +110,7 @@ public class DataTierAllocationDeciderTests extends ESAllocationTestCase {
     public void testClusterIncludes() {
     public void testClusterIncludes() {
         ClusterState state = prepareState(service.reroute(ClusterState.EMPTY_STATE, "initial state"));
         ClusterState state = prepareState(service.reroute(ClusterState.EMPTY_STATE, "initial state"));
         RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, state.getRoutingNodes(), state,
         RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, state.getRoutingNodes(), state,
-            null, 0);
+            null, null, 0);
         allocation.debugDecision(true);
         allocation.debugDecision(true);
         clusterSettings.applySettings(Settings.builder()
         clusterSettings.applySettings(Settings.builder()
             .put(DataTierAllocationDecider.CLUSTER_ROUTING_INCLUDE, "data_warm,data_cold")
             .put(DataTierAllocationDecider.CLUSTER_ROUTING_INCLUDE, "data_warm,data_cold")
@@ -143,7 +145,7 @@ public class DataTierAllocationDeciderTests extends ESAllocationTestCase {
     public void testClusterExcludes() {
     public void testClusterExcludes() {
         ClusterState state = prepareState(service.reroute(ClusterState.EMPTY_STATE, "initial state"));
         ClusterState state = prepareState(service.reroute(ClusterState.EMPTY_STATE, "initial state"));
         RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, state.getRoutingNodes(), state,
         RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, state.getRoutingNodes(), state,
-            null, 0);
+            null, null, 0);
         allocation.debugDecision(true);
         allocation.debugDecision(true);
         clusterSettings.applySettings(Settings.builder()
         clusterSettings.applySettings(Settings.builder()
             .put(DataTierAllocationDecider.CLUSTER_ROUTING_EXCLUDE, "data_warm")
             .put(DataTierAllocationDecider.CLUSTER_ROUTING_EXCLUDE, "data_warm")
@@ -181,7 +183,7 @@ public class DataTierAllocationDeciderTests extends ESAllocationTestCase {
                 .put(DataTierAllocationDecider.INDEX_ROUTING_REQUIRE, "data_hot")
                 .put(DataTierAllocationDecider.INDEX_ROUTING_REQUIRE, "data_hot")
                 .build());
                 .build());
         RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, state.getRoutingNodes(), state,
         RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, state.getRoutingNodes(), state,
-            null, 0);
+            null, null, 0);
         allocation.debugDecision(true);
         allocation.debugDecision(true);
         Decision d;
         Decision d;
         RoutingNode node;
         RoutingNode node;
@@ -213,7 +215,7 @@ public class DataTierAllocationDeciderTests extends ESAllocationTestCase {
                 .put(DataTierAllocationDecider.INDEX_ROUTING_INCLUDE, "data_warm,data_cold")
                 .put(DataTierAllocationDecider.INDEX_ROUTING_INCLUDE, "data_warm,data_cold")
                 .build());
                 .build());
         RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, state.getRoutingNodes(), state,
         RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, state.getRoutingNodes(), state,
-            null, 0);
+            null, null, 0);
         allocation.debugDecision(true);
         allocation.debugDecision(true);
         Decision d;
         Decision d;
         RoutingNode node;
         RoutingNode node;
@@ -247,7 +249,7 @@ public class DataTierAllocationDeciderTests extends ESAllocationTestCase {
                 .put(DataTierAllocationDecider.INDEX_ROUTING_EXCLUDE, "data_warm,data_cold")
                 .put(DataTierAllocationDecider.INDEX_ROUTING_EXCLUDE, "data_warm,data_cold")
                 .build());
                 .build());
         RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, state.getRoutingNodes(), state,
         RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, state.getRoutingNodes(), state,
-            null, 0);
+            null, null,0);
         allocation.debugDecision(true);
         allocation.debugDecision(true);
         Decision d;
         Decision d;
         RoutingNode node;
         RoutingNode node;
@@ -292,8 +294,7 @@ public class DataTierAllocationDeciderTests extends ESAllocationTestCase {
                         .build()))
                         .build()))
                 .build())
                 .build())
             .build();
             .build();
-        RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, state.getRoutingNodes(), state,
-            null, 0);
+        RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, state.getRoutingNodes(), state, null, null, 0);
         allocation.debugDecision(true);
         allocation.debugDecision(true);
         Decision d;
         Decision d;
         RoutingNode node;
         RoutingNode node;
@@ -328,7 +329,7 @@ public class DataTierAllocationDeciderTests extends ESAllocationTestCase {
                         .build()))
                         .build()))
                 .build())
                 .build())
             .build();
             .build();
-        allocation = new RoutingAllocation(allocationDeciders, state.getRoutingNodes(), state, null, 0);
+        allocation = new RoutingAllocation(allocationDeciders, state.getRoutingNodes(), state, null, null, 0);
         allocation.debugDecision(true);
         allocation.debugDecision(true);
 
 
         for (DiscoveryNode n : Arrays.asList(HOT_NODE, WARM_NODE)) {
         for (DiscoveryNode n : Arrays.asList(HOT_NODE, WARM_NODE)) {
@@ -376,8 +377,7 @@ public class DataTierAllocationDeciderTests extends ESAllocationTestCase {
                         .build()))
                         .build()))
                 .build())
                 .build())
             .build();
             .build();
-        RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, state.getRoutingNodes(), state,
-            null, 0);
+        RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, state.getRoutingNodes(), state, null, null, 0);
         allocation.debugDecision(true);
         allocation.debugDecision(true);
         Decision d;
         Decision d;
         RoutingNode node;
         RoutingNode node;
@@ -439,8 +439,7 @@ public class DataTierAllocationDeciderTests extends ESAllocationTestCase {
                         .build()))
                         .build()))
                 .build())
                 .build())
             .build();
             .build();
-        RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, state.getRoutingNodes(), state,
-            null, 0);
+        RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, state.getRoutingNodes(), state, null, null, 0);
         allocation.debugDecision(true);
         allocation.debugDecision(true);
         Decision d;
         Decision d;
         RoutingNode node;
         RoutingNode node;
@@ -502,8 +501,7 @@ public class DataTierAllocationDeciderTests extends ESAllocationTestCase {
                         .build()))
                         .build()))
                 .build())
                 .build())
             .build();
             .build();
-        RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, state.getRoutingNodes(), state,
-            null, 0);
+        RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, state.getRoutingNodes(), state, null, null, 0);
         allocation.debugDecision(true);
         allocation.debugDecision(true);
         Decision d;
         Decision d;
         RoutingNode node;
         RoutingNode node;
@@ -553,7 +551,7 @@ public class DataTierAllocationDeciderTests extends ESAllocationTestCase {
                 .put(DataTierAllocationDecider.INDEX_ROUTING_INCLUDE, "data_warm,data_cold")
                 .put(DataTierAllocationDecider.INDEX_ROUTING_INCLUDE, "data_warm,data_cold")
                 .build());
                 .build());
         RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, state.getRoutingNodes(), state,
         RoutingAllocation allocation = new RoutingAllocation(allocationDeciders, state.getRoutingNodes(), state,
-            null, 0);
+            null, null,0);
         clusterSettings.applySettings(Settings.builder()
         clusterSettings.applySettings(Settings.builder()
             .put(DataTierAllocationDecider.CLUSTER_ROUTING_EXCLUDE, "data_cold")
             .put(DataTierAllocationDecider.CLUSTER_ROUTING_EXCLUDE, "data_cold")
             .build());
             .build());

+ 6 - 0
x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/SearchableSnapshotAllocator.java

@@ -8,6 +8,7 @@ package org.elasticsearch.xpack.searchablesnapshots;
 import org.elasticsearch.Version;
 import org.elasticsearch.Version;
 import org.elasticsearch.cluster.routing.RecoverySource;
 import org.elasticsearch.cluster.routing.RecoverySource;
 import org.elasticsearch.cluster.routing.ShardRouting;
 import org.elasticsearch.cluster.routing.ShardRouting;
+import org.elasticsearch.cluster.routing.UnassignedInfo;
 import org.elasticsearch.cluster.routing.allocation.AllocateUnassignedDecision;
 import org.elasticsearch.cluster.routing.allocation.AllocateUnassignedDecision;
 import org.elasticsearch.cluster.routing.allocation.AllocationDecision;
 import org.elasticsearch.cluster.routing.allocation.AllocationDecision;
 import org.elasticsearch.cluster.routing.allocation.ExistingShardsAllocator;
 import org.elasticsearch.cluster.routing.allocation.ExistingShardsAllocator;
@@ -84,6 +85,11 @@ public class SearchableSnapshotAllocator implements ExistingShardsAllocator {
             allocation.metadata().getIndexSafe(shardRouting.index()).getSettings()
             allocation.metadata().getIndexSafe(shardRouting.index()).getSettings()
         ).equals(ALLOCATOR_NAME);
         ).equals(ALLOCATOR_NAME);
 
 
+        if (shardRouting.recoverySource().getType() == RecoverySource.Type.SNAPSHOT
+            && allocation.snapshotShardSizeInfo().getShardSize(shardRouting) == null) {
+            return AllocateUnassignedDecision.no(UnassignedInfo.AllocationStatus.FETCHING_SHARD_DATA, null);
+        }
+
         // let BalancedShardsAllocator take care of allocating this shard
         // let BalancedShardsAllocator take care of allocating this shard
         // TODO: once we have persistent cache, choose a node that has existing data
         // TODO: once we have persistent cache, choose a node that has existing data
         return AllocateUnassignedDecision.NOT_TAKEN;
         return AllocateUnassignedDecision.NOT_TAKEN;