瀏覽代碼

More Efficient Ordering of Shard Upload Execution (#42791)

* Change the upload order of of snapshots to work file by file in parallel on the snapshot pool instead of merely shard-by-shard
* Inspired by #39657
Armin Braun 6 年之前
父節點
當前提交
4cf5ffac34

+ 32 - 0
server/src/main/java/org/elasticsearch/action/ActionListener.java

@@ -22,6 +22,7 @@ package org.elasticsearch.action;
 import org.elasticsearch.ExceptionsHelper;
 import org.elasticsearch.common.CheckedConsumer;
 import org.elasticsearch.common.CheckedFunction;
+import org.elasticsearch.common.CheckedRunnable;
 import org.elasticsearch.common.CheckedSupplier;
 
 import java.util.ArrayList;
@@ -226,6 +227,37 @@ public interface ActionListener<Response> {
         };
     }
 
+    /**
+     * Wraps a given listener and returns a new listener which executes the provided {@code runBefore}
+     * callback before the listener is notified via either {@code #onResponse} or {@code #onFailure}.
+     * If the callback throws an exception then it will be passed to the listener's {@code #onFailure} and its {@code #onResponse} will
+     * not be executed.
+     */
+    static <Response> ActionListener<Response> runBefore(ActionListener<Response> delegate, CheckedRunnable<?> runBefore) {
+        return new ActionListener<>() {
+            @Override
+            public void onResponse(Response response) {
+                try {
+                    runBefore.run();
+                } catch (Exception ex) {
+                    delegate.onFailure(ex);
+                    return;
+                }
+                delegate.onResponse(response);
+            }
+
+            @Override
+            public void onFailure(Exception e) {
+                try {
+                    runBefore.run();
+                } catch (Exception ex) {
+                    e.addSuppressed(ex);
+                }
+                delegate.onFailure(e);
+            }
+        };
+    }
+
     /**
      * Wraps a given listener and returns a new listener which makes sure {@link #onResponse(Object)}
      * and {@link #onFailure(Exception)} of the provided listener will be called at most once.

+ 2 - 4
server/src/main/java/org/elasticsearch/repositories/FilterRepository.java

@@ -121,13 +121,11 @@ public class FilterRepository implements Repository {
         return in.isReadOnly();
     }
 
-
     @Override
     public void snapshotShard(Store store, MapperService mapperService, SnapshotId snapshotId, IndexId indexId,
-                              IndexCommit snapshotIndexCommit, IndexShardSnapshotStatus snapshotStatus) {
-        in.snapshotShard(store, mapperService, snapshotId, indexId, snapshotIndexCommit, snapshotStatus);
+                              IndexCommit snapshotIndexCommit, IndexShardSnapshotStatus snapshotStatus, ActionListener<Void> listener) {
+        in.snapshotShard(store, mapperService, snapshotId, indexId, snapshotIndexCommit, snapshotStatus, listener);
     }
-
     @Override
     public void restoreShard(Store store, SnapshotId snapshotId,
                              Version version, IndexId indexId, ShardId snapshotShardId, RecoveryState recoveryState) {

+ 3 - 2
server/src/main/java/org/elasticsearch/repositories/Repository.java

@@ -50,7 +50,7 @@ import java.util.function.Function;
  * <ul>
  * <li>Master calls {@link #initializeSnapshot(SnapshotId, List, org.elasticsearch.cluster.metadata.MetaData)}
  * with list of indices that will be included into the snapshot</li>
- * <li>Data nodes call {@link Repository#snapshotShard(Store, MapperService, SnapshotId, IndexId, IndexCommit, IndexShardSnapshotStatus)}
+ * <li>Data nodes call {@link Repository#snapshotShard}
  * for each shard</li>
  * <li>When all shard calls return master calls {@link #finalizeSnapshot} with possible list of failures</li>
  * </ul>
@@ -204,9 +204,10 @@ public interface Repository extends LifecycleComponent {
      * @param indexId             id for the index being snapshotted
      * @param snapshotIndexCommit commit point
      * @param snapshotStatus      snapshot status
+     * @param listener            listener invoked on completion
      */
     void snapshotShard(Store store, MapperService mapperService, SnapshotId snapshotId, IndexId indexId, IndexCommit snapshotIndexCommit,
-                       IndexShardSnapshotStatus snapshotStatus);
+                       IndexShardSnapshotStatus snapshotStatus, ActionListener<Void> listener);
 
     /**
      * Restores snapshot of the shard.

+ 127 - 103
server/src/main/java/org/elasticsearch/repositories/blobstore/BlobStoreRepository.java

@@ -32,6 +32,7 @@ import org.elasticsearch.ExceptionsHelper;
 import org.elasticsearch.Version;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.ActionRunnable;
+import org.elasticsearch.action.StepListener;
 import org.elasticsearch.action.support.GroupedActionListener;
 import org.elasticsearch.cluster.metadata.IndexMetaData;
 import org.elasticsearch.cluster.metadata.MetaData;
@@ -109,6 +110,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.Optional;
 import java.util.Set;
+import java.util.concurrent.Executor;
 import java.util.stream.Collectors;
 
 import static org.elasticsearch.index.snapshots.blobstore.BlobStoreIndexShardSnapshot.FileInfo.canonicalName;
@@ -883,9 +885,15 @@ public abstract class BlobStoreRepository extends AbstractLifecycleComponent imp
 
     @Override
     public void snapshotShard(Store store, MapperService mapperService, SnapshotId snapshotId, IndexId indexId,
-                              IndexCommit snapshotIndexCommit, IndexShardSnapshotStatus snapshotStatus) {
+                              IndexCommit snapshotIndexCommit, IndexShardSnapshotStatus snapshotStatus, ActionListener<Void> listener) {
         final ShardId shardId = store.shardId();
         final long startTime = threadPool.absoluteTimeInMillis();
+        final StepListener<Void> snapshotDoneListener = new StepListener<>();
+        snapshotDoneListener.whenComplete(listener::onResponse, e -> {
+            snapshotStatus.moveToFailed(threadPool.absoluteTimeInMillis(), ExceptionsHelper.detailedMessage(e));
+            listener.onFailure(e instanceof IndexShardSnapshotFailedException ? (IndexShardSnapshotFailedException) e
+                : new IndexShardSnapshotFailedException(store.shardId(), e));
+        });
         try {
             logger.debug("[{}] [{}] snapshot to [{}] ...", shardId, snapshotId, metadata.name());
 
@@ -907,132 +915,145 @@ public abstract class BlobStoreRepository extends AbstractLifecycleComponent imp
             }
 
             final List<BlobStoreIndexShardSnapshot.FileInfo> indexCommitPointFiles = new ArrayList<>();
+            ArrayList<BlobStoreIndexShardSnapshot.FileInfo> filesToSnapshot = new ArrayList<>();
             store.incRef();
+            final Collection<String> fileNames;
+            final Store.MetadataSnapshot metadataFromStore;
             try {
-                ArrayList<BlobStoreIndexShardSnapshot.FileInfo> filesToSnapshot = new ArrayList<>();
-                final Store.MetadataSnapshot metadata;
                 // TODO apparently we don't use the MetadataSnapshot#.recoveryDiff(...) here but we should
-                final Collection<String> fileNames;
                 try {
                     logger.trace(
                         "[{}] [{}] Loading store metadata using index commit [{}]", shardId, snapshotId, snapshotIndexCommit);
-                    metadata = store.getMetadata(snapshotIndexCommit);
+                    metadataFromStore = store.getMetadata(snapshotIndexCommit);
                     fileNames = snapshotIndexCommit.getFileNames();
                 } catch (IOException e) {
                     throw new IndexShardSnapshotFailedException(shardId, "Failed to get store file metadata", e);
                 }
-                int indexIncrementalFileCount = 0;
-                int indexTotalNumberOfFiles = 0;
-                long indexIncrementalSize = 0;
-                long indexTotalFileCount = 0;
-                for (String fileName : fileNames) {
-                    if (snapshotStatus.isAborted()) {
-                        logger.debug("[{}] [{}] Aborted on the file [{}], exiting", shardId, snapshotId, fileName);
-                        throw new IndexShardSnapshotFailedException(shardId, "Aborted");
-                    }
+            } finally {
+                store.decRef();
+            }
+            int indexIncrementalFileCount = 0;
+            int indexTotalNumberOfFiles = 0;
+            long indexIncrementalSize = 0;
+            long indexTotalFileCount = 0;
+            for (String fileName : fileNames) {
+                if (snapshotStatus.isAborted()) {
+                    logger.debug("[{}] [{}] Aborted on the file [{}], exiting", shardId, snapshotId, fileName);
+                    throw new IndexShardSnapshotFailedException(shardId, "Aborted");
+                }
 
-                    logger.trace("[{}] [{}] Processing [{}]", shardId, snapshotId, fileName);
-                    final StoreFileMetaData md = metadata.get(fileName);
-                    BlobStoreIndexShardSnapshot.FileInfo existingFileInfo = null;
-                    List<BlobStoreIndexShardSnapshot.FileInfo> filesInfo = snapshots.findPhysicalIndexFiles(fileName);
-                    if (filesInfo != null) {
-                        for (BlobStoreIndexShardSnapshot.FileInfo fileInfo : filesInfo) {
-                            if (fileInfo.isSame(md)) {
-                                // a commit point file with the same name, size and checksum was already copied to repository
-                                // we will reuse it for this snapshot
-                                existingFileInfo = fileInfo;
-                                break;
-                            }
+                logger.trace("[{}] [{}] Processing [{}]", shardId, snapshotId, fileName);
+                final StoreFileMetaData md = metadataFromStore.get(fileName);
+                BlobStoreIndexShardSnapshot.FileInfo existingFileInfo = null;
+                List<BlobStoreIndexShardSnapshot.FileInfo> filesInfo = snapshots.findPhysicalIndexFiles(fileName);
+                if (filesInfo != null) {
+                    for (BlobStoreIndexShardSnapshot.FileInfo fileInfo : filesInfo) {
+                        if (fileInfo.isSame(md)) {
+                            // a commit point file with the same name, size and checksum was already copied to repository
+                            // we will reuse it for this snapshot
+                            existingFileInfo = fileInfo;
+                            break;
                         }
                     }
-
-                    indexTotalFileCount += md.length();
-                    indexTotalNumberOfFiles++;
-
-                    if (existingFileInfo == null) {
-                        indexIncrementalFileCount++;
-                        indexIncrementalSize += md.length();
-                        // create a new FileInfo
-                        BlobStoreIndexShardSnapshot.FileInfo snapshotFileInfo =
-                            new BlobStoreIndexShardSnapshot.FileInfo(DATA_BLOB_PREFIX + UUIDs.randomBase64UUID(), md, chunkSize());
-                        indexCommitPointFiles.add(snapshotFileInfo);
-                        filesToSnapshot.add(snapshotFileInfo);
-                    } else {
-                        indexCommitPointFiles.add(existingFileInfo);
-                    }
                 }
 
-                snapshotStatus.moveToStarted(startTime, indexIncrementalFileCount,
-                    indexTotalNumberOfFiles, indexIncrementalSize, indexTotalFileCount);
-
-                for (BlobStoreIndexShardSnapshot.FileInfo snapshotFileInfo : filesToSnapshot) {
-                    try {
-                        snapshotFile(snapshotFileInfo, indexId, shardId, snapshotId, snapshotStatus, store);
-                    } catch (IOException e) {
-                        throw new IndexShardSnapshotFailedException(shardId, "Failed to perform snapshot (index files)", e);
-                    }
+                indexTotalFileCount += md.length();
+                indexTotalNumberOfFiles++;
+
+                if (existingFileInfo == null) {
+                    indexIncrementalFileCount++;
+                    indexIncrementalSize += md.length();
+                    // create a new FileInfo
+                    BlobStoreIndexShardSnapshot.FileInfo snapshotFileInfo =
+                        new BlobStoreIndexShardSnapshot.FileInfo(DATA_BLOB_PREFIX + UUIDs.randomBase64UUID(), md, chunkSize());
+                    indexCommitPointFiles.add(snapshotFileInfo);
+                    filesToSnapshot.add(snapshotFileInfo);
+                } else {
+                    indexCommitPointFiles.add(existingFileInfo);
                 }
-            } finally {
-                store.decRef();
             }
 
-            final IndexShardSnapshotStatus.Copy lastSnapshotStatus = snapshotStatus.moveToFinalize(snapshotIndexCommit.getGeneration());
+            snapshotStatus.moveToStarted(startTime, indexIncrementalFileCount,
+                indexTotalNumberOfFiles, indexIncrementalSize, indexTotalFileCount);
 
-            // now create and write the commit point
-            final BlobStoreIndexShardSnapshot snapshot = new BlobStoreIndexShardSnapshot(snapshotId.getName(),
-                lastSnapshotStatus.getIndexVersion(),
-                indexCommitPointFiles,
-                lastSnapshotStatus.getStartTime(),
-                threadPool.absoluteTimeInMillis() - lastSnapshotStatus.getStartTime(),
-                lastSnapshotStatus.getIncrementalFileCount(),
-                lastSnapshotStatus.getIncrementalSize()
-            );
+            assert indexIncrementalFileCount == filesToSnapshot.size();
 
-            logger.trace("[{}] [{}] writing shard snapshot file", shardId, snapshotId);
-            try {
-                indexShardSnapshotFormat.write(snapshot, shardContainer, snapshotId.getUUID());
-            } catch (IOException e) {
-                throw new IndexShardSnapshotFailedException(shardId, "Failed to write commit point", e);
-            }
+            final StepListener<Collection<Void>> allFilesUploadedListener = new StepListener<>();
+            allFilesUploadedListener.whenComplete(v -> {
+                final IndexShardSnapshotStatus.Copy lastSnapshotStatus =
+                    snapshotStatus.moveToFinalize(snapshotIndexCommit.getGeneration());
 
-            // delete all files that are not referenced by any commit point
-            // build a new BlobStoreIndexShardSnapshot, that includes this one and all the saved ones
-            List<SnapshotFiles> newSnapshotsList = new ArrayList<>();
-            newSnapshotsList.add(new SnapshotFiles(snapshot.snapshot(), snapshot.indexFiles()));
-            for (SnapshotFiles point : snapshots) {
-                newSnapshotsList.add(point);
-            }
-            final String indexGeneration = Long.toString(fileListGeneration + 1);
-            final List<String> blobsToDelete;
-            try {
-                final BlobStoreIndexShardSnapshots updatedSnapshots = new BlobStoreIndexShardSnapshots(newSnapshotsList);
-                indexShardSnapshotsFormat.writeAtomic(updatedSnapshots, shardContainer, indexGeneration);
-                // Delete all previous index-N blobs
-                blobsToDelete =
-                    blobs.keySet().stream().filter(blob -> blob.startsWith(SNAPSHOT_INDEX_PREFIX)).collect(Collectors.toList());
-                assert blobsToDelete.stream().mapToLong(b -> Long.parseLong(b.replaceFirst(SNAPSHOT_INDEX_PREFIX, "")))
-                    .max().orElse(-1L) < Long.parseLong(indexGeneration)
-                    : "Tried to delete an index-N blob newer than the current generation [" + indexGeneration + "] when deleting index-N" +
-                    " blobs " + blobsToDelete;
-            } catch (IOException e) {
-                throw new IndexShardSnapshotFailedException(shardId,
-                    "Failed to finalize snapshot creation [" + snapshotId + "] with shard index ["
-                        + indexShardSnapshotsFormat.blobName(indexGeneration) + "]", e);
+                // now create and write the commit point
+                final BlobStoreIndexShardSnapshot snapshot = new BlobStoreIndexShardSnapshot(snapshotId.getName(),
+                    lastSnapshotStatus.getIndexVersion(),
+                    indexCommitPointFiles,
+                    lastSnapshotStatus.getStartTime(),
+                    threadPool.absoluteTimeInMillis() - lastSnapshotStatus.getStartTime(),
+                    lastSnapshotStatus.getIncrementalFileCount(),
+                    lastSnapshotStatus.getIncrementalSize()
+                );
+
+                logger.trace("[{}] [{}] writing shard snapshot file", shardId, snapshotId);
+                try {
+                    indexShardSnapshotFormat.write(snapshot, shardContainer, snapshotId.getUUID());
+                } catch (IOException e) {
+                    throw new IndexShardSnapshotFailedException(shardId, "Failed to write commit point", e);
+                }
+                // delete all files that are not referenced by any commit point
+                // build a new BlobStoreIndexShardSnapshot, that includes this one and all the saved ones
+                List<SnapshotFiles> newSnapshotsList = new ArrayList<>();
+                newSnapshotsList.add(new SnapshotFiles(snapshot.snapshot(), snapshot.indexFiles()));
+                for (SnapshotFiles point : snapshots) {
+                    newSnapshotsList.add(point);
+                }
+                final String indexGeneration = Long.toString(fileListGeneration + 1);
+                final List<String> blobsToDelete;
+                try {
+                    final BlobStoreIndexShardSnapshots updatedSnapshots = new BlobStoreIndexShardSnapshots(newSnapshotsList);
+                    indexShardSnapshotsFormat.writeAtomic(updatedSnapshots, shardContainer, indexGeneration);
+                    // Delete all previous index-N blobs
+                    blobsToDelete =
+                        blobs.keySet().stream().filter(blob -> blob.startsWith(SNAPSHOT_INDEX_PREFIX)).collect(Collectors.toList());
+                    assert blobsToDelete.stream().mapToLong(b -> Long.parseLong(b.replaceFirst(SNAPSHOT_INDEX_PREFIX, "")))
+                        .max().orElse(-1L) < Long.parseLong(indexGeneration)
+                        : "Tried to delete an index-N blob newer than the current generation [" + indexGeneration
+                        + "] when deleting index-N blobs " + blobsToDelete;
+                } catch (IOException e) {
+                    throw new IndexShardSnapshotFailedException(shardId,
+                        "Failed to finalize snapshot creation [" + snapshotId + "] with shard index ["
+                            + indexShardSnapshotsFormat.blobName(indexGeneration) + "]", e);
+                }
+                try {
+                    shardContainer.deleteBlobsIgnoringIfNotExists(blobsToDelete);
+                } catch (IOException e) {
+                    logger.warn(() -> new ParameterizedMessage("[{}][{}] failed to delete old index-N blobs during finalization",
+                        snapshotId, shardId), e);
+                }
+                snapshotStatus.moveToDone(threadPool.absoluteTimeInMillis());
+                snapshotDoneListener.onResponse(null);
+            }, snapshotDoneListener::onFailure);
+            if (indexIncrementalFileCount == 0) {
+                allFilesUploadedListener.onResponse(Collections.emptyList());
+                return;
             }
-            try {
-                shardContainer.deleteBlobsIgnoringIfNotExists(blobsToDelete);
-            } catch (IOException e) {
-                logger.warn(() -> new ParameterizedMessage("[{}][{}] failed to delete old index-N blobs during finalization",
-                    snapshotId, shardId), e);
+            final GroupedActionListener<Void> filesListener =
+                new GroupedActionListener<>(allFilesUploadedListener, indexIncrementalFileCount);
+            final Executor executor = threadPool.executor(ThreadPool.Names.SNAPSHOT);
+            for (BlobStoreIndexShardSnapshot.FileInfo snapshotFileInfo : filesToSnapshot) {
+                executor.execute(new ActionRunnable<>(filesListener) {
+                    @Override
+                    protected void doRun() {
+                        try {
+                            snapshotFile(snapshotFileInfo, indexId, shardId, snapshotId, snapshotStatus, store);
+                            filesListener.onResponse(null);
+                        } catch (IOException e) {
+                            throw new IndexShardSnapshotFailedException(shardId, "Failed to perform snapshot (index files)", e);
+                        }
+                    }
+                });
             }
-            snapshotStatus.moveToDone(threadPool.absoluteTimeInMillis());
         } catch (Exception e) {
-            snapshotStatus.moveToFailed(threadPool.absoluteTimeInMillis(), ExceptionsHelper.detailedMessage(e));
-            if (e instanceof IndexShardSnapshotFailedException) {
-                throw (IndexShardSnapshotFailedException) e;
-            } else {
-                throw new IndexShardSnapshotFailedException(store.shardId(), e);
-            }
+            snapshotDoneListener.onFailure(e);
         }
     }
 
@@ -1219,6 +1240,7 @@ public abstract class BlobStoreRepository extends AbstractLifecycleComponent imp
                               IndexShardSnapshotStatus snapshotStatus, Store store) throws IOException {
         final BlobContainer shardContainer = shardContainer(indexId, shardId);
         final String file = fileInfo.physicalName();
+        store.incRef();
         try (IndexInput indexInput = store.openVerifyingInput(file, IOContext.READONCE, fileInfo.metadata())) {
             for (int i = 0; i < fileInfo.numberOfParts(); i++) {
                 final long partBytes = fileInfo.partBytes(i);
@@ -1258,6 +1280,8 @@ public abstract class BlobStoreRepository extends AbstractLifecycleComponent imp
             failStoreIfCorrupted(store, t);
             snapshotStatus.addProcessedFile(0);
             throw t;
+        } finally {
+            store.decRef();
         }
     }
 

+ 52 - 69
server/src/main/java/org/elasticsearch/snapshots/SnapshotShardsService.java

@@ -23,7 +23,6 @@ import com.carrotsearch.hppc.cursors.ObjectObjectCursor;
 import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 import org.apache.logging.log4j.message.ParameterizedMessage;
-import org.apache.lucene.util.SetOnce;
 import org.elasticsearch.ExceptionsHelper;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.ActionRequestValidationException;
@@ -53,9 +52,8 @@ import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.unit.TimeValue;
-import org.elasticsearch.common.util.concurrent.AbstractRunnable;
+import org.elasticsearch.core.internal.io.IOUtils;
 import org.elasticsearch.index.engine.Engine;
-import org.elasticsearch.index.engine.SnapshotFailedEngineException;
 import org.elasticsearch.index.shard.IndexEventListener;
 import org.elasticsearch.index.shard.IndexShard;
 import org.elasticsearch.index.shard.IndexShardState;
@@ -80,7 +78,6 @@ import java.util.HashMap;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
-import java.util.concurrent.Executor;
 import java.util.function.Function;
 import java.util.stream.Collectors;
 
@@ -298,46 +295,33 @@ public class SnapshotShardsService extends AbstractLifecycleComponent implements
     }
 
     private void startNewShards(SnapshotsInProgress.Entry entry, Map<ShardId, IndexShardSnapshotStatus> startedShards) {
-        final Snapshot snapshot = entry.snapshot();
-        final Map<String, IndexId> indicesMap = entry.indices().stream().collect(Collectors.toMap(IndexId::getName, Function.identity()));
-        final Executor executor = threadPool.executor(ThreadPool.Names.SNAPSHOT);
-        for (final Map.Entry<ShardId, IndexShardSnapshotStatus> shardEntry : startedShards.entrySet()) {
-            final ShardId shardId = shardEntry.getKey();
-            final IndexId indexId = indicesMap.get(shardId.getIndexName());
-            assert indexId != null;
-            executor.execute(new AbstractRunnable() {
-
-                private final SetOnce<Exception> failure = new SetOnce<>();
-
-                @Override
-                public void doRun() {
-                    final IndexShard indexShard =
-                        indicesService.indexServiceSafe(shardId.getIndex()).getShardOrNull(shardId.id());
-                    snapshot(indexShard, snapshot, indexId, shardEntry.getValue());
-                }
-
-                @Override
-                public void onFailure(Exception e) {
-                    logger.warn(() -> new ParameterizedMessage("[{}][{}] failed to snapshot shard", shardId, snapshot), e);
-                    failure.set(e);
-                }
-
-                @Override
-                public void onRejection(Exception e) {
-                    failure.set(e);
-                }
-
-                @Override
-                public void onAfter() {
-                    final Exception exception = failure.get();
-                    if (exception != null) {
-                        notifyFailedSnapshotShard(snapshot, shardId, ExceptionsHelper.detailedMessage(exception));
-                    } else {
+        threadPool.executor(ThreadPool.Names.SNAPSHOT).execute(() -> {
+            final Snapshot snapshot = entry.snapshot();
+            final Map<String, IndexId> indicesMap =
+                entry.indices().stream().collect(Collectors.toMap(IndexId::getName, Function.identity()));
+            for (final Map.Entry<ShardId, IndexShardSnapshotStatus> shardEntry : startedShards.entrySet()) {
+                final ShardId shardId = shardEntry.getKey();
+                final IndexShardSnapshotStatus snapshotStatus = shardEntry.getValue();
+                final IndexId indexId = indicesMap.get(shardId.getIndexName());
+                assert indexId != null;
+                snapshot(shardId, snapshot, indexId, snapshotStatus, new ActionListener<>() {
+                    @Override
+                    public void onResponse(final Void aVoid) {
+                        if (logger.isDebugEnabled()) {
+                            final IndexShardSnapshotStatus.Copy lastSnapshotStatus = snapshotStatus.asCopy();
+                            logger.debug("snapshot ({}) completed to {} with {}", snapshot, snapshot.getRepository(), lastSnapshotStatus);
+                        }
                         notifySuccessfulSnapshotShard(snapshot, shardId);
                     }
-                }
-            });
-        }
+
+                    @Override
+                    public void onFailure(Exception e) {
+                        logger.warn(() -> new ParameterizedMessage("[{}][{}] failed to snapshot shard", shardId, snapshot), e);
+                        notifyFailedSnapshotShard(snapshot, shardId, ExceptionsHelper.detailedMessage(e));
+                    }
+                });
+            }
+        });
     }
 
     /**
@@ -346,38 +330,37 @@ public class SnapshotShardsService extends AbstractLifecycleComponent implements
      * @param snapshot       snapshot
      * @param snapshotStatus snapshot status
      */
-    private void snapshot(final IndexShard indexShard, final Snapshot snapshot, final IndexId indexId,
-                          final IndexShardSnapshotStatus snapshotStatus) {
-        final ShardId shardId = indexShard.shardId();
-        if (indexShard.routingEntry().primary() == false) {
-            throw new IndexShardSnapshotFailedException(shardId, "snapshot should be performed only on primary");
-        }
-        if (indexShard.routingEntry().relocating()) {
-            // do not snapshot when in the process of relocation of primaries so we won't get conflicts
-            throw new IndexShardSnapshotFailedException(shardId, "cannot snapshot while relocating");
-        }
+    private void snapshot(final ShardId shardId, final Snapshot snapshot, final IndexId indexId,
+                          final IndexShardSnapshotStatus snapshotStatus, ActionListener<Void> listener) {
+        try {
+            final IndexShard indexShard = indicesService.indexServiceSafe(shardId.getIndex()).getShardOrNull(shardId.id());
+            if (indexShard.routingEntry().primary() == false) {
+                throw new IndexShardSnapshotFailedException(shardId, "snapshot should be performed only on primary");
+            }
+            if (indexShard.routingEntry().relocating()) {
+                // do not snapshot when in the process of relocation of primaries so we won't get conflicts
+                throw new IndexShardSnapshotFailedException(shardId, "cannot snapshot while relocating");
+            }
 
-        final IndexShardState indexShardState = indexShard.state();
-        if (indexShardState == IndexShardState.CREATED || indexShardState == IndexShardState.RECOVERING) {
-            // shard has just been created, or still recovering
-            throw new IndexShardSnapshotFailedException(shardId, "shard didn't fully recover yet");
-        }
+            final IndexShardState indexShardState = indexShard.state();
+            if (indexShardState == IndexShardState.CREATED || indexShardState == IndexShardState.RECOVERING) {
+                // shard has just been created, or still recovering
+                throw new IndexShardSnapshotFailedException(shardId, "shard didn't fully recover yet");
+            }
 
-        final Repository repository = repositoriesService.repository(snapshot.getRepository());
-        try {
-            // we flush first to make sure we get the latest writes snapshotted
-            try (Engine.IndexCommitRef snapshotRef = indexShard.acquireLastIndexCommit(true)) {
+            final Repository repository = repositoriesService.repository(snapshot.getRepository());
+            Engine.IndexCommitRef snapshotRef = null;
+            try {
+                // we flush first to make sure we get the latest writes snapshotted
+                snapshotRef = indexShard.acquireLastIndexCommit(true);
                 repository.snapshotShard(indexShard.store(), indexShard.mapperService(), snapshot.getSnapshotId(), indexId,
-                    snapshotRef.getIndexCommit(), snapshotStatus);
-                if (logger.isDebugEnabled()) {
-                    final IndexShardSnapshotStatus.Copy lastSnapshotStatus = snapshotStatus.asCopy();
-                    logger.debug("snapshot ({}) completed to {} with {}", snapshot, repository, lastSnapshotStatus);
-                }
+                    snapshotRef.getIndexCommit(), snapshotStatus, ActionListener.runBefore(listener, snapshotRef::close));
+            } catch (Exception e) {
+                IOUtils.close(snapshotRef);
+                throw e;
             }
-        } catch (SnapshotFailedEngineException | IndexShardSnapshotFailedException e) {
-            throw e;
         } catch (Exception e) {
-            throw new IndexShardSnapshotFailedException(shardId, "Failed to snapshot", e);
+            listener.onFailure(e);
         }
     }
 

+ 17 - 0
server/src/test/java/org/elasticsearch/action/ActionListenerTests.java

@@ -171,6 +171,23 @@ public class ActionListenerTests extends ESTestCase {
         }
     }
 
+    public void testRunBefore() {
+        {
+            AtomicBoolean afterSuccess = new AtomicBoolean();
+            ActionListener<Object> listener =
+                ActionListener.runBefore(ActionListener.wrap(r -> {}, e -> {}), () -> afterSuccess.set(true));
+            listener.onResponse(null);
+            assertThat(afterSuccess.get(), equalTo(true));
+        }
+        {
+            AtomicBoolean afterFailure = new AtomicBoolean();
+            ActionListener<Object> listener =
+                ActionListener.runBefore(ActionListener.wrap(r -> {}, e -> {}), () -> afterFailure.set(true));
+            listener.onFailure(null);
+            assertThat(afterFailure.get(), equalTo(true));
+        }
+    }
+
     public void testNotifyOnce() {
         AtomicInteger onResponseTimes = new AtomicInteger();
         AtomicInteger onFailureTimes = new AtomicInteger();

+ 1 - 1
server/src/test/java/org/elasticsearch/repositories/RepositoriesServiceTests.java

@@ -202,7 +202,7 @@ public class RepositoriesServiceTests extends ESTestCase {
 
         @Override
         public void snapshotShard(Store store, MapperService mapperService, SnapshotId snapshotId, IndexId indexId, IndexCommit
-            snapshotIndexCommit, IndexShardSnapshotStatus snapshotStatus) {
+            snapshotIndexCommit, IndexShardSnapshotStatus snapshotStatus, ActionListener<Void> listener) {
 
         }
 

+ 8 - 2
server/src/test/java/org/elasticsearch/repositories/fs/FsRepositoryTests.java

@@ -35,6 +35,7 @@ import org.apache.lucene.util.BytesRef;
 import org.apache.lucene.util.IOSupplier;
 import org.apache.lucene.util.TestUtil;
 import org.elasticsearch.Version;
+import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.cluster.metadata.IndexMetaData;
 import org.elasticsearch.cluster.metadata.RepositoryMetaData;
 import org.elasticsearch.cluster.node.DiscoveryNode;
@@ -99,10 +100,12 @@ public class FsRepositoryTests extends ESTestCase {
             IndexId indexId = new IndexId(idxSettings.getIndex().getName(), idxSettings.getUUID());
 
             IndexCommit indexCommit = Lucene.getIndexCommit(Lucene.readSegmentInfos(store.directory()), store.directory());
+            final PlainActionFuture<Void> future1 = PlainActionFuture.newFuture();
             runGeneric(threadPool, () -> {
                 IndexShardSnapshotStatus snapshotStatus = IndexShardSnapshotStatus.newInitializing();
                 repository.snapshotShard(store, null, snapshotId, indexId, indexCommit,
-                    snapshotStatus);
+                    snapshotStatus, future1);
+                future1.actionGet();
                 IndexShardSnapshotStatus.Copy copy = snapshotStatus.asCopy();
                 assertEquals(copy.getTotalFileCount(), copy.getIncrementalFileCount());
             });
@@ -124,9 +127,11 @@ public class FsRepositoryTests extends ESTestCase {
             SnapshotId incSnapshotId = new SnapshotId("test1", "test1");
             IndexCommit incIndexCommit = Lucene.getIndexCommit(Lucene.readSegmentInfos(store.directory()), store.directory());
             Collection<String> commitFileNames = incIndexCommit.getFileNames();
+            final PlainActionFuture<Void> future2 = PlainActionFuture.newFuture();
             runGeneric(threadPool, () -> {
                 IndexShardSnapshotStatus snapshotStatus = IndexShardSnapshotStatus.newInitializing();
-                repository.snapshotShard(store, null, incSnapshotId, indexId, incIndexCommit, snapshotStatus);
+                repository.snapshotShard(store, null, incSnapshotId, indexId, incIndexCommit, snapshotStatus, future2);
+                future2.actionGet();
                 IndexShardSnapshotStatus.Copy copy = snapshotStatus.asCopy();
                 assertEquals(2, copy.getIncrementalFileCount());
                 assertEquals(commitFileNames.size(), copy.getTotalFileCount());
@@ -198,4 +203,5 @@ public class FsRepositoryTests extends ESTestCase {
             return docs;
         }
     }
+
 }

+ 3 - 1
test/framework/src/main/java/org/elasticsearch/index/shard/IndexShardTestCase.java

@@ -832,12 +832,14 @@ public abstract class IndexShardTestCase extends ESTestCase {
                                  final Snapshot snapshot,
                                  final Repository repository) throws IOException {
         final IndexShardSnapshotStatus snapshotStatus = IndexShardSnapshotStatus.newInitializing();
+        final PlainActionFuture<Void> future = PlainActionFuture.newFuture();
         try (Engine.IndexCommitRef indexCommitRef = shard.acquireLastIndexCommit(true)) {
             Index index = shard.shardId().getIndex();
             IndexId indexId = new IndexId(index.getName(), index.getUUID());
 
             repository.snapshotShard(shard.store(), shard.mapperService(), snapshot.getSnapshotId(), indexId,
-                indexCommitRef.getIndexCommit(), snapshotStatus);
+                indexCommitRef.getIndexCommit(), snapshotStatus, future);
+            future.actionGet();
         }
 
         final IndexShardSnapshotStatus.Copy lastSnapshotStatus = snapshotStatus.asCopy();

+ 1 - 1
test/framework/src/main/java/org/elasticsearch/index/shard/RestoreOnlyRepository.java

@@ -135,7 +135,7 @@ public abstract class RestoreOnlyRepository extends AbstractLifecycleComponent i
 
     @Override
     public void snapshotShard(Store store, MapperService mapperService, SnapshotId snapshotId, IndexId indexId,
-                              IndexCommit snapshotIndexCommit, IndexShardSnapshotStatus snapshotStatus) {
+                              IndexCommit snapshotIndexCommit, IndexShardSnapshotStatus snapshotStatus, ActionListener<Void> listener) {
     }
 
     @Override

+ 1 - 1
x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/repository/CcrRepository.java

@@ -296,7 +296,7 @@ public class CcrRepository extends AbstractLifecycleComponent implements Reposit
 
     @Override
     public void snapshotShard(Store store, MapperService mapperService, SnapshotId snapshotId, IndexId indexId,
-                              IndexCommit snapshotIndexCommit, IndexShardSnapshotStatus snapshotStatus) {
+                              IndexCommit snapshotIndexCommit, IndexShardSnapshotStatus snapshotStatus, ActionListener<Void> listener) {
         throw new UnsupportedOperationException("Unsupported for repository of type: " + TYPE);
     }
 

+ 26 - 13
x-pack/plugin/core/src/main/java/org/elasticsearch/snapshots/SourceOnlySnapshotRepository.java

@@ -15,6 +15,7 @@ import org.apache.lucene.store.Directory;
 import org.apache.lucene.store.FSDirectory;
 import org.apache.lucene.store.FilterDirectory;
 import org.apache.lucene.store.SimpleFSDirectory;
+import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.cluster.metadata.IndexMetaData;
 import org.elasticsearch.cluster.metadata.MappingMetaData;
 import org.elasticsearch.cluster.metadata.MetaData;
@@ -24,6 +25,7 @@ import org.elasticsearch.common.collect.ImmutableOpenMap;
 import org.elasticsearch.common.lucene.search.Queries;
 import org.elasticsearch.common.settings.Setting;
 import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.core.internal.io.IOUtils;
 import org.elasticsearch.env.ShardLock;
 import org.elasticsearch.index.engine.EngineFactory;
 import org.elasticsearch.index.engine.ReadOnlyEngine;
@@ -35,9 +37,11 @@ import org.elasticsearch.repositories.FilterRepository;
 import org.elasticsearch.repositories.IndexId;
 import org.elasticsearch.repositories.Repository;
 
+import java.io.Closeable;
 import java.io.IOException;
 import java.io.UncheckedIOException;
 import java.nio.file.Path;
+import java.util.ArrayList;
 import java.util.Collections;
 import java.util.Iterator;
 import java.util.List;
@@ -108,11 +112,13 @@ public final class SourceOnlySnapshotRepository extends FilterRepository {
 
     @Override
     public void snapshotShard(Store store, MapperService mapperService, SnapshotId snapshotId, IndexId indexId,
-                              IndexCommit snapshotIndexCommit, IndexShardSnapshotStatus snapshotStatus) {
+                              IndexCommit snapshotIndexCommit, IndexShardSnapshotStatus snapshotStatus, ActionListener<Void> listener) {
         if (mapperService.documentMapper() != null // if there is no mapping this is null
             && mapperService.documentMapper().sourceMapper().isComplete() == false) {
-            throw new IllegalStateException("Can't snapshot _source only on an index that has incomplete source ie. has _source disabled " +
-                "or filters the source");
+            listener.onFailure(
+                new IllegalStateException("Can't snapshot _source only on an index that has incomplete source ie. has _source disabled " +
+                    "or filters the source"));
+            return;
         }
         Directory unwrap = FilterDirectory.unwrap(store.directory());
         if (unwrap instanceof FSDirectory == false) {
@@ -121,7 +127,10 @@ public final class SourceOnlySnapshotRepository extends FilterRepository {
         Path dataPath = ((FSDirectory) unwrap).getDirectory().getParent();
         // TODO should we have a snapshot tmp directory per shard that is maintained by the system?
         Path snapPath = dataPath.resolve(SNAPSHOT_DIR_NAME);
-        try (FSDirectory directory = new SimpleFSDirectory(snapPath)) {
+        final List<Closeable> toClose = new ArrayList<>(3);
+        try {
+            FSDirectory directory = new SimpleFSDirectory(snapPath);
+            toClose.add(directory);
             Store tempStore = new Store(store.shardId(), store.indexSettings(), directory, new ShardLock(store.shardId()) {
                 @Override
                 protected void closeInternal() {
@@ -137,16 +146,20 @@ public final class SourceOnlySnapshotRepository extends FilterRepository {
             final long maxDoc = segmentInfos.totalMaxDoc();
             tempStore.bootstrapNewHistory(maxDoc, maxDoc);
             store.incRef();
-            try (DirectoryReader reader = DirectoryReader.open(tempStore.directory(),
-                Collections.singletonMap(BlockTreeTermsReader.FST_MODE_KEY, BlockTreeTermsReader.FSTLoadMode.OFF_HEAP.name()))) {
-                IndexCommit indexCommit = reader.getIndexCommit();
-                super.snapshotShard(tempStore, mapperService, snapshotId, indexId, indexCommit, snapshotStatus);
-            } finally {
-                store.decRef();
-            }
+            toClose.add(store::decRef);
+            DirectoryReader reader = DirectoryReader.open(tempStore.directory(),
+                Collections.singletonMap(BlockTreeTermsReader.FST_MODE_KEY, BlockTreeTermsReader.FSTLoadMode.OFF_HEAP.name()));
+            toClose.add(reader);
+            IndexCommit indexCommit = reader.getIndexCommit();
+            super.snapshotShard(tempStore, mapperService, snapshotId, indexId, indexCommit, snapshotStatus,
+                ActionListener.runBefore(listener, () -> IOUtils.close(toClose)));
         } catch (IOException e) {
-            // why on earth does this super method not declare IOException
-            throw new UncheckedIOException(e);
+            try {
+                IOUtils.close(toClose);
+            } catch (IOException ex) {
+                e.addSuppressed(ex);
+            }
+            listener.onFailure(e);
         }
     }
 

+ 19 - 10
x-pack/plugin/core/src/test/java/org/elasticsearch/snapshots/SourceOnlySnapshotShardTests.java

@@ -96,12 +96,13 @@ public class SourceOnlySnapshotShardTests extends IndexShardTestCase {
         repository.start();
         try (Engine.IndexCommitRef snapshotRef = shard.acquireLastIndexCommit(true)) {
             IndexShardSnapshotStatus indexShardSnapshotStatus = IndexShardSnapshotStatus.newInitializing();
-            IllegalStateException illegalStateException = expectThrows(IllegalStateException.class, () ->
-                runAsSnapshot(shard.getThreadPool(),
-                    () -> repository.snapshotShard(shard.store(), shard.mapperService(), snapshotId, indexId,
-                        snapshotRef.getIndexCommit(), indexShardSnapshotStatus)));
-            assertEquals("Can't snapshot _source only on an index that has incomplete source ie. has _source disabled or filters the source"
-                , illegalStateException.getMessage());
+            final PlainActionFuture<Void> future = PlainActionFuture.newFuture();
+            runAsSnapshot(shard.getThreadPool(), () -> repository.snapshotShard(shard.store(), shard.mapperService(), snapshotId, indexId,
+                snapshotRef.getIndexCommit(), indexShardSnapshotStatus, future));
+            IllegalStateException illegalStateException = expectThrows(IllegalStateException.class, future::actionGet);
+            assertEquals(
+                "Can't snapshot _source only on an index that has incomplete source ie. has _source disabled or filters the source",
+                illegalStateException.getMessage());
         }
         closeShards(shard);
     }
@@ -120,8 +121,10 @@ public class SourceOnlySnapshotShardTests extends IndexShardTestCase {
         try (Engine.IndexCommitRef snapshotRef = shard.acquireLastIndexCommit(true)) {
             IndexShardSnapshotStatus indexShardSnapshotStatus = IndexShardSnapshotStatus.newInitializing();
             SnapshotId snapshotId = new SnapshotId("test", "test");
+            final PlainActionFuture<Void> future = PlainActionFuture.newFuture();
             runAsSnapshot(shard.getThreadPool(), () -> repository.snapshotShard(shard.store(), shard.mapperService(), snapshotId, indexId,
-                snapshotRef.getIndexCommit(), indexShardSnapshotStatus));
+                snapshotRef.getIndexCommit(), indexShardSnapshotStatus, future));
+            future.actionGet();
             IndexShardSnapshotStatus.Copy copy = indexShardSnapshotStatus.asCopy();
             assertEquals(copy.getTotalFileCount(), copy.getIncrementalFileCount());
             totalFileCount = copy.getTotalFileCount();
@@ -134,8 +137,10 @@ public class SourceOnlySnapshotShardTests extends IndexShardTestCase {
             SnapshotId snapshotId = new SnapshotId("test_1", "test_1");
 
             IndexShardSnapshotStatus indexShardSnapshotStatus = IndexShardSnapshotStatus.newInitializing();
+            final PlainActionFuture<Void> future = PlainActionFuture.newFuture();
             runAsSnapshot(shard.getThreadPool(), () -> repository.snapshotShard(shard.store(), shard.mapperService(), snapshotId, indexId,
-                snapshotRef.getIndexCommit(), indexShardSnapshotStatus));
+                snapshotRef.getIndexCommit(), indexShardSnapshotStatus, future));
+            future.actionGet();
             IndexShardSnapshotStatus.Copy copy = indexShardSnapshotStatus.asCopy();
             // we processed the segments_N file plus _1.si, _1.fdx, _1.fnm, _1.fdt
             assertEquals(5, copy.getIncrementalFileCount());
@@ -148,8 +153,10 @@ public class SourceOnlySnapshotShardTests extends IndexShardTestCase {
             SnapshotId snapshotId = new SnapshotId("test_2", "test_2");
 
             IndexShardSnapshotStatus indexShardSnapshotStatus = IndexShardSnapshotStatus.newInitializing();
+            final PlainActionFuture<Void> future = PlainActionFuture.newFuture();
             runAsSnapshot(shard.getThreadPool(), () -> repository.snapshotShard(shard.store(), shard.mapperService(), snapshotId, indexId,
-                snapshotRef.getIndexCommit(), indexShardSnapshotStatus));
+                snapshotRef.getIndexCommit(), indexShardSnapshotStatus, future));
+            future.actionGet();
             IndexShardSnapshotStatus.Copy copy = indexShardSnapshotStatus.asCopy();
             // we processed the segments_N file plus _1_1.liv
             assertEquals(2, copy.getIncrementalFileCount());
@@ -197,8 +204,10 @@ public class SourceOnlySnapshotShardTests extends IndexShardTestCase {
                 repository.initializeSnapshot(snapshotId, Arrays.asList(indexId),
                     MetaData.builder().put(shard.indexSettings()
                     .getIndexMetaData(), false).build());
+                final PlainActionFuture<Void> future = PlainActionFuture.newFuture();
                 repository.snapshotShard(shard.store(), shard.mapperService(), snapshotId, indexId, snapshotRef.getIndexCommit(),
-                    indexShardSnapshotStatus);
+                    indexShardSnapshotStatus, future);
+                future.actionGet();
             });
             IndexShardSnapshotStatus.Copy copy = indexShardSnapshotStatus.asCopy();
             assertEquals(copy.getTotalFileCount(), copy.getIncrementalFileCount());