Browse Source

Revert "Make peer recovery send file chunks async (#44040)"

This reverts commit 30b7545995337f7186c20c699da7c569eea3d751.
Nhat Nguyen 6 years ago
parent
commit
16eb9ad531

+ 0 - 210
server/src/main/java/org/elasticsearch/indices/recovery/MultiFileTransfer.java

@@ -1,210 +0,0 @@
-/*
- * Licensed to Elasticsearch under one or more contributor
- * license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright
- * ownership. Elasticsearch licenses this file to you under
- * the Apache License, Version 2.0 (the "License"); you may
- * not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.elasticsearch.indices.recovery;
-
-import org.apache.logging.log4j.Logger;
-import org.apache.logging.log4j.message.ParameterizedMessage;
-import org.elasticsearch.Assertions;
-import org.elasticsearch.action.ActionListener;
-import org.elasticsearch.common.collect.Tuple;
-import org.elasticsearch.common.util.concurrent.AsyncIOProcessor;
-import org.elasticsearch.common.util.concurrent.ThreadContext;
-import org.elasticsearch.core.internal.io.IOUtils;
-import org.elasticsearch.index.seqno.LocalCheckpointTracker;
-import org.elasticsearch.index.store.StoreFileMetaData;
-
-import java.io.Closeable;
-import java.io.IOException;
-import java.util.Iterator;
-import java.util.List;
-import java.util.function.Consumer;
-
-import static org.elasticsearch.index.seqno.SequenceNumbers.NO_OPS_PERFORMED;
-import static org.elasticsearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO;
-
-/**
- * File chunks are sent/requested sequentially by at most one thread at any time. However, the sender/requestor won't wait for the response
- * before processing the next file chunk request to reduce the recovery time especially on secure/compressed or high latency communication.
- * <p>
- * The sender/requestor can send up to {@code maxConcurrentFileChunks} file chunk requests without waiting for responses. Since the recovery
- * target can receive file chunks out of order, it has to buffer those file chunks in memory and only flush to disk when there's no gap.
- * To ensure the recover target never buffers more than {@code maxConcurrentFileChunks} file chunks, we allow the sender/requestor to send
- * only up to {@code maxConcurrentFileChunks} file chunk requests from the last flushed (and acknowledged) file chunk. We leverage the local
- * checkpoint tracker for this purpose. We generate a new sequence number and assign it to each file chunk request before sending; then mark
- * that sequence number as processed when we receive a response for the corresponding file chunk request. With the local checkpoint tracker,
- * we know the last acknowledged-flushed file-chunk is a file chunk whose {@code requestSeqId} equals to the local checkpoint because the
- * recover target can flush all file chunks up to the local checkpoint.
- * <p>
- * When the number of un-replied file chunk requests reaches the limit (i.e. the gap between the max_seq_no and the local checkpoint is
- * greater than {@code maxConcurrentFileChunks}), the sending/requesting thread will abort its execution. That process will be resumed by
- * one of the networking threads which receive/handle the responses of the current pending file chunk requests. This process will continue
- * until all chunk requests are sent/responded.
- */
-abstract class MultiFileTransfer<Request extends MultiFileTransfer.ChunkRequest> implements Closeable {
-    private Status status = Status.PROCESSING;
-    private final Logger logger;
-    private final ActionListener<Void> listener;
-    private final LocalCheckpointTracker requestSeqIdTracker = new LocalCheckpointTracker(NO_OPS_PERFORMED, NO_OPS_PERFORMED);
-    private final AsyncIOProcessor<FileChunkResponseItem> processor;
-    private final int maxConcurrentFileChunks;
-    private StoreFileMetaData currentFile = null;
-    private final Iterator<StoreFileMetaData> remainingFiles;
-    private Tuple<StoreFileMetaData, Request> readAheadRequest = null;
-
-    protected MultiFileTransfer(Logger logger, ThreadContext threadContext, ActionListener<Void> listener,
-                                int maxConcurrentFileChunks, List<StoreFileMetaData> files) {
-        this.logger = logger;
-        this.maxConcurrentFileChunks = maxConcurrentFileChunks;
-        this.listener = listener;
-        this.processor = new AsyncIOProcessor<>(logger, maxConcurrentFileChunks, threadContext) {
-            @Override
-            protected void write(List<Tuple<FileChunkResponseItem, Consumer<Exception>>> items) {
-                handleItems(items);
-            }
-        };
-        this.remainingFiles = files.iterator();
-    }
-
-    public final void start() {
-        addItem(UNASSIGNED_SEQ_NO, null, null); // put a dummy item to start the processor
-    }
-
-    private void addItem(long requestSeqId, StoreFileMetaData md, Exception failure) {
-        processor.put(new FileChunkResponseItem(requestSeqId, md, failure), e -> { assert e == null : e; });
-    }
-
-    private void handleItems(List<Tuple<FileChunkResponseItem, Consumer<Exception>>> items) {
-        if (status != Status.PROCESSING) {
-            assert status == Status.FAILED : "must not receive any response after the transfer was completed";
-            // These exceptions will be ignored as we record only the first failure, log them for debugging purpose.
-            items.stream().filter(item -> item.v1().failure != null).forEach(item ->
-                logger.debug(new ParameterizedMessage("failed to transfer a file chunk request {}", item.v1().md), item.v1().failure));
-            return;
-        }
-        try {
-            for (Tuple<FileChunkResponseItem, Consumer<Exception>> item : items) {
-                final FileChunkResponseItem resp = item.v1();
-                if (resp.requestSeqId == UNASSIGNED_SEQ_NO) {
-                    continue; // not an actual item
-                }
-                requestSeqIdTracker.markSeqNoAsProcessed(resp.requestSeqId);
-                if (resp.failure != null) {
-                    handleError(resp.md, resp.failure);
-                    throw resp.failure;
-                }
-            }
-            while (requestSeqIdTracker.getMaxSeqNo() - requestSeqIdTracker.getProcessedCheckpoint() < maxConcurrentFileChunks) {
-                final Tuple<StoreFileMetaData, Request> request = readAheadRequest != null ? readAheadRequest : getNextRequest();
-                readAheadRequest = null;
-                if (request == null) {
-                    assert currentFile == null && remainingFiles.hasNext() == false;
-                    if (requestSeqIdTracker.getMaxSeqNo() == requestSeqIdTracker.getProcessedCheckpoint()) {
-                        onCompleted(null);
-                    }
-                    return;
-                }
-                final long requestSeqId = requestSeqIdTracker.generateSeqNo();
-                sendChunkRequest(request.v2(), ActionListener.wrap(
-                    r -> addItem(requestSeqId, request.v1(), null),
-                    e -> addItem(requestSeqId, request.v1(), e)));
-            }
-            // While we are waiting for the responses, we can prepare the next request in advance
-            // so we can send it immediately when the responses arrive to reduce the transfer time.
-            if (readAheadRequest == null) {
-                readAheadRequest = getNextRequest();
-            }
-        } catch (Exception e) {
-            onCompleted(e);
-        }
-    }
-
-    private void onCompleted(Exception failure) {
-        if (Assertions.ENABLED && status != Status.PROCESSING) {
-            throw new AssertionError("invalid status: expected [" + Status.PROCESSING + "] actual [" + status + "]", failure);
-        }
-        status = failure == null ? Status.SUCCESS : Status.FAILED;
-        try {
-            IOUtils.close(failure, this);
-        } catch (Exception e) {
-            listener.onFailure(e);
-            return;
-        }
-        listener.onResponse(null);
-    }
-
-    private Tuple<StoreFileMetaData, Request> getNextRequest() throws Exception {
-        try {
-            if (currentFile == null) {
-                if (remainingFiles.hasNext()) {
-                    currentFile = remainingFiles.next();
-                    onNewFile(currentFile);
-                } else {
-                    return null;
-                }
-            }
-            final StoreFileMetaData md = currentFile;
-            final Request request = nextChunkRequest(md);
-            if (request.lastChunk()) {
-                currentFile = null;
-            }
-            return Tuple.tuple(md, request);
-        } catch (Exception e) {
-            handleError(currentFile, e);
-            throw e;
-        }
-    }
-
-    /**
-     * This method is called when starting sending/requesting a new file. Subclasses should override
-     * this method to reset the file offset or close the previous file and open a new file if needed.
-     */
-    protected abstract void onNewFile(StoreFileMetaData md) throws IOException;
-
-    protected abstract Request nextChunkRequest(StoreFileMetaData md) throws IOException;
-
-    protected abstract void sendChunkRequest(Request request, ActionListener<Void> listener);
-
-    protected abstract void handleError(StoreFileMetaData md, Exception e) throws Exception;
-
-    private static class FileChunkResponseItem {
-        final long requestSeqId;
-        final StoreFileMetaData md;
-        final Exception failure;
-
-        FileChunkResponseItem(long requestSeqId, StoreFileMetaData md, Exception failure) {
-            this.requestSeqId = requestSeqId;
-            this.md = md;
-            this.failure = failure;
-        }
-    }
-
-    protected interface ChunkRequest {
-        /**
-         * @return {@code true} if this chunk request is the last chunk of the current file
-         */
-        boolean lastChunk();
-    }
-
-    private enum Status {
-        PROCESSING,
-        SUCCESS,
-        FAILED
-    }
-}

+ 1 - 1
server/src/main/java/org/elasticsearch/indices/recovery/PeerRecoverySourceService.java

@@ -176,7 +176,7 @@ public class PeerRecoverySourceService implements IndexEventListener {
                 final RemoteRecoveryTargetHandler recoveryTarget =
                     new RemoteRecoveryTargetHandler(request.recoveryId(), request.shardId(), transportService,
                         request.targetNode(), recoverySettings, throttleTime -> shard.recoveryStats().addThrottleTime(throttleTime));
-                handler = new RecoverySourceHandler(shard, recoveryTarget, shard.getThreadPool(), request,
+                handler = new RecoverySourceHandler(shard, recoveryTarget, request,
                     Math.toIntExact(recoverySettings.getChunkSize().getBytes()), recoverySettings.getMaxConcurrentFileChunks());
                 return handler;
             }

+ 59 - 87
server/src/main/java/org/elasticsearch/indices/recovery/RecoverySourceHandler.java

@@ -37,7 +37,7 @@ import org.elasticsearch.cluster.routing.ShardRouting;
 import org.elasticsearch.common.CheckedSupplier;
 import org.elasticsearch.common.StopWatch;
 import org.elasticsearch.common.bytes.BytesArray;
-import org.elasticsearch.common.bytes.BytesReference;
+import org.elasticsearch.common.collect.Tuple;
 import org.elasticsearch.common.lease.Releasable;
 import org.elasticsearch.common.lease.Releasables;
 import org.elasticsearch.common.logging.Loggers;
@@ -49,6 +49,7 @@ import org.elasticsearch.common.util.concurrent.FutureUtils;
 import org.elasticsearch.core.internal.io.IOUtils;
 import org.elasticsearch.index.engine.Engine;
 import org.elasticsearch.index.engine.RecoveryEngineException;
+import org.elasticsearch.index.seqno.LocalCheckpointTracker;
 import org.elasticsearch.index.seqno.RetentionLeases;
 import org.elasticsearch.index.seqno.SequenceNumbers;
 import org.elasticsearch.index.shard.IndexShard;
@@ -60,12 +61,11 @@ import org.elasticsearch.index.store.StoreFileMetaData;
 import org.elasticsearch.index.translog.Translog;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.RemoteTransportException;
-import org.elasticsearch.transport.Transports;
 
 import java.io.Closeable;
 import java.io.IOException;
+import java.io.InputStream;
 import java.util.ArrayList;
-import java.util.Arrays;
 import java.util.Collections;
 import java.util.Comparator;
 import java.util.List;
@@ -73,10 +73,13 @@ import java.util.Locale;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.CopyOnWriteArrayList;
 import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.Consumer;
 import java.util.function.IntSupplier;
 import java.util.stream.StreamSupport;
 
+import static org.elasticsearch.index.seqno.SequenceNumbers.NO_OPS_PERFORMED;
+
 /**
  * RecoverySourceHandler handles the three phases of shard recovery, which is
  * everything relating to copying the segment files as well as sending translog
@@ -99,15 +102,12 @@ public class RecoverySourceHandler {
     private final int chunkSizeInBytes;
     private final RecoveryTargetHandler recoveryTarget;
     private final int maxConcurrentFileChunks;
-    private final ThreadPool threadPool;
     private final CancellableThreads cancellableThreads = new CancellableThreads();
-    private final List<Closeable> resources = new CopyOnWriteArrayList<>();
 
-    public RecoverySourceHandler(IndexShard shard, RecoveryTargetHandler recoveryTarget, ThreadPool threadPool,
-                                 StartRecoveryRequest request, int fileChunkSizeInBytes, int maxConcurrentFileChunks) {
+    public RecoverySourceHandler(final IndexShard shard, RecoveryTargetHandler recoveryTarget, final StartRecoveryRequest request,
+                                 final int fileChunkSizeInBytes, final int maxConcurrentFileChunks) {
         this.shard = shard;
         this.recoveryTarget = recoveryTarget;
-        this.threadPool = threadPool;
         this.request = request;
         this.shardId = this.request.shardId().id();
         this.logger = Loggers.getLogger(getClass(), request.shardId(), "recover to " + request.targetNode().getName());
@@ -123,6 +123,7 @@ public class RecoverySourceHandler {
      * performs the recovery from the local engine to the target
      */
     public void recoverToTarget(ActionListener<RecoveryResponse> listener) {
+        final List<Closeable> resources = new CopyOnWriteArrayList<>();
         final Closeable releaseResources = () -> IOUtils.close(resources);
         final ActionListener<RecoveryResponse> wrappedListener = ActionListener.notifyOnce(listener);
         try {
@@ -403,17 +404,15 @@ public class RecoverySourceHandler {
                     phase1FileNames.size(), new ByteSizeValue(totalSizeInBytes),
                     phase1ExistingFileNames.size(), new ByteSizeValue(existingTotalSizeInBytes));
                 final StepListener<Void> sendFileInfoStep = new StepListener<>();
-                final StepListener<Void> sendFilesStep = new StepListener<>();
                 final StepListener<Void> cleanFilesStep = new StepListener<>();
                 cancellableThreads.execute(() ->
                     recoveryTarget.receiveFileInfo(phase1FileNames, phase1FileSizes, phase1ExistingFileNames,
                         phase1ExistingFileSizes, translogOps.getAsInt(), sendFileInfoStep));
 
-                sendFileInfoStep.whenComplete(r ->
-                    sendFiles(store, phase1Files.toArray(new StoreFileMetaData[0]), translogOps, sendFilesStep), listener::onFailure);
-
-                sendFilesStep.whenComplete(r ->
-                    cleanFiles(store, recoverySourceMetadata, translogOps, globalCheckpoint, cleanFilesStep), listener::onFailure);
+                sendFileInfoStep.whenComplete(r -> {
+                    sendFiles(store, phase1Files.toArray(new StoreFileMetaData[0]), translogOps);
+                    cleanFiles(store, recoverySourceMetadata, translogOps, globalCheckpoint, cleanFilesStep);
+                }, listener::onFailure);
 
                 final long totalSize = totalSizeInBytes;
                 final long existingTotalSize = existingTotalSizeInBytes;
@@ -572,7 +571,6 @@ public class RecoverySourceHandler {
             final long mappingVersionOnPrimary,
             final ActionListener<Long> listener) throws IOException {
         assert ThreadPool.assertCurrentMethodIsNotCalledRecursively();
-        assert Transports.assertNotTransportThread(RecoverySourceHandler.this + "[send translog]");
         final List<Translog.Operation> operations = nextBatch.get();
         // send the leftover operations or if no operations were sent, request the target to respond with its local checkpoint
         if (operations.isEmpty() == false || firstBatch) {
@@ -671,80 +669,54 @@ public class RecoverySourceHandler {
                 '}';
     }
 
-    private static class FileChunk implements MultiFileTransfer.ChunkRequest {
-        final StoreFileMetaData md;
-        final BytesReference content;
-        final long position;
-        final boolean lastChunk;
-
-        FileChunk(StoreFileMetaData md, BytesReference content, long position, boolean lastChunk) {
-            this.md = md;
-            this.content = content;
-            this.position = position;
-            this.lastChunk = lastChunk;
-        }
-
-        @Override
-        public boolean lastChunk() {
-            return lastChunk;
-        }
-    }
-
-    void sendFiles(Store store, StoreFileMetaData[] files, IntSupplier translogOps, ActionListener<Void> listener) {
+    void sendFiles(Store store, StoreFileMetaData[] files, IntSupplier translogOps) throws Exception {
         ArrayUtil.timSort(files, Comparator.comparingLong(StoreFileMetaData::length)); // send smallest first
-
-        final MultiFileTransfer<FileChunk> multiFileSender =
-            new MultiFileTransfer<>(logger, threadPool.getThreadContext(), listener, maxConcurrentFileChunks, Arrays.asList(files)) {
-
-                final byte[] buffer = new byte[chunkSizeInBytes];
-                InputStreamIndexInput currentInput = null;
-                long offset = 0;
-
-                @Override
-                protected void onNewFile(StoreFileMetaData md) throws IOException {
-                    offset = 0;
-                    IOUtils.close(currentInput, () -> currentInput = null);
-                    final IndexInput indexInput = store.directory().openInput(md.name(), IOContext.READONCE);
-                    currentInput = new InputStreamIndexInput(indexInput, md.length()) {
-                        @Override
-                        public void close() throws IOException {
-                            IOUtils.close(indexInput, super::close); // InputStreamIndexInput's close is a noop
-                        }
-                    };
-                }
-
-                @Override
-                protected FileChunk nextChunkRequest(StoreFileMetaData md) throws IOException {
-                    assert Transports.assertNotTransportThread("read file chunk");
+        final LocalCheckpointTracker requestSeqIdTracker = new LocalCheckpointTracker(NO_OPS_PERFORMED, NO_OPS_PERFORMED);
+        final AtomicReference<Tuple<StoreFileMetaData, Exception>> error = new AtomicReference<>();
+        final byte[] buffer = new byte[chunkSizeInBytes];
+        for (final StoreFileMetaData md : files) {
+            if (error.get() != null) {
+                break;
+            }
+            try (IndexInput indexInput = store.directory().openInput(md.name(), IOContext.READONCE);
+                 InputStream in = new InputStreamIndexInput(indexInput, md.length())) {
+                long position = 0;
+                int bytesRead;
+                while ((bytesRead = in.read(buffer, 0, buffer.length)) != -1) {
+                    final BytesArray content = new BytesArray(buffer, 0, bytesRead);
+                    final boolean lastChunk = position + content.length() == md.length();
+                    final long requestSeqId = requestSeqIdTracker.generateSeqNo();
+                    cancellableThreads.execute(
+                        () -> requestSeqIdTracker.waitForProcessedOpsToComplete(requestSeqId - maxConcurrentFileChunks));
                     cancellableThreads.checkForCancel();
-                    final int bytesRead = currentInput.read(buffer);
-                    if (bytesRead == -1) {
-                        throw new CorruptIndexException("file truncated; length=" + md.length() + " offset=" + offset, md.name());
+                    if (error.get() != null) {
+                        break;
                     }
-                    final boolean lastChunk = offset + bytesRead == md.length();
-                    final FileChunk chunk = new FileChunk(md, new BytesArray(buffer, 0, bytesRead), offset, lastChunk);
-                    offset += bytesRead;
-                    return chunk;
-                }
-
-                @Override
-                protected void sendChunkRequest(FileChunk request, ActionListener<Void> listener) {
-                    cancellableThreads.execute(() -> recoveryTarget.writeFileChunk(
-                        request.md, request.position, request.content, request.lastChunk, translogOps.getAsInt(), listener));
-                }
-
-                @Override
-                protected void handleError(StoreFileMetaData md, Exception e) throws Exception {
-                    handleErrorOnSendFiles(store, e, new StoreFileMetaData[]{md});
-                }
-
-                @Override
-                public void close() throws IOException {
-                    IOUtils.close(currentInput, () -> currentInput = null);
+                    final long requestFilePosition = position;
+                    cancellableThreads.executeIO(() ->
+                        recoveryTarget.writeFileChunk(md, requestFilePosition, content, lastChunk, translogOps.getAsInt(),
+                            ActionListener.wrap(
+                                r -> requestSeqIdTracker.markSeqNoAsProcessed(requestSeqId),
+                                e -> {
+                                    error.compareAndSet(null, Tuple.tuple(md, e));
+                                    requestSeqIdTracker.markSeqNoAsProcessed(requestSeqId);
+                                }
+                            )));
+                    position += content.length();
                 }
-            };
-        resources.add(multiFileSender);
-        multiFileSender.start();
+            } catch (Exception e) {
+                error.compareAndSet(null, Tuple.tuple(md, e));
+                break;
+            }
+        }
+        // When we terminate exceptionally, we don't wait for the outstanding requests as we don't use their results anyway.
+        // This allows us to end quickly and eliminate the complexity of handling requestSeqIds in case of error.
+        if (error.get() == null) {
+            cancellableThreads.execute(() -> requestSeqIdTracker.waitForProcessedOpsToComplete(requestSeqIdTracker.getMaxSeqNo()));
+        }
+        if (error.get() != null) {
+            handleErrorOnSendFiles(store, error.get().v2(), new StoreFileMetaData[]{error.get().v1()});
+        }
     }
 
     private void cleanFiles(Store store, Store.MetadataSnapshot sourceMetadata, IntSupplier translogOps,
@@ -768,7 +740,6 @@ public class RecoverySourceHandler {
 
     private void handleErrorOnSendFiles(Store store, Exception e, StoreFileMetaData[] mds) throws Exception {
         final IOException corruptIndexException = ExceptionsHelper.unwrapCorruption(e);
-        assert Transports.assertNotTransportThread(RecoverySourceHandler.this + "[handle error on send/clean files]");
         if (corruptIndexException != null) {
             Exception localException = null;
             for (StoreFileMetaData md : mds) {
@@ -792,8 +763,9 @@ public class RecoverySourceHandler {
                     shardId, request.targetNode(), mds), corruptIndexException);
                 throw remoteException;
             }
+        } else {
+            throw e;
         }
-        throw e;
     }
 
     protected void failEngine(IOException cause) {

+ 1 - 1
server/src/main/java/org/elasticsearch/indices/recovery/RemoteRecoveryTargetHandler.java

@@ -181,7 +181,7 @@ public class RemoteRecoveryTargetHandler implements RecoveryTargetHandler {
                  * would be in to restart file copy again (new deltas) if we have too many translog ops are piling up.
                  */
                 throttleTimeInNanos), fileChunkRequestOptions, new ActionListenerResponseHandler<>(
-                    ActionListener.map(listener, r -> null), in -> TransportResponse.Empty.INSTANCE, ThreadPool.Names.GENERIC));
+                    ActionListener.map(listener, r -> null), in -> TransportResponse.Empty.INSTANCE));
     }
 
 }

+ 119 - 84
server/src/test/java/org/elasticsearch/indices/recovery/RecoverySourceHandlerTests.java

@@ -32,11 +32,11 @@ import org.apache.lucene.index.Term;
 import org.apache.lucene.store.BaseDirectoryWrapper;
 import org.apache.lucene.store.Directory;
 import org.apache.lucene.store.IOContext;
-import org.apache.lucene.util.SetOnce;
+import org.apache.lucene.util.BytesRef;
+import org.apache.lucene.util.BytesRefIterator;
 import org.elasticsearch.ExceptionsHelper;
 import org.elasticsearch.Version;
 import org.elasticsearch.action.ActionListener;
-import org.elasticsearch.action.LatchedActionListener;
 import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.cluster.metadata.IndexMetaData;
 import org.elasticsearch.cluster.node.DiscoveryNode;
@@ -77,7 +77,6 @@ import org.elasticsearch.test.CorruptionUtils;
 import org.elasticsearch.test.DummyShardLock;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.IndexSettingsModule;
-import org.elasticsearch.threadpool.FixedExecutorBuilder;
 import org.elasticsearch.threadpool.TestThreadPool;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.junit.After;
@@ -94,11 +93,10 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.concurrent.CopyOnWriteArrayList;
-import java.util.concurrent.CountDownLatch;
-import java.util.concurrent.Executor;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicLong;
+import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.IntSupplier;
 import java.util.zip.CRC32;
 
@@ -107,7 +105,7 @@ import static java.util.Collections.emptySet;
 import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.hasSize;
-import static org.hamcrest.Matchers.instanceOf;
+import static org.hamcrest.core.IsNull.notNullValue;
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.anyBoolean;
 import static org.mockito.Matchers.anyObject;
@@ -123,19 +121,10 @@ public class RecoverySourceHandlerTests extends ESTestCase {
     private final ClusterSettings service = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
 
     private ThreadPool threadPool;
-    private Executor recoveryExecutor;
 
     @Before
     public void setUpThreadPool() {
-        if (randomBoolean()) {
-            threadPool = new TestThreadPool(getTestName());
-            recoveryExecutor = threadPool.generic();
-        } else {
-            // verify that both sending and receiving files can be completed with a single thread
-            threadPool = new TestThreadPool(getTestName(),
-                new FixedExecutorBuilder(Settings.EMPTY, "recovery_executor", between(1, 16), between(16, 128), "recovery_executor"));
-            recoveryExecutor = threadPool.executor("recovery_executor");
-        }
+        threadPool = new TestThreadPool(getTestName());
     }
 
     @After
@@ -144,7 +133,9 @@ public class RecoverySourceHandlerTests extends ESTestCase {
     }
 
     public void testSendFiles() throws Throwable {
-        final RecoverySettings recoverySettings = new RecoverySettings(Settings.EMPTY, service);
+        Settings settings = Settings.builder().put("indices.recovery.concurrent_streams", 1).
+            put("indices.recovery.concurrent_small_file_streams", 1).build();
+        final RecoverySettings recoverySettings = new RecoverySettings(settings, service);
         final StartRecoveryRequest request = getStartRecoveryRequest();
         Store store = newStore(createTempDir());
         Directory dir = store.directory();
@@ -165,22 +156,38 @@ public class RecoverySourceHandlerTests extends ESTestCase {
             metas.add(md);
         }
         Store targetStore = newStore(createTempDir());
-        MultiFileWriter multiFileWriter = new MultiFileWriter(targetStore, mock(RecoveryState.Index.class), "", logger, () -> {});
         RecoveryTargetHandler target = new TestRecoveryTargetHandler() {
+            IndexOutputOutputStream out;
             @Override
             public void writeFileChunk(StoreFileMetaData md, long position, BytesReference content, boolean lastChunk,
                                        int totalTranslogOps, ActionListener<Void> listener) {
-                ActionListener.completeWith(listener, () -> {
-                    multiFileWriter.writeFileChunk(md, position, content, lastChunk);
-                    return null;
-                });
+                try {
+                    if (position == 0) {
+                        out = new IndexOutputOutputStream(targetStore.createVerifyingOutput(md.name(), md, IOContext.DEFAULT)) {
+                            @Override
+                            public void close() throws IOException {
+                                super.close();
+                                targetStore.directory().sync(Collections.singleton(md.name())); // sync otherwise MDW will mess with it
+                            }
+                        };
+                    }
+                    final BytesRefIterator iterator = content.iterator();
+                    BytesRef scratch;
+                    while ((scratch = iterator.next()) != null) {
+                        out.write(scratch.bytes, scratch.offset, scratch.length);
+                    }
+                    if (lastChunk) {
+                        out.close();
+                    }
+                    listener.onResponse(null);
+                } catch (Exception e) {
+                    listener.onFailure(e);
+                }
             }
         };
-        RecoverySourceHandler handler = new RecoverySourceHandler(null, new AsyncRecoveryTarget(target, recoveryExecutor),
-            threadPool, request, Math.toIntExact(recoverySettings.getChunkSize().getBytes()), between(1, 5));
-        PlainActionFuture<Void> sendFilesFuture = new PlainActionFuture<>();
-        handler.sendFiles(store, metas.toArray(new StoreFileMetaData[0]), () -> 0, sendFilesFuture);
-        sendFilesFuture.actionGet();
+        RecoverySourceHandler handler = new RecoverySourceHandler(null, target, request,
+            Math.toIntExact(recoverySettings.getChunkSize().getBytes()), between(1, 5));
+        handler.sendFiles(store, metas.toArray(new StoreFileMetaData[0]), () -> 0);
         Store.MetadataSnapshot targetStoreMetadata = targetStore.getMetadata(null);
         Store.RecoveryDiff recoveryDiff = targetStoreMetadata.recoveryDiff(metadata);
         assertEquals(metas.size(), recoveryDiff.identical.size());
@@ -188,7 +195,7 @@ public class RecoverySourceHandlerTests extends ESTestCase {
         assertEquals(0, recoveryDiff.missing.size());
         IndexReader reader = DirectoryReader.open(targetStore.directory());
         assertEquals(numDocs, reader.maxDoc());
-        IOUtils.close(reader, store, multiFileWriter, targetStore);
+        IOUtils.close(reader, store, targetStore);
     }
 
     public StartRecoveryRequest getStartRecoveryRequest() throws IOException {
@@ -234,11 +241,10 @@ public class RecoverySourceHandlerTests extends ESTestCase {
                                                 RetentionLeases retentionLeases, long mappingVersion, ActionListener<Long> listener) {
                 shippedOps.addAll(operations);
                 checkpointOnTarget.set(randomLongBetween(checkpointOnTarget.get(), Long.MAX_VALUE));
-                listener.onResponse(checkpointOnTarget.get());
-            }
+                listener.onResponse(checkpointOnTarget.get());            }
         };
-        RecoverySourceHandler handler = new RecoverySourceHandler(shard, new AsyncRecoveryTarget(recoveryTarget, threadPool.generic()),
-            threadPool, request, fileChunkSizeInBytes, between(1, 10));
+        RecoverySourceHandler handler = new RecoverySourceHandler(
+            shard, new AsyncRecoveryTarget(recoveryTarget, threadPool.generic()), request, fileChunkSizeInBytes, between(1, 10));
         PlainActionFuture<RecoverySourceHandler.SendSnapshotResult> future = new PlainActionFuture<>();
         handler.phase2(startingSeqNo, endingSeqNo, newTranslogSnapshot(operations, Collections.emptyList()),
             randomNonNegativeLong(), randomNonNegativeLong(), RetentionLeases.EMPTY, randomNonNegativeLong(), future);
@@ -277,8 +283,8 @@ public class RecoverySourceHandlerTests extends ESTestCase {
                 }
             }
         };
-        RecoverySourceHandler handler = new RecoverySourceHandler(shard, new AsyncRecoveryTarget(recoveryTarget, threadPool.generic()),
-            threadPool, request, fileChunkSizeInBytes, between(1, 10));
+        RecoverySourceHandler handler = new RecoverySourceHandler(
+            shard, new AsyncRecoveryTarget(recoveryTarget, threadPool.generic()), request, fileChunkSizeInBytes, between(1, 10));
         PlainActionFuture<RecoverySourceHandler.SendSnapshotResult> future = new PlainActionFuture<>();
         final long startingSeqNo = randomLongBetween(0, ops.size() - 1L);
         final long endingSeqNo = randomLongBetween(startingSeqNo, ops.size() - 1L);
@@ -337,36 +343,52 @@ public class RecoverySourceHandlerTests extends ESTestCase {
             (p.getFileName().toString().equals("write.lock") ||
                 p.getFileName().toString().startsWith("extra")) == false));
         Store targetStore = newStore(createTempDir(), false);
-        MultiFileWriter multiFileWriter = new MultiFileWriter(targetStore, mock(RecoveryState.Index.class), "", logger, () -> {});
         RecoveryTargetHandler target = new TestRecoveryTargetHandler() {
+            IndexOutputOutputStream out;
              @Override
             public void writeFileChunk(StoreFileMetaData md, long position, BytesReference content, boolean lastChunk,
                                        int totalTranslogOps, ActionListener<Void> listener) {
-                 ActionListener.completeWith(listener, () -> {
-                     multiFileWriter.writeFileChunk(md, position, content, lastChunk);
-                     return null;
-                 });
+                try {
+                    if (position == 0) {
+                        out = new IndexOutputOutputStream(targetStore.createVerifyingOutput(md.name(), md, IOContext.DEFAULT)) {
+                            @Override
+                            public void close() throws IOException {
+                                super.close();
+                                targetStore.directory().sync(Collections.singleton(md.name())); // sync otherwise MDW will mess with it
+                            }
+                        };
+                    }
+                    final BytesRefIterator iterator = content.iterator();
+                    BytesRef scratch;
+                    while ((scratch = iterator.next()) != null) {
+                        out.write(scratch.bytes, scratch.offset, scratch.length);
+                    }
+                    if (lastChunk) {
+                        out.close();
+                    }
+                    listener.onResponse(null);
+                } catch (Exception e) {
+                    IOUtils.closeWhileHandlingException(out, () -> listener.onFailure(e));
+                }
             }
         };
-        RecoverySourceHandler handler = new RecoverySourceHandler(null, new AsyncRecoveryTarget(target, recoveryExecutor), threadPool,
-            request, Math.toIntExact(recoverySettings.getChunkSize().getBytes()), between(1, 8)) {
+        RecoverySourceHandler handler = new RecoverySourceHandler(null, target, request,
+            Math.toIntExact(recoverySettings.getChunkSize().getBytes()), between(1, 8)) {
             @Override
             protected void failEngine(IOException cause) {
                 assertFalse(failedEngine.get());
                 failedEngine.set(true);
             }
         };
-        SetOnce<Exception> sendFilesError = new SetOnce<>();
-        CountDownLatch latch = new CountDownLatch(1);
-        handler.sendFiles(store, metas.toArray(new StoreFileMetaData[0]), () -> 0,
-            new LatchedActionListener<>(ActionListener.wrap(r -> sendFilesError.set(null), e -> sendFilesError.set(e)), latch));
-        latch.await();
-        assertThat(sendFilesError.get(), instanceOf(IOException.class));
-        assertNotNull(ExceptionsHelper.unwrapCorruption(sendFilesError.get()));
+
+        try {
+            handler.sendFiles(store, metas.toArray(new StoreFileMetaData[0]), () -> 0);
+            fail("corrupted index");
+        } catch (IOException ex) {
+            assertNotNull(ExceptionsHelper.unwrapCorruption(ex));
+        }
         assertTrue(failedEngine.get());
-        // ensure all chunk requests have been completed; otherwise some files on the target are left open.
-        IOUtils.close(() -> terminate(threadPool), () -> threadPool = null);
-        IOUtils.close(store, multiFileWriter, targetStore);
+        IOUtils.close(store, targetStore);
     }
 
 
@@ -405,24 +427,28 @@ public class RecoverySourceHandlerTests extends ESTestCase {
                 }
             }
         };
-        RecoverySourceHandler handler = new RecoverySourceHandler(null, new AsyncRecoveryTarget(target, recoveryExecutor), threadPool,
-            request, Math.toIntExact(recoverySettings.getChunkSize().getBytes()), between(1, 10)) {
+        RecoverySourceHandler handler = new RecoverySourceHandler(null, target, request,
+            Math.toIntExact(recoverySettings.getChunkSize().getBytes()), between(1, 10)) {
             @Override
             protected void failEngine(IOException cause) {
                 assertFalse(failedEngine.get());
                 failedEngine.set(true);
             }
         };
-        PlainActionFuture<Void> sendFilesFuture = new PlainActionFuture<>();
-        handler.sendFiles(store, metas.toArray(new StoreFileMetaData[0]), () -> 0, sendFilesFuture);
-        Exception ex = expectThrows(Exception.class, sendFilesFuture::actionGet);
-        final IOException unwrappedCorruption = ExceptionsHelper.unwrapCorruption(ex);
-        if (throwCorruptedIndexException) {
-            assertNotNull(unwrappedCorruption);
-            assertEquals(ex.getMessage(), "[File corruption occurred on recovery but checksums are ok]");
-        } else {
-            assertNull(unwrappedCorruption);
-            assertEquals(ex.getMessage(), "boom");
+        try {
+            handler.sendFiles(store, metas.toArray(new StoreFileMetaData[0]), () -> 0);
+            fail("exception index");
+        } catch (RuntimeException ex) {
+            final IOException unwrappedCorruption = ExceptionsHelper.unwrapCorruption(ex);
+            if (throwCorruptedIndexException) {
+                assertNotNull(unwrappedCorruption);
+                assertEquals(ex.getMessage(), "[File corruption occurred on recovery but checksums are ok]");
+            } else {
+                assertNull(unwrappedCorruption);
+                assertEquals(ex.getMessage(), "boom");
+            }
+        } catch (CorruptIndexException ex) {
+            fail("not expected here");
         }
         assertFalse(failedEngine.get());
         IOUtils.close(store);
@@ -446,7 +472,6 @@ public class RecoverySourceHandlerTests extends ESTestCase {
         final RecoverySourceHandler handler = new RecoverySourceHandler(
                 shard,
                 mock(RecoveryTargetHandler.class),
-                threadPool,
                 request,
                 Math.toIntExact(recoverySettings.getChunkSize().getBytes()),
                 between(1, 8)) {
@@ -525,13 +550,19 @@ public class RecoverySourceHandlerTests extends ESTestCase {
         };
         final int maxConcurrentChunks = between(1, 8);
         final int chunkSize = between(1, 32);
-        final RecoverySourceHandler handler = new RecoverySourceHandler(shard, recoveryTarget, threadPool, getStartRecoveryRequest(),
+        final RecoverySourceHandler handler = new RecoverySourceHandler(shard, recoveryTarget, getStartRecoveryRequest(),
             chunkSize, maxConcurrentChunks);
         Store store = newStore(createTempDir(), false);
         List<StoreFileMetaData> files = generateFiles(store, between(1, 10), () -> between(1, chunkSize * 20));
         int totalChunks = files.stream().mapToInt(md -> ((int) md.length() + chunkSize - 1) / chunkSize).sum();
-        PlainActionFuture<Void> sendFilesFuture = new PlainActionFuture<>();
-        handler.sendFiles(store, files.toArray(new StoreFileMetaData[0]), () -> 0, sendFilesFuture);
+        Thread sender = new Thread(() -> {
+            try {
+                handler.sendFiles(store, files.toArray(new StoreFileMetaData[0]), () -> 0);
+            } catch (Exception ex) {
+                throw new AssertionError(ex);
+            }
+        });
+        sender.start();
         assertBusy(() -> {
             assertThat(sentChunks.get(), equalTo(Math.min(totalChunks, maxConcurrentChunks)));
             assertThat(unrepliedChunks, hasSize(sentChunks.get()));
@@ -563,11 +594,13 @@ public class RecoverySourceHandlerTests extends ESTestCase {
                 assertThat(unrepliedChunks, hasSize(expectedUnrepliedChunks));
             });
         }
-        sendFilesFuture.actionGet();
+        sender.join();
         store.close();
     }
 
     public void testSendFileChunksStopOnError() throws Exception {
+        final IndexShard shard = mock(IndexShard.class);
+        when(shard.state()).thenReturn(IndexShardState.STARTED);
         final List<FileChunkResponse> unrepliedChunks = new CopyOnWriteArrayList<>();
         final AtomicInteger sentChunks = new AtomicInteger();
         final TestRecoveryTargetHandler recoveryTarget = new TestRecoveryTargetHandler() {
@@ -583,23 +616,23 @@ public class RecoverySourceHandlerTests extends ESTestCase {
         };
         final int maxConcurrentChunks = between(1, 4);
         final int chunkSize = between(1, 16);
-        final RecoverySourceHandler handler = new RecoverySourceHandler(null, new AsyncRecoveryTarget(recoveryTarget, recoveryExecutor),
-            threadPool, getStartRecoveryRequest(), chunkSize, maxConcurrentChunks);
+        final RecoverySourceHandler handler = new RecoverySourceHandler(shard, recoveryTarget, getStartRecoveryRequest(),
+            chunkSize, maxConcurrentChunks);
         Store store = newStore(createTempDir(), false);
         List<StoreFileMetaData> files = generateFiles(store, between(1, 10), () -> between(1, chunkSize * 20));
         int totalChunks = files.stream().mapToInt(md -> ((int) md.length() + chunkSize - 1) / chunkSize).sum();
-        SetOnce<Exception> sendFilesError = new SetOnce<>();
-        CountDownLatch sendFilesLatch = new CountDownLatch(1);
-        handler.sendFiles(store, files.toArray(new StoreFileMetaData[0]), () -> 0,
-            new LatchedActionListener<>(ActionListener.wrap(r -> sendFilesError.set(null), e -> sendFilesError.set(e)), sendFilesLatch));
+        AtomicReference<Exception> error = new AtomicReference<>();
+        Thread sender = new Thread(() -> {
+            try {
+                handler.sendFiles(store, files.toArray(new StoreFileMetaData[0]), () -> 0);
+            } catch (Exception ex) {
+                error.set(ex);
+            }
+        });
+        sender.start();
         assertBusy(() -> assertThat(sentChunks.get(), equalTo(Math.min(totalChunks, maxConcurrentChunks))));
         List<FileChunkResponse> failedChunks = randomSubsetOf(between(1, unrepliedChunks.size()), unrepliedChunks);
-        CountDownLatch replyLatch = new CountDownLatch(failedChunks.size());
-        failedChunks.forEach(c -> {
-            c.listener.onFailure(new IllegalStateException("test chunk exception"));
-            replyLatch.countDown();
-        });
-        replyLatch.await();
+        failedChunks.forEach(c -> c.listener.onFailure(new RuntimeException("test chunk exception")));
         unrepliedChunks.removeAll(failedChunks);
         unrepliedChunks.forEach(c -> {
             if (randomBoolean()) {
@@ -608,10 +641,12 @@ public class RecoverySourceHandlerTests extends ESTestCase {
                 c.listener.onResponse(null);
             }
         });
-        sendFilesLatch.await();
-        assertThat(sendFilesError.get(), instanceOf(IllegalStateException.class));
-        assertThat(sendFilesError.get().getMessage(), containsString("test chunk exception"));
+        assertBusy(() -> {
+            assertThat(error.get(), notNullValue());
+            assertThat(error.get().getMessage(), containsString("test chunk exception"));
+        });
         assertThat("no more chunks should be sent", sentChunks.get(), equalTo(Math.min(totalChunks, maxConcurrentChunks)));
+        sender.join();
         store.close();
     }
 
@@ -619,7 +654,7 @@ public class RecoverySourceHandlerTests extends ESTestCase {
         IndexShard shard = mock(IndexShard.class);
         when(shard.state()).thenReturn(IndexShardState.STARTED);
         RecoverySourceHandler handler = new RecoverySourceHandler(
-            shard, new TestRecoveryTargetHandler(), threadPool, getStartRecoveryRequest(), between(1, 16), between(1, 4));
+            shard, new TestRecoveryTargetHandler(), getStartRecoveryRequest(), between(1, 16), between(1, 4));
 
         String syncId = UUIDs.randomBase64UUID();
         int numDocs = between(0, 1000);

+ 1 - 1
test/framework/src/main/java/org/elasticsearch/index/shard/IndexShardTestCase.java

@@ -635,7 +635,7 @@ public abstract class IndexShardTestCase extends ESTestCase {
         final StartRecoveryRequest request = new StartRecoveryRequest(replica.shardId(), targetAllocationId,
             pNode, rNode, snapshot, replica.routingEntry().primary(), 0, startingSeqNo);
         final RecoverySourceHandler recovery = new RecoverySourceHandler(primary,
-            new AsyncRecoveryTarget(recoveryTarget, threadPool.generic()), threadPool,
+            new AsyncRecoveryTarget(recoveryTarget, threadPool.generic()),
             request, Math.toIntExact(ByteSizeUnit.MB.toBytes(1)), between(1, 8));
         primary.updateShardState(primary.routingEntry(), primary.getPendingPrimaryTerm(), null,
             currentClusterStateVersion.incrementAndGet(), inSyncIds, routingTable);

+ 1 - 0
test/framework/src/main/java/org/elasticsearch/indices/recovery/AsyncRecoveryTarget.java

@@ -83,6 +83,7 @@ public class AsyncRecoveryTarget implements RecoveryTargetHandler {
     @Override
     public void writeFileChunk(StoreFileMetaData fileMetaData, long position, BytesReference content,
                                boolean lastChunk, int totalTranslogOps, ActionListener<Void> listener) {
+        // TODO: remove this clone once we send file chunk async
         final BytesReference copy = new BytesArray(BytesRef.deepCopyOf(content.toBytesRef()));
         executor.execute(() -> target.writeFileChunk(fileMetaData, position, copy, lastChunk, totalTranslogOps, listener));
     }