1
0
Эх сурвалжийг харах

Restore from Individual Shard Snapshot Files in Parallel (#48110)

Make restoring shard snapshots run in parallel on the `SNAPSHOT` thread-pool.
Armin Braun 6 жил өмнө
parent
commit
e58fc03d42
15 өөрчлөгдсөн 383 нэмэгдсэн , 285 устгасан
  1. 20 17
      server/src/main/java/org/elasticsearch/index/shard/IndexShard.java
  2. 99 76
      server/src/main/java/org/elasticsearch/index/shard/StoreRecovery.java
  3. 3 2
      server/src/main/java/org/elasticsearch/repositories/FilterRepository.java
  4. 3 2
      server/src/main/java/org/elasticsearch/repositories/Repository.java
  5. 45 20
      server/src/main/java/org/elasticsearch/repositories/blobstore/BlobStoreRepository.java
  6. 40 25
      server/src/main/java/org/elasticsearch/repositories/blobstore/FileRestoreContext.java
  7. 1 1
      server/src/main/java/org/elasticsearch/snapshots/RestoreService.java
  8. 8 7
      server/src/test/java/org/elasticsearch/index/shard/IndexShardTests.java
  9. 2 1
      server/src/test/java/org/elasticsearch/repositories/RepositoriesServiceTests.java
  10. 9 5
      server/src/test/java/org/elasticsearch/repositories/fs/FsRepositoryTests.java
  11. 4 1
      test/framework/src/main/java/org/elasticsearch/index/shard/IndexShardTestCase.java
  12. 124 113
      x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/repository/CcrRepository.java
  13. 11 6
      x-pack/plugin/ccr/src/test/java/org/elasticsearch/xpack/ccr/action/ShardFollowTaskReplicationTests.java
  14. 9 7
      x-pack/plugin/ccr/src/test/java/org/elasticsearch/xpack/ccr/index/engine/FollowEngineIndexShardTests.java
  15. 5 2
      x-pack/plugin/core/src/test/java/org/elasticsearch/snapshots/SourceOnlySnapshotShardTests.java

+ 20 - 17
server/src/main/java/org/elasticsearch/index/shard/IndexShard.java

@@ -42,6 +42,7 @@ import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.ExceptionsHelper;
 import org.elasticsearch.ExceptionsHelper;
 import org.elasticsearch.Version;
 import org.elasticsearch.Version;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.ActionRunnable;
 import org.elasticsearch.action.admin.indices.flush.FlushRequest;
 import org.elasticsearch.action.admin.indices.flush.FlushRequest;
 import org.elasticsearch.action.admin.indices.forcemerge.ForceMergeRequest;
 import org.elasticsearch.action.admin.indices.forcemerge.ForceMergeRequest;
 import org.elasticsearch.action.admin.indices.upgrade.post.UpgradeRequest;
 import org.elasticsearch.action.admin.indices.upgrade.post.UpgradeRequest;
@@ -1816,12 +1817,16 @@ public class IndexShard extends AbstractIndexShardComponent implements IndicesCl
         return storeRecovery.recoverFromStore(this);
         return storeRecovery.recoverFromStore(this);
     }
     }
 
 
-    public boolean restoreFromRepository(Repository repository) {
-        assert shardRouting.primary() : "recover from store only makes sense if the shard is a primary shard";
-        assert recoveryState.getRecoverySource().getType() == RecoverySource.Type.SNAPSHOT : "invalid recovery type: " +
-            recoveryState.getRecoverySource();
-        StoreRecovery storeRecovery = new StoreRecovery(shardId, logger);
-        return storeRecovery.recoverFromRepository(this, repository);
+    public void restoreFromRepository(Repository repository, ActionListener<Boolean> listener) {
+        try {
+            assert shardRouting.primary() : "recover from store only makes sense if the shard is a primary shard";
+            assert recoveryState.getRecoverySource().getType() == RecoverySource.Type.SNAPSHOT : "invalid recovery type: " +
+                recoveryState.getRecoverySource();
+            StoreRecovery storeRecovery = new StoreRecovery(shardId, logger);
+            storeRecovery.recoverFromRepository(this, repository, listener);
+        } catch (Exception e) {
+            listener.onFailure(e);
+        }
     }
     }
 
 
     /**
     /**
@@ -2504,17 +2509,15 @@ public class IndexShard extends AbstractIndexShardComponent implements IndicesCl
             case SNAPSHOT:
             case SNAPSHOT:
                 markAsRecovering("from snapshot", recoveryState); // mark the shard as recovering on the cluster state thread
                 markAsRecovering("from snapshot", recoveryState); // mark the shard as recovering on the cluster state thread
                 SnapshotRecoverySource recoverySource = (SnapshotRecoverySource) recoveryState.getRecoverySource();
                 SnapshotRecoverySource recoverySource = (SnapshotRecoverySource) recoveryState.getRecoverySource();
-                threadPool.generic().execute(() -> {
-                    try {
-                        final Repository repository = repositoriesService.repository(recoverySource.snapshot().getRepository());
-                        if (restoreFromRepository(repository)) {
-                            recoveryListener.onRecoveryDone(recoveryState);
-                        }
-                    } catch (Exception e) {
-                        recoveryListener.onRecoveryFailure(recoveryState,
-                            new RecoveryFailedException(recoveryState, null, e), true);
-                    }
-                });
+                threadPool.generic().execute(
+                    ActionRunnable.<Boolean>wrap(ActionListener.wrap(r -> {
+                            if (r) {
+                                recoveryListener.onRecoveryDone(recoveryState);
+                            }
+                        },
+                        e -> recoveryListener.onRecoveryFailure(recoveryState, new RecoveryFailedException(recoveryState, null, e), true)),
+                        restoreListener -> restoreFromRepository(
+                            repositoriesService.repository(recoverySource.snapshot().getRepository()), restoreListener)));
                 break;
                 break;
             case LOCAL_SHARDS:
             case LOCAL_SHARDS:
                 final IndexMetaData indexMetaData = indexSettings().getIndexMetaData();
                 final IndexMetaData indexMetaData = indexSettings().getIndexMetaData();

+ 99 - 76
server/src/main/java/org/elasticsearch/index/shard/StoreRecovery.java

@@ -30,6 +30,8 @@ import org.apache.lucene.store.FilterDirectory;
 import org.apache.lucene.store.IOContext;
 import org.apache.lucene.store.IOContext;
 import org.apache.lucene.store.IndexInput;
 import org.apache.lucene.store.IndexInput;
 import org.elasticsearch.ExceptionsHelper;
 import org.elasticsearch.ExceptionsHelper;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.cluster.metadata.IndexMetaData;
 import org.elasticsearch.cluster.metadata.IndexMetaData;
 import org.elasticsearch.cluster.metadata.MappingMetaData;
 import org.elasticsearch.cluster.metadata.MappingMetaData;
 import org.elasticsearch.cluster.routing.RecoverySource;
 import org.elasticsearch.cluster.routing.RecoverySource;
@@ -88,10 +90,16 @@ final class StoreRecovery {
             RecoverySource.Type recoveryType = indexShard.recoveryState().getRecoverySource().getType();
             RecoverySource.Type recoveryType = indexShard.recoveryState().getRecoverySource().getType();
             assert recoveryType == RecoverySource.Type.EMPTY_STORE || recoveryType == RecoverySource.Type.EXISTING_STORE :
             assert recoveryType == RecoverySource.Type.EMPTY_STORE || recoveryType == RecoverySource.Type.EXISTING_STORE :
                 "expected store recovery type but was: " + recoveryType;
                 "expected store recovery type but was: " + recoveryType;
-            return executeRecovery(indexShard, () -> {
+            final PlainActionFuture<Boolean> future = PlainActionFuture.newFuture();
+            final ActionListener<Boolean> recoveryListener = recoveryListener(indexShard, future);
+            try {
                 logger.debug("starting recovery from store ...");
                 logger.debug("starting recovery from store ...");
                 internalRecoverFromStore(indexShard);
                 internalRecoverFromStore(indexShard);
-            });
+                recoveryListener.onResponse(true);
+            } catch (Exception e) {
+                recoveryListener.onFailure(e);
+            }
+            return future.actionGet();
         }
         }
         return false;
         return false;
     }
     }
@@ -117,14 +125,15 @@ final class StoreRecovery {
             Sort indexSort = indexShard.getIndexSort();
             Sort indexSort = indexShard.getIndexSort();
             final boolean hasNested = indexShard.mapperService().hasNested();
             final boolean hasNested = indexShard.mapperService().hasNested();
             final boolean isSplit = sourceMetaData.getNumberOfShards() < indexShard.indexSettings().getNumberOfShards();
             final boolean isSplit = sourceMetaData.getNumberOfShards() < indexShard.indexSettings().getNumberOfShards();
-            return executeRecovery(indexShard, () -> {
+            final PlainActionFuture<Boolean> future = PlainActionFuture.newFuture();
+            ActionListener.completeWith(recoveryListener(indexShard, future), () -> {
                 logger.debug("starting recovery from local shards {}", shards);
                 logger.debug("starting recovery from local shards {}", shards);
                 try {
                 try {
                     final Directory directory = indexShard.store().directory(); // don't close this directory!!
                     final Directory directory = indexShard.store().directory(); // don't close this directory!!
                     final Directory[] sources = shards.stream().map(LocalShardSnapshot::getSnapshotDirectory).toArray(Directory[]::new);
                     final Directory[] sources = shards.stream().map(LocalShardSnapshot::getSnapshotDirectory).toArray(Directory[]::new);
                     final long maxSeqNo = shards.stream().mapToLong(LocalShardSnapshot::maxSeqNo).max().getAsLong();
                     final long maxSeqNo = shards.stream().mapToLong(LocalShardSnapshot::maxSeqNo).max().getAsLong();
                     final long maxUnsafeAutoIdTimestamp =
                     final long maxUnsafeAutoIdTimestamp =
-                            shards.stream().mapToLong(LocalShardSnapshot::maxUnsafeAutoIdTimestamp).max().getAsLong();
+                        shards.stream().mapToLong(LocalShardSnapshot::maxUnsafeAutoIdTimestamp).max().getAsLong();
                     addIndices(indexShard.recoveryState().getIndex(), directory, indexSort, sources, maxSeqNo, maxUnsafeAutoIdTimestamp,
                     addIndices(indexShard.recoveryState().getIndex(), directory, indexSort, sources, maxSeqNo, maxUnsafeAutoIdTimestamp,
                         indexShard.indexSettings().getIndexMetaData(), indexShard.shardId().id(), isSplit, hasNested);
                         indexShard.indexSettings().getIndexMetaData(), indexShard.shardId().id(), isSplit, hasNested);
                     internalRecoverFromStore(indexShard);
                     internalRecoverFromStore(indexShard);
@@ -132,11 +141,13 @@ final class StoreRecovery {
                     // copied segments - we will also see them in stats etc.
                     // copied segments - we will also see them in stats etc.
                     indexShard.getEngine().forceMerge(false, -1, false,
                     indexShard.getEngine().forceMerge(false, -1, false,
                         false, false);
                         false, false);
+                    return true;
                 } catch (IOException ex) {
                 } catch (IOException ex) {
                     throw new IndexShardRecoveryException(indexShard.shardId(), "failed to recover from local shards", ex);
                     throw new IndexShardRecoveryException(indexShard.shardId(), "failed to recover from local shards", ex);
                 }
                 }
-
             });
             });
+            assert future.isDone();
+            return future.actionGet();
         }
         }
         return false;
         return false;
     }
     }
@@ -262,21 +273,22 @@ final class StoreRecovery {
      * previously created index snapshot into an existing initializing shard.
      * previously created index snapshot into an existing initializing shard.
      * @param indexShard the index shard instance to recovery the snapshot from
      * @param indexShard the index shard instance to recovery the snapshot from
      * @param repository the repository holding the physical files the shard should be recovered from
      * @param repository the repository holding the physical files the shard should be recovered from
-     * @return <code>true</code> if the shard has been recovered successfully, <code>false</code> if the recovery
-     * has been ignored due to a concurrent modification of if the clusters state has changed due to async updates.
+     * @param listener resolves to <code>true</code> if the shard has been recovered successfully, <code>false</code> if the recovery
+     *                 has been ignored due to a concurrent modification of if the clusters state has changed due to async updates.
      */
      */
-    boolean recoverFromRepository(final IndexShard indexShard, Repository repository) {
-        if (canRecover(indexShard)) {
-            RecoverySource.Type recoveryType = indexShard.recoveryState().getRecoverySource().getType();
-            assert recoveryType == RecoverySource.Type.SNAPSHOT : "expected snapshot recovery type: " + recoveryType;
-            SnapshotRecoverySource recoverySource = (SnapshotRecoverySource) indexShard.recoveryState().getRecoverySource();
-            return executeRecovery(indexShard, () -> {
-                logger.debug("restoring from {} ...", indexShard.recoveryState().getRecoverySource());
-                restore(indexShard, repository, recoverySource);
-            });
+    void recoverFromRepository(final IndexShard indexShard, Repository repository, ActionListener<Boolean> listener) {
+        try {
+            if (canRecover(indexShard)) {
+                RecoverySource.Type recoveryType = indexShard.recoveryState().getRecoverySource().getType();
+                assert recoveryType == RecoverySource.Type.SNAPSHOT : "expected snapshot recovery type: " + recoveryType;
+                SnapshotRecoverySource recoverySource = (SnapshotRecoverySource) indexShard.recoveryState().getRecoverySource();
+                restore(indexShard, repository, recoverySource, recoveryListener(indexShard, listener));
+            } else {
+                listener.onResponse(false);
+            }
+        } catch (Exception e) {
+            listener.onFailure(e);
         }
         }
-        return false;
-
     }
     }
 
 
     private boolean canRecover(IndexShard indexShard) {
     private boolean canRecover(IndexShard indexShard) {
@@ -290,59 +302,62 @@ final class StoreRecovery {
         return true;
         return true;
     }
     }
 
 
-    /**
-     * Recovers the state of the shard from the store.
-     */
-    private boolean executeRecovery(final IndexShard indexShard, Runnable recoveryRunnable) throws IndexShardRecoveryException {
-        try {
-            recoveryRunnable.run();
-            // Check that the gateway didn't leave the shard in init or recovering stage. it is up to the gateway
-            // to call post recovery.
-            final IndexShardState shardState = indexShard.state();
-            final RecoveryState recoveryState = indexShard.recoveryState();
-            assert shardState != IndexShardState.CREATED && shardState != IndexShardState.RECOVERING :
-                "recovery process of " + shardId + " didn't get to post_recovery. shardState [" + shardState + "]";
-
-            if (logger.isTraceEnabled()) {
-                RecoveryState.Index index = recoveryState.getIndex();
-                StringBuilder sb = new StringBuilder();
-                sb.append("    index    : files           [").append(index.totalFileCount()).append("] with total_size [")
+    private ActionListener<Boolean> recoveryListener(IndexShard indexShard, ActionListener<Boolean> listener) {
+        return ActionListener.wrap(res -> {
+            if (res) {
+                // Check that the gateway didn't leave the shard in init or recovering stage. it is up to the gateway
+                // to call post recovery.
+                final IndexShardState shardState = indexShard.state();
+                final RecoveryState recoveryState = indexShard.recoveryState();
+                assert shardState != IndexShardState.CREATED && shardState != IndexShardState.RECOVERING :
+                    "recovery process of " + shardId + " didn't get to post_recovery. shardState [" + shardState + "]";
+
+                if (logger.isTraceEnabled()) {
+                    RecoveryState.Index index = recoveryState.getIndex();
+                    StringBuilder sb = new StringBuilder();
+                    sb.append("    index    : files           [").append(index.totalFileCount()).append("] with total_size [")
                         .append(new ByteSizeValue(index.totalBytes())).append("], took[")
                         .append(new ByteSizeValue(index.totalBytes())).append("], took[")
                         .append(TimeValue.timeValueMillis(index.time())).append("]\n");
                         .append(TimeValue.timeValueMillis(index.time())).append("]\n");
-                sb.append("             : recovered_files [").append(index.recoveredFileCount()).append("] with total_size [")
+                    sb.append("             : recovered_files [").append(index.recoveredFileCount()).append("] with total_size [")
                         .append(new ByteSizeValue(index.recoveredBytes())).append("]\n");
                         .append(new ByteSizeValue(index.recoveredBytes())).append("]\n");
-                sb.append("             : reusing_files   [").append(index.reusedFileCount()).append("] with total_size [")
+                    sb.append("             : reusing_files   [").append(index.reusedFileCount()).append("] with total_size [")
                         .append(new ByteSizeValue(index.reusedBytes())).append("]\n");
                         .append(new ByteSizeValue(index.reusedBytes())).append("]\n");
-                sb.append("    verify_index    : took [")
-                    .append(TimeValue.timeValueMillis(recoveryState.getVerifyIndex().time())).append("], check_index [")
-                    .append(timeValueMillis(recoveryState.getVerifyIndex().checkIndexTime())).append("]\n");
-                sb.append("    translog : number_of_operations [").append(recoveryState.getTranslog().recoveredOperations())
+                    sb.append("    verify_index    : took [")
+                        .append(TimeValue.timeValueMillis(recoveryState.getVerifyIndex().time())).append("], check_index [")
+                        .append(timeValueMillis(recoveryState.getVerifyIndex().checkIndexTime())).append("]\n");
+                    sb.append("    translog : number_of_operations [").append(recoveryState.getTranslog().recoveredOperations())
                         .append("], took [").append(TimeValue.timeValueMillis(recoveryState.getTranslog().time())).append("]");
                         .append("], took [").append(TimeValue.timeValueMillis(recoveryState.getTranslog().time())).append("]");
-                logger.trace("recovery completed from [shard_store], took [{}]\n{}",
-                    timeValueMillis(recoveryState.getTimer().time()), sb);
-            } else if (logger.isDebugEnabled()) {
-                logger.debug("recovery completed from [shard_store], took [{}]", timeValueMillis(recoveryState.getTimer().time()));
-            }
-            return true;
-        } catch (IndexShardRecoveryException e) {
-            if (indexShard.state() == IndexShardState.CLOSED) {
-                // got closed on us, just ignore this recovery
-                return false;
-            }
-            if ((e.getCause() instanceof IndexShardClosedException) || (e.getCause() instanceof IndexShardNotStartedException)) {
-                // got closed on us, just ignore this recovery
-                return false;
+                    logger.trace("recovery completed from [shard_store], took [{}]\n{}",
+                        timeValueMillis(recoveryState.getTimer().time()), sb);
+                } else if (logger.isDebugEnabled()) {
+                    logger.debug("recovery completed from [shard_store], took [{}]", timeValueMillis(recoveryState.getTimer().time()));
+                }
             }
             }
-            throw e;
-        } catch (IndexShardClosedException | IndexShardNotStartedException e) {
-        } catch (Exception e) {
-            if (indexShard.state() == IndexShardState.CLOSED) {
-                // got closed on us, just ignore this recovery
-                return false;
+            listener.onResponse(res);
+        }, ex -> {
+            if (ex instanceof IndexShardRecoveryException) {
+                if (indexShard.state() == IndexShardState.CLOSED) {
+                    // got closed on us, just ignore this recovery
+                    listener.onResponse(false);
+                    return;
+                }
+                if ((ex.getCause() instanceof IndexShardClosedException) || (ex.getCause() instanceof IndexShardNotStartedException)) {
+                    // got closed on us, just ignore this recovery
+                    listener.onResponse(false);
+                    return;
+                }
+                listener.onFailure(ex);
+            } else if (ex instanceof IndexShardClosedException || ex instanceof IndexShardNotStartedException) {
+                listener.onResponse(false);
+            } else {
+                if (indexShard.state() == IndexShardState.CLOSED) {
+                    // got closed on us, just ignore this recovery
+                    listener.onResponse(false);
+                } else {
+                    listener.onFailure(new IndexShardRecoveryException(shardId, "failed recovery", ex));
+                }
             }
             }
-            throw new IndexShardRecoveryException(shardId, "failed recovery", e);
-        }
-        return false;
+        });
     }
     }
 
 
     /**
     /**
@@ -436,14 +451,30 @@ final class StoreRecovery {
     /**
     /**
      * Restores shard from {@link SnapshotRecoverySource} associated with this shard in routing table
      * Restores shard from {@link SnapshotRecoverySource} associated with this shard in routing table
      */
      */
-    private void restore(final IndexShard indexShard, final Repository repository, final SnapshotRecoverySource restoreSource) {
+    private void restore(IndexShard indexShard, Repository repository, SnapshotRecoverySource restoreSource,
+                         ActionListener<Boolean> listener) {
+        logger.debug("restoring from {} ...", indexShard.recoveryState().getRecoverySource());
         final RecoveryState.Translog translogState = indexShard.recoveryState().getTranslog();
         final RecoveryState.Translog translogState = indexShard.recoveryState().getTranslog();
         if (restoreSource == null) {
         if (restoreSource == null) {
-            throw new IndexShardRestoreFailedException(shardId, "empty restore source");
+            listener.onFailure(new IndexShardRestoreFailedException(shardId, "empty restore source"));
+            return;
         }
         }
         if (logger.isTraceEnabled()) {
         if (logger.isTraceEnabled()) {
             logger.trace("[{}] restoring shard [{}]", restoreSource.snapshot(), shardId);
             logger.trace("[{}] restoring shard [{}]", restoreSource.snapshot(), shardId);
         }
         }
+        final ActionListener<Void> restoreListener = ActionListener.wrap(
+            v -> {
+                final Store store = indexShard.store();
+                bootstrap(indexShard, store);
+                assert indexShard.shardRouting.primary() : "only primary shards can recover from store";
+                writeEmptyRetentionLeasesFile(indexShard);
+                indexShard.openEngineAndRecoverFromTranslog();
+                indexShard.getEngine().fillSeqNoGaps(indexShard.getPendingPrimaryTerm());
+                indexShard.finalizeRecovery();
+                indexShard.postRecovery("restore done");
+                listener.onResponse(true);
+            }, e -> listener.onFailure(new IndexShardRestoreFailedException(shardId, "restore failed", e))
+        );
         try {
         try {
             translogState.totalOperations(0);
             translogState.totalOperations(0);
             translogState.totalOperationsOnStart(0);
             translogState.totalOperationsOnStart(0);
@@ -456,17 +487,9 @@ final class StoreRecovery {
             final IndexId indexId = repository.getRepositoryData().resolveIndexId(indexName);
             final IndexId indexId = repository.getRepositoryData().resolveIndexId(indexName);
             assert indexShard.getEngineOrNull() == null;
             assert indexShard.getEngineOrNull() == null;
             repository.restoreShard(indexShard.store(), restoreSource.snapshot().getSnapshotId(), indexId, snapshotShardId,
             repository.restoreShard(indexShard.store(), restoreSource.snapshot().getSnapshotId(), indexId, snapshotShardId,
-                indexShard.recoveryState());
-            final Store store = indexShard.store();
-            bootstrap(indexShard, store);
-            assert indexShard.shardRouting.primary() : "only primary shards can recover from store";
-            writeEmptyRetentionLeasesFile(indexShard);
-            indexShard.openEngineAndRecoverFromTranslog();
-            indexShard.getEngine().fillSeqNoGaps(indexShard.getPendingPrimaryTerm());
-            indexShard.finalizeRecovery();
-            indexShard.postRecovery("restore done");
+                indexShard.recoveryState(), restoreListener);
         } catch (Exception e) {
         } catch (Exception e) {
-            throw new IndexShardRestoreFailedException(shardId, "restore failed", e);
+            restoreListener.onFailure(e);
         }
         }
     }
     }
 
 

+ 3 - 2
server/src/main/java/org/elasticsearch/repositories/FilterRepository.java

@@ -123,8 +123,9 @@ public class FilterRepository implements Repository {
         in.snapshotShard(store, mapperService, snapshotId, indexId, snapshotIndexCommit, snapshotStatus, writeShardGens, listener);
         in.snapshotShard(store, mapperService, snapshotId, indexId, snapshotIndexCommit, snapshotStatus, writeShardGens, listener);
     }
     }
     @Override
     @Override
-    public void restoreShard(Store store, SnapshotId snapshotId, IndexId indexId, ShardId snapshotShardId, RecoveryState recoveryState) {
-        in.restoreShard(store, snapshotId, indexId, snapshotShardId, recoveryState);
+    public void restoreShard(Store store, SnapshotId snapshotId, IndexId indexId, ShardId snapshotShardId, RecoveryState recoveryState,
+                             ActionListener<Void> listener) {
+        in.restoreShard(store, snapshotId, indexId, snapshotShardId, recoveryState, listener);
     }
     }
 
 
     @Override
     @Override

+ 3 - 2
server/src/main/java/org/elasticsearch/repositories/Repository.java

@@ -211,9 +211,10 @@ public interface Repository extends LifecycleComponent {
      * @param indexId         id of the index in the repository from which the restore is occurring
      * @param indexId         id of the index in the repository from which the restore is occurring
      * @param snapshotShardId shard id (in the snapshot)
      * @param snapshotShardId shard id (in the snapshot)
      * @param recoveryState   recovery state
      * @param recoveryState   recovery state
+     * @param listener        listener to invoke once done
      */
      */
-    void restoreShard(Store store, SnapshotId snapshotId, IndexId indexId, ShardId snapshotShardId, RecoveryState recoveryState);
-
+    void restoreShard(Store store, SnapshotId snapshotId, IndexId indexId, ShardId snapshotShardId, RecoveryState recoveryState,
+                      ActionListener<Void> listener);
     /**
     /**
      * Retrieve shard snapshot status for the stored snapshot
      * Retrieve shard snapshot status for the stored snapshot
      *
      *

+ 45 - 20
server/src/main/java/org/elasticsearch/repositories/blobstore/BlobStoreRepository.java

@@ -1195,11 +1195,7 @@ public abstract class BlobStoreRepository extends AbstractLifecycleComponent imp
             final Executor executor = threadPool.executor(ThreadPool.Names.SNAPSHOT);
             final Executor executor = threadPool.executor(ThreadPool.Names.SNAPSHOT);
             // Start as many workers as fit into the snapshot pool at once at the most
             // Start as many workers as fit into the snapshot pool at once at the most
             final int workers = Math.min(threadPool.info(ThreadPool.Names.SNAPSHOT).getMax(), indexIncrementalFileCount);
             final int workers = Math.min(threadPool.info(ThreadPool.Names.SNAPSHOT).getMax(), indexIncrementalFileCount);
-            final ActionListener<Void> filesListener = ActionListener.delegateResponse(
-                new GroupedActionListener<>(allFilesUploadedListener, workers), (l, e) -> {
-                filesToSnapshot.clear(); // Stop uploading the remaining files if we run into any exception
-                l.onFailure(e);
-            });
+            final ActionListener<Void> filesListener = fileQueueListener(filesToSnapshot, workers, allFilesUploadedListener);
             for (int i = 0; i < workers; ++i) {
             for (int i = 0; i < workers; ++i) {
                 executor.execute(ActionRunnable.run(filesListener, () -> {
                 executor.execute(ActionRunnable.run(filesListener, () -> {
                     BlobStoreIndexShardSnapshot.FileInfo snapshotFileInfo = filesToSnapshot.poll(0L, TimeUnit.MILLISECONDS);
                     BlobStoreIndexShardSnapshot.FileInfo snapshotFileInfo = filesToSnapshot.poll(0L, TimeUnit.MILLISECONDS);
@@ -1223,19 +1219,42 @@ public abstract class BlobStoreRepository extends AbstractLifecycleComponent imp
 
 
     @Override
     @Override
     public void restoreShard(Store store, SnapshotId snapshotId, IndexId indexId, ShardId snapshotShardId,
     public void restoreShard(Store store, SnapshotId snapshotId, IndexId indexId, ShardId snapshotShardId,
-                             RecoveryState recoveryState) {
-        ShardId shardId = store.shardId();
-        try {
-            final BlobContainer container = shardContainer(indexId, snapshotShardId);
-            BlobStoreIndexShardSnapshot snapshot = loadShardSnapshot(container, snapshotId);
-            SnapshotFiles snapshotFiles = new SnapshotFiles(snapshot.snapshot(), snapshot.indexFiles());
+                             RecoveryState recoveryState, ActionListener<Void> listener) {
+        final ShardId shardId = store.shardId();
+        final ActionListener<Void> restoreListener = ActionListener.delegateResponse(listener,
+            (l, e) -> l.onFailure(new IndexShardRestoreFailedException(shardId, "failed to restore snapshot [" + snapshotId + "]", e)));
+        final Executor executor = threadPool.executor(ThreadPool.Names.SNAPSHOT);
+        final BlobContainer container = shardContainer(indexId, snapshotShardId);
+        executor.execute(ActionRunnable.wrap(restoreListener, l -> {
+            final BlobStoreIndexShardSnapshot snapshot = loadShardSnapshot(container, snapshotId);
+            final SnapshotFiles snapshotFiles = new SnapshotFiles(snapshot.snapshot(), snapshot.indexFiles());
             new FileRestoreContext(metadata.name(), shardId, snapshotId, recoveryState) {
             new FileRestoreContext(metadata.name(), shardId, snapshotId, recoveryState) {
                 @Override
                 @Override
-                protected void restoreFiles(List<BlobStoreIndexShardSnapshot.FileInfo> filesToRecover, Store store) throws IOException {
-                    // restore the files from the snapshot to the Lucene store
-                    for (final BlobStoreIndexShardSnapshot.FileInfo fileToRecover : filesToRecover) {
-                        logger.trace("[{}] [{}] restoring file [{}]", shardId, snapshotId, fileToRecover.name());
-                        restoreFile(fileToRecover, store);
+                protected void restoreFiles(List<BlobStoreIndexShardSnapshot.FileInfo> filesToRecover, Store store,
+                                            ActionListener<Void> listener) {
+                    if (filesToRecover.isEmpty()) {
+                        listener.onResponse(null);
+                    } else {
+                        // Start as many workers as fit into the snapshot pool at once at the most
+                        final int workers =
+                            Math.min(threadPool.info(ThreadPool.Names.SNAPSHOT).getMax(), snapshotFiles.indexFiles().size());
+                        final BlockingQueue<BlobStoreIndexShardSnapshot.FileInfo> files = new LinkedBlockingQueue<>(filesToRecover);
+                        final ActionListener<Void> allFilesListener =
+                            fileQueueListener(files, workers, ActionListener.map(listener, v -> null));
+                        // restore the files from the snapshot to the Lucene store
+                        for (int i = 0; i < workers; ++i) {
+                            executor.execute(ActionRunnable.run(allFilesListener, () -> {
+                                store.incRef();
+                                try {
+                                    BlobStoreIndexShardSnapshot.FileInfo fileToRecover;
+                                    while ((fileToRecover = files.poll(0L, TimeUnit.MILLISECONDS)) != null) {
+                                        restoreFile(fileToRecover, store);
+                                    }
+                                } finally {
+                                    store.decRef();
+                                }
+                            }));
+                        }
                     }
                     }
                 }
                 }
 
 
@@ -1275,10 +1294,16 @@ public abstract class BlobStoreRepository extends AbstractLifecycleComponent imp
                         }
                         }
                     }
                     }
                 }
                 }
-            }.restore(snapshotFiles, store);
-        } catch (Exception e) {
-            throw new IndexShardRestoreFailedException(shardId, "failed to restore snapshot [" + snapshotId + "]", e);
-        }
+            }.restore(snapshotFiles, store, l);
+        }));
+    }
+
+    private static ActionListener<Void> fileQueueListener(BlockingQueue<BlobStoreIndexShardSnapshot.FileInfo> files, int workers,
+                                                          ActionListener<Collection<Void>> listener) {
+        return ActionListener.delegateResponse(new GroupedActionListener<>(listener, workers), (l, e) -> {
+            files.clear(); // Stop uploading the remaining files if we run into any exception
+            l.onFailure(e);
+        });
     }
     }
 
 
     private static InputStream maybeRateLimit(InputStream stream, @Nullable RateLimiter rateLimiter, CounterMetric metric) {
     private static InputStream maybeRateLimit(InputStream stream, @Nullable RateLimiter rateLimiter, CounterMetric metric) {

+ 40 - 25
server/src/main/java/org/elasticsearch/repositories/blobstore/FileRestoreContext.java

@@ -21,6 +21,7 @@ package org.elasticsearch.repositories.blobstore;
 import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 import org.apache.logging.log4j.Logger;
 import org.apache.logging.log4j.message.ParameterizedMessage;
 import org.apache.logging.log4j.message.ParameterizedMessage;
+import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.common.lucene.Lucene;
 import org.elasticsearch.common.lucene.Lucene;
 import org.elasticsearch.common.util.iterable.Iterables;
 import org.elasticsearch.common.util.iterable.Iterables;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.index.shard.ShardId;
@@ -74,7 +75,7 @@ public abstract class FileRestoreContext {
     /**
     /**
      * Performs restore operation
      * Performs restore operation
      */
      */
-    public void restore(SnapshotFiles snapshotFiles, Store store) {
+    public void restore(SnapshotFiles snapshotFiles, Store store, ActionListener<Void> listener) {
         store.incRef();
         store.incRef();
         try {
         try {
             logger.debug("[{}] [{}] restoring to [{}] ...", snapshotId, repositoryName, shardId);
             logger.debug("[{}] [{}] restoring to [{}] ...", snapshotId, repositoryName, shardId);
@@ -150,36 +151,49 @@ public abstract class FileRestoreContext {
                     }
                     }
                 }
                 }
 
 
-                restoreFiles(filesToRecover, store);
+                restoreFiles(filesToRecover, store, ActionListener.wrap(
+                    v -> {
+                        store.incRef();
+                        try {
+                            afterRestore(snapshotFiles, store, restoredSegmentsFile);
+                            listener.onResponse(null);
+                        } finally {
+                            store.decRef();
+                        }
+                    }, listener::onFailure));
             } catch (IOException ex) {
             } catch (IOException ex) {
                 throw new IndexShardRestoreFailedException(shardId, "Failed to recover index", ex);
                 throw new IndexShardRestoreFailedException(shardId, "Failed to recover index", ex);
             }
             }
+        } catch (Exception e) {
+            listener.onFailure(e);
+        } finally {
+            store.decRef();
+        }
+    }
 
 
-            // read the snapshot data persisted
-            try {
-                Lucene.pruneUnreferencedFiles(restoredSegmentsFile.name(), store.directory());
-            } catch (IOException e) {
-                throw new IndexShardRestoreFailedException(shardId, "Failed to fetch index version after copying it over", e);
-            }
+    private void afterRestore(SnapshotFiles snapshotFiles, Store store, StoreFileMetaData restoredSegmentsFile) {
+        // read the snapshot data persisted
+        try {
+            Lucene.pruneUnreferencedFiles(restoredSegmentsFile.name(), store.directory());
+        } catch (IOException e) {
+            throw new IndexShardRestoreFailedException(shardId, "Failed to fetch index version after copying it over", e);
+        }
 
 
-            /// now, go over and clean files that are in the store, but were not in the snapshot
-            try {
-                for (String storeFile : store.directory().listAll()) {
-                    if (Store.isAutogenerated(storeFile) || snapshotFiles.containPhysicalIndexFile(storeFile)) {
-                        continue; //skip write.lock, checksum files and files that exist in the snapshot
-                    }
-                    try {
-                        store.deleteQuiet("restore", storeFile);
-                        store.directory().deleteFile(storeFile);
-                    } catch (IOException e) {
-                        logger.warn("[{}] [{}] failed to delete file [{}] during snapshot cleanup", shardId, snapshotId, storeFile);
-                    }
+        /// now, go over and clean files that are in the store, but were not in the snapshot
+        try {
+            for (String storeFile : store.directory().listAll()) {
+                if (Store.isAutogenerated(storeFile) || snapshotFiles.containPhysicalIndexFile(storeFile)) {
+                    continue; //skip write.lock, checksum files and files that exist in the snapshot
+                }
+                try {
+                    store.deleteQuiet("restore", storeFile);
+                    store.directory().deleteFile(storeFile);
+                } catch (IOException e) {
+                    logger.warn("[{}] [{}] failed to delete file [{}] during snapshot cleanup", shardId, snapshotId, storeFile);
                 }
                 }
-            } catch (IOException e) {
-                logger.warn("[{}] [{}] failed to list directory - some of files might not be deleted", shardId, snapshotId);
             }
             }
-        } finally {
-            store.decRef();
+        } catch (IOException e) {
+            logger.warn("[{}] [{}] failed to list directory - some of files might not be deleted", shardId, snapshotId);
         }
         }
     }
     }
 
 
@@ -189,7 +203,8 @@ public abstract class FileRestoreContext {
      * @param filesToRecover List of files to restore
      * @param filesToRecover List of files to restore
      * @param store          Store to restore into
      * @param store          Store to restore into
      */
      */
-    protected abstract void restoreFiles(List<BlobStoreIndexShardSnapshot.FileInfo> filesToRecover, Store store) throws IOException;
+    protected abstract void restoreFiles(List<BlobStoreIndexShardSnapshot.FileInfo> filesToRecover, Store store,
+                                         ActionListener<Void> listener);
 
 
     @SuppressWarnings("unchecked")
     @SuppressWarnings("unchecked")
     private static Iterable<StoreFileMetaData> concat(Store.RecoveryDiff diff) {
     private static Iterable<StoreFileMetaData> concat(Store.RecoveryDiff diff) {

+ 1 - 1
server/src/main/java/org/elasticsearch/snapshots/RestoreService.java

@@ -107,7 +107,7 @@ import static org.elasticsearch.snapshots.SnapshotUtils.filterIndices;
  * {@link RoutingTable.Builder#addAsRestore(IndexMetaData, SnapshotRecoverySource)} method.
  * {@link RoutingTable.Builder#addAsRestore(IndexMetaData, SnapshotRecoverySource)} method.
  * <p>
  * <p>
  * Individual shards are getting restored as part of normal recovery process in
  * Individual shards are getting restored as part of normal recovery process in
- * {@link IndexShard#restoreFromRepository(Repository)} )}
+ * {@link IndexShard#restoreFromRepository} )}
  * method, which detects that shard should be restored from snapshot rather than recovered from gateway by looking
  * method, which detects that shard should be restored from snapshot rather than recovered from gateway by looking
  * at the {@link ShardRouting#recoverySource()} property.
  * at the {@link ShardRouting#recoverySource()} property.
  * <p>
  * <p>

+ 8 - 7
server/src/test/java/org/elasticsearch/index/shard/IndexShardTests.java

@@ -2338,11 +2338,12 @@ public class IndexShardTests extends IndexShardTestCase {
 
 
         DiscoveryNode localNode = new DiscoveryNode("foo", buildNewFakeTransportAddress(), emptyMap(), emptySet(), Version.CURRENT);
         DiscoveryNode localNode = new DiscoveryNode("foo", buildNewFakeTransportAddress(), emptyMap(), emptySet(), Version.CURRENT);
         target.markAsRecovering("store", new RecoveryState(routing, localNode, null));
         target.markAsRecovering("store", new RecoveryState(routing, localNode, null));
-        assertTrue(target.restoreFromRepository(new RestoreOnlyRepository("test") {
+        final PlainActionFuture<Boolean> future = PlainActionFuture.newFuture();
+        target.restoreFromRepository(new RestoreOnlyRepository("test") {
             @Override
             @Override
             public void restoreShard(Store store, SnapshotId snapshotId, IndexId indexId, ShardId snapshotShardId,
             public void restoreShard(Store store, SnapshotId snapshotId, IndexId indexId, ShardId snapshotShardId,
-                                     RecoveryState recoveryState) {
-                try {
+                                     RecoveryState recoveryState, ActionListener<Void> listener) {
+                ActionListener.completeWith(listener, () -> {
                     cleanLuceneIndex(targetStore.directory());
                     cleanLuceneIndex(targetStore.directory());
                     for (String file : sourceStore.directory().listAll()) {
                     for (String file : sourceStore.directory().listAll()) {
                         if (file.equals("write.lock") || file.startsWith("extra")) {
                         if (file.equals("write.lock") || file.startsWith("extra")) {
@@ -2350,11 +2351,11 @@ public class IndexShardTests extends IndexShardTestCase {
                         }
                         }
                         targetStore.directory().copyFrom(sourceStore.directory(), file, file, IOContext.DEFAULT);
                         targetStore.directory().copyFrom(sourceStore.directory(), file, file, IOContext.DEFAULT);
                     }
                     }
-                } catch (Exception ex) {
-                    throw new RuntimeException(ex);
-                }
+                    return null;
+                });
             }
             }
-        }));
+        }, future);
+        assertTrue(future.actionGet());
         assertThat(target.getLocalCheckpoint(), equalTo(2L));
         assertThat(target.getLocalCheckpoint(), equalTo(2L));
         assertThat(target.seqNoStats().getMaxSeqNo(), equalTo(2L));
         assertThat(target.seqNoStats().getMaxSeqNo(), equalTo(2L));
         assertThat(target.seqNoStats().getGlobalCheckpoint(), equalTo(0L));
         assertThat(target.seqNoStats().getGlobalCheckpoint(), equalTo(0L));

+ 2 - 1
server/src/test/java/org/elasticsearch/repositories/RepositoriesServiceTests.java

@@ -205,7 +205,8 @@ public class RepositoriesServiceTests extends ESTestCase {
 
 
         @Override
         @Override
         public void restoreShard(Store store, SnapshotId snapshotId, IndexId indexId, ShardId snapshotShardId,
         public void restoreShard(Store store, SnapshotId snapshotId, IndexId indexId, ShardId snapshotShardId,
-                                 RecoveryState recoveryState) {
+                                 RecoveryState recoveryState, ActionListener<Void> listener) {
+
         }
         }
 
 
         @Override
         @Override

+ 9 - 5
server/src/test/java/org/elasticsearch/repositories/fs/FsRepositoryTests.java

@@ -117,7 +117,9 @@ public class FsRepositoryTests extends ESTestCase {
                 new UnassignedInfo(UnassignedInfo.Reason.EXISTING_INDEX_RESTORED, ""));
                 new UnassignedInfo(UnassignedInfo.Reason.EXISTING_INDEX_RESTORED, ""));
             routing = ShardRoutingHelper.initialize(routing, localNode.getId(), 0);
             routing = ShardRoutingHelper.initialize(routing, localNode.getId(), 0);
             RecoveryState state = new RecoveryState(routing, localNode, null);
             RecoveryState state = new RecoveryState(routing, localNode, null);
-            runGeneric(threadPool, () -> repository.restoreShard(store, snapshotId, indexId, shardId, state));
+            final PlainActionFuture<Void> futureA = PlainActionFuture.newFuture();
+            runGeneric(threadPool, () -> repository.restoreShard(store, snapshotId, indexId, shardId, state, futureA));
+            futureA.actionGet();
             assertTrue(state.getIndex().recoveredBytes() > 0);
             assertTrue(state.getIndex().recoveredBytes() > 0);
             assertEquals(0, state.getIndex().reusedFileCount());
             assertEquals(0, state.getIndex().reusedFileCount());
             assertEquals(indexCommit.getFileNames().size(), state.getIndex().recoveredFileCount());
             assertEquals(indexCommit.getFileNames().size(), state.getIndex().recoveredFileCount());
@@ -138,14 +140,16 @@ public class FsRepositoryTests extends ESTestCase {
 
 
             // roll back to the first snap and then incrementally restore
             // roll back to the first snap and then incrementally restore
             RecoveryState firstState = new RecoveryState(routing, localNode, null);
             RecoveryState firstState = new RecoveryState(routing, localNode, null);
-            runGeneric(threadPool, () ->
-                repository.restoreShard(store, snapshotId, indexId, shardId, firstState));
+            final PlainActionFuture<Void> futureB =  PlainActionFuture.newFuture();
+            runGeneric(threadPool, () -> repository.restoreShard(store, snapshotId, indexId, shardId, firstState, futureB));
+            futureB.actionGet();
             assertEquals("should reuse everything except of .liv and .si",
             assertEquals("should reuse everything except of .liv and .si",
                 commitFileNames.size()-2, firstState.getIndex().reusedFileCount());
                 commitFileNames.size()-2, firstState.getIndex().reusedFileCount());
 
 
             RecoveryState secondState = new RecoveryState(routing, localNode, null);
             RecoveryState secondState = new RecoveryState(routing, localNode, null);
-            runGeneric(threadPool, () ->
-                repository.restoreShard(store, incSnapshotId, indexId, shardId, secondState));
+            final PlainActionFuture<Void> futureC = PlainActionFuture.newFuture();
+            runGeneric(threadPool, () -> repository.restoreShard(store, incSnapshotId, indexId, shardId, secondState, futureC));
+            futureC.actionGet();
             assertEquals(secondState.getIndex().reusedFileCount(), commitFileNames.size()-2);
             assertEquals(secondState.getIndex().reusedFileCount(), commitFileNames.size()-2);
             assertEquals(secondState.getIndex().recoveredFileCount(), 2);
             assertEquals(secondState.getIndex().recoveredFileCount(), 2);
             List<RecoveryState.File> recoveredFiles =
             List<RecoveryState.File> recoveredFiles =

+ 4 - 1
test/framework/src/main/java/org/elasticsearch/index/shard/IndexShardTestCase.java

@@ -805,11 +805,14 @@ public abstract class IndexShardTestCase extends ESTestCase {
             new RecoverySource.SnapshotRecoverySource(UUIDs.randomBase64UUID(), snapshot, version, index);
             new RecoverySource.SnapshotRecoverySource(UUIDs.randomBase64UUID(), snapshot, version, index);
         final ShardRouting shardRouting = newShardRouting(shardId, node.getId(), true, ShardRoutingState.INITIALIZING, recoverySource);
         final ShardRouting shardRouting = newShardRouting(shardId, node.getId(), true, ShardRoutingState.INITIALIZING, recoverySource);
         shard.markAsRecovering("from snapshot", new RecoveryState(shardRouting, node, null));
         shard.markAsRecovering("from snapshot", new RecoveryState(shardRouting, node, null));
+        final PlainActionFuture<Void> future = PlainActionFuture.newFuture();
         repository.restoreShard(shard.store(),
         repository.restoreShard(shard.store(),
             snapshot.getSnapshotId(),
             snapshot.getSnapshotId(),
             indexId,
             indexId,
             shard.shardId(),
             shard.shardId(),
-            shard.recoveryState());
+            shard.recoveryState(),
+            future);
+        future.actionGet();
     }
     }
 
 
     /**
     /**

+ 124 - 113
x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/repository/CcrRepository.java

@@ -291,26 +291,29 @@ public class CcrRepository extends AbstractLifecycleComponent implements Reposit
     }
     }
 
 
     @Override
     @Override
-    public void restoreShard(Store store, SnapshotId snapshotId, IndexId indexId, ShardId snapshotShardId, RecoveryState recoveryState) {
-        // TODO: Add timeouts to network calls / the restore process.
-        createEmptyStore(store);
-        ShardId shardId = store.shardId();
-
-        final Map<String, String> ccrMetaData = store.indexSettings().getIndexMetaData().getCustomData(Ccr.CCR_CUSTOM_METADATA_KEY);
-        final String leaderIndexName = ccrMetaData.get(Ccr.CCR_CUSTOM_METADATA_LEADER_INDEX_NAME_KEY);
-        final String leaderUUID = ccrMetaData.get(Ccr.CCR_CUSTOM_METADATA_LEADER_INDEX_UUID_KEY);
-        final Index leaderIndex = new Index(leaderIndexName, leaderUUID);
-        final ShardId leaderShardId = new ShardId(leaderIndex, shardId.getId());
-
-        final Client remoteClient = getRemoteClusterClient();
-
-        final String retentionLeaseId =
+    public void restoreShard(Store store, SnapshotId snapshotId, IndexId indexId, ShardId snapshotShardId, RecoveryState recoveryState,
+                             ActionListener<Void> listener) {
+        // TODO: Instead of blocking in the restore logic and synchronously completing the listener we should just make below logic async
+        ActionListener.completeWith(listener, () -> {
+            // TODO: Add timeouts to network calls / the restore process.
+            createEmptyStore(store);
+            ShardId shardId = store.shardId();
+
+            final Map<String, String> ccrMetaData = store.indexSettings().getIndexMetaData().getCustomData(Ccr.CCR_CUSTOM_METADATA_KEY);
+            final String leaderIndexName = ccrMetaData.get(Ccr.CCR_CUSTOM_METADATA_LEADER_INDEX_NAME_KEY);
+            final String leaderUUID = ccrMetaData.get(Ccr.CCR_CUSTOM_METADATA_LEADER_INDEX_UUID_KEY);
+            final Index leaderIndex = new Index(leaderIndexName, leaderUUID);
+            final ShardId leaderShardId = new ShardId(leaderIndex, shardId.getId());
+
+            final Client remoteClient = getRemoteClusterClient();
+
+            final String retentionLeaseId =
                 retentionLeaseId(localClusterName, shardId.getIndex(), remoteClusterAlias, leaderIndex);
                 retentionLeaseId(localClusterName, shardId.getIndex(), remoteClusterAlias, leaderIndex);
 
 
-        acquireRetentionLeaseOnLeader(shardId, retentionLeaseId, leaderShardId, remoteClient);
+            acquireRetentionLeaseOnLeader(shardId, retentionLeaseId, leaderShardId, remoteClient);
 
 
-        // schedule renewals to run during the restore
-        final Scheduler.Cancellable renewable = threadPool.scheduleWithFixedDelay(
+            // schedule renewals to run during the restore
+            final Scheduler.Cancellable renewable = threadPool.scheduleWithFixedDelay(
                 () -> {
                 () -> {
                     logger.trace("{} background renewal of retention lease [{}] during restore", shardId, retentionLeaseId);
                     logger.trace("{} background renewal of retention lease [{}] during restore", shardId, retentionLeaseId);
                     final ThreadContext threadContext = threadPool.getThreadContext();
                     final ThreadContext threadContext = threadPool.getThreadContext();
@@ -318,38 +321,40 @@ public class CcrRepository extends AbstractLifecycleComponent implements Reposit
                         // we have to execute under the system context so that if security is enabled the renewal is authorized
                         // we have to execute under the system context so that if security is enabled the renewal is authorized
                         threadContext.markAsSystemContext();
                         threadContext.markAsSystemContext();
                         CcrRetentionLeases.asyncRenewRetentionLease(
                         CcrRetentionLeases.asyncRenewRetentionLease(
-                                leaderShardId,
-                                retentionLeaseId,
-                                RETAIN_ALL,
-                                remoteClient,
-                                ActionListener.wrap(
-                                        r -> {},
-                                        e -> {
-                                            final Throwable cause = ExceptionsHelper.unwrapCause(e);
-                                            assert cause instanceof ElasticsearchSecurityException == false : cause;
-                                            if (cause instanceof RetentionLeaseInvalidRetainingSeqNoException == false) {
-                                                logger.warn(new ParameterizedMessage(
-                                                    "{} background renewal of retention lease [{}] failed during restore", shardId,
-                                                    retentionLeaseId), cause);
-                                            }
-                                        }));
+                            leaderShardId,
+                            retentionLeaseId,
+                            RETAIN_ALL,
+                            remoteClient,
+                            ActionListener.wrap(
+                                r -> {},
+                                e -> {
+                                    final Throwable cause = ExceptionsHelper.unwrapCause(e);
+                                    assert cause instanceof ElasticsearchSecurityException == false : cause;
+                                    if (cause instanceof RetentionLeaseInvalidRetainingSeqNoException == false) {
+                                        logger.warn(new ParameterizedMessage(
+                                            "{} background renewal of retention lease [{}] failed during restore", shardId,
+                                            retentionLeaseId), cause);
+                                    }
+                                }));
                     }
                     }
                 },
                 },
                 CcrRetentionLeases.RETENTION_LEASE_RENEW_INTERVAL_SETTING.get(store.indexSettings().getNodeSettings()),
                 CcrRetentionLeases.RETENTION_LEASE_RENEW_INTERVAL_SETTING.get(store.indexSettings().getNodeSettings()),
                 Ccr.CCR_THREAD_POOL_NAME);
                 Ccr.CCR_THREAD_POOL_NAME);
 
 
-        // TODO: There should be some local timeout. And if the remote cluster returns an unknown session
-        //  response, we should be able to retry by creating a new session.
-        try (RestoreSession restoreSession = openSession(metadata.name(), remoteClient, leaderShardId, shardId, recoveryState)) {
-            restoreSession.restoreFiles(store);
-            updateMappings(remoteClient, leaderIndex, restoreSession.mappingVersion, client, shardId.getIndex());
-        } catch (Exception e) {
-            throw new IndexShardRestoreFailedException(shardId, "failed to restore snapshot [" + snapshotId + "]", e);
-        } finally {
-            logger.trace("{} canceling background renewal of retention lease [{}] at the end of restore", shardId,
-                retentionLeaseId);
-            renewable.cancel();
-        }
+            // TODO: There should be some local timeout. And if the remote cluster returns an unknown session
+            //  response, we should be able to retry by creating a new session.
+            try (RestoreSession restoreSession = openSession(metadata.name(), remoteClient, leaderShardId, shardId, recoveryState)) {
+                restoreSession.restoreFiles(store);
+                updateMappings(remoteClient, leaderIndex, restoreSession.mappingVersion, client, shardId.getIndex());
+            } catch (Exception e) {
+                throw new IndexShardRestoreFailedException(shardId, "failed to restore snapshot [" + snapshotId + "]", e);
+            } finally {
+                logger.trace("{} canceling background renewal of retention lease [{}] at the end of restore", shardId,
+                    retentionLeaseId);
+                renewable.cancel();
+            }
+            return null;
+        });
     }
     }
 
 
     private void createEmptyStore(Store store) {
     private void createEmptyStore(Store store) {
@@ -465,86 +470,92 @@ public class CcrRepository extends AbstractLifecycleComponent implements Reposit
                 fileInfos.add(new FileInfo(fileMetaData.name(), fileMetaData, fileSize));
                 fileInfos.add(new FileInfo(fileMetaData.name(), fileMetaData, fileSize));
             }
             }
             SnapshotFiles snapshotFiles = new SnapshotFiles(LATEST, fileInfos);
             SnapshotFiles snapshotFiles = new SnapshotFiles(LATEST, fileInfos);
-            restore(snapshotFiles, store);
+            final PlainActionFuture<Void> future = PlainActionFuture.newFuture();
+            restore(snapshotFiles, store, future);
+            future.actionGet();
         }
         }
 
 
         @Override
         @Override
-        protected void restoreFiles(List<FileInfo> filesToRecover, Store store) {
-            logger.trace("[{}] starting CCR restore of {} files", shardId, filesToRecover);
-            final PlainActionFuture<Void> restoreFilesFuture = new PlainActionFuture<>();
-            final List<StoreFileMetaData> mds = filesToRecover.stream().map(FileInfo::metadata).collect(Collectors.toList());
-            final MultiFileTransfer<FileChunk> multiFileTransfer = new MultiFileTransfer<>(
-                logger, threadPool.getThreadContext(), restoreFilesFuture, ccrSettings.getMaxConcurrentFileChunks(), mds) {
-
-                final MultiFileWriter multiFileWriter = new MultiFileWriter(store, recoveryState.getIndex(), "", logger, () -> {});
-                long offset = 0;
-
-                @Override
-                protected void onNewFile(StoreFileMetaData md) {
-                    offset = 0;
-                }
-
-                @Override
-                protected FileChunk nextChunkRequest(StoreFileMetaData md) {
-                    final int bytesRequested = Math.toIntExact(Math.min(ccrSettings.getChunkSize().getBytes(), md.length() - offset));
-                    offset += bytesRequested;
-                    return new FileChunk(md, bytesRequested, offset == md.length());
-                }
-
-                @Override
-                protected void executeChunkRequest(FileChunk request, ActionListener<Void> listener) {
-                    final ActionListener<GetCcrRestoreFileChunkAction.GetCcrRestoreFileChunkResponse> threadedListener
-                        = new ThreadedActionListener<>(logger, threadPool, ThreadPool.Names.GENERIC, ActionListener.wrap(
+        protected void restoreFiles(List<FileInfo> filesToRecover, Store store, ActionListener<Void> listener) {
+            ActionListener.completeWith(listener, () -> {
+                logger.trace("[{}] starting CCR restore of {} files", shardId, filesToRecover);
+                final PlainActionFuture<Void> restoreFilesFuture = new PlainActionFuture<>();
+                final List<StoreFileMetaData> mds = filesToRecover.stream().map(FileInfo::metadata).collect(Collectors.toList());
+                final MultiFileTransfer<FileChunk> multiFileTransfer = new MultiFileTransfer<>(
+                    logger, threadPool.getThreadContext(), restoreFilesFuture, ccrSettings.getMaxConcurrentFileChunks(), mds) {
+
+                    final MultiFileWriter multiFileWriter = new MultiFileWriter(store, recoveryState.getIndex(), "", logger, () -> {
+                    });
+                    long offset = 0;
+
+                    @Override
+                    protected void onNewFile(StoreFileMetaData md) {
+                        offset = 0;
+                    }
+
+                    @Override
+                    protected FileChunk nextChunkRequest(StoreFileMetaData md) {
+                        final int bytesRequested = Math.toIntExact(Math.min(ccrSettings.getChunkSize().getBytes(), md.length() - offset));
+                        offset += bytesRequested;
+                        return new FileChunk(md, bytesRequested, offset == md.length());
+                    }
+
+                    @Override
+                    protected void executeChunkRequest(FileChunk request, ActionListener<Void> listener) {
+                        final ActionListener<GetCcrRestoreFileChunkAction.GetCcrRestoreFileChunkResponse> threadedListener
+                            = new ThreadedActionListener<>(logger, threadPool, ThreadPool.Names.GENERIC, ActionListener.wrap(
                             r -> {
                             r -> {
                                 writeFileChunk(request.md, r);
                                 writeFileChunk(request.md, r);
                                 listener.onResponse(null);
                                 listener.onResponse(null);
                             }, listener::onFailure), false);
                             }, listener::onFailure), false);
 
 
-                    remoteClient.execute(GetCcrRestoreFileChunkAction.INSTANCE,
-                        new GetCcrRestoreFileChunkRequest(node, sessionUUID, request.md.name(), request.bytesRequested),
-                        ListenerTimeouts.wrapWithTimeout(threadPool, threadedListener, ccrSettings.getRecoveryActionTimeout(),
-                            ThreadPool.Names.GENERIC, GetCcrRestoreFileChunkAction.NAME));
-                }
-
-                private void writeFileChunk(StoreFileMetaData md,
-                                            GetCcrRestoreFileChunkAction.GetCcrRestoreFileChunkResponse r) throws Exception {
-                    final int actualChunkSize = r.getChunk().length();
-                    logger.trace("[{}] [{}] got response for file [{}], offset: {}, length: {}",
-                        shardId, snapshotId, md.name(), r.getOffset(), actualChunkSize);
-                    final long nanosPaused = ccrSettings.getRateLimiter().maybePause(actualChunkSize);
-                    throttleListener.accept(nanosPaused);
-                    multiFileWriter.incRef();
-                    try (Releasable ignored = multiFileWriter::decRef) {
-                        final boolean lastChunk = r.getOffset() + actualChunkSize >= md.length();
-                        multiFileWriter.writeFileChunk(md, r.getOffset(), r.getChunk(), lastChunk);
-                    } catch (Exception e) {
-                        handleError(md, e);
-                        throw e;
+                        remoteClient.execute(GetCcrRestoreFileChunkAction.INSTANCE,
+                            new GetCcrRestoreFileChunkRequest(node, sessionUUID, request.md.name(), request.bytesRequested),
+                            ListenerTimeouts.wrapWithTimeout(threadPool, threadedListener, ccrSettings.getRecoveryActionTimeout(),
+                                ThreadPool.Names.GENERIC, GetCcrRestoreFileChunkAction.NAME));
+                    }
+
+                    private void writeFileChunk(StoreFileMetaData md,
+                        GetCcrRestoreFileChunkAction.GetCcrRestoreFileChunkResponse r) throws Exception {
+                        final int actualChunkSize = r.getChunk().length();
+                        logger.trace("[{}] [{}] got response for file [{}], offset: {}, length: {}",
+                            shardId, snapshotId, md.name(), r.getOffset(), actualChunkSize);
+                        final long nanosPaused = ccrSettings.getRateLimiter().maybePause(actualChunkSize);
+                        throttleListener.accept(nanosPaused);
+                        multiFileWriter.incRef();
+                        try (Releasable ignored = multiFileWriter::decRef) {
+                            final boolean lastChunk = r.getOffset() + actualChunkSize >= md.length();
+                            multiFileWriter.writeFileChunk(md, r.getOffset(), r.getChunk(), lastChunk);
+                        } catch (Exception e) {
+                            handleError(md, e);
+                            throw e;
+                        }
                     }
                     }
-                }
-
-                @Override
-                protected void handleError(StoreFileMetaData md, Exception e) throws Exception {
-                    final IOException corruptIndexException;
-                    if ((corruptIndexException = ExceptionsHelper.unwrapCorruption(e)) != null) {
-                        try {
-                            store.markStoreCorrupted(corruptIndexException);
-                        } catch (IOException ioe) {
-                            logger.warn("store cannot be marked as corrupted", e);
+
+                    @Override
+                    protected void handleError(StoreFileMetaData md, Exception e) throws Exception {
+                        final IOException corruptIndexException;
+                        if ((corruptIndexException = ExceptionsHelper.unwrapCorruption(e)) != null) {
+                            try {
+                                store.markStoreCorrupted(corruptIndexException);
+                            } catch (IOException ioe) {
+                                logger.warn("store cannot be marked as corrupted", e);
+                            }
+                            throw corruptIndexException;
                         }
                         }
-                        throw corruptIndexException;
+                        throw e;
                     }
                     }
-                    throw e;
-                }
-
-                @Override
-                public void close() {
-                    multiFileWriter.close();
-                }
-            };
-            multiFileTransfer.start();
-            restoreFilesFuture.actionGet();
-            logger.trace("[{}] completed CCR restore", shardId);
+
+                    @Override
+                    public void close() {
+                        multiFileWriter.close();
+                    }
+                };
+                multiFileTransfer.start();
+                restoreFilesFuture.actionGet();
+                logger.trace("[{}] completed CCR restore", shardId);
+                return null;
+            });
         }
         }
 
 
         @Override
         @Override

+ 11 - 6
x-pack/plugin/ccr/src/test/java/org/elasticsearch/xpack/ccr/action/ShardFollowTaskReplicationTests.java

@@ -451,11 +451,12 @@ public class ShardFollowTaskReplicationTests extends ESIndexLevelReplicationTest
                 ShardRouting routing = ShardRoutingHelper.newWithRestoreSource(primary.routingEntry(),
                 ShardRouting routing = ShardRoutingHelper.newWithRestoreSource(primary.routingEntry(),
                     new RecoverySource.SnapshotRecoverySource(UUIDs.randomBase64UUID(), snapshot, Version.CURRENT, "test"));
                     new RecoverySource.SnapshotRecoverySource(UUIDs.randomBase64UUID(), snapshot, Version.CURRENT, "test"));
                 primary.markAsRecovering("remote recovery from leader", new RecoveryState(routing, localNode, null));
                 primary.markAsRecovering("remote recovery from leader", new RecoveryState(routing, localNode, null));
+                final PlainActionFuture<Boolean> future = PlainActionFuture.newFuture();
                 primary.restoreFromRepository(new RestoreOnlyRepository(index.getName()) {
                 primary.restoreFromRepository(new RestoreOnlyRepository(index.getName()) {
                     @Override
                     @Override
                     public void restoreShard(Store store, SnapshotId snapshotId, IndexId indexId, ShardId snapshotShardId,
                     public void restoreShard(Store store, SnapshotId snapshotId, IndexId indexId, ShardId snapshotShardId,
-                                             RecoveryState recoveryState) {
-                        try {
+                                             RecoveryState recoveryState, ActionListener<Void> listener) {
+                        ActionListener.completeWith(listener, () -> {
                             IndexShard leader = leaderGroup.getPrimary();
                             IndexShard leader = leaderGroup.getPrimary();
                             Lucene.cleanLuceneIndex(primary.store().directory());
                             Lucene.cleanLuceneIndex(primary.store().directory());
                             try (Engine.IndexCommitRef sourceCommit = leader.acquireSafeIndexCommit()) {
                             try (Engine.IndexCommitRef sourceCommit = leader.acquireSafeIndexCommit()) {
@@ -465,11 +466,15 @@ public class ShardFollowTaskReplicationTests extends ESIndexLevelReplicationTest
                                         leader.store().directory(), md.name(), md.name(), IOContext.DEFAULT);
                                         leader.store().directory(), md.name(), md.name(), IOContext.DEFAULT);
                                 }
                                 }
                             }
                             }
-                        } catch (Exception ex) {
-                            throw new AssertionError(ex);
-                        }
+                            return null;
+                        });
                     }
                     }
-                });
+                }, future);
+                try {
+                    future.actionGet();
+                } catch (Exception ex) {
+                    throw new AssertionError(ex);
+                }
             }
             }
         };
         };
     }
     }

+ 9 - 7
x-pack/plugin/ccr/src/test/java/org/elasticsearch/xpack/ccr/index/engine/FollowEngineIndexShardTests.java

@@ -9,6 +9,7 @@ import org.apache.lucene.store.IOContext;
 import org.elasticsearch.Version;
 import org.elasticsearch.Version;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.index.IndexRequest;
 import org.elasticsearch.action.index.IndexRequest;
+import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.routing.IndexShardRoutingTable;
 import org.elasticsearch.cluster.routing.IndexShardRoutingTable;
 import org.elasticsearch.cluster.routing.RecoverySource;
 import org.elasticsearch.cluster.routing.RecoverySource;
@@ -125,11 +126,12 @@ public class FollowEngineIndexShardTests extends IndexShardTestCase {
 
 
         DiscoveryNode localNode = new DiscoveryNode("foo", buildNewFakeTransportAddress(), emptyMap(), emptySet(), Version.CURRENT);
         DiscoveryNode localNode = new DiscoveryNode("foo", buildNewFakeTransportAddress(), emptyMap(), emptySet(), Version.CURRENT);
         target.markAsRecovering("store", new RecoveryState(routing, localNode, null));
         target.markAsRecovering("store", new RecoveryState(routing, localNode, null));
-        assertTrue(target.restoreFromRepository(new RestoreOnlyRepository("test") {
+        final PlainActionFuture<Boolean> future = PlainActionFuture.newFuture();
+        target.restoreFromRepository(new RestoreOnlyRepository("test") {
             @Override
             @Override
             public void restoreShard(Store store, SnapshotId snapshotId, IndexId indexId, ShardId snapshotShardId,
             public void restoreShard(Store store, SnapshotId snapshotId, IndexId indexId, ShardId snapshotShardId,
-                                     RecoveryState recoveryState) {
-                try {
+                                     RecoveryState recoveryState, ActionListener<Void> listener) {
+                ActionListener.completeWith(listener, () -> {
                     cleanLuceneIndex(targetStore.directory());
                     cleanLuceneIndex(targetStore.directory());
                     for (String file : sourceStore.directory().listAll()) {
                     for (String file : sourceStore.directory().listAll()) {
                         if (file.equals("write.lock") || file.startsWith("extra")) {
                         if (file.equals("write.lock") || file.startsWith("extra")) {
@@ -137,11 +139,11 @@ public class FollowEngineIndexShardTests extends IndexShardTestCase {
                         }
                         }
                         targetStore.directory().copyFrom(sourceStore.directory(), file, file, IOContext.DEFAULT);
                         targetStore.directory().copyFrom(sourceStore.directory(), file, file, IOContext.DEFAULT);
                     }
                     }
-                } catch (Exception ex) {
-                    throw new RuntimeException(ex);
-                }
+                    return null;
+                });
             }
             }
-        }));
+        }, future);
+        assertTrue(future.actionGet());
         assertThat(target.getLocalCheckpoint(), equalTo(0L));
         assertThat(target.getLocalCheckpoint(), equalTo(0L));
         assertThat(target.seqNoStats().getMaxSeqNo(), equalTo(2L));
         assertThat(target.seqNoStats().getMaxSeqNo(), equalTo(2L));
         assertThat(target.seqNoStats().getGlobalCheckpoint(), equalTo(0L));
         assertThat(target.seqNoStats().getGlobalCheckpoint(), equalTo(0L));

+ 5 - 2
x-pack/plugin/core/src/test/java/org/elasticsearch/snapshots/SourceOnlySnapshotShardTests.java

@@ -233,8 +233,11 @@ public class SourceOnlySnapshotShardTests extends IndexShardTestCase {
         restoredShard.mapperService().merge(shard.indexSettings().getIndexMetaData(), MapperService.MergeReason.MAPPING_RECOVERY);
         restoredShard.mapperService().merge(shard.indexSettings().getIndexMetaData(), MapperService.MergeReason.MAPPING_RECOVERY);
         DiscoveryNode discoveryNode = new DiscoveryNode("node_g", buildNewFakeTransportAddress(), Version.CURRENT);
         DiscoveryNode discoveryNode = new DiscoveryNode("node_g", buildNewFakeTransportAddress(), Version.CURRENT);
         restoredShard.markAsRecovering("test from snap", new RecoveryState(restoredShard.routingEntry(), discoveryNode, null));
         restoredShard.markAsRecovering("test from snap", new RecoveryState(restoredShard.routingEntry(), discoveryNode, null));
-        runAsSnapshot(shard.getThreadPool(), () ->
-            assertTrue(restoredShard.restoreFromRepository(repository)));
+        runAsSnapshot(shard.getThreadPool(), () -> {
+            final PlainActionFuture<Boolean> future = PlainActionFuture.newFuture();
+            restoredShard.restoreFromRepository(repository, future);
+            assertTrue(future.actionGet());
+        });
         assertEquals(restoredShard.recoveryState().getStage(), RecoveryState.Stage.DONE);
         assertEquals(restoredShard.recoveryState().getStage(), RecoveryState.Stage.DONE);
         assertEquals(restoredShard.recoveryState().getTranslog().recoveredOperations(), 0);
         assertEquals(restoredShard.recoveryState().getTranslog().recoveredOperations(), 0);
         assertEquals(IndexShardState.POST_RECOVERY, restoredShard.state());
         assertEquals(IndexShardState.POST_RECOVERY, restoredShard.state());