Browse Source

Add recovery state tracking for Searchable Snapshots (#60505)

This pull request adds recovery state tracking for Searchable Snapshots.

In order to track recoveries for searchable snapshot backed indices, this pull
request adds a new type of RecoveryState.
This newRecoveryState instance is able to deal with the
small differences that arise during Searchable snapshots recoveries.

Those differences can be summarized as follows:

-  The Directory implementation that's provided by SearchableSnapshots mark the
    snapshot files as reused during recovery. In order to keep track of the
    recovery process as the cache is pre-warmed, those files shouldn't be marked
    as reused.
 - Once the shard is created, the cache starts its pre-warming phase, meaning that
    we should keep track of those downloads during that process and tie the recovery
    to this pre-warming phase. The shard is considered recovered once this pre-warming
    phase has finished.
Francisco Fernández Castaño 5 years ago
parent
commit
9b71cdea7e
17 changed files with 853 additions and 132 deletions
  1. 2 2
      server/src/internalClusterTest/java/org/elasticsearch/gateway/RecoveryFromGatewayIT.java
  2. 45 32
      server/src/main/java/org/elasticsearch/indices/recovery/RecoveryState.java
  3. 1 1
      server/src/test/java/org/elasticsearch/index/shard/IndexShardTests.java
  4. 10 10
      server/src/test/java/org/elasticsearch/indices/recovery/RecoveryTargetTests.java
  5. 2 2
      server/src/test/java/org/elasticsearch/repositories/fs/FsRepositoryTests.java
  6. 2 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/searchablesnapshots/SearchableSnapshotsConstants.java
  7. 73 46
      x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/index/store/SearchableSnapshotDirectory.java
  8. 126 0
      x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/indices/recovery/SearchableSnapshotRecoveryState.java
  9. 1 1
      x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/SearchableSnapshotIndexEventListener.java
  10. 12 1
      x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/SearchableSnapshots.java
  11. 2 0
      x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/action/TransportMountSearchableSnapshotAction.java
  12. 16 1
      x-pack/plugin/searchable-snapshots/src/test/java/org/elasticsearch/index/store/SearchableSnapshotDirectoryStatsTests.java
  13. 167 3
      x-pack/plugin/searchable-snapshots/src/test/java/org/elasticsearch/index/store/SearchableSnapshotDirectoryTests.java
  14. 29 2
      x-pack/plugin/searchable-snapshots/src/test/java/org/elasticsearch/index/store/cache/CachedBlobContainerIndexInputTests.java
  15. 141 0
      x-pack/plugin/searchable-snapshots/src/test/java/org/elasticsearch/indices/recovery/SearchableSnapshotsRecoveryStateTests.java
  16. 159 0
      x-pack/plugin/searchable-snapshots/src/test/java/org/elasticsearch/xpack/searchablesnapshots/SearchableSnapshotRecoveryStateIntegrationTests.java
  17. 65 31
      x-pack/plugin/searchable-snapshots/src/test/java/org/elasticsearch/xpack/searchablesnapshots/SearchableSnapshotsIntegTests.java

+ 2 - 2
server/src/internalClusterTest/java/org/elasticsearch/gateway/RecoveryFromGatewayIT.java

@@ -451,7 +451,7 @@ public class RecoveryFromGatewayIT extends ESIntegTestCase {
         final Set<String> files = new HashSet<>();
         for (final RecoveryState recoveryState : initialRecoveryReponse.shardRecoveryStates().get("test")) {
             if (recoveryState.getTargetNode().getName().equals(replicaNode)) {
-                for (final RecoveryState.File file : recoveryState.getIndex().fileDetails()) {
+                for (final RecoveryState.FileDetail file : recoveryState.getIndex().fileDetails()) {
                     files.add(file.name());
                 }
                 break;
@@ -494,7 +494,7 @@ public class RecoveryFromGatewayIT extends ESIntegTestCase {
             long reused = 0;
             int filesRecovered = 0;
             int filesReused = 0;
-            for (final RecoveryState.File file : recoveryState.getIndex().fileDetails()) {
+            for (final RecoveryState.FileDetail file : recoveryState.getIndex().fileDetails()) {
                 if (files.contains(file.name()) == false) {
                     recovered += file.length();
                     filesRecovered++;

+ 45 - 32
server/src/main/java/org/elasticsearch/indices/recovery/RecoveryState.java

@@ -116,7 +116,16 @@ public class RecoveryState implements ToXContentFragment, Writeable {
     private DiscoveryNode targetNode;
     private boolean primary;
 
-    public RecoveryState(ShardRouting shardRouting, DiscoveryNode targetNode, @Nullable DiscoveryNode sourceNode) {
+    public RecoveryState(ShardRouting shardRouting,
+                         DiscoveryNode targetNode,
+                         @Nullable DiscoveryNode sourceNode) {
+        this(shardRouting, targetNode, sourceNode, new Index());
+    }
+
+    public RecoveryState(ShardRouting shardRouting,
+                         DiscoveryNode targetNode,
+                         @Nullable DiscoveryNode sourceNode,
+                         Index index) {
         assert shardRouting.initializing() : "only allow initializing shard routing to be recovered: " + shardRouting;
         RecoverySource recoverySource = shardRouting.recoverySource();
         assert (recoverySource.getType() == RecoverySource.Type.PEER) == (sourceNode != null) :
@@ -127,7 +136,7 @@ public class RecoveryState implements ToXContentFragment, Writeable {
         this.sourceNode = sourceNode;
         this.targetNode = targetNode;
         stage = Stage.INIT;
-        index = new Index();
+        this.index = index;
         translog = new Translog();
         verifyIndex = new VerifyIndex();
         timer = new Timer();
@@ -170,7 +179,7 @@ public class RecoveryState implements ToXContentFragment, Writeable {
     }
 
 
-    private void validateAndSetStage(Stage expected, Stage next) {
+    protected void validateAndSetStage(Stage expected, Stage next) {
         if (stage != expected) {
             assert false : "can't move recovery to stage [" + next + "]. current stage: [" + stage + "] (expected [" + expected + "])";
             throw new IllegalStateException("can't move recovery to stage [" + next + "]. current stage: ["
@@ -598,20 +607,20 @@ public class RecoveryState implements ToXContentFragment, Writeable {
         }
     }
 
-    public static class File implements ToXContentObject, Writeable {
+    public static class FileDetail implements ToXContentObject, Writeable {
         private String name;
         private long length;
         private long recovered;
         private boolean reused;
 
-        public File(String name, long length, boolean reused) {
+        public FileDetail(String name, long length, boolean reused) {
             assert name != null;
             this.name = name;
             this.length = length;
             this.reused = reused;
         }
 
-        public File(StreamInput in) throws IOException {
+        public FileDetail(StreamInput in) throws IOException {
             name = in.readString();
             length = in.readVLong();
             recovered = in.readVLong();
@@ -677,8 +686,8 @@ public class RecoveryState implements ToXContentFragment, Writeable {
 
         @Override
         public boolean equals(Object obj) {
-            if (obj instanceof File) {
-                File other = (File) obj;
+            if (obj instanceof FileDetail) {
+                FileDetail other = (FileDetail) obj;
                 return name.equals(other.name) && length == other.length() && reused == other.reused() && recovered == other.recovered();
             }
             return false;
@@ -700,16 +709,16 @@ public class RecoveryState implements ToXContentFragment, Writeable {
     }
 
     public static class RecoveryFilesDetails implements ToXContentFragment, Writeable {
-        private final Map<String, File> fileDetails = new HashMap<>();
-        private boolean complete;
+        protected final Map<String, FileDetail> fileDetails = new HashMap<>();
+        protected boolean complete;
 
-        RecoveryFilesDetails() {
+        public RecoveryFilesDetails() {
         }
 
         RecoveryFilesDetails(StreamInput in) throws IOException {
             int size = in.readVInt();
             for (int i = 0; i < size; i++) {
-                File file = new File(in);
+                FileDetail file = new FileDetail(in);
                 fileDetails.put(file.name, file);
             }
             if (in.getVersion().onOrAfter(StoreStats.RESERVED_BYTES_VERSION)) {
@@ -725,9 +734,9 @@ public class RecoveryState implements ToXContentFragment, Writeable {
 
         @Override
         public void writeTo(StreamOutput out) throws IOException {
-            final File[] files = values().toArray(new File[0]);
+            final FileDetail[] files = values().toArray(new FileDetail[0]);
             out.writeVInt(files.length);
-            for (File file : files) {
+            for (FileDetail file : files) {
                 file.writeTo(out);
             }
             if (out.getVersion().onOrAfter(StoreStats.RESERVED_BYTES_VERSION)) {
@@ -739,7 +748,7 @@ public class RecoveryState implements ToXContentFragment, Writeable {
         public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
             if (params.paramAsBoolean("detailed", false)) {
                 builder.startArray(Fields.DETAILS);
-                for (File file : values()) {
+                for (FileDetail file : values()) {
                     file.toXContent(builder, params);
                 }
                 builder.endArray();
@@ -750,17 +759,17 @@ public class RecoveryState implements ToXContentFragment, Writeable {
 
         public void addFileDetails(String name, long length, boolean reused) {
             assert complete == false : "addFileDetail for [" + name + "] when file details are already complete";
-            File existing = fileDetails.put(name, new File(name, length, reused));
+            FileDetail existing = fileDetails.put(name, new FileDetail(name, length, reused));
             assert existing == null : "file [" + name + "] is already reported";
         }
 
         public void addRecoveredBytesToFile(String name, long bytes) {
-            File file = fileDetails.get(name);
+            FileDetail file = fileDetails.get(name);
             assert file != null : "file [" + name + "] hasn't been reported";
             file.addRecoveredBytes(bytes);
         }
 
-        public File get(String name) {
+        public FileDetail get(String name) {
             return fileDetails.get(name);
         }
 
@@ -781,7 +790,7 @@ public class RecoveryState implements ToXContentFragment, Writeable {
             complete = false;
         }
 
-        public Collection<File> values() {
+        public Collection<FileDetail> values() {
             return fileDetails.values();
         }
 
@@ -799,7 +808,11 @@ public class RecoveryState implements ToXContentFragment, Writeable {
         private long targetThrottleTimeInNanos = UNKNOWN;
 
         public Index() {
-            this.fileDetails = new RecoveryFilesDetails();
+            this(new RecoveryFilesDetails());
+        }
+
+        public Index(RecoveryFilesDetails recoveryFilesDetails) {
+            this.fileDetails = recoveryFilesDetails;
         }
 
         public Index(StreamInput in) throws IOException {
@@ -817,7 +830,7 @@ public class RecoveryState implements ToXContentFragment, Writeable {
             out.writeLong(targetThrottleTimeInNanos);
         }
 
-        public synchronized List<File> fileDetails() {
+        public synchronized List<FileDetail> fileDetails() {
             return List.copyOf(fileDetails.values());
         }
 
@@ -876,7 +889,7 @@ public class RecoveryState implements ToXContentFragment, Writeable {
          */
         public synchronized int totalRecoverFiles() {
             int total = 0;
-            for (File file : fileDetails.values()) {
+            for (FileDetail file : fileDetails.values()) {
                 if (file.reused() == false) {
                     total++;
                 }
@@ -889,7 +902,7 @@ public class RecoveryState implements ToXContentFragment, Writeable {
          */
         public synchronized int recoveredFileCount() {
             int count = 0;
-            for (File file : fileDetails.values()) {
+            for (FileDetail file : fileDetails.values()) {
                 if (file.fullyRecovered()) {
                     count++;
                 }
@@ -903,7 +916,7 @@ public class RecoveryState implements ToXContentFragment, Writeable {
         public synchronized float recoveredFilesPercent() {
             int total = 0;
             int recovered = 0;
-            for (File file : fileDetails.values()) {
+            for (FileDetail file : fileDetails.values()) {
                 if (file.reused() == false) {
                     total++;
                     if (file.fullyRecovered()) {
@@ -927,7 +940,7 @@ public class RecoveryState implements ToXContentFragment, Writeable {
          */
         public synchronized long totalBytes() {
             long total = 0;
-            for (File file : fileDetails.values()) {
+            for (FileDetail file : fileDetails.values()) {
                 total += file.length();
             }
             return total;
@@ -938,7 +951,7 @@ public class RecoveryState implements ToXContentFragment, Writeable {
          */
         public synchronized long recoveredBytes() {
             long recovered = 0;
-            for (File file : fileDetails.values()) {
+            for (FileDetail file : fileDetails.values()) {
                 recovered += file.recovered();
             }
             return recovered;
@@ -949,7 +962,7 @@ public class RecoveryState implements ToXContentFragment, Writeable {
          */
         public synchronized long totalRecoverBytes() {
             long total = 0;
-            for (File file : fileDetails.values()) {
+            for (FileDetail file : fileDetails.values()) {
                 if (file.reused() == false) {
                     total += file.length();
                 }
@@ -966,7 +979,7 @@ public class RecoveryState implements ToXContentFragment, Writeable {
                 return -1L;
             }
             long total = 0L;
-            for (File file : fileDetails.values()) {
+            for (FileDetail file : fileDetails.values()) {
                 if (file.reused() == false) {
                     total += file.length() - file.recovered();
                 }
@@ -980,7 +993,7 @@ public class RecoveryState implements ToXContentFragment, Writeable {
         public synchronized float recoveredBytesPercent() {
             long total = 0;
             long recovered = 0;
-            for (File file : fileDetails.values()) {
+            for (FileDetail file : fileDetails.values()) {
                 if (file.reused() == false) {
                     total += file.length();
                     recovered += file.recovered();
@@ -999,7 +1012,7 @@ public class RecoveryState implements ToXContentFragment, Writeable {
 
         public synchronized int reusedFileCount() {
             int reused = 0;
-            for (File file : fileDetails.values()) {
+            for (FileDetail file : fileDetails.values()) {
                 if (file.reused()) {
                     reused++;
                 }
@@ -1009,7 +1022,7 @@ public class RecoveryState implements ToXContentFragment, Writeable {
 
         public synchronized long reusedBytes() {
             long reused = 0;
-            for (File file : fileDetails.values()) {
+            for (FileDetail file : fileDetails.values()) {
                 if (file.reused()) {
                     reused += file.length();
                 }
@@ -1053,7 +1066,7 @@ public class RecoveryState implements ToXContentFragment, Writeable {
             }
         }
 
-        public synchronized File getFileDetails(String dest) {
+        public synchronized FileDetail getFileDetails(String dest) {
             return fileDetails.get(dest);
         }
     }

+ 1 - 1
server/src/test/java/org/elasticsearch/index/shard/IndexShardTests.java

@@ -2831,7 +2831,7 @@ public class IndexShardTests extends IndexShardTestCase {
             RecoveryState recoveryState = targetShard.recoveryState();
             assertEquals(RecoveryState.Stage.DONE, recoveryState.getStage());
             assertTrue(recoveryState.getIndex().fileDetails().size() > 0);
-            for (RecoveryState.File file : recoveryState.getIndex().fileDetails()) {
+            for (RecoveryState.FileDetail file : recoveryState.getIndex().fileDetails()) {
                 if (file.reused()) {
                     assertEquals(file.recovered(), 0);
                 } else {

+ 10 - 10
server/src/test/java/org/elasticsearch/indices/recovery/RecoveryTargetTests.java

@@ -28,7 +28,7 @@ import org.elasticsearch.common.io.stream.BytesStreamOutput;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.index.shard.ShardId;
-import org.elasticsearch.indices.recovery.RecoveryState.File;
+import org.elasticsearch.indices.recovery.RecoveryState.FileDetail;
 import org.elasticsearch.indices.recovery.RecoveryState.Index;
 import org.elasticsearch.indices.recovery.RecoveryState.Stage;
 import org.elasticsearch.indices.recovery.RecoveryState.Timer;
@@ -180,8 +180,8 @@ public class RecoveryTargetTests extends ESTestCase {
     }
 
     public void testIndex() throws Throwable {
-        File[] files = new File[randomIntBetween(1, 20)];
-        ArrayList<File> filesToRecover = new ArrayList<>();
+        FileDetail[] files = new FileDetail[randomIntBetween(1, 20)];
+        ArrayList<FileDetail> filesToRecover = new ArrayList<>();
         long totalFileBytes = 0;
         long totalReusedBytes = 0;
         int totalReused = 0;
@@ -189,7 +189,7 @@ public class RecoveryTargetTests extends ESTestCase {
             final int fileLength = randomIntBetween(1, 1000);
             final boolean reused = randomBoolean();
             totalFileBytes += fileLength;
-            files[i] = new RecoveryState.File("f_" + i, fileLength, reused);
+            files[i] = new FileDetail("f_" + i, fileLength, reused);
             if (reused) {
                 totalReused++;
                 totalReusedBytes += fileLength;
@@ -230,7 +230,7 @@ public class RecoveryTargetTests extends ESTestCase {
         assertThat(index.targetThrottling().nanos(), equalTo(Index.UNKNOWN));
 
         index.start();
-        for (File file : files) {
+        for (FileDetail file : files) {
             index.addFileDetail(file.name(), file.length(), file.reused());
         }
 
@@ -271,7 +271,7 @@ public class RecoveryTargetTests extends ESTestCase {
         long sourceThrottling = Index.UNKNOWN;
         long targetThrottling = Index.UNKNOWN;
         while (bytesToRecover > 0) {
-            File file = randomFrom(filesToRecover);
+            FileDetail file = randomFrom(filesToRecover);
             final long toRecover = Math.min(bytesToRecover, randomIntBetween(1, (int) (file.length() - file.recovered())));
             final long throttledOnSource = rarely() ? randomIntBetween(10, 200) : 0;
             index.addSourceThrottling(throttledOnSource);
@@ -534,14 +534,14 @@ public class RecoveryTargetTests extends ESTestCase {
     }
 
     public void testFileHashCodeAndEquals() {
-        File f = new File("foo", randomIntBetween(0, 100), randomBoolean());
-        File anotherFile = new File(f.name(), f.length(), f.reused());
+        FileDetail f = new FileDetail("foo", randomIntBetween(0, 100), randomBoolean());
+        FileDetail anotherFile = new FileDetail(f.name(), f.length(), f.reused());
         assertEquals(f, anotherFile);
         assertEquals(f.hashCode(), anotherFile.hashCode());
         int iters = randomIntBetween(10, 100);
         for (int i = 0; i < iters; i++) {
-            f = new File("foo", randomIntBetween(0, 100), randomBoolean());
-            anotherFile = new File(f.name(), randomIntBetween(0, 100), randomBoolean());
+            f = new FileDetail("foo", randomIntBetween(0, 100), randomBoolean());
+            anotherFile = new FileDetail(f.name(), randomIntBetween(0, 100), randomBoolean());
             if (f.equals(anotherFile)) {
                 assertEquals(f.hashCode(), anotherFile.hashCode());
             } else if (f.hashCode() != anotherFile.hashCode()) {

+ 2 - 2
server/src/test/java/org/elasticsearch/repositories/fs/FsRepositoryTests.java

@@ -159,9 +159,9 @@ public class FsRepositoryTests extends ESTestCase {
             futureC.actionGet();
             assertEquals(secondState.getIndex().reusedFileCount(), commitFileNames.size()-2);
             assertEquals(secondState.getIndex().recoveredFileCount(), 2);
-            List<RecoveryState.File> recoveredFiles =
+            List<RecoveryState.FileDetail> recoveredFiles =
                 secondState.getIndex().fileDetails().stream().filter(f -> f.reused() == false).collect(Collectors.toList());
-            Collections.sort(recoveredFiles, Comparator.comparing(RecoveryState.File::name));
+            Collections.sort(recoveredFiles, Comparator.comparing(RecoveryState.FileDetail::name));
             assertTrue(recoveredFiles.get(0).name(), recoveredFiles.get(0).name().endsWith(".liv"));
             assertTrue(recoveredFiles.get(1).name(), recoveredFiles.get(1).name().endsWith("segments_" + incIndexCommit.getGeneration()));
         } finally {

+ 2 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/searchablesnapshots/SearchableSnapshotsConstants.java

@@ -30,6 +30,8 @@ public class SearchableSnapshotsConstants {
 
     public static final String SNAPSHOT_DIRECTORY_FACTORY_KEY = "snapshot";
 
+    public static final String SNAPSHOT_RECOVERY_STATE_FACTORY_KEY = "snapshot_prewarm";
+
     public static boolean isSearchableSnapshotStore(Settings indexSettings) {
         return SEARCHABLE_SNAPSHOTS_FEATURE_ENABLED
             && SNAPSHOT_DIRECTORY_FACTORY_KEY.equals(INDEX_STORE_TYPE_SETTING.get(indexSettings));

+ 73 - 46
x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/index/store/SearchableSnapshotDirectory.java

@@ -19,6 +19,7 @@ import org.apache.lucene.store.SingleInstanceLockFactory;
 import org.apache.lucene.util.BytesRef;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.ActionRunnable;
+import org.elasticsearch.action.StepListener;
 import org.elasticsearch.action.support.GroupedActionListener;
 import org.elasticsearch.cluster.metadata.IndexMetadata;
 import org.elasticsearch.common.CheckedRunnable;
@@ -41,6 +42,8 @@ import org.elasticsearch.index.store.cache.CacheKey;
 import org.elasticsearch.index.store.cache.CachedBlobContainerIndexInput;
 import org.elasticsearch.index.store.checksum.ChecksumBlobContainerIndexInput;
 import org.elasticsearch.index.store.direct.DirectBlobContainerIndexInput;
+import org.elasticsearch.indices.recovery.RecoveryState;
+import org.elasticsearch.indices.recovery.SearchableSnapshotRecoveryState;
 import org.elasticsearch.repositories.IndexId;
 import org.elasticsearch.repositories.RepositoriesService;
 import org.elasticsearch.repositories.Repository;
@@ -121,6 +124,7 @@ public class SearchableSnapshotDirectory extends BaseDirectory {
     private volatile BlobStoreIndexShardSnapshot snapshot;
     private volatile BlobContainer blobContainer;
     private volatile boolean loaded;
+    private volatile SearchableSnapshotRecoveryState recoveryState;
 
     public SearchableSnapshotDirectory(
         Supplier<BlobContainer> blobContainer,
@@ -176,8 +180,13 @@ public class SearchableSnapshotDirectory extends BaseDirectory {
      *
      * @return true if the snapshot was loaded by executing this method, false otherwise
      */
-    public boolean loadSnapshot() {
+    public boolean loadSnapshot(RecoveryState recoveryState) {
+        assert recoveryState != null;
+        assert recoveryState instanceof SearchableSnapshotRecoveryState;
         assert assertCurrentThreadMayLoadSnapshot();
+        if (recoveryState instanceof SearchableSnapshotRecoveryState == false) {
+            throw new IllegalArgumentException("A SearchableSnapshotRecoveryState instance was expected");
+        }
         boolean alreadyLoaded = this.loaded;
         if (alreadyLoaded == false) {
             synchronized (this) {
@@ -187,6 +196,7 @@ public class SearchableSnapshotDirectory extends BaseDirectory {
                     this.snapshot = snapshotSupplier.get();
                     this.loaded = true;
                     cleanExistingRegularShardFiles();
+                    this.recoveryState = (SearchableSnapshotRecoveryState) recoveryState;
                     prewarmCache();
                 }
             }
@@ -388,57 +398,74 @@ public class SearchableSnapshotDirectory extends BaseDirectory {
     }
 
     private void prewarmCache() {
-        if (prewarmCache) {
-            final BlockingQueue<Tuple<ActionListener<Void>, CheckedRunnable<Exception>>> queue = new LinkedBlockingQueue<>();
-            final Executor executor = prewarmExecutor();
+        if (prewarmCache == false) {
+            recoveryState.preWarmFinished();
+            return;
+        }
+
+        final BlockingQueue<Tuple<ActionListener<Void>, CheckedRunnable<Exception>>> queue = new LinkedBlockingQueue<>();
+        final Executor executor = prewarmExecutor();
+
+        final GroupedActionListener<Void> completionListener = new GroupedActionListener<>(
+            ActionListener.wrap(voids -> recoveryState.preWarmFinished(), e -> {}), // Ignore pre-warm errors
+            snapshot().totalFileCount()
+        );
 
-            for (BlobStoreIndexShardSnapshot.FileInfo file : snapshot().indexFiles()) {
-                if (file.metadata().hashEqualsContents() || isExcludedFromCache(file.physicalName())) {
-                    continue;
+        for (BlobStoreIndexShardSnapshot.FileInfo file : snapshot().indexFiles()) {
+            if (file.metadata().hashEqualsContents() || isExcludedFromCache(file.physicalName())) {
+                if (file.metadata().hashEqualsContents()) {
+                    recoveryState.getIndex().addFileDetail(file.physicalName(), file.length(), true);
+                } else {
+                    recoveryState.ignoreFile(file.physicalName());
                 }
-                try {
-                    final IndexInput input = openInput(file.physicalName(), CachedBlobContainerIndexInput.CACHE_WARMING_CONTEXT);
-                    assert input instanceof CachedBlobContainerIndexInput : "expected cached index input but got " + input.getClass();
-
-                    final int numberOfParts = Math.toIntExact(file.numberOfParts());
-                    final GroupedActionListener<Void> listener = new GroupedActionListener<>(
-                        ActionListener.wrap(voids -> input.close(), e -> IOUtils.closeWhileHandlingException(input)),
-                        numberOfParts
-                    );
-
-                    for (int p = 0; p < numberOfParts; p++) {
-                        final int part = p;
-                        queue.add(Tuple.tuple(listener, () -> {
-                            ensureOpen();
-
-                            logger.trace("{} warming cache for [{}] part [{}/{}]", shardId, file.physicalName(), part + 1, numberOfParts);
-                            final long startTimeInNanos = statsCurrentTimeNanosSupplier.getAsLong();
-                            ((CachedBlobContainerIndexInput) input).prefetchPart(part);
-
-                            logger.trace(
-                                () -> new ParameterizedMessage(
-                                    "{} part [{}/{}] of [{}] warmed in [{}] ms",
-                                    shardId,
-                                    part + 1,
-                                    numberOfParts,
-                                    file.physicalName(),
-                                    TimeValue.timeValueNanos(statsCurrentTimeNanosSupplier.getAsLong() - startTimeInNanos).millis()
-                                )
-                            );
-                        }));
-                    }
-                } catch (IOException e) {
-                    logger.warn(() -> new ParameterizedMessage("{} unable to prewarm file [{}]", shardId, file.physicalName()), e);
+                completionListener.onResponse(null);
+                continue;
+            }
+            recoveryState.getIndex().addFileDetail(file.physicalName(), file.length(), false);
+            try {
+                final IndexInput input = openInput(file.physicalName(), CachedBlobContainerIndexInput.CACHE_WARMING_CONTEXT);
+                assert input instanceof CachedBlobContainerIndexInput : "expected cached index input but got " + input.getClass();
+
+                final int numberOfParts = Math.toIntExact(file.numberOfParts());
+                final StepListener<Collection<Void>> fileCompletionListener = new StepListener<>();
+                fileCompletionListener.whenComplete(voids -> input.close(), e -> IOUtils.closeWhileHandlingException(input));
+                fileCompletionListener.whenComplete(voids -> completionListener.onResponse(null), completionListener::onFailure);
+
+                final GroupedActionListener<Void> listener = new GroupedActionListener<>(fileCompletionListener, numberOfParts);
+
+                for (int p = 0; p < numberOfParts; p++) {
+                    final int part = p;
+                    queue.add(Tuple.tuple(listener, () -> {
+                        ensureOpen();
+
+                        logger.trace("{} warming cache for [{}] part [{}/{}]", shardId, file.physicalName(), part + 1, numberOfParts);
+                        final long startTimeInNanos = statsCurrentTimeNanosSupplier.getAsLong();
+                        ((CachedBlobContainerIndexInput) input).prefetchPart(part);
+                        recoveryState.getIndex().addRecoveredBytesToFile(file.physicalName(), file.partBytes(part));
+
+                        logger.trace(
+                            () -> new ParameterizedMessage(
+                                "{} part [{}/{}] of [{}] warmed in [{}] ms",
+                                shardId,
+                                part + 1,
+                                numberOfParts,
+                                file.physicalName(),
+                                TimeValue.timeValueNanos(statsCurrentTimeNanosSupplier.getAsLong() - startTimeInNanos).millis()
+                            )
+                        );
+                    }));
                 }
+            } catch (IOException e) {
+                logger.warn(() -> new ParameterizedMessage("{} unable to prewarm file [{}]", shardId, file.physicalName()), e);
             }
+        }
 
-            logger.debug("{} warming shard cache for [{}] files", shardId, queue.size());
+        logger.debug("{} warming shard cache for [{}] files", shardId, queue.size());
 
-            // Start as many workers as fit into the searchable snapshot pool at once at the most
-            final int workers = Math.min(threadPool.info(CACHE_FETCH_ASYNC_THREAD_POOL_NAME).getMax(), queue.size());
-            for (int i = 0; i < workers; ++i) {
-                prewarmNext(executor, queue);
-            }
+        // Start as many workers as fit into the searchable snapshot pool at once at the most
+        final int workers = Math.min(threadPool.info(CACHE_FETCH_ASYNC_THREAD_POOL_NAME).getMax(), queue.size());
+        for (int i = 0; i < workers; ++i) {
+            prewarmNext(executor, queue);
         }
     }
 

+ 126 - 0
x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/indices/recovery/SearchableSnapshotRecoveryState.java

@@ -0,0 +1,126 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+
+package org.elasticsearch.indices.recovery;
+
+import org.elasticsearch.cluster.node.DiscoveryNode;
+import org.elasticsearch.cluster.routing.ShardRouting;
+import org.elasticsearch.common.Nullable;
+
+import java.util.HashSet;
+import java.util.Set;
+
+public final class SearchableSnapshotRecoveryState extends RecoveryState {
+    private boolean preWarmFinished;
+
+    public SearchableSnapshotRecoveryState(ShardRouting shardRouting, DiscoveryNode targetNode, @Nullable DiscoveryNode sourceNode) {
+        super(shardRouting, targetNode, sourceNode, new Index());
+    }
+
+    @Override
+    public synchronized RecoveryState setStage(Stage stage) {
+        // The transition to the final state was done by #prewarmCompleted, just ignore the transition
+        if (getStage() == Stage.DONE) {
+            return this;
+        }
+
+        // Pre-warm is still running, hold the state transition
+        // until the pre-warm process finishes
+        if (preWarmFinished == false && stage == Stage.DONE) {
+            validateCurrentStage(Stage.FINALIZE);
+            return this;
+        }
+
+        return super.setStage(stage);
+    }
+
+    public synchronized void preWarmFinished() {
+        // For small shards it's possible that the
+        // cache is pre-warmed before the stage has transitioned
+        // to FINALIZE, so the transition to the final state is delayed until
+        // the recovery process catches up.
+        if (getStage() == Stage.FINALIZE) {
+            super.setStage(Stage.DONE);
+        }
+
+        SearchableSnapshotRecoveryState.Index index = (Index) getIndex();
+        index.stopTimer();
+        preWarmFinished = true;
+    }
+
+    public synchronized void ignoreFile(String name) {
+        SearchableSnapshotRecoveryState.Index index = (Index) getIndex();
+        index.addFileToIgnore(name);
+    }
+
+    private static final class Index extends RecoveryState.Index {
+        // We ignore the files that won't be part of the pre-warming
+        // phase since the information for those files won't be
+        // updated and marking them as reused might be confusing,
+        // as they are fetched on-demand from the underlying repository.
+        private final Set<String> filesToIgnore = new HashSet<>();
+
+        private Index() {
+            super(new SearchableSnapshotRecoveryFilesDetails());
+            // We start loading data just at the beginning
+            super.start();
+        }
+
+        private synchronized void addFileToIgnore(String name) {
+            filesToIgnore.add(name);
+        }
+
+        @Override
+        public synchronized void addFileDetail(String name, long length, boolean reused) {
+            if (filesToIgnore.contains(name)) {
+                return;
+            }
+
+            super.addFileDetail(name, length, reused);
+        }
+
+        // We have to bypass all the calls to the timer
+        @Override
+        public synchronized void start() {}
+
+        @Override
+        public synchronized void stop() {}
+
+        @Override
+        public synchronized void reset() {}
+
+        private synchronized void stopTimer() {
+            super.stop();
+        }
+    }
+
+    private static class SearchableSnapshotRecoveryFilesDetails extends RecoveryFilesDetails {
+        @Override
+        public void addFileDetails(String name, long length, boolean reused) {
+            // We allow reporting the same file details multiple times as we populate the file
+            // details before the recovery is executed (see SearchableSnapshotDirectory#prewarmCache)
+            // and therefore we ignore the rest of the calls for the same files.
+            // Additionally, it's possible that a segments_n file that wasn't part of the snapshot is
+            // sent over during peer recoveries as after restore a new segments file is generated
+            // (see StoreRecovery#bootstrap).
+            FileDetail fileDetail = fileDetails.computeIfAbsent(name, n -> new FileDetail(name, length, reused));
+            assert fileDetail == null || fileDetail.name().equals(name) && fileDetail.length() == length : "The file "
+                + name
+                + " was reported multiple times with different lengths: ["
+                + fileDetail.length()
+                + "] and ["
+                + length
+                + "]";
+        }
+
+        @Override
+        public void clear() {
+            // Since we don't want to remove the recovery information that might have been
+            // populated during cache pre-warming we just ignore clearing the file details.
+            complete = false;
+        }
+    }
+}

+ 1 - 1
x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/SearchableSnapshotIndexEventListener.java

@@ -34,7 +34,7 @@ public class SearchableSnapshotIndexEventListener implements IndexEventListener
         final SearchableSnapshotDirectory directory = SearchableSnapshotDirectory.unwrapDirectory(indexShard.store().directory());
         assert directory != null;
 
-        final boolean success = directory.loadSnapshot();
+        final boolean success = directory.loadSnapshot(indexShard.recoveryState());
         assert directory.listAll().length > 0 : "expecting directory listing to be non-empty";
         assert success
             || indexShard.routingEntry()

+ 12 - 1
x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/SearchableSnapshots.java

@@ -32,6 +32,7 @@ import org.elasticsearch.index.engine.EngineFactory;
 import org.elasticsearch.index.engine.ReadOnlyEngine;
 import org.elasticsearch.index.store.SearchableSnapshotDirectory;
 import org.elasticsearch.index.translog.TranslogStats;
+import org.elasticsearch.indices.recovery.SearchableSnapshotRecoveryState;
 import org.elasticsearch.license.LicenseUtils;
 import org.elasticsearch.license.XPackLicenseState;
 import org.elasticsearch.plugins.ActionPlugin;
@@ -73,12 +74,13 @@ import java.util.Optional;
 import java.util.function.Function;
 import java.util.function.Supplier;
 
+import static org.elasticsearch.xpack.searchablesnapshots.SearchableSnapshotsConstants.CACHE_FETCH_ASYNC_THREAD_POOL_NAME;
 import static org.elasticsearch.xpack.searchablesnapshots.SearchableSnapshotsConstants.CACHE_FETCH_ASYNC_THREAD_POOL_SETTING;
 import static org.elasticsearch.xpack.searchablesnapshots.SearchableSnapshotsConstants.CACHE_PREWARMING_THREAD_POOL_NAME;
 import static org.elasticsearch.xpack.searchablesnapshots.SearchableSnapshotsConstants.CACHE_PREWARMING_THREAD_POOL_SETTING;
 import static org.elasticsearch.xpack.searchablesnapshots.SearchableSnapshotsConstants.SEARCHABLE_SNAPSHOTS_FEATURE_ENABLED;
-import static org.elasticsearch.xpack.searchablesnapshots.SearchableSnapshotsConstants.CACHE_FETCH_ASYNC_THREAD_POOL_NAME;
 import static org.elasticsearch.xpack.searchablesnapshots.SearchableSnapshotsConstants.SNAPSHOT_DIRECTORY_FACTORY_KEY;
+import static org.elasticsearch.xpack.searchablesnapshots.SearchableSnapshotsConstants.SNAPSHOT_RECOVERY_STATE_FACTORY_KEY;
 
 /**
  * Plugin for Searchable Snapshots feature
@@ -312,6 +314,15 @@ public class SearchableSnapshots extends Plugin implements IndexStorePlugin, Eng
         }
     }
 
+    @Override
+    public Map<String, RecoveryStateFactory> getRecoveryStateFactories() {
+        if (SEARCHABLE_SNAPSHOTS_FEATURE_ENABLED) {
+            return Map.of(SNAPSHOT_RECOVERY_STATE_FACTORY_KEY, SearchableSnapshotRecoveryState::new);
+        } else {
+            return Map.of();
+        }
+    }
+
     public static ScalingExecutorBuilder[] executorBuilders() {
         return new ScalingExecutorBuilder[] {
             new ScalingExecutorBuilder(

+ 2 - 0
x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/action/TransportMountSearchableSnapshotAction.java

@@ -45,6 +45,7 @@ import java.util.Map;
 import java.util.Objects;
 import java.util.Optional;
 
+import static org.elasticsearch.index.IndexModule.INDEX_RECOVERY_TYPE_SETTING;
 import static org.elasticsearch.index.IndexModule.INDEX_STORE_TYPE_SETTING;
 
 /**
@@ -118,6 +119,7 @@ public class TransportMountSearchableSnapshotAction extends TransportMasterNodeA
             .put(INDEX_STORE_TYPE_SETTING.getKey(), SearchableSnapshotsConstants.SNAPSHOT_DIRECTORY_FACTORY_KEY)
             .put(IndexMetadata.SETTING_BLOCKS_WRITE, true)
             .put(ExistingShardsAllocator.EXISTING_SHARDS_ALLOCATOR_SETTING.getKey(), SearchableSnapshotAllocator.ALLOCATOR_NAME)
+            .put(INDEX_RECOVERY_TYPE_SETTING.getKey(), SearchableSnapshotsConstants.SNAPSHOT_RECOVERY_STATE_FACTORY_KEY)
             .build();
     }
 

+ 16 - 1
x-pack/plugin/searchable-snapshots/src/test/java/org/elasticsearch/index/store/SearchableSnapshotDirectoryStatsTests.java

@@ -9,6 +9,10 @@ import org.apache.lucene.store.BufferedIndexInput;
 import org.apache.lucene.store.IOContext;
 import org.apache.lucene.store.IndexInput;
 import org.elasticsearch.Version;
+import org.elasticsearch.cluster.node.DiscoveryNode;
+import org.elasticsearch.cluster.routing.ShardRouting;
+import org.elasticsearch.cluster.routing.ShardRoutingState;
+import org.elasticsearch.cluster.routing.TestShardRouting;
 import org.elasticsearch.common.TriConsumer;
 import org.elasticsearch.common.blobstore.BlobContainer;
 import org.elasticsearch.common.lucene.store.ESIndexInputTestCase;
@@ -22,6 +26,8 @@ import org.elasticsearch.index.shard.ShardPath;
 import org.elasticsearch.index.snapshots.blobstore.BlobStoreIndexShardSnapshot;
 import org.elasticsearch.index.snapshots.blobstore.BlobStoreIndexShardSnapshot.FileInfo;
 import org.elasticsearch.index.store.cache.TestUtils;
+import org.elasticsearch.indices.recovery.RecoveryState;
+import org.elasticsearch.indices.recovery.SearchableSnapshotRecoveryState;
 import org.elasticsearch.repositories.IndexId;
 import org.elasticsearch.snapshots.SnapshotId;
 import org.elasticsearch.threadpool.TestThreadPool;
@@ -637,7 +643,16 @@ public class SearchableSnapshotDirectoryStatsTests extends ESIndexInputTestCase
             cacheService.start();
             assertThat(directory.getStats(fileName), nullValue());
 
-            final boolean loaded = directory.loadSnapshot();
+            ShardRouting shardRouting = TestShardRouting.newShardRouting(
+                randomAlphaOfLength(10),
+                0,
+                randomAlphaOfLength(10),
+                true,
+                ShardRoutingState.INITIALIZING
+            );
+            DiscoveryNode targetNode = new DiscoveryNode("local", buildNewFakeTransportAddress(), Version.CURRENT);
+            RecoveryState recoveryState = new SearchableSnapshotRecoveryState(shardRouting, targetNode, null);
+            final boolean loaded = directory.loadSnapshot(recoveryState);
             assertThat("Failed to load snapshot", loaded, is(true));
             assertThat("Snapshot should be loaded", directory.snapshot(), notNullValue());
             assertThat("BlobContainer should be loaded", directory.blobContainer(), notNullValue());

+ 167 - 3
x-pack/plugin/searchable-snapshots/src/test/java/org/elasticsearch/index/store/SearchableSnapshotDirectoryTests.java

@@ -23,6 +23,8 @@ import org.apache.lucene.index.LeafReader;
 import org.apache.lucene.index.SegmentInfos;
 import org.apache.lucene.index.Term;
 import org.apache.lucene.index.Terms;
+import org.apache.lucene.mockfile.FilterFileSystemProvider;
+import org.apache.lucene.mockfile.FilterSeekableByteChannel;
 import org.apache.lucene.search.CheckHits;
 import org.apache.lucene.search.IndexSearcher;
 import org.apache.lucene.search.MatchAllDocsQuery;
@@ -38,6 +40,10 @@ import org.elasticsearch.Version;
 import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.cluster.metadata.IndexMetadata;
 import org.elasticsearch.cluster.metadata.RepositoryMetadata;
+import org.elasticsearch.cluster.node.DiscoveryNode;
+import org.elasticsearch.cluster.routing.ShardRouting;
+import org.elasticsearch.cluster.routing.ShardRoutingState;
+import org.elasticsearch.cluster.routing.TestShardRouting;
 import org.elasticsearch.common.CheckedBiConsumer;
 import org.elasticsearch.common.CheckedFunction;
 import org.elasticsearch.common.UUIDs;
@@ -45,6 +51,8 @@ import org.elasticsearch.common.blobstore.BlobContainer;
 import org.elasticsearch.common.blobstore.BlobPath;
 import org.elasticsearch.common.blobstore.fs.FsBlobContainer;
 import org.elasticsearch.common.blobstore.fs.FsBlobStore;
+import org.elasticsearch.common.io.PathUtils;
+import org.elasticsearch.common.io.PathUtilsForTesting;
 import org.elasticsearch.common.lease.Releasable;
 import org.elasticsearch.common.lease.Releasables;
 import org.elasticsearch.common.lucene.BytesRefs;
@@ -70,6 +78,8 @@ import org.elasticsearch.index.store.cache.TestUtils;
 import org.elasticsearch.index.store.checksum.ChecksumBlobContainerIndexInput;
 import org.elasticsearch.index.translog.Translog;
 import org.elasticsearch.indices.recovery.RecoverySettings;
+import org.elasticsearch.indices.recovery.RecoveryState;
+import org.elasticsearch.indices.recovery.SearchableSnapshotRecoveryState;
 import org.elasticsearch.repositories.IndexId;
 import org.elasticsearch.repositories.blobstore.BlobStoreRepository;
 import org.elasticsearch.repositories.blobstore.BlobStoreTestUtil;
@@ -89,21 +99,31 @@ import java.io.EOFException;
 import java.io.FileNotFoundException;
 import java.io.IOException;
 import java.io.UncheckedIOException;
+import java.nio.ByteBuffer;
+import java.nio.channels.SeekableByteChannel;
 import java.nio.charset.StandardCharsets;
 import java.nio.file.DirectoryStream;
+import java.nio.file.FileSystem;
 import java.nio.file.Files;
 import java.nio.file.NoSuchFileException;
+import java.nio.file.OpenOption;
 import java.nio.file.Path;
 import java.nio.file.StandardOpenOption;
 import java.nio.file.attribute.BasicFileAttributes;
+import java.nio.file.attribute.FileAttribute;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.Executor;
+import java.util.concurrent.ThreadPoolExecutor;
+import java.util.stream.Collectors;
 
 import static java.util.Collections.emptyMap;
 import static org.elasticsearch.xpack.searchablesnapshots.SearchableSnapshots.SNAPSHOT_CACHE_ENABLED_SETTING;
+import static org.elasticsearch.xpack.searchablesnapshots.SearchableSnapshots.SNAPSHOT_CACHE_EXCLUDED_FILE_TYPES_SETTING;
 import static org.elasticsearch.xpack.searchablesnapshots.SearchableSnapshots.SNAPSHOT_CACHE_PREWARM_ENABLED_SETTING;
 import static org.elasticsearch.xpack.searchablesnapshots.SearchableSnapshots.SNAPSHOT_INDEX_ID_SETTING;
 import static org.elasticsearch.xpack.searchablesnapshots.SearchableSnapshots.SNAPSHOT_REPOSITORY_SETTING;
@@ -116,6 +136,7 @@ import static org.hamcrest.Matchers.hasSize;
 import static org.hamcrest.Matchers.instanceOf;
 import static org.hamcrest.Matchers.is;
 import static org.hamcrest.Matchers.lessThanOrEqualTo;
+import static org.hamcrest.Matchers.notNullValue;
 import static org.hamcrest.Matchers.nullValue;
 import static org.hamcrest.Matchers.sameInstance;
 
@@ -442,6 +463,16 @@ public class SearchableSnapshotDirectoryTests extends ESTestCase {
         final boolean enableCache,
         final boolean prewarmCache,
         final CheckedBiConsumer<Directory, SearchableSnapshotDirectory, Exception> consumer
+    ) throws Exception {
+        testDirectories(enableCache, prewarmCache, createRecoveryState(), Settings.EMPTY, consumer);
+    }
+
+    private void testDirectories(
+        final boolean enableCache,
+        final boolean prewarmCache,
+        final SearchableSnapshotRecoveryState recoveryState,
+        final Settings searchableSnapshotDirectorySettings,
+        final CheckedBiConsumer<Directory, SearchableSnapshotDirectory, Exception> consumer
     ) throws Exception {
         final IndexSettings indexSettings = newIndexSettings();
         final ShardId shardId = new ShardId(indexSettings.getIndex(), randomIntBetween(0, 10));
@@ -571,6 +602,7 @@ public class SearchableSnapshotDirectoryTests extends ESTestCase {
                         indexId,
                         shardId,
                         Settings.builder()
+                            .put(searchableSnapshotDirectorySettings)
                             .put(SNAPSHOT_CACHE_ENABLED_SETTING.getKey(), enableCache)
                             .put(SNAPSHOT_CACHE_PREWARM_ENABLED_SETTING.getKey(), prewarmCache)
                             .build(),
@@ -581,7 +613,7 @@ public class SearchableSnapshotDirectoryTests extends ESTestCase {
                         threadPool
                     )
                 ) {
-                    final boolean loaded = snapshotDirectory.loadSnapshot();
+                    final boolean loaded = snapshotDirectory.loadSnapshot(recoveryState);
                     assertThat("Failed to load snapshot", loaded, is(true));
                     assertThat("Snapshot should be loaded", snapshotDirectory.snapshot(), sameInstance(snapshot));
                     assertThat("BlobContainer should be loaded", snapshotDirectory.blobContainer(), sameInstance(blobContainer));
@@ -677,8 +709,8 @@ public class SearchableSnapshotDirectoryTests extends ESTestCase {
                     threadPool
                 )
             ) {
-
-                final boolean loaded = directory.loadSnapshot();
+                final RecoveryState recoveryState = createRecoveryState();
+                final boolean loaded = directory.loadSnapshot(recoveryState);
                 assertThat("Failed to load snapshot", loaded, is(true));
                 assertThat("Snapshot should be loaded", directory.snapshot(), sameInstance(snapshot));
                 assertThat("BlobContainer should be loaded", directory.blobContainer(), sameInstance(blobContainer));
@@ -736,6 +768,92 @@ public class SearchableSnapshotDirectoryTests extends ESTestCase {
         }
     }
 
+    public void testRecoveryStateIsKeptOpenAfterPreWarmFailures() throws Exception {
+        FileSystem fileSystem = PathUtils.getDefaultFileSystem();
+        FaultyReadsFileSystem disruptFileSystemProvider = new FaultyReadsFileSystem(fileSystem);
+        fileSystem = disruptFileSystemProvider.getFileSystem(null);
+        PathUtilsForTesting.installMock(fileSystem);
+
+        try {
+            SearchableSnapshotRecoveryState recoveryState = createRecoveryState();
+            testDirectories(true, true, recoveryState, Settings.EMPTY, (directory, snapshotDirectory) -> {
+                assertExecutorIsIdle(snapshotDirectory.prewarmExecutor());
+
+                assertThat(recoveryState.getStage(), equalTo(RecoveryState.Stage.FINALIZE));
+                // All pre-warm tasks failed
+                assertThat(recoveryState.getIndex().recoveredBytes(), equalTo(0L));
+            });
+        } finally {
+            PathUtilsForTesting.teardown();
+        }
+    }
+
+    public void testRecoveryStateIsEmptyWhenTheCacheIsNotPreWarmed() throws Exception {
+        SearchableSnapshotRecoveryState recoveryState = createRecoveryState();
+        testDirectories(true, false, recoveryState, Settings.EMPTY, (directory, snapshotDirectory) -> {
+            assertThat(recoveryState.getStage(), equalTo(RecoveryState.Stage.DONE));
+            assertThat(recoveryState.getIndex().recoveredBytes(), equalTo(0L));
+            assertThat(recoveryState.getIndex().totalRecoverFiles(), equalTo(0));
+        });
+    }
+
+    public void testNonCachedFilesAreExcludedFromRecoveryState() throws Exception {
+        SearchableSnapshotRecoveryState recoveryState = createRecoveryState();
+
+        List<String> allFileExtensions = List.of(
+            "fdt",
+            "fdx",
+            "nvd",
+            "dvd",
+            "tip",
+            "cfs",
+            "dim",
+            "fnm",
+            "dvm",
+            "tmd",
+            "doc",
+            "tim",
+            "pos",
+            "cfe",
+            "fdm",
+            "nvm"
+        );
+        List<String> fileTypesExcludedFromCaching = randomSubsetOf(allFileExtensions);
+        Settings settings = Settings.builder()
+            .putList(SNAPSHOT_CACHE_EXCLUDED_FILE_TYPES_SETTING.getKey(), fileTypesExcludedFromCaching)
+            .build();
+        testDirectories(true, true, recoveryState, settings, (directory, snapshotDirectory) -> {
+            assertExecutorIsIdle(snapshotDirectory.prewarmExecutor());
+
+            assertThat(recoveryState.getStage(), equalTo(RecoveryState.Stage.DONE));
+            for (RecoveryState.FileDetail fileDetail : recoveryState.getIndex().fileDetails()) {
+                boolean fileHasExcludedType = fileTypesExcludedFromCaching.stream().anyMatch(type -> fileDetail.name().endsWith(type));
+                assertFalse(fileHasExcludedType);
+            }
+        });
+    }
+
+    public void testFilesWithHashEqualsContentsAreMarkedAsReusedOnRecoveryState() throws Exception {
+        SearchableSnapshotRecoveryState recoveryState = createRecoveryState();
+
+        testDirectories(true, true, recoveryState, Settings.EMPTY, (directory, snapshotDirectory) -> {
+            assertExecutorIsIdle(snapshotDirectory.prewarmExecutor());
+            assertThat(recoveryState.getStage(), equalTo(RecoveryState.Stage.DONE));
+
+            List<BlobStoreIndexShardSnapshot.FileInfo> filesWithEqualContent = snapshotDirectory.snapshot()
+                .indexFiles()
+                .stream()
+                .filter(f -> f.metadata().hashEqualsContents())
+                .collect(Collectors.toList());
+
+            for (BlobStoreIndexShardSnapshot.FileInfo fileWithEqualContent : filesWithEqualContent) {
+                RecoveryState.FileDetail fileDetail = recoveryState.getIndex().getFileDetails(fileWithEqualContent.physicalName());
+                assertThat(fileDetail, is(notNullValue()));
+                assertTrue(fileDetail.reused());
+            }
+        });
+    }
+
     private static <T> void assertThat(
         String reason,
         IndexInput actual,
@@ -771,6 +889,14 @@ public class SearchableSnapshotDirectoryTests extends ESTestCase {
         assertThat("Sum of file sizes mismatch, got: " + files, files.values().stream().mapToLong(Long::longValue).sum(), matchSizeOfFiles);
     }
 
+    private void assertExecutorIsIdle(Executor executor) throws Exception {
+        ThreadPoolExecutor threadPoolExecutor = (ThreadPoolExecutor) executor;
+        assertBusy(() -> {
+            assertThat(threadPoolExecutor.getActiveCount(), equalTo(0));
+            assertThat(threadPoolExecutor.getQueue().size(), equalTo(0));
+        });
+    }
+
     private static IndexSettings newIndexSettings() {
         return IndexSettingsModule.newIndexSettings(
             "_index",
@@ -781,4 +907,42 @@ public class SearchableSnapshotDirectoryTests extends ESTestCase {
         );
     }
 
+    private SearchableSnapshotRecoveryState createRecoveryState() {
+        ShardRouting shardRouting = TestShardRouting.newShardRouting(
+            randomAlphaOfLength(10),
+            0,
+            randomAlphaOfLength(10),
+            true,
+            ShardRoutingState.INITIALIZING
+        );
+        DiscoveryNode targetNode = new DiscoveryNode("local", buildNewFakeTransportAddress(), Version.CURRENT);
+        SearchableSnapshotRecoveryState recoveryState = new SearchableSnapshotRecoveryState(shardRouting, targetNode, null);
+
+        recoveryState.setStage(RecoveryState.Stage.INIT)
+            .setStage(RecoveryState.Stage.INDEX)
+            .setStage(RecoveryState.Stage.VERIFY_INDEX)
+            .setStage(RecoveryState.Stage.TRANSLOG);
+        recoveryState.getIndex().setFileDetailsComplete();
+        recoveryState.setStage(RecoveryState.Stage.FINALIZE).setStage(RecoveryState.Stage.DONE);
+
+        return recoveryState;
+    }
+
+    private static class FaultyReadsFileSystem extends FilterFileSystemProvider {
+        FaultyReadsFileSystem(FileSystem inner) {
+            super("faulty_fs://", inner);
+        }
+
+        @Override
+        public SeekableByteChannel newByteChannel(Path path, Set<? extends OpenOption> options, FileAttribute<?>... attrs)
+            throws IOException {
+            return new FilterSeekableByteChannel(super.newByteChannel(path, options, attrs)) {
+                @Override
+                public int read(ByteBuffer dst) throws IOException {
+                    throw new IOException("IO Failure");
+                }
+            };
+        }
+    }
+
 }

+ 29 - 2
x-pack/plugin/searchable-snapshots/src/test/java/org/elasticsearch/index/store/cache/CachedBlobContainerIndexInputTests.java

@@ -7,6 +7,10 @@ package org.elasticsearch.index.store.cache;
 
 import org.apache.lucene.store.IndexInput;
 import org.elasticsearch.Version;
+import org.elasticsearch.cluster.node.DiscoveryNode;
+import org.elasticsearch.cluster.routing.ShardRouting;
+import org.elasticsearch.cluster.routing.ShardRoutingState;
+import org.elasticsearch.cluster.routing.TestShardRouting;
 import org.elasticsearch.common.blobstore.BlobContainer;
 import org.elasticsearch.common.blobstore.support.FilterBlobContainer;
 import org.elasticsearch.common.lucene.store.ESIndexInputTestCase;
@@ -18,6 +22,8 @@ import org.elasticsearch.index.shard.ShardPath;
 import org.elasticsearch.index.snapshots.blobstore.BlobStoreIndexShardSnapshot;
 import org.elasticsearch.index.store.SearchableSnapshotDirectory;
 import org.elasticsearch.index.store.StoreFileMetadata;
+import org.elasticsearch.indices.recovery.RecoveryState;
+import org.elasticsearch.indices.recovery.SearchableSnapshotRecoveryState;
 import org.elasticsearch.repositories.IndexId;
 import org.elasticsearch.snapshots.SnapshotId;
 import org.elasticsearch.threadpool.TestThreadPool;
@@ -112,7 +118,15 @@ public class CachedBlobContainerIndexInputTests extends ESIndexInputTestCase {
                         threadPool
                     )
                 ) {
-                    final boolean loaded = directory.loadSnapshot();
+                    ShardRouting shardRouting = TestShardRouting.newShardRouting(
+                        randomAlphaOfLength(10),
+                        0,
+                        randomAlphaOfLength(10),
+                        true,
+                        ShardRoutingState.INITIALIZING
+                    );
+                    RecoveryState recoveryState = createRecoveryState();
+                    final boolean loaded = directory.loadSnapshot(recoveryState);
                     assertThat("Failed to load snapshot", loaded, is(true));
                     assertThat("Snapshot should be loaded", directory.snapshot(), notNullValue());
                     assertThat("BlobContainer should be loaded", directory.blobContainer(), notNullValue());
@@ -192,7 +206,8 @@ public class CachedBlobContainerIndexInputTests extends ESIndexInputTestCase {
                     threadPool
                 )
             ) {
-                final boolean loaded = searchableSnapshotDirectory.loadSnapshot();
+                RecoveryState recoveryState = createRecoveryState();
+                final boolean loaded = searchableSnapshotDirectory.loadSnapshot(recoveryState);
                 assertThat("Failed to load snapshot", loaded, is(true));
                 assertThat("Snapshot should be loaded", searchableSnapshotDirectory.snapshot(), notNullValue());
                 assertThat("BlobContainer should be loaded", searchableSnapshotDirectory.blobContainer(), notNullValue());
@@ -225,6 +240,18 @@ public class CachedBlobContainerIndexInputTests extends ESIndexInputTestCase {
         return containsEOFException(throwable.getCause(), seenThrowables);
     }
 
+    private SearchableSnapshotRecoveryState createRecoveryState() {
+        ShardRouting shardRouting = TestShardRouting.newShardRouting(
+            randomAlphaOfLength(10),
+            0,
+            randomAlphaOfLength(10),
+            true,
+            ShardRoutingState.INITIALIZING
+        );
+        DiscoveryNode targetNode = new DiscoveryNode("local", buildNewFakeTransportAddress(), Version.CURRENT);
+        return new SearchableSnapshotRecoveryState(shardRouting, targetNode, null);
+    }
+
     /**
      * BlobContainer that counts the number of {@link java.io.InputStream} it opens, as well as the
      * total number of bytes read from them.

+ 141 - 0
x-pack/plugin/searchable-snapshots/src/test/java/org/elasticsearch/indices/recovery/SearchableSnapshotsRecoveryStateTests.java

@@ -0,0 +1,141 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+
+package org.elasticsearch.indices.recovery;
+
+import org.elasticsearch.Version;
+import org.elasticsearch.cluster.node.DiscoveryNode;
+import org.elasticsearch.cluster.routing.ShardRouting;
+import org.elasticsearch.cluster.routing.ShardRoutingState;
+import org.elasticsearch.cluster.routing.TestShardRouting;
+import org.elasticsearch.test.ESTestCase;
+
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.greaterThan;
+import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.not;
+import static org.hamcrest.Matchers.nullValue;
+
+public class SearchableSnapshotsRecoveryStateTests extends ESTestCase {
+    public void testStageDoesNotTransitionToDoneUntilPreWarmingHasFinished() {
+        SearchableSnapshotRecoveryState recoveryState = createRecoveryState();
+
+        recoveryState.setStage(RecoveryState.Stage.INIT)
+            .setStage(RecoveryState.Stage.INDEX)
+            .setStage(RecoveryState.Stage.VERIFY_INDEX)
+            .setStage(RecoveryState.Stage.TRANSLOG);
+        recoveryState.getIndex().setFileDetailsComplete();
+        recoveryState.setStage(RecoveryState.Stage.FINALIZE).setStage(RecoveryState.Stage.DONE);
+
+        assertThat(recoveryState.getStage(), equalTo(RecoveryState.Stage.FINALIZE));
+    }
+
+    public void testsetStageThrowsAnExceptionOnInvalidTransitions() {
+        SearchableSnapshotRecoveryState recoveryState = createRecoveryState();
+        expectThrows(AssertionError.class, () -> recoveryState.setStage(RecoveryState.Stage.DONE));
+    }
+
+    public void testStageTransitionsToDoneOncePreWarmingHasFinished() {
+        SearchableSnapshotRecoveryState recoveryState = createRecoveryState();
+        assertThat(recoveryState.getStage(), equalTo(RecoveryState.Stage.INIT));
+        recoveryState.preWarmFinished();
+
+        assertThat(recoveryState.getStage(), equalTo(RecoveryState.Stage.INIT));
+
+        recoveryState.setStage(RecoveryState.Stage.INDEX).setStage(RecoveryState.Stage.VERIFY_INDEX).setStage(RecoveryState.Stage.TRANSLOG);
+        recoveryState.getIndex().setFileDetailsComplete();
+        recoveryState.setStage(RecoveryState.Stage.FINALIZE).setStage(RecoveryState.Stage.DONE);
+
+        assertThat(recoveryState.getStage(), equalTo(RecoveryState.Stage.DONE));
+    }
+
+    public void testStageTransitionsToDoneOncePreWarmingFinishesOnShardStartedStage() {
+        SearchableSnapshotRecoveryState recoveryState = createRecoveryState();
+
+        recoveryState.setStage(RecoveryState.Stage.INDEX).setStage(RecoveryState.Stage.VERIFY_INDEX).setStage(RecoveryState.Stage.TRANSLOG);
+        recoveryState.getIndex().setFileDetailsComplete();
+        recoveryState.setStage(RecoveryState.Stage.FINALIZE);
+
+        recoveryState.preWarmFinished();
+
+        recoveryState.setStage(RecoveryState.Stage.DONE);
+
+        assertThat(recoveryState.getStage(), equalTo(RecoveryState.Stage.DONE));
+
+        assertThat(recoveryState.getTimer().stopTime(), greaterThan(0L));
+    }
+
+    public void testStageTransitionsToDoneOncePreWarmingFinishesOnHoldShardStartedStage() {
+        SearchableSnapshotRecoveryState recoveryState = createRecoveryState();
+
+        recoveryState.setStage(RecoveryState.Stage.INDEX).setStage(RecoveryState.Stage.VERIFY_INDEX).setStage(RecoveryState.Stage.TRANSLOG);
+        recoveryState.getIndex().setFileDetailsComplete();
+        recoveryState.setStage(RecoveryState.Stage.FINALIZE).setStage(RecoveryState.Stage.DONE);
+
+        recoveryState.preWarmFinished();
+
+        assertThat(recoveryState.getStage(), equalTo(RecoveryState.Stage.DONE));
+
+        assertThat(recoveryState.getTimer().stopTime(), greaterThan(0L));
+    }
+
+    public void testIndexTimerIsStartedDuringConstruction() {
+        SearchableSnapshotRecoveryState recoveryState = createRecoveryState();
+
+        assertThat(recoveryState.getIndex().startTime(), not(equalTo(0L)));
+    }
+
+    public void testIndexTimerMethodsAreBypassed() {
+        SearchableSnapshotRecoveryState recoveryState = createRecoveryState();
+
+        RecoveryState.Index index = recoveryState.getIndex();
+        long initialStartTime = index.startTime();
+        assertThat(initialStartTime, not(equalTo(0L)));
+
+        index.reset();
+
+        assertThat(index.startTime(), equalTo(initialStartTime));
+
+        index.start();
+
+        assertThat(index.startTime(), equalTo(initialStartTime));
+
+        assertThat(index.stopTime(), equalTo(0L));
+
+        index.stop();
+
+        assertThat(index.stopTime(), equalTo(0L));
+    }
+
+    public void testIndexTimerIsStoppedOncePreWarmingFinishes() {
+        SearchableSnapshotRecoveryState recoveryState = createRecoveryState();
+        assertThat(recoveryState.getIndex().stopTime(), equalTo(0L));
+
+        recoveryState.preWarmFinished();
+
+        assertThat(recoveryState.getIndex().stopTime(), greaterThan(0L));
+    }
+
+    public void testFilesAreIgnored() {
+        SearchableSnapshotRecoveryState recoveryState = createRecoveryState();
+        recoveryState.ignoreFile("non_pre_warmed_file");
+        recoveryState.getIndex().addFileDetail("non_pre_warmed_file", 100, false);
+
+        assertThat(recoveryState.getIndex().getFileDetails("non_pre_warmed_file"), is(nullValue()));
+    }
+
+    private SearchableSnapshotRecoveryState createRecoveryState() {
+        ShardRouting shardRouting = TestShardRouting.newShardRouting(
+            randomAlphaOfLength(10),
+            0,
+            randomAlphaOfLength(10),
+            true,
+            ShardRoutingState.INITIALIZING
+        );
+        DiscoveryNode targetNode = new DiscoveryNode("local", buildNewFakeTransportAddress(), Version.CURRENT);
+        return new SearchableSnapshotRecoveryState(shardRouting, targetNode, null);
+    }
+}

+ 159 - 0
x-pack/plugin/searchable-snapshots/src/test/java/org/elasticsearch/xpack/searchablesnapshots/SearchableSnapshotRecoveryStateIntegrationTests.java

@@ -0,0 +1,159 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+
+package org.elasticsearch.xpack.searchablesnapshots;
+
+import com.carrotsearch.hppc.ObjectContainer;
+import org.elasticsearch.action.admin.cluster.snapshots.create.CreateSnapshotResponse;
+import org.elasticsearch.action.admin.cluster.snapshots.restore.RestoreSnapshotResponse;
+import org.elasticsearch.action.admin.indices.recovery.RecoveryResponse;
+import org.elasticsearch.cluster.metadata.IndexMetadata;
+import org.elasticsearch.cluster.node.DiscoveryNode;
+import org.elasticsearch.cluster.node.DiscoveryNodes;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.SuppressForbidden;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.unit.ByteSizeUnit;
+import org.elasticsearch.common.unit.ByteSizeValue;
+import org.elasticsearch.index.Index;
+import org.elasticsearch.index.IndexService;
+import org.elasticsearch.indices.IndicesService;
+import org.elasticsearch.indices.recovery.RecoveryState;
+import org.elasticsearch.snapshots.SnapshotInfo;
+import org.elasticsearch.test.ESIntegTestCase;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.xpack.core.searchablesnapshots.MountSearchableSnapshotAction;
+import org.elasticsearch.xpack.core.searchablesnapshots.MountSearchableSnapshotRequest;
+import org.elasticsearch.xpack.searchablesnapshots.cache.CacheService;
+
+import java.io.File;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.util.List;
+import java.util.Locale;
+import java.util.Map;
+import java.util.concurrent.ThreadPoolExecutor;
+import java.util.stream.Stream;
+
+import static org.elasticsearch.index.IndexSettings.INDEX_SOFT_DELETES_SETTING;
+import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.greaterThan;
+
+@ESIntegTestCase.ClusterScope(numDataNodes = 1)
+public class SearchableSnapshotRecoveryStateIntegrationTests extends BaseSearchableSnapshotsIntegTestCase {
+
+    @Override
+    protected Settings nodeSettings(int nodeOrdinal) {
+        final Settings.Builder builder = Settings.builder().put(super.nodeSettings(nodeOrdinal));
+        builder.put(CacheService.SNAPSHOT_CACHE_SIZE_SETTING.getKey(), new ByteSizeValue(Long.MAX_VALUE, ByteSizeUnit.BYTES));
+
+        return builder.build();
+    }
+
+    public void testRecoveryStateRecoveredBytesMatchPhysicalCacheState() throws Exception {
+        final String fsRepoName = randomAlphaOfLength(10);
+        final String indexName = randomAlphaOfLength(10).toLowerCase(Locale.ROOT);
+        final String restoredIndexName = randomBoolean() ? indexName : randomAlphaOfLength(10).toLowerCase(Locale.ROOT);
+        final String snapshotName = randomAlphaOfLength(10).toLowerCase(Locale.ROOT);
+
+        createRepo(fsRepoName);
+
+        final Settings.Builder originalIndexSettings = Settings.builder();
+        originalIndexSettings.put(INDEX_SOFT_DELETES_SETTING.getKey(), true);
+        originalIndexSettings.put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1);
+
+        createAndPopulateIndex(indexName, originalIndexSettings);
+
+        CreateSnapshotResponse createSnapshotResponse = client().admin()
+            .cluster()
+            .prepareCreateSnapshot(fsRepoName, snapshotName)
+            .setWaitForCompletion(true)
+            .get();
+
+        final SnapshotInfo snapshotInfo = createSnapshotResponse.getSnapshotInfo();
+        assertThat(snapshotInfo.successfulShards(), greaterThan(0));
+        assertThat(snapshotInfo.successfulShards(), equalTo(snapshotInfo.totalShards()));
+
+        assertAcked(client().admin().indices().prepareDelete(indexName));
+
+        final MountSearchableSnapshotRequest req = new MountSearchableSnapshotRequest(
+            restoredIndexName,
+            fsRepoName,
+            snapshotInfo.snapshotId().getName(),
+            indexName,
+            Settings.EMPTY,
+            Strings.EMPTY_ARRAY,
+            true
+        );
+
+        final RestoreSnapshotResponse restoreSnapshotResponse = client().execute(MountSearchableSnapshotAction.INSTANCE, req).get();
+        assertThat(restoreSnapshotResponse.getRestoreInfo().failedShards(), equalTo(0));
+        ensureGreen(restoredIndexName);
+
+        final Index restoredIndex = client().admin()
+            .cluster()
+            .prepareState()
+            .clear()
+            .setMetadata(true)
+            .get()
+            .getState()
+            .metadata()
+            .index(restoredIndexName)
+            .getIndex();
+
+        assertExecutorIsIdle(SearchableSnapshotsConstants.CACHE_PREWARMING_THREAD_POOL_NAME);
+        assertExecutorIsIdle(SearchableSnapshotsConstants.CACHE_FETCH_ASYNC_THREAD_POOL_NAME);
+
+        final RecoveryResponse recoveryResponse = client().admin().indices().prepareRecoveries(restoredIndexName).get();
+        Map<String, List<RecoveryState>> shardRecoveries = recoveryResponse.shardRecoveryStates();
+        assertThat(shardRecoveries.containsKey(restoredIndexName), equalTo(true));
+        List<RecoveryState> recoveryStates = shardRecoveries.get(restoredIndexName);
+        assertThat(recoveryStates.size(), equalTo(1));
+        RecoveryState recoveryState = recoveryStates.get(0);
+
+        assertThat(recoveryState.getStage(), equalTo(RecoveryState.Stage.DONE));
+
+        long recoveredBytes = recoveryState.getIndex().recoveredBytes();
+        long physicalCacheSize = getPhysicalCacheSize(restoredIndex, snapshotInfo.snapshotId().getUUID());
+
+        assertThat("Physical cache size doesn't match with recovery state data", physicalCacheSize, equalTo(recoveredBytes));
+        assertThat("Expected to recover 100% of files", recoveryState.getIndex().recoveredBytesPercent(), equalTo(100.0f));
+    }
+
+    @SuppressForbidden(reason = "Uses FileSystem APIs")
+    private long getPhysicalCacheSize(Index index, String snapshotUUID) throws Exception {
+        final ObjectContainer<DiscoveryNode> dataNodes = getDiscoveryNodes().getDataNodes().values();
+
+        assertThat(dataNodes.size(), equalTo(1));
+
+        final String dataNode = dataNodes.iterator().next().value.getName();
+
+        final IndexService indexService = internalCluster().getInstance(IndicesService.class, dataNode).indexService(index);
+        final Path shardCachePath = CacheService.getShardCachePath(indexService.getShard(0).shardPath());
+
+        long physicalCacheSize;
+        try (Stream<Path> files = Files.list(shardCachePath.resolve(snapshotUUID))) {
+            physicalCacheSize = files.map(Path::toFile).mapToLong(File::length).sum();
+        }
+        return physicalCacheSize;
+    }
+
+    private void assertExecutorIsIdle(String executorName) throws Exception {
+        assertBusy(() -> {
+            for (DiscoveryNode node : getDiscoveryNodes()) {
+                ThreadPool threadPool = internalCluster().getInstance(ThreadPool.class, node.getName());
+                ThreadPoolExecutor threadPoolExecutor = (ThreadPoolExecutor) threadPool.executor(executorName);
+                assertThat(threadPoolExecutor.getQueue().size(), equalTo(0));
+                assertThat(threadPoolExecutor.getActiveCount(), equalTo(0));
+            }
+        });
+    }
+
+    private DiscoveryNodes getDiscoveryNodes() {
+        return client().admin().cluster().prepareState().clear().setNodes(true).get().getState().nodes();
+    }
+}

+ 65 - 31
x-pack/plugin/searchable-snapshots/src/test/java/org/elasticsearch/xpack/searchablesnapshots/SearchableSnapshotsIntegTests.java

@@ -11,6 +11,7 @@ import org.apache.lucene.search.TotalHits;
 import org.elasticsearch.ResourceNotFoundException;
 import org.elasticsearch.action.admin.cluster.snapshots.create.CreateSnapshotResponse;
 import org.elasticsearch.action.admin.cluster.snapshots.restore.RestoreSnapshotResponse;
+import org.elasticsearch.action.admin.cluster.state.ClusterStateResponse;
 import org.elasticsearch.action.admin.indices.alias.IndicesAliasesRequest;
 import org.elasticsearch.action.admin.indices.recovery.RecoveryResponse;
 import org.elasticsearch.action.admin.indices.shrink.ResizeType;
@@ -73,6 +74,8 @@ import static org.elasticsearch.index.query.QueryBuilders.matchQuery;
 import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
 import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertHitCount;
 import static org.elasticsearch.xpack.searchablesnapshots.SearchableSnapshotsConstants.SNAPSHOT_DIRECTORY_FACTORY_KEY;
+import static org.elasticsearch.xpack.searchablesnapshots.SearchableSnapshotsConstants.SNAPSHOT_RECOVERY_STATE_FACTORY_KEY;
+import static org.hamcrest.Matchers.anyOf;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.greaterThan;
 import static org.hamcrest.Matchers.greaterThanOrEqualTo;
@@ -150,8 +153,10 @@ public class SearchableSnapshotsIntegTests extends BaseSearchableSnapshotsIntegT
         Settings.Builder indexSettingsBuilder = Settings.builder()
             .put(SearchableSnapshots.SNAPSHOT_CACHE_ENABLED_SETTING.getKey(), cacheEnabled)
             .put(IndexSettings.INDEX_CHECK_ON_STARTUP.getKey(), Boolean.FALSE.toString());
+        boolean preWarmEnabled = false;
         if (cacheEnabled) {
-            indexSettingsBuilder.put(SearchableSnapshots.SNAPSHOT_CACHE_PREWARM_ENABLED_SETTING.getKey(), randomBoolean());
+            preWarmEnabled = randomBoolean();
+            indexSettingsBuilder.put(SearchableSnapshots.SNAPSHOT_CACHE_PREWARM_ENABLED_SETTING.getKey(), preWarmEnabled);
         }
         final List<String> nonCachedExtensions;
         if (randomBoolean()) {
@@ -195,13 +200,15 @@ public class SearchableSnapshotsIntegTests extends BaseSearchableSnapshotsIntegT
         assertThat(SearchableSnapshots.SNAPSHOT_REPOSITORY_SETTING.get(settings), equalTo(fsRepoName));
         assertThat(SearchableSnapshots.SNAPSHOT_SNAPSHOT_NAME_SETTING.get(settings), equalTo(snapshotName));
         assertThat(IndexModule.INDEX_STORE_TYPE_SETTING.get(settings), equalTo(SNAPSHOT_DIRECTORY_FACTORY_KEY));
+        assertThat(IndexModule.INDEX_RECOVERY_TYPE_SETTING.get(settings), equalTo(SNAPSHOT_RECOVERY_STATE_FACTORY_KEY));
         assertTrue(IndexMetadata.INDEX_BLOCKS_WRITE_SETTING.get(settings));
         assertTrue(SearchableSnapshots.SNAPSHOT_SNAPSHOT_ID_SETTING.exists(settings));
         assertTrue(SearchableSnapshots.SNAPSHOT_INDEX_ID_SETTING.exists(settings));
         assertThat(IndexMetadata.INDEX_AUTO_EXPAND_REPLICAS_SETTING.get(settings).toString(), equalTo("false"));
         assertThat(IndexMetadata.INDEX_NUMBER_OF_REPLICAS_SETTING.get(settings), equalTo(expectedReplicas));
 
-        assertRecovered(restoredIndexName, originalAllHits, originalBarHits);
+        assertTotalHits(restoredIndexName, originalAllHits, originalBarHits);
+        assertRecoveryStats(restoredIndexName, preWarmEnabled);
         assertSearchableSnapshotStats(restoredIndexName, cacheEnabled, nonCachedExtensions);
         ensureGreen(restoredIndexName);
         assertShardFolders(restoredIndexName, true);
@@ -220,11 +227,12 @@ public class SearchableSnapshotsIntegTests extends BaseSearchableSnapshotsIntegT
             );
         }
         assertThat(client().admin().indices().prepareGetAliases(aliasName).get().getAliases().size(), equalTo(1));
-        assertRecovered(aliasName, originalAllHits, originalBarHits, false);
+        assertTotalHits(aliasName, originalAllHits, originalBarHits);
 
         internalCluster().fullRestart();
-        assertRecovered(restoredIndexName, originalAllHits, originalBarHits);
-        assertRecovered(aliasName, originalAllHits, originalBarHits, false);
+        assertTotalHits(restoredIndexName, originalAllHits, originalBarHits);
+        assertRecoveryStats(restoredIndexName, preWarmEnabled);
+        assertTotalHits(aliasName, originalAllHits, originalBarHits);
         assertSearchableSnapshotStats(restoredIndexName, cacheEnabled, nonCachedExtensions);
 
         internalCluster().ensureAtLeastNumDataNodes(2);
@@ -260,7 +268,8 @@ public class SearchableSnapshotsIntegTests extends BaseSearchableSnapshotsIntegT
                 .isTimedOut()
         );
 
-        assertRecovered(restoredIndexName, originalAllHits, originalBarHits);
+        assertTotalHits(restoredIndexName, originalAllHits, originalBarHits);
+        assertRecoveryStats(restoredIndexName, preWarmEnabled);
         assertSearchableSnapshotStats(restoredIndexName, cacheEnabled, nonCachedExtensions);
 
         assertAcked(
@@ -274,16 +283,24 @@ public class SearchableSnapshotsIntegTests extends BaseSearchableSnapshotsIntegT
                 )
         );
 
+        assertTotalHits(restoredIndexName, originalAllHits, originalBarHits);
+        assertRecoveryStats(restoredIndexName, preWarmEnabled);
+
         final String clonedIndexName = randomAlphaOfLength(10).toLowerCase(Locale.ROOT);
         assertAcked(
             client().admin()
                 .indices()
                 .prepareResizeIndex(restoredIndexName, clonedIndexName)
                 .setResizeType(ResizeType.CLONE)
-                .setSettings(Settings.builder().putNull(IndexModule.INDEX_STORE_TYPE_SETTING.getKey()).build())
+                .setSettings(
+                    Settings.builder()
+                        .putNull(IndexModule.INDEX_STORE_TYPE_SETTING.getKey())
+                        .putNull(IndexModule.INDEX_RECOVERY_TYPE_SETTING.getKey())
+                        .build()
+                )
         );
         ensureGreen(clonedIndexName);
-        assertRecovered(clonedIndexName, originalAllHits, originalBarHits, false);
+        assertTotalHits(clonedIndexName, originalAllHits, originalBarHits);
 
         final Settings clonedIndexSettings = client().admin()
             .indices()
@@ -296,12 +313,12 @@ public class SearchableSnapshotsIntegTests extends BaseSearchableSnapshotsIntegT
         assertFalse(clonedIndexSettings.hasValue(SearchableSnapshots.SNAPSHOT_SNAPSHOT_NAME_SETTING.getKey()));
         assertFalse(clonedIndexSettings.hasValue(SearchableSnapshots.SNAPSHOT_SNAPSHOT_ID_SETTING.getKey()));
         assertFalse(clonedIndexSettings.hasValue(SearchableSnapshots.SNAPSHOT_INDEX_ID_SETTING.getKey()));
+        assertFalse(clonedIndexSettings.hasValue(IndexModule.INDEX_RECOVERY_TYPE_SETTING.getKey()));
 
         assertAcked(client().admin().indices().prepareDelete(restoredIndexName));
         assertThat(client().admin().indices().prepareGetAliases(aliasName).get().getAliases().size(), equalTo(0));
         assertAcked(client().admin().indices().prepareAliases().addAlias(clonedIndexName, aliasName));
-        assertRecovered(aliasName, originalAllHits, originalBarHits, false);
-
+        assertTotalHits(aliasName, originalAllHits, originalBarHits);
     }
 
     private void assertShardFolders(String indexName, boolean snapshotDirectory) throws IOException {
@@ -640,13 +657,7 @@ public class SearchableSnapshotsIntegTests extends BaseSearchableSnapshotsIntegT
         }
     }
 
-    private void assertRecovered(String indexName, TotalHits originalAllHits, TotalHits originalBarHits) throws Exception {
-        assertRecovered(indexName, originalAllHits, originalBarHits, true);
-    }
-
-    private void assertRecovered(String indexName, TotalHits originalAllHits, TotalHits originalBarHits, boolean checkRecoveryStats)
-        throws Exception {
-
+    private void assertTotalHits(String indexName, TotalHits originalAllHits, TotalHits originalBarHits) throws Exception {
         final Thread[] threads = new Thread[between(1, 5)];
         final AtomicArray<TotalHits> allHits = new AtomicArray<>(threads.length);
         final AtomicArray<TotalHits> barHits = new AtomicArray<>(threads.length);
@@ -677,20 +688,6 @@ public class SearchableSnapshotsIntegTests extends BaseSearchableSnapshotsIntegT
         ensureGreen(indexName);
         latch.countDown();
 
-        if (checkRecoveryStats) {
-            final RecoveryResponse recoveryResponse = client().admin().indices().prepareRecoveries(indexName).get();
-            for (List<RecoveryState> recoveryStates : recoveryResponse.shardRecoveryStates().values()) {
-                for (RecoveryState recoveryState : recoveryStates) {
-                    logger.info("Checking {}[{}]", recoveryState.getShardId(), recoveryState.getPrimary() ? "p" : "r");
-                    assertThat(
-                        Strings.toString(recoveryState), // we make a new commit so we write a new `segments_n` file
-                        recoveryState.getIndex().recoveredFileCount(),
-                        lessThanOrEqualTo(1)
-                    );
-                }
-            }
-        }
-
         for (int i = 0; i < threads.length; i++) {
             threads[i].join();
 
@@ -703,6 +700,34 @@ public class SearchableSnapshotsIntegTests extends BaseSearchableSnapshotsIntegT
         }
     }
 
+    private void assertRecoveryStats(String indexName, boolean preWarmEnabled) {
+        int shardCount = getNumShards(indexName).totalNumShards;
+        final RecoveryResponse recoveryResponse = client().admin().indices().prepareRecoveries(indexName).get();
+        assertThat(recoveryResponse.shardRecoveryStates().get(indexName).size(), equalTo(shardCount));
+
+        for (List<RecoveryState> recoveryStates : recoveryResponse.shardRecoveryStates().values()) {
+            for (RecoveryState recoveryState : recoveryStates) {
+                ByteSizeValue cacheSize = getCacheSizeForShard(recoveryState.getShardId());
+                boolean unboundedCache = cacheSize.equals(new ByteSizeValue(Long.MAX_VALUE, ByteSizeUnit.BYTES));
+                RecoveryState.Index index = recoveryState.getIndex();
+                assertThat(
+                    Strings.toString(recoveryState),
+                    index.recoveredFileCount(),
+                    preWarmEnabled && unboundedCache ? equalTo(index.totalRecoverFiles()) : greaterThanOrEqualTo(0)
+                );
+
+                // Since the cache size is variable, the pre-warm phase might fail as some of the files can be evicted
+                // while a part is pre-fetched, in that case the recovery state stage is left as FINALIZE.
+                assertThat(
+                    recoveryState.getStage(),
+                    unboundedCache
+                        ? equalTo(RecoveryState.Stage.DONE)
+                        : anyOf(equalTo(RecoveryState.Stage.DONE), equalTo(RecoveryState.Stage.FINALIZE))
+                );
+            }
+        }
+    }
+
     private void assertSearchableSnapshotStats(String indexName, boolean cacheEnabled, List<String> nonCachedExtensions) {
         final SearchableSnapshotsStatsResponse statsResponse = client().execute(
             SearchableSnapshotsStatsAction.INSTANCE,
@@ -804,4 +829,13 @@ public class SearchableSnapshotsIntegTests extends BaseSearchableSnapshotsIntegT
         }
     }
 
+    private ByteSizeValue getCacheSizeForShard(ShardId shardId) {
+        ClusterStateResponse clusterStateResponse = client().admin().cluster().prepareState().setRoutingTable(true).setNodes(true).get();
+        ClusterState clusterStateResponseState = clusterStateResponse.getState();
+        String nodeId = clusterStateResponseState.getRoutingTable().shardRoutingTable(shardId).primaryShard().currentNodeId();
+        DiscoveryNode discoveryNode = clusterStateResponseState.nodes().get(nodeId);
+
+        final Settings nodeSettings = internalCluster().getInstance(Environment.class, discoveryNode.getName()).settings();
+        return CacheService.SNAPSHOT_CACHE_SIZE_SETTING.get(nodeSettings);
+    }
 }