Browse Source

Release index commit promptly on snapshot abort (#96442)

Today we react to a shard snapshot abort when processing that shard on a
`SNAPSHOT` thread, but those threads may be busy handling tasks related
to other shards for an extended period of time which delays the release
of the resources held by the shard snapshot tasks. With this commit we
fast-track the abort handling using the machinery introduced in #96426.

Closes #95316
David Turner 2 years ago
parent
commit
d0a95a7991

+ 111 - 0
server/src/internalClusterTest/java/org/elasticsearch/snapshots/AbortedSnapshotIT.java

@@ -0,0 +1,111 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.snapshots;
+
+import org.elasticsearch.cluster.SnapshotsInProgress;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.index.shard.ShardId;
+import org.elasticsearch.index.snapshots.IndexShardSnapshotStatus;
+import org.elasticsearch.indices.IndicesService;
+import org.elasticsearch.test.ESIntegTestCase;
+import org.elasticsearch.threadpool.ThreadPool;
+
+import java.util.concurrent.CyclicBarrier;
+import java.util.concurrent.atomic.AtomicBoolean;
+
+import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
+import static org.hamcrest.Matchers.greaterThan;
+import static org.hamcrest.Matchers.lessThan;
+
+@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.TEST, numDataNodes = 0)
+public class AbortedSnapshotIT extends AbstractSnapshotIntegTestCase {
+
+    public void testQueuedSnapshotReleasesCommitOnAbort() throws Exception {
+        internalCluster().startMasterOnlyNode();
+        // one-thread pool so we are able to fully block it later
+        final String dataNode = internalCluster().startDataOnlyNode(Settings.builder().put("thread_pool.snapshot.max", 1).build());
+        final String indexName = "test-index";
+        createIndexWithContent(indexName);
+
+        final var indicesService = internalCluster().getInstance(IndicesService.class, dataNode);
+        final var clusterService = indicesService.clusterService();
+        final var index = clusterService.state().metadata().index(indexName).getIndex();
+        final var store = indicesService.indexServiceSafe(index).getShard(0).store();
+        assertTrue(store.hasReferences());
+
+        final String repoName = "test-repo";
+        createRepository(repoName, "fs");
+
+        final var snapshotExecutor = internalCluster().getInstance(ThreadPool.class, dataNode).executor(ThreadPool.Names.SNAPSHOT);
+
+        final CyclicBarrier barrier = new CyclicBarrier(2);
+        final AtomicBoolean stopBlocking = new AtomicBoolean();
+        class BlockingTask implements Runnable {
+            @Override
+            public void run() {
+                safeAwait(barrier);
+                safeAwait(barrier);
+                if (stopBlocking.get() == false) {
+                    // enqueue another block to happen just after the currently-enqueued tasks
+                    snapshotExecutor.execute(BlockingTask.this);
+                }
+            }
+        }
+        snapshotExecutor.execute(new BlockingTask());
+        safeAwait(barrier); // wait for snapshot thread to be blocked
+
+        clusterAdmin().prepareCreateSnapshot(repoName, "snapshot-1").setWaitForCompletion(false).setPartial(true).get();
+        // resulting cluster state has been applied on all nodes, which means the first task for the SNAPSHOT pool is queued up
+
+        final var snapshot = clusterService.state()
+            .custom(SnapshotsInProgress.TYPE, SnapshotsInProgress.EMPTY)
+            .forRepo(repoName)
+            .get(0)
+            .snapshot();
+        final var snapshotShardsService = internalCluster().getInstance(SnapshotShardsService.class, dataNode);
+
+        // Run up to 3 snapshot tasks, which are (in order):
+        // 1. run BlobStoreRepository#snapshotShard
+        // 2. run BlobStoreRepository#doSnapshotShard (moves the shard to state STARTED)
+        // 3. process one file (there will be at least two, but the per-file tasks are enqueued one at a time by the throttling executor)
+
+        final var steps = between(0, 3);
+        for (int i = 0; i < steps; i++) {
+            safeAwait(barrier); // release snapshot thread so it can run the enqueued task
+            safeAwait(barrier); // wait for snapshot thread to be blocked again
+
+            final var shardStatuses = snapshotShardsService.currentSnapshotShards(snapshot);
+            assertEquals(1, shardStatuses.size());
+            final var shardStatus = shardStatuses.get(new ShardId(index, 0)).asCopy();
+            logger.info("--> {}", shardStatus);
+
+            if (i == 0) {
+                assertEquals(IndexShardSnapshotStatus.Stage.INIT, shardStatus.getStage());
+                assertEquals(0, shardStatus.getProcessedFileCount());
+                assertEquals(0, shardStatus.getTotalFileCount());
+            } else {
+                assertEquals(IndexShardSnapshotStatus.Stage.STARTED, shardStatus.getStage());
+                assertThat(shardStatus.getProcessedFileCount(), greaterThan(0));
+                assertThat(shardStatus.getProcessedFileCount(), lessThan(shardStatus.getTotalFileCount()));
+            }
+        }
+
+        assertTrue(store.hasReferences());
+        assertAcked(client().admin().indices().prepareDelete(indexName).get());
+
+        // this is the key assertion: we must release the store without needing any SNAPSHOT threads to make further progress
+        assertBusy(() -> assertFalse(store.hasReferences()));
+
+        stopBlocking.set(true);
+        safeAwait(barrier); // release snapshot thread
+
+        assertBusy(() -> assertTrue(clusterService.state().custom(SnapshotsInProgress.TYPE, SnapshotsInProgress.EMPTY).isEmpty()));
+    }
+
+}

+ 29 - 2
server/src/main/java/org/elasticsearch/index/snapshots/IndexShardSnapshotStatus.java

@@ -8,13 +8,18 @@
 
 package org.elasticsearch.index.snapshots;
 
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.support.SubscribableListener;
 import org.elasticsearch.common.Strings;
+import org.elasticsearch.core.Releasable;
+import org.elasticsearch.core.Releasables;
 import org.elasticsearch.repositories.ShardGeneration;
 import org.elasticsearch.repositories.ShardSnapshotResult;
 import org.elasticsearch.snapshots.AbortedSnapshotException;
 
 import java.util.Objects;
 import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Consumer;
 
 /**
  * Represent shard snapshot status
@@ -51,6 +56,15 @@ public class IndexShardSnapshotStatus {
         ABORTED
     }
 
+    /**
+     * Used to complete listeners added via {@link #addAbortListener} when the shard snapshot is either aborted or it gets past the stages
+     * where an abort could have occurred.
+     */
+    public enum AbortStatus {
+        NO_ABORT,
+        ABORTED
+    }
+
     private final AtomicReference<Stage> stage;
     private final AtomicReference<ShardGeneration> generation;
     private final AtomicReference<ShardSnapshotResult> shardSnapshotResult; // only set in stage DONE
@@ -63,6 +77,7 @@ public class IndexShardSnapshotStatus {
     private long incrementalSize;
     private long processedSize;
     private String failure;
+    private final SubscribableListener<AbortStatus> abortListeners = new SubscribableListener<>();
 
     private IndexShardSnapshotStatus(
         final Stage stage,
@@ -118,7 +133,10 @@ public class IndexShardSnapshotStatus {
     public synchronized Copy moveToFinalize() {
         final var prevStage = stage.compareAndExchange(Stage.STARTED, Stage.FINALIZE);
         return switch (prevStage) {
-            case STARTED -> asCopy();
+            case STARTED -> {
+                abortListeners.onResponse(AbortStatus.NO_ABORT);
+                yield asCopy();
+            }
             case ABORTED -> throw new AbortedSnapshotException();
             default -> {
                 final var message = Strings.format(
@@ -146,14 +164,23 @@ public class IndexShardSnapshotStatus {
         }
     }
 
-    public synchronized void abortIfNotCompleted(final String failure) {
+    public void addAbortListener(ActionListener<AbortStatus> listener) {
+        abortListeners.addListener(listener);
+    }
+
+    public synchronized void abortIfNotCompleted(final String failure, Consumer<ActionListener<Releasable>> notifyRunner) {
         if (stage.compareAndSet(Stage.INIT, Stage.ABORTED) || stage.compareAndSet(Stage.STARTED, Stage.ABORTED)) {
             this.failure = failure;
+            notifyRunner.accept(abortListeners.map(r -> {
+                Releasables.closeExpectNoException(r);
+                return AbortStatus.ABORTED;
+            }));
         }
     }
 
     public synchronized void moveToFailed(final long endTime, final String failure) {
         if (stage.getAndSet(Stage.FAILURE) != Stage.FAILURE) {
+            abortListeners.onResponse(AbortStatus.NO_ABORT);
             this.totalTime = Math.max(0L, endTime - startTime);
             this.failure = failure;
         }

+ 1 - 1
server/src/main/java/org/elasticsearch/repositories/blobstore/BlobStoreRepository.java

@@ -2647,7 +2647,6 @@ public abstract class BlobStoreRepository extends AbstractLifecycleComponent imp
             return;
         }
         final Store store = context.store();
-        final IndexCommit snapshotIndexCommit = context.indexCommit();
         final ShardId shardId = store.shardId();
         final SnapshotId snapshotId = context.snapshotId();
         final IndexShardSnapshotStatus snapshotStatus = context.status();
@@ -2715,6 +2714,7 @@ public abstract class BlobStoreRepository extends AbstractLifecycleComponent imp
                 try (Releasable ignored = context.withCommitRef()) {
                     // TODO apparently we don't use the MetadataSnapshot#.recoveryDiff(...) here but we should
                     try {
+                        final IndexCommit snapshotIndexCommit = context.indexCommit();
                         logger.trace("[{}] [{}] Loading store metadata using index commit [{}]", shardId, snapshotId, snapshotIndexCommit);
                         metadataFromStore = store.getMetadata(snapshotIndexCommit);
                         fileNames = snapshotIndexCommit.getFileNames();

+ 70 - 13
server/src/main/java/org/elasticsearch/snapshots/SnapshotShardsService.java

@@ -24,9 +24,10 @@ import org.elasticsearch.cluster.SnapshotsInProgress.ShardState;
 import org.elasticsearch.cluster.SnapshotsInProgress.State;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.service.ClusterService;
+import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.component.AbstractLifecycleComponent;
 import org.elasticsearch.common.settings.Settings;
-import org.elasticsearch.core.IOUtils;
+import org.elasticsearch.common.util.concurrent.ThrottledTaskRunner;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.index.engine.Engine;
 import org.elasticsearch.index.seqno.SequenceNumbers;
@@ -83,6 +84,9 @@ public class SnapshotShardsService extends AbstractLifecycleComponent implements
     // A map of snapshots to the shardIds that we already reported to the master as failed
     private final ResultDeduplicator<UpdateIndexShardSnapshotStatusRequest, Void> remoteFailedRequestDeduplicator;
 
+    // Runs the tasks that promptly notify shards of aborted snapshots so that resources can be released ASAP
+    private final ThrottledTaskRunner notifyOnAbortTaskRunner;
+
     public SnapshotShardsService(
         Settings settings,
         ClusterService clusterService,
@@ -100,6 +104,13 @@ public class SnapshotShardsService extends AbstractLifecycleComponent implements
             // this is only useful on the nodes that can hold data
             clusterService.addListener(this);
         }
+
+        // Abort notification may release the last store ref, closing the shard, so we do them in the background on a generic thread.
+        this.notifyOnAbortTaskRunner = new ThrottledTaskRunner(
+            "notify-on-abort",
+            threadPool.info(ThreadPool.Names.SNAPSHOT).getMax(),
+            threadPool.generic()
+        );
     }
 
     @Override
@@ -149,14 +160,14 @@ public class SnapshotShardsService extends AbstractLifecycleComponent implements
         // abort any snapshots occurring on the soon-to-be closed shard
         synchronized (shardSnapshots) {
             for (Map.Entry<Snapshot, Map<ShardId, IndexShardSnapshotStatus>> snapshotShards : shardSnapshots.entrySet()) {
-                Map<ShardId, IndexShardSnapshotStatus> shards = snapshotShards.getValue();
-                if (shards.containsKey(shardId)) {
+                final var indexShardSnapshotStatus = snapshotShards.getValue().get(shardId);
+                if (indexShardSnapshotStatus != null) {
                     logger.debug(
                         "[{}] shard closing, abort snapshotting for snapshot [{}]",
                         shardId,
                         snapshotShards.getKey().getSnapshotId()
                     );
-                    shards.get(shardId).abortIfNotCompleted("shard is closing, aborting");
+                    indexShardSnapshotStatus.abortIfNotCompleted("shard is closing, aborting", notifyOnAbortTaskRunner::enqueueTask);
                 }
             }
         }
@@ -191,7 +202,10 @@ public class SnapshotShardsService extends AbstractLifecycleComponent implements
                 // state update, which is being processed here
                 it.remove();
                 for (IndexShardSnapshotStatus snapshotStatus : entry.getValue().values()) {
-                    snapshotStatus.abortIfNotCompleted("snapshot has been removed in cluster state, aborting");
+                    snapshotStatus.abortIfNotCompleted(
+                        "snapshot has been removed in cluster state, aborting",
+                        notifyOnAbortTaskRunner::enqueueTask
+                    );
                 }
             }
         }
@@ -258,7 +272,7 @@ public class SnapshotShardsService extends AbstractLifecycleComponent implements
                             notifyFailedSnapshotShard(snapshot, sid, shard.getValue().reason(), shard.getValue().generation());
                         }
                     } else {
-                        snapshotStatus.abortIfNotCompleted("snapshot has been aborted");
+                        snapshotStatus.abortIfNotCompleted("snapshot has been aborted", notifyOnAbortTaskRunner::enqueueTask);
                     }
                 }
             }
@@ -365,9 +379,10 @@ public class SnapshotShardsService extends AbstractLifecycleComponent implements
             }
 
             final Repository repository = repositoriesService.repository(snapshot.getRepository());
-            Engine.IndexCommitRef snapshotRef = null;
+            SnapshotIndexCommit snapshotIndexCommit = null;
             try {
-                snapshotRef = indexShard.acquireIndexCommitForSnapshot();
+                snapshotIndexCommit = new SnapshotIndexCommit(indexShard.acquireIndexCommitForSnapshot());
+                snapshotStatus.addAbortListener(makeAbortListener(indexShard.shardId(), snapshot, snapshotIndexCommit));
                 snapshotStatus.ensureNotAborted();
                 repository.snapshotShard(
                     new SnapshotShardContext(
@@ -375,21 +390,63 @@ public class SnapshotShardsService extends AbstractLifecycleComponent implements
                         indexShard.mapperService(),
                         snapshot.getSnapshotId(),
                         indexId,
-                        new SnapshotIndexCommit(snapshotRef),
-                        getShardStateId(indexShard, snapshotRef.getIndexCommit()),
+                        snapshotIndexCommit,
+                        getShardStateId(indexShard, snapshotIndexCommit.indexCommit()),
                         snapshotStatus,
                         version,
                         entryStartTime,
                         listener
                     )
                 );
-            } catch (Exception e) {
-                IOUtils.close(snapshotRef);
-                throw e;
+                snapshotIndexCommit = null; // success
+            } finally {
+                if (snapshotIndexCommit != null) {
+                    snapshotIndexCommit.closingBefore(new ActionListener<Void>() {
+                        @Override
+                        public void onResponse(Void unused) {}
+
+                        @Override
+                        public void onFailure(Exception e) {
+                            // we're already failing exceptionally, and prefer to propagate the original exception instead of this one
+                            logger.warn(Strings.format("exception closing commit for [%s] in [%s]", indexShard.shardId(), snapshot), e);
+                        }
+                    }).onResponse(null);
+                }
             }
         });
     }
 
+    private static ActionListener<IndexShardSnapshotStatus.AbortStatus> makeAbortListener(
+        ShardId shardId,
+        Snapshot snapshot,
+        SnapshotIndexCommit snapshotIndexCommit
+    ) {
+        return new ActionListener<>() {
+            @Override
+            public void onResponse(IndexShardSnapshotStatus.AbortStatus abortStatus) {
+                if (abortStatus == IndexShardSnapshotStatus.AbortStatus.ABORTED) {
+                    assert ThreadPool.assertCurrentThreadPool(ThreadPool.Names.GENERIC, ThreadPool.Names.SNAPSHOT);
+                    snapshotIndexCommit.onAbort();
+                }
+            }
+
+            @Override
+            public void onFailure(Exception e) {
+                logger.error(() -> Strings.format("unexpected failure in %s", description()), e);
+                assert false : e;
+            }
+
+            @Override
+            public String toString() {
+                return description();
+            }
+
+            private String description() {
+                return Strings.format("abort listener for [%s] in [%s]", shardId, snapshot);
+            }
+        };
+    }
+
     /**
      * Generates an identifier from the current state of a shard that can be used to detect whether a shard's contents
      * have changed between two snapshots.