Переглянути джерело

[cache] Support async RangeMissingHandler callbacks (#110587)

Change `fillCacheRange` method to accept a completion listener that must be called by `RangeMissingHandler` implementations when they finish fetching data. By doing so, we support asynchronously fetching the data from a third party storage. We also support asynchronous `SourceInputStreamFactory` for reading gaps from the storage.

Depends on #111177
Artem Prigoda 1 рік тому
батько
коміт
cb7a21e8ff

+ 71 - 30
x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/shared/SharedBlobCacheService.java

@@ -646,13 +646,14 @@ public class SharedBlobCacheService<KeyType> implements Releasable {
             // no need to allocate a new capturing lambda if the offset isn't adjusted
             return writer;
         }
-        return (channel, channelPos, streamFactory, relativePos, len, progressUpdater) -> writer.fillCacheRange(
+        return (channel, channelPos, streamFactory, relativePos, len, progressUpdater, completionListener) -> writer.fillCacheRange(
             channel,
             channelPos,
             streamFactory,
             relativePos - writeOffset,
             len,
-            progressUpdater
+            progressUpdater,
+            completionListener
         );
     }
 
@@ -987,16 +988,17 @@ public class SharedBlobCacheService<KeyType> implements Releasable {
                                 executor.execute(fillGapRunnable(gap, writer, null, refs.acquireListener()));
                             }
                         } else {
-                            final List<AbstractRunnable> gapFillingTasks = gaps.stream()
-                                .map(gap -> fillGapRunnable(gap, writer, streamFactory, refs.acquireListener()))
-                                .toList();
-                            executor.execute(() -> {
-                                try (streamFactory) {
+                            var gapFillingListener = refs.acquireListener();
+                            try (var gfRefs = new RefCountingRunnable(ActionRunnable.run(gapFillingListener, streamFactory::close))) {
+                                final List<Runnable> gapFillingTasks = gaps.stream()
+                                    .map(gap -> fillGapRunnable(gap, writer, streamFactory, gfRefs.acquireListener()))
+                                    .toList();
+                                executor.execute(() -> {
                                     // Fill the gaps in order. If a gap fails to fill for whatever reason, the task for filling the next
                                     // gap will still be executed.
                                     gapFillingTasks.forEach(Runnable::run);
-                                }
-                            });
+                                });
+                            }
                         }
                     }
                 }
@@ -1005,13 +1007,13 @@ public class SharedBlobCacheService<KeyType> implements Releasable {
             }
         }
 
-        private AbstractRunnable fillGapRunnable(
+        private Runnable fillGapRunnable(
             SparseFileTracker.Gap gap,
             RangeMissingHandler writer,
             @Nullable SourceInputStreamFactory streamFactory,
             ActionListener<Void> listener
         ) {
-            return ActionRunnable.run(listener.delegateResponse((l, e) -> failGapAndListener(gap, l, e)), () -> {
+            return () -> ActionListener.run(listener, l -> {
                 var ioRef = io;
                 assert regionOwners.get(ioRef) == CacheFileRegion.this;
                 assert CacheFileRegion.this.hasReferences() : CacheFileRegion.this;
@@ -1022,10 +1024,15 @@ public class SharedBlobCacheService<KeyType> implements Releasable {
                     streamFactory,
                     start,
                     Math.toIntExact(gap.end() - start),
-                    progress -> gap.onProgress(start + progress)
+                    progress -> gap.onProgress(start + progress),
+                    l.<Void>map(unused -> {
+                        assert regionOwners.get(ioRef) == CacheFileRegion.this;
+                        assert CacheFileRegion.this.hasReferences() : CacheFileRegion.this;
+                        writeCount.increment();
+                        gap.onCompletion();
+                        return null;
+                    }).delegateResponse((delegate, e) -> failGapAndListener(gap, delegate, e))
                 );
-                writeCount.increment();
-                gap.onCompletion();
             });
         }
 
@@ -1113,12 +1120,23 @@ public class SharedBlobCacheService<KeyType> implements Releasable {
                     SourceInputStreamFactory streamFactory,
                     int relativePos,
                     int length,
-                    IntConsumer progressUpdater
+                    IntConsumer progressUpdater,
+                    ActionListener<Void> completionListener
                 ) throws IOException {
-                    writer.fillCacheRange(channel, channelPos, streamFactory, relativePos, length, progressUpdater);
-                    var elapsedTime = TimeUnit.NANOSECONDS.toMicros(relativeTimeInNanosSupplier.getAsLong() - startTime);
-                    SharedBlobCacheService.this.blobCacheMetrics.getCacheMissLoadTimes().record(elapsedTime);
-                    SharedBlobCacheService.this.blobCacheMetrics.getCacheMissCounter().increment();
+                    writer.fillCacheRange(
+                        channel,
+                        channelPos,
+                        streamFactory,
+                        relativePos,
+                        length,
+                        progressUpdater,
+                        completionListener.map(unused -> {
+                            var elapsedTime = TimeUnit.NANOSECONDS.toMicros(relativeTimeInNanosSupplier.getAsLong() - startTime);
+                            blobCacheMetrics.getCacheMissLoadTimes().record(elapsedTime);
+                            blobCacheMetrics.getCacheMissCounter().increment();
+                            return null;
+                        })
+                    );
                 }
             };
             if (rangeToRead.isEmpty()) {
@@ -1211,9 +1229,18 @@ public class SharedBlobCacheService<KeyType> implements Releasable {
                         SourceInputStreamFactory streamFactory,
                         int relativePos,
                         int len,
-                        IntConsumer progressUpdater
+                        IntConsumer progressUpdater,
+                        ActionListener<Void> completionListener
                     ) throws IOException {
-                        delegate.fillCacheRange(channel, channelPos, streamFactory, relativePos - writeOffset, len, progressUpdater);
+                        delegate.fillCacheRange(
+                            channel,
+                            channelPos,
+                            streamFactory,
+                            relativePos - writeOffset,
+                            len,
+                            progressUpdater,
+                            completionListener
+                        );
                     }
                 };
             }
@@ -1226,14 +1253,25 @@ public class SharedBlobCacheService<KeyType> implements Releasable {
                         SourceInputStreamFactory streamFactory,
                         int relativePos,
                         int len,
-                        IntConsumer progressUpdater
+                        IntConsumer progressUpdater,
+                        ActionListener<Void> completionListener
                     ) throws IOException {
                         assert assertValidRegionAndLength(fileRegion, channelPos, len);
-                        delegate.fillCacheRange(channel, channelPos, streamFactory, relativePos, len, progressUpdater);
-                        assert regionOwners.get(fileRegion.io) == fileRegion
-                            : "File chunk [" + fileRegion.regionKey + "] no longer owns IO [" + fileRegion.io + "]";
+                        delegate.fillCacheRange(
+                            channel,
+                            channelPos,
+                            streamFactory,
+                            relativePos,
+                            len,
+                            progressUpdater,
+                            Assertions.ENABLED ? ActionListener.runBefore(completionListener, () -> {
+                                assert regionOwners.get(fileRegion.io) == fileRegion
+                                    : "File chunk [" + fileRegion.regionKey + "] no longer owns IO [" + fileRegion.io + "]";
+                            }) : completionListener
+                        );
                     }
                 };
+
             }
             return adjustedWriter;
         }
@@ -1320,6 +1358,7 @@ public class SharedBlobCacheService<KeyType> implements Releasable {
          * @param length of data to fetch
          * @param progressUpdater consumer to invoke with the number of copied bytes as they are written in cache.
          *                        This is used to notify waiting readers that data become available in cache.
+         * @param completionListener listener that has to be called when the callback method completes
          */
         void fillCacheRange(
             SharedBytes.IO channel,
@@ -1327,7 +1366,8 @@ public class SharedBlobCacheService<KeyType> implements Releasable {
             @Nullable SourceInputStreamFactory streamFactory,
             int relativePos,
             int length,
-            IntConsumer progressUpdater
+            IntConsumer progressUpdater,
+            ActionListener<Void> completionListener
         ) throws IOException;
     }
 
@@ -1339,9 +1379,9 @@ public class SharedBlobCacheService<KeyType> implements Releasable {
         /**
          * Create the input stream at the specified position.
          * @param relativePos the relative position in the remote storage to read from.
-         * @return the input stream ready to be read from.
+         * @param listener listener for the input stream ready to be read from.
          */
-        InputStream create(int relativePos) throws IOException;
+        void create(int relativePos, ActionListener<InputStream> listener) throws IOException;
     }
 
     private abstract static class DelegatingRangeMissingHandler implements RangeMissingHandler {
@@ -1363,9 +1403,10 @@ public class SharedBlobCacheService<KeyType> implements Releasable {
             SourceInputStreamFactory streamFactory,
             int relativePos,
             int length,
-            IntConsumer progressUpdater
+            IntConsumer progressUpdater,
+            ActionListener<Void> completionListener
         ) throws IOException {
-            delegate.fillCacheRange(channel, channelPos, streamFactory, relativePos, length, progressUpdater);
+            delegate.fillCacheRange(channel, channelPos, streamFactory, relativePos, length, progressUpdater, completionListener);
         }
     }
 

+ 151 - 69
x-pack/plugin/blob-cache/src/test/java/org/elasticsearch/blobcache/shared/SharedBlobCacheServiceTests.java

@@ -29,6 +29,7 @@ import org.elasticsearch.common.util.concurrent.DeterministicTaskQueue;
 import org.elasticsearch.common.util.concurrent.EsExecutors;
 import org.elasticsearch.common.util.concurrent.StoppableExecutorServiceWrapper;
 import org.elasticsearch.common.util.set.Sets;
+import org.elasticsearch.core.CheckedRunnable;
 import org.elasticsearch.env.Environment;
 import org.elasticsearch.env.NodeEnvironment;
 import org.elasticsearch.env.TestEnvironment;
@@ -72,6 +73,13 @@ public class SharedBlobCacheServiceTests extends ESTestCase {
         return numPages * SharedBytes.PAGE_SIZE;
     }
 
+    private static <E extends Exception> void completeWith(ActionListener<Void> listener, CheckedRunnable<E> runnable) {
+        ActionListener.completeWith(listener, () -> {
+            runnable.run();
+            return null;
+        });
+    }
+
     public void testBasicEviction() throws IOException {
         Settings settings = Settings.builder()
             .put(NODE_NAME_SETTING.getKey(), "node")
@@ -115,7 +123,10 @@ public class SharedBlobCacheServiceTests extends ESTestCase {
                 ByteRange.of(0L, 1L),
                 ByteRange.of(0L, 1L),
                 (channel, channelPos, relativePos, length) -> 1,
-                (channel, channelPos, streamFactory, relativePos, length, progressUpdater) -> progressUpdater.accept(length),
+                (channel, channelPos, streamFactory, relativePos, length, progressUpdater, completionListener) -> completeWith(
+                    completionListener,
+                    () -> progressUpdater.accept(length)
+                ),
                 taskQueue.getThreadPool().generic(),
                 bytesReadFuture
             );
@@ -552,11 +563,14 @@ public class SharedBlobCacheServiceTests extends ESTestCase {
                 cacheService.maybeFetchFullEntry(
                     cacheKey,
                     size,
-                    (channel, channelPos, streamFactory, relativePos, length, progressUpdater) -> {
-                        assert streamFactory == null : streamFactory;
-                        bytesRead.addAndGet(-length);
-                        progressUpdater.accept(length);
-                    },
+                    (channel, channelPos, streamFactory, relativePos, length, progressUpdater, completionListener) -> completeWith(
+                        completionListener,
+                        () -> {
+                            assert streamFactory == null : streamFactory;
+                            bytesRead.addAndGet(-length);
+                            progressUpdater.accept(length);
+                        }
+                    ),
                     bulkExecutor,
                     future
                 );
@@ -570,9 +584,15 @@ public class SharedBlobCacheServiceTests extends ESTestCase {
                 // a download that would use up all regions should not run
                 final var cacheKey = generateCacheKey();
                 assertEquals(2, cacheService.freeRegionCount());
-                var configured = cacheService.maybeFetchFullEntry(cacheKey, size(500), (ch, chPos, streamFactory, relPos, len, update) -> {
-                    throw new AssertionError("Should never reach here");
-                }, bulkExecutor, ActionListener.noop());
+                var configured = cacheService.maybeFetchFullEntry(
+                    cacheKey,
+                    size(500),
+                    (ch, chPos, streamFactory, relPos, len, update, completionListener) -> completeWith(completionListener, () -> {
+                        throw new AssertionError("Should never reach here");
+                    }),
+                    bulkExecutor,
+                    ActionListener.noop()
+                );
                 assertFalse(configured);
                 assertEquals(2, cacheService.freeRegionCount());
             }
@@ -613,9 +633,14 @@ public class SharedBlobCacheServiceTests extends ESTestCase {
                             (ActionListener<Void> listener) -> cacheService.maybeFetchFullEntry(
                                 cacheKey,
                                 size,
-                                (channel, channelPos, streamFactory, relativePos, length, progressUpdater) -> progressUpdater.accept(
-                                    length
-                                ),
+                                (
+                                    channel,
+                                    channelPos,
+                                    streamFactory,
+                                    relativePos,
+                                    length,
+                                    progressUpdater,
+                                    completionListener) -> completeWith(completionListener, () -> progressUpdater.accept(length)),
                                 bulkExecutor,
                                 listener
                             )
@@ -859,7 +884,10 @@ public class SharedBlobCacheServiceTests extends ESTestCase {
                 var entry = cacheService.get(cacheKey, regionSize, 0);
                 entry.populate(
                     ByteRange.of(0L, regionSize),
-                    (channel, channelPos, streamFactory, relativePos, length, progressUpdater) -> progressUpdater.accept(length),
+                    (channel, channelPos, streamFactory, relativePos, length, progressUpdater, completionListener) -> completeWith(
+                        completionListener,
+                        () -> progressUpdater.accept(length)
+                    ),
                     taskQueue.getThreadPool().generic(),
                     ActionListener.noop()
                 );
@@ -954,11 +982,14 @@ public class SharedBlobCacheServiceTests extends ESTestCase {
                     cacheKey,
                     0,
                     blobLength,
-                    (channel, channelPos, streamFactory, relativePos, length, progressUpdater) -> {
-                        assert streamFactory == null : streamFactory;
-                        bytesRead.addAndGet(length);
-                        progressUpdater.accept(length);
-                    },
+                    (channel, channelPos, streamFactory, relativePos, length, progressUpdater, completionListener) -> completeWith(
+                        completionListener,
+                        () -> {
+                            assert streamFactory == null : streamFactory;
+                            bytesRead.addAndGet(length);
+                            progressUpdater.accept(length);
+                        }
+                    ),
                     bulkExecutor,
                     future
                 );
@@ -985,11 +1016,14 @@ public class SharedBlobCacheServiceTests extends ESTestCase {
                         cacheKey,
                         region,
                         blobLength,
-                        (channel, channelPos, streamFactory, relativePos, length, progressUpdater) -> {
-                            assert streamFactory == null : streamFactory;
-                            bytesRead.addAndGet(length);
-                            progressUpdater.accept(length);
-                        },
+                        (channel, channelPos, streamFactory, relativePos, length, progressUpdater, completionListener) -> completeWith(
+                            completionListener,
+                            () -> {
+                                assert streamFactory == null : streamFactory;
+                                bytesRead.addAndGet(length);
+                                progressUpdater.accept(length);
+                            }
+                        ),
                         bulkExecutor,
                         listener
                     );
@@ -1010,13 +1044,16 @@ public class SharedBlobCacheServiceTests extends ESTestCase {
                     cacheKey,
                     randomIntBetween(0, 10),
                     randomLongBetween(1L, regionSize),
-                    (channel, channelPos, streamFactory, relativePos, length, progressUpdater) -> {
-                        throw new AssertionError("should not be executed");
-                    },
+                    (channel, channelPos, streamFactory, relativePos, length, progressUpdater, completionListener) -> completeWith(
+                        completionListener,
+                        () -> {
+                            throw new AssertionError("should not be executed");
+                        }
+                    ),
                     bulkExecutor,
                     future
                 );
-                assertThat("Listener is immediately completed", future.isDone(), is(true));
+                assertThat("Listener is immediately completionListener", future.isDone(), is(true));
                 assertThat("Region already exists in cache", future.get(), is(false));
             }
             {
@@ -1032,11 +1069,14 @@ public class SharedBlobCacheServiceTests extends ESTestCase {
                     cacheKey,
                     0,
                     blobLength,
-                    (channel, channelPos, ignore, relativePos, length, progressUpdater) -> {
-                        assert ignore == null : ignore;
-                        bytesRead.addAndGet(length);
-                        progressUpdater.accept(length);
-                    },
+                    (channel, channelPos, ignore, relativePos, length, progressUpdater, completionListener) -> completeWith(
+                        completionListener,
+                        () -> {
+                            assert ignore == null : ignore;
+                            bytesRead.addAndGet(length);
+                            progressUpdater.accept(length);
+                        }
+                    ),
                     bulkExecutor,
                     future
                 );
@@ -1110,12 +1150,15 @@ public class SharedBlobCacheServiceTests extends ESTestCase {
                     region,
                     range,
                     blobLength,
-                    (channel, channelPos, streamFactory, relativePos, length, progressUpdater) -> {
-                        assertThat(range.start() + relativePos, equalTo(cacheService.getRegionStart(region) + regionRange.start()));
-                        assertThat(channelPos, equalTo(Math.toIntExact(regionRange.start())));
-                        assertThat(length, equalTo(Math.toIntExact(regionRange.length())));
-                        bytesCopied.addAndGet(length);
-                    },
+                    (channel, channelPos, streamFactory, relativePos, length, progressUpdater, completionListener) -> completeWith(
+                        completionListener,
+                        () -> {
+                            assertThat(range.start() + relativePos, equalTo(cacheService.getRegionStart(region) + regionRange.start()));
+                            assertThat(channelPos, equalTo(Math.toIntExact(regionRange.start())));
+                            assertThat(length, equalTo(Math.toIntExact(regionRange.length())));
+                            bytesCopied.addAndGet(length);
+                        }
+                    ),
                     bulkExecutor,
                     future
                 );
@@ -1150,7 +1193,10 @@ public class SharedBlobCacheServiceTests extends ESTestCase {
                         region,
                         ByteRange.of(0L, blobLength),
                         blobLength,
-                        (channel, channelPos, streamFactory, relativePos, length, progressUpdater) -> bytesCopied.addAndGet(length),
+                        (channel, channelPos, streamFactory, relativePos, length, progressUpdater, completionListener) -> completeWith(
+                            completionListener,
+                            () -> bytesCopied.addAndGet(length)
+                        ),
                         bulkExecutor,
                         listener
                     );
@@ -1173,13 +1219,16 @@ public class SharedBlobCacheServiceTests extends ESTestCase {
                     randomIntBetween(0, 10),
                     ByteRange.of(0L, blobLength),
                     blobLength,
-                    (channel, channelPos, streamFactory, relativePos, length, progressUpdater) -> {
-                        throw new AssertionError("should not be executed");
-                    },
+                    (channel, channelPos, streamFactory, relativePos, length, progressUpdater, completionListener) -> completeWith(
+                        completionListener,
+                        () -> {
+                            throw new AssertionError("should not be executed");
+                        }
+                    ),
                     bulkExecutor,
                     future
                 );
-                assertThat("Listener is immediately completed", future.isDone(), is(true));
+                assertThat("Listener is immediately completionListener", future.isDone(), is(true));
                 assertThat("Region already exists in cache", future.get(), is(false));
             }
             {
@@ -1196,7 +1245,10 @@ public class SharedBlobCacheServiceTests extends ESTestCase {
                     0,
                     ByteRange.of(0L, blobLength),
                     blobLength,
-                    (channel, channelPos, streamFactory, relativePos, length, progressUpdater) -> bytesCopied.addAndGet(length),
+                    (channel, channelPos, streamFactory, relativePos, length, progressUpdater, completionListener) -> completeWith(
+                        completionListener,
+                        () -> bytesCopied.addAndGet(length)
+                    ),
                     bulkExecutor,
                     future
                 );
@@ -1237,10 +1289,18 @@ public class SharedBlobCacheServiceTests extends ESTestCase {
             var entry = cacheService.get(cacheKey, blobLength, 0);
             AtomicLong bytesWritten = new AtomicLong(0L);
             final PlainActionFuture<Boolean> future1 = new PlainActionFuture<>();
-            entry.populate(ByteRange.of(0, regionSize - 1), (channel, channelPos, streamFactory, relativePos, length, progressUpdater) -> {
-                bytesWritten.addAndGet(length);
-                progressUpdater.accept(length);
-            }, taskQueue.getThreadPool().generic(), future1);
+            entry.populate(
+                ByteRange.of(0, regionSize - 1),
+                (channel, channelPos, streamFactory, relativePos, length, progressUpdater, completionListener) -> completeWith(
+                    completionListener,
+                    () -> {
+                        bytesWritten.addAndGet(length);
+                        progressUpdater.accept(length);
+                    }
+                ),
+                taskQueue.getThreadPool().generic(),
+                future1
+            );
 
             assertThat(future1.isDone(), is(false));
             assertThat(taskQueue.hasRunnableTasks(), is(true));
@@ -1248,18 +1308,34 @@ public class SharedBlobCacheServiceTests extends ESTestCase {
             // start populating the second region
             entry = cacheService.get(cacheKey, blobLength, 1);
             final PlainActionFuture<Boolean> future2 = new PlainActionFuture<>();
-            entry.populate(ByteRange.of(0, regionSize - 1), (channel, channelPos, streamFactory, relativePos, length, progressUpdater) -> {
-                bytesWritten.addAndGet(length);
-                progressUpdater.accept(length);
-            }, taskQueue.getThreadPool().generic(), future2);
+            entry.populate(
+                ByteRange.of(0, regionSize - 1),
+                (channel, channelPos, streamFactory, relativePos, length, progressUpdater, completionListener) -> completeWith(
+                    completionListener,
+                    () -> {
+                        bytesWritten.addAndGet(length);
+                        progressUpdater.accept(length);
+                    }
+                ),
+                taskQueue.getThreadPool().generic(),
+                future2
+            );
 
             // start populating again the first region, listener should be called immediately
             entry = cacheService.get(cacheKey, blobLength, 0);
             final PlainActionFuture<Boolean> future3 = new PlainActionFuture<>();
-            entry.populate(ByteRange.of(0, regionSize - 1), (channel, channelPos, streamFactory, relativePos, length, progressUpdater) -> {
-                bytesWritten.addAndGet(length);
-                progressUpdater.accept(length);
-            }, taskQueue.getThreadPool().generic(), future3);
+            entry.populate(
+                ByteRange.of(0, regionSize - 1),
+                (channel, channelPos, streamFactory, relativePos, length, progressUpdater, completionListener) -> completeWith(
+                    completionListener,
+                    () -> {
+                        bytesWritten.addAndGet(length);
+                        progressUpdater.accept(length);
+                    }
+                ),
+                taskQueue.getThreadPool().generic(),
+                future3
+            );
 
             assertThat(future3.isDone(), is(true));
             var written = future3.get(10L, TimeUnit.SECONDS);
@@ -1377,7 +1453,10 @@ public class SharedBlobCacheServiceTests extends ESTestCase {
                     range,
                     range,
                     (channel, channelPos, relativePos, length) -> length,
-                    (channel, channelPos, streamFactory, relativePos, length, progressUpdater) -> progressUpdater.accept(length),
+                    (channel, channelPos, streamFactory, relativePos, length, progressUpdater, completionListener) -> completeWith(
+                        completionListener,
+                        () -> progressUpdater.accept(length)
+                    ),
                     EsExecutors.DIRECT_EXECUTOR_SERVICE,
                     future
                 );
@@ -1394,8 +1473,8 @@ public class SharedBlobCacheServiceTests extends ESTestCase {
             final var factoryClosed = new AtomicBoolean(false);
             final var dummyStreamFactory = new SourceInputStreamFactory() {
                 @Override
-                public InputStream create(int relativePos) {
-                    return null;
+                public void create(int relativePos, ActionListener<InputStream> listener) {
+                    listener.onResponse(null);
                 }
 
                 @Override
@@ -1420,17 +1499,20 @@ public class SharedBlobCacheServiceTests extends ESTestCase {
                     SourceInputStreamFactory streamFactory,
                     int relativePos,
                     int length,
-                    IntConsumer progressUpdater
+                    IntConsumer progressUpdater,
+                    ActionListener<Void> completion
                 ) throws IOException {
-                    if (invocationCounter.incrementAndGet() == 1) {
-                        final Thread witness = invocationThread.compareAndExchange(null, Thread.currentThread());
-                        assertThat(witness, nullValue());
-                    } else {
-                        assertThat(invocationThread.get(), sameInstance(Thread.currentThread()));
-                    }
-                    assertThat(streamFactory, sameInstance(dummyStreamFactory));
-                    assertThat(position.getAndSet(relativePos), lessThan(relativePos));
-                    progressUpdater.accept(length);
+                    completeWith(completion, () -> {
+                        if (invocationCounter.incrementAndGet() == 1) {
+                            final Thread witness = invocationThread.compareAndExchange(null, Thread.currentThread());
+                            assertThat(witness, nullValue());
+                        } else {
+                            assertThat(invocationThread.get(), sameInstance(Thread.currentThread()));
+                        }
+                        assertThat(streamFactory, sameInstance(dummyStreamFactory));
+                        assertThat(position.getAndSet(relativePos), lessThan(relativePos));
+                        progressUpdater.accept(length);
+                    });
                 }
             };
 

+ 33 - 26
x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/store/input/FrozenIndexInput.java

@@ -11,6 +11,7 @@ import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 import org.apache.lucene.store.IOContext;
 import org.apache.lucene.store.IndexInput;
+import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.blobcache.BlobCacheUtils;
 import org.elasticsearch.blobcache.common.ByteBufferReference;
 import org.elasticsearch.blobcache.common.ByteRange;
@@ -146,32 +147,38 @@ public final class FrozenIndexInput extends MetadataCachingIndexInput {
                 final int read = SharedBytes.readCacheFile(channel, pos, relativePos, len, byteBufferReference);
                 stats.addCachedBytesRead(read);
                 return read;
-            }, (channel, channelPos, streamFactory, relativePos, len, progressUpdater) -> {
-                assert streamFactory == null : streamFactory;
-                final long startTimeNanos = stats.currentTimeNanos();
-                try (InputStream input = openInputStreamFromBlobStore(rangeToWrite.start() + relativePos, len)) {
-                    assert ThreadPool.assertCurrentThreadPool(SearchableSnapshots.CACHE_FETCH_ASYNC_THREAD_POOL_NAME);
-                    logger.trace(
-                        "{}: writing channel {} pos {} length {} (details: {})",
-                        fileInfo.physicalName(),
-                        channelPos,
-                        relativePos,
-                        len,
-                        cacheFile
-                    );
-                    SharedBytes.copyToCacheFileAligned(
-                        channel,
-                        input,
-                        channelPos,
-                        relativePos,
-                        len,
-                        progressUpdater,
-                        writeBuffer.get().clear()
-                    );
-                    final long endTimeNanos = stats.currentTimeNanos();
-                    stats.addCachedBytesWritten(len, endTimeNanos - startTimeNanos);
-                }
-            });
+            },
+                (channel, channelPos, streamFactory, relativePos, len, progressUpdater, completionListener) -> ActionListener.completeWith(
+                    completionListener,
+                    () -> {
+                        assert streamFactory == null : streamFactory;
+                        final long startTimeNanos = stats.currentTimeNanos();
+                        try (InputStream input = openInputStreamFromBlobStore(rangeToWrite.start() + relativePos, len)) {
+                            assert ThreadPool.assertCurrentThreadPool(SearchableSnapshots.CACHE_FETCH_ASYNC_THREAD_POOL_NAME);
+                            logger.trace(
+                                "{}: writing channel {} pos {} length {} (details: {})",
+                                fileInfo.physicalName(),
+                                channelPos,
+                                relativePos,
+                                len,
+                                cacheFile
+                            );
+                            SharedBytes.copyToCacheFileAligned(
+                                channel,
+                                input,
+                                channelPos,
+                                relativePos,
+                                len,
+                                progressUpdater,
+                                writeBuffer.get().clear()
+                            );
+                            final long endTimeNanos = stats.currentTimeNanos();
+                            stats.addCachedBytesWritten(len, endTimeNanos - startTimeNanos);
+                            return null;
+                        }
+                    }
+                )
+            );
             assert bytesRead == length : bytesRead + " vs " + length;
             byteBufferReference.finish(bytesRead);
         } finally {