Pārlūkot izejas kodu

Unify nio read / write channel contexts (#28160)

This commit is related to #27260. Right now we have separate read and
write contexts for implementing specific protocol logic. However, some
protocols require a closer relationship between read and write
operations than is allowed by our current model. An example is HTTP
which might require a write if some problem with request parsing was
encountered.

Additionally, some protocols require close messages to be sent when a
channel is shutdown. This is also problematic in our current model,
where we assume that channels should simply be queued for close and
forgotten.

This commit transitions to a single ChannelContext which implements
all read, write, and close logic for protocols. It is the job of the
context to tell the selector when to close the channel. A channel can
still be manually queued for close with a selector. This is how server
channels are closed for now. And this route allows timeout mechanisms on
normal channel closes to be implemented.
Tim Brooks 7 gadi atpakaļ
vecāks
revīzija
4ea9ddb7d3
30 mainītis faili ar 1025 papildinājumiem un 890 dzēšanām
  1. 1 21
      libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/AbstractNioChannel.java
  2. 169 0
      libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/BytesChannelContext.java
  3. 0 64
      libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/BytesReadContext.java
  4. 0 111
      libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/BytesWriteContext.java
  5. 88 0
      libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/BytesWriteOperation.java
  6. 81 0
      libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/ChannelContext.java
  7. 1 2
      libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/ChannelFactory.java
  8. 5 0
      libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/ESSelector.java
  9. 0 2
      libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/NioChannel.java
  10. 26 38
      libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/NioSocketChannel.java
  11. 0 35
      libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/ReadContext.java
  12. 54 1
      libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SelectionKeyUtils.java
  13. 28 13
      libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SocketEventHandler.java
  14. 8 6
      libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SocketSelector.java
  15. 0 37
      libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/WriteContext.java
  16. 8 66
      libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/WriteOperation.java
  17. 1 1
      libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/AcceptorEventHandlerTests.java
  18. 337 0
      libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/BytesChannelContextTests.java
  19. 0 142
      libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/BytesReadContextTests.java
  20. 0 212
      libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/BytesWriteContextTests.java
  21. 1 1
      libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/ChannelFactoryTests.java
  22. 1 1
      libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/NioServerSocketChannelTests.java
  23. 43 4
      libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/NioSocketChannelTests.java
  24. 69 50
      libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/SocketEventHandlerTests.java
  25. 27 21
      libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/SocketSelectorTests.java
  26. 24 37
      libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/WriteOperationTests.java
  27. 24 14
      plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTransport.java
  28. 6 1
      plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/TcpNioServerSocketChannel.java
  29. 7 2
      plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/TcpNioSocketChannel.java
  30. 16 8
      test/framework/src/main/java/org/elasticsearch/transport/nio/MockNioTransport.java

+ 1 - 21
libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/AbstractNioChannel.java

@@ -26,7 +26,6 @@ import java.nio.channels.NetworkChannel;
 import java.nio.channels.SelectableChannel;
 import java.nio.channels.SelectionKey;
 import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.function.BiConsumer;
 
 /**
@@ -48,9 +47,6 @@ import java.util.function.BiConsumer;
 public abstract class AbstractNioChannel<S extends SelectableChannel & NetworkChannel> implements NioChannel {
 
     final S socketChannel;
-    // This indicates if the channel has been scheduled to be closed. Read the closeFuture to determine if
-    // the channel close process has completed.
-    final AtomicBoolean isClosing = new AtomicBoolean(false);
 
     private final InetSocketAddress localAddress;
     private final CompletableFuture<Void> closeContext = new CompletableFuture<>();
@@ -73,21 +69,6 @@ public abstract class AbstractNioChannel<S extends SelectableChannel & NetworkCh
         return localAddress;
     }
 
-    /**
-     * Schedules a channel to be closed by the selector event loop with which it is registered.
-     * <p>
-     * If the channel is open and the state can be transitioned to closed, the close operation will
-     * be scheduled with the event loop.
-     * <p>
-     * If the channel is already set to closed, it is assumed that it is already scheduled to be closed.
-     */
-    @Override
-    public void close() {
-        if (isClosing.compareAndSet(false, true)) {
-            selector.queueChannelClose(this);
-        }
-    }
-
     /**
      * Closes the channel synchronously. This method should only be called from the selector thread.
      * <p>
@@ -95,8 +76,7 @@ public abstract class AbstractNioChannel<S extends SelectableChannel & NetworkCh
      */
     @Override
     public void closeFromSelector() throws IOException {
-        assert selector.isOnCurrentThread() : "Should only call from selector thread";
-        isClosing.set(true);
+        selector.assertOnSelectorThread();
         if (closeContext.isDone() == false) {
             try {
                 closeRawChannel();

+ 169 - 0
libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/BytesChannelContext.java

@@ -0,0 +1,169 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.nio;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.channels.ClosedChannelException;
+import java.util.LinkedList;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.function.BiConsumer;
+
+public class BytesChannelContext implements ChannelContext {
+
+    private final NioSocketChannel channel;
+    private final ReadConsumer readConsumer;
+    private final InboundChannelBuffer channelBuffer;
+    private final LinkedList<BytesWriteOperation> queued = new LinkedList<>();
+    private final AtomicBoolean isClosing = new AtomicBoolean(false);
+    private boolean peerClosed = false;
+    private boolean ioException = false;
+
+    public BytesChannelContext(NioSocketChannel channel, ReadConsumer readConsumer, InboundChannelBuffer channelBuffer) {
+        this.channel = channel;
+        this.readConsumer = readConsumer;
+        this.channelBuffer = channelBuffer;
+    }
+
+    @Override
+    public void channelRegistered() throws IOException {}
+
+    @Override
+    public int read() throws IOException {
+        if (channelBuffer.getRemaining() == 0) {
+            // Requiring one additional byte will ensure that a new page is allocated.
+            channelBuffer.ensureCapacity(channelBuffer.getCapacity() + 1);
+        }
+
+        int bytesRead;
+        try {
+            bytesRead = channel.read(channelBuffer.sliceBuffersFrom(channelBuffer.getIndex()));
+        } catch (IOException ex) {
+            ioException = true;
+            throw ex;
+        }
+
+        if (bytesRead == -1) {
+            peerClosed = true;
+            return 0;
+        }
+
+        channelBuffer.incrementIndex(bytesRead);
+
+        int bytesConsumed = Integer.MAX_VALUE;
+        while (bytesConsumed > 0 && channelBuffer.getIndex() > 0) {
+            bytesConsumed = readConsumer.consumeReads(channelBuffer);
+            channelBuffer.release(bytesConsumed);
+        }
+
+        return bytesRead;
+    }
+
+    @Override
+    public void sendMessage(ByteBuffer[] buffers, BiConsumer<Void, Throwable> listener) {
+        if (isClosing.get()) {
+            listener.accept(null, new ClosedChannelException());
+            return;
+        }
+
+        BytesWriteOperation writeOperation = new BytesWriteOperation(channel, buffers, listener);
+        SocketSelector selector = channel.getSelector();
+        if (selector.isOnCurrentThread() == false) {
+            selector.queueWrite(writeOperation);
+            return;
+        }
+
+        // TODO: Eval if we will allow writes from sendMessage
+        selector.queueWriteInChannelBuffer(writeOperation);
+    }
+
+    @Override
+    public void queueWriteOperation(WriteOperation writeOperation) {
+        channel.getSelector().assertOnSelectorThread();
+        queued.add((BytesWriteOperation) writeOperation);
+    }
+
+    @Override
+    public void flushChannel() throws IOException {
+        channel.getSelector().assertOnSelectorThread();
+        int ops = queued.size();
+        if (ops == 1) {
+            singleFlush(queued.pop());
+        } else if (ops > 1) {
+            multiFlush();
+        }
+    }
+
+    @Override
+    public boolean hasQueuedWriteOps() {
+        channel.getSelector().assertOnSelectorThread();
+        return queued.isEmpty() == false;
+    }
+
+    @Override
+    public void closeChannel() {
+        if (isClosing.compareAndSet(false, true)) {
+            channel.getSelector().queueChannelClose(channel);
+        }
+    }
+
+    @Override
+    public boolean selectorShouldClose() {
+        return peerClosed || ioException || isClosing.get();
+    }
+
+    @Override
+    public void closeFromSelector() {
+        channel.getSelector().assertOnSelectorThread();
+        // Set to true in order to reject new writes before queuing with selector
+        isClosing.set(true);
+        channelBuffer.close();
+        for (BytesWriteOperation op : queued) {
+            channel.getSelector().executeFailedListener(op.getListener(), new ClosedChannelException());
+        }
+        queued.clear();
+    }
+
+    private void singleFlush(BytesWriteOperation headOp) throws IOException {
+        try {
+            int written = channel.write(headOp.getBuffersToWrite());
+            headOp.incrementIndex(written);
+        } catch (IOException e) {
+            channel.getSelector().executeFailedListener(headOp.getListener(), e);
+            ioException = true;
+            throw e;
+        }
+
+        if (headOp.isFullyFlushed()) {
+            channel.getSelector().executeListener(headOp.getListener(), null);
+        } else {
+            queued.push(headOp);
+        }
+    }
+
+    private void multiFlush() throws IOException {
+        boolean lastOpCompleted = true;
+        while (lastOpCompleted && queued.isEmpty() == false) {
+            BytesWriteOperation op = queued.pop();
+            singleFlush(op);
+            lastOpCompleted = op.isFullyFlushed();
+        }
+    }
+}

+ 0 - 64
libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/BytesReadContext.java

@@ -1,64 +0,0 @@
-/*
- * Licensed to Elasticsearch under one or more contributor
- * license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright
- * ownership. Elasticsearch licenses this file to you under
- * the Apache License, Version 2.0 (the "License"); you may
- * not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.elasticsearch.nio;
-
-import java.io.IOException;
-
-public class BytesReadContext implements ReadContext {
-
-    private final NioSocketChannel channel;
-    private final ReadConsumer readConsumer;
-    private final InboundChannelBuffer channelBuffer;
-
-    public BytesReadContext(NioSocketChannel channel, ReadConsumer readConsumer, InboundChannelBuffer channelBuffer) {
-        this.channel = channel;
-        this.channelBuffer = channelBuffer;
-        this.readConsumer = readConsumer;
-    }
-
-    @Override
-    public int read() throws IOException {
-        if (channelBuffer.getRemaining() == 0) {
-            // Requiring one additional byte will ensure that a new page is allocated.
-            channelBuffer.ensureCapacity(channelBuffer.getCapacity() + 1);
-        }
-
-        int bytesRead = channel.read(channelBuffer.sliceBuffersFrom(channelBuffer.getIndex()));
-
-        if (bytesRead == -1) {
-            return bytesRead;
-        }
-
-        channelBuffer.incrementIndex(bytesRead);
-
-        int bytesConsumed = Integer.MAX_VALUE;
-        while (bytesConsumed > 0) {
-            bytesConsumed = readConsumer.consumeReads(channelBuffer);
-            channelBuffer.release(bytesConsumed);
-        }
-
-        return bytesRead;
-    }
-
-    @Override
-    public void close() {
-        channelBuffer.close();
-    }
-}

+ 0 - 111
libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/BytesWriteContext.java

@@ -1,111 +0,0 @@
-/*
- * Licensed to Elasticsearch under one or more contributor
- * license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright
- * ownership. Elasticsearch licenses this file to you under
- * the Apache License, Version 2.0 (the "License"); you may
- * not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.elasticsearch.nio;
-
-import java.io.IOException;
-import java.nio.ByteBuffer;
-import java.nio.channels.ClosedChannelException;
-import java.util.LinkedList;
-import java.util.function.BiConsumer;
-
-public class BytesWriteContext implements WriteContext {
-
-    private final NioSocketChannel channel;
-    private final LinkedList<WriteOperation> queued = new LinkedList<>();
-
-    public BytesWriteContext(NioSocketChannel channel) {
-        this.channel = channel;
-    }
-
-    @Override
-    public void sendMessage(Object message, BiConsumer<Void, Throwable> listener) {
-        ByteBuffer[] buffers = (ByteBuffer[]) message;
-        if (channel.isWritable() == false) {
-            listener.accept(null, new ClosedChannelException());
-            return;
-        }
-
-        WriteOperation writeOperation = new WriteOperation(channel, buffers, listener);
-        SocketSelector selector = channel.getSelector();
-        if (selector.isOnCurrentThread() == false) {
-            selector.queueWrite(writeOperation);
-            return;
-        }
-
-        // TODO: Eval if we will allow writes from sendMessage
-        selector.queueWriteInChannelBuffer(writeOperation);
-    }
-
-    @Override
-    public void queueWriteOperations(WriteOperation writeOperation) {
-        assert channel.getSelector().isOnCurrentThread() : "Must be on selector thread to queue writes";
-        queued.add(writeOperation);
-    }
-
-    @Override
-    public void flushChannel() throws IOException {
-        assert channel.getSelector().isOnCurrentThread() : "Must be on selector thread to flush writes";
-        int ops = queued.size();
-        if (ops == 1) {
-            singleFlush(queued.pop());
-        } else if (ops > 1) {
-            multiFlush();
-        }
-    }
-
-    @Override
-    public boolean hasQueuedWriteOps() {
-        assert channel.getSelector().isOnCurrentThread() : "Must be on selector thread to access queued writes";
-        return queued.isEmpty() == false;
-    }
-
-    @Override
-    public void clearQueuedWriteOps(Exception e) {
-        assert channel.getSelector().isOnCurrentThread() : "Must be on selector thread to clear queued writes";
-        for (WriteOperation op : queued) {
-            channel.getSelector().executeFailedListener(op.getListener(), e);
-        }
-        queued.clear();
-    }
-
-    private void singleFlush(WriteOperation headOp) throws IOException {
-        try {
-            headOp.flush();
-        } catch (IOException e) {
-            channel.getSelector().executeFailedListener(headOp.getListener(), e);
-            throw e;
-        }
-
-        if (headOp.isFullyFlushed()) {
-            channel.getSelector().executeListener(headOp.getListener(), null);
-        } else {
-            queued.push(headOp);
-        }
-    }
-
-    private void multiFlush() throws IOException {
-        boolean lastOpCompleted = true;
-        while (lastOpCompleted && queued.isEmpty() == false) {
-            WriteOperation op = queued.pop();
-            singleFlush(op);
-            lastOpCompleted = op.isFullyFlushed();
-        }
-    }
-}

+ 88 - 0
libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/BytesWriteOperation.java

@@ -0,0 +1,88 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.nio;
+
+import java.nio.ByteBuffer;
+import java.util.Arrays;
+import java.util.function.BiConsumer;
+
+public class BytesWriteOperation implements WriteOperation {
+
+    private final NioSocketChannel channel;
+    private final BiConsumer<Void, Throwable> listener;
+    private final ByteBuffer[] buffers;
+    private final int[] offsets;
+    private final int length;
+    private int internalIndex;
+
+    public BytesWriteOperation(NioSocketChannel channel, ByteBuffer[] buffers, BiConsumer<Void, Throwable> listener) {
+        this.channel = channel;
+        this.listener = listener;
+        this.buffers = buffers;
+        this.offsets = new int[buffers.length];
+        int offset = 0;
+        for (int i = 0; i < buffers.length; i++) {
+            ByteBuffer buffer = buffers[i];
+            offsets[i] = offset;
+            offset += buffer.remaining();
+        }
+        length = offset;
+    }
+
+    @Override
+    public BiConsumer<Void, Throwable> getListener() {
+        return listener;
+    }
+
+    @Override
+    public NioSocketChannel getChannel() {
+        return channel;
+    }
+
+    public boolean isFullyFlushed() {
+        assert length >= internalIndex : "Should never have an index that is greater than the length [length=" + length + ", index="
+            + internalIndex + "]";
+        return internalIndex == length;
+    }
+
+    public void incrementIndex(int delta) {
+        internalIndex += delta;
+        assert length >= internalIndex : "Should never increment index past length [length=" + length + ", post-increment index="
+            + internalIndex + ", delta=" + delta + "]";
+    }
+
+    public ByteBuffer[] getBuffersToWrite() {
+        final int index = Arrays.binarySearch(offsets, internalIndex);
+        int offsetIndex = index < 0 ? (-(index + 1)) - 1 : index;
+
+        ByteBuffer[] postIndexBuffers = new ByteBuffer[buffers.length - offsetIndex];
+
+        ByteBuffer firstBuffer = buffers[offsetIndex].duplicate();
+        firstBuffer.position(internalIndex - offsets[offsetIndex]);
+        postIndexBuffers[0] = firstBuffer;
+        int j = 1;
+        for (int i = (offsetIndex + 1); i < buffers.length; ++i) {
+            postIndexBuffers[j++] = buffers[i].duplicate();
+        }
+
+        return postIndexBuffers;
+    }
+
+}

+ 81 - 0
libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/ChannelContext.java

@@ -0,0 +1,81 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.nio;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.function.BiConsumer;
+
+/**
+ * This context should implement the specific logic for a channel. When a channel receives a notification
+ * that it is ready to perform certain operations (read, write, etc) the {@link ChannelContext} will be
+ * called. This context will need to implement all protocol related logic. Additionally, if any special
+ * close behavior is required, it should be implemented in this context.
+ *
+ * The only methods of the context that should ever be called from a non-selector thread are
+ * {@link #closeChannel()} and {@link #sendMessage(ByteBuffer[], BiConsumer)}.
+ */
+public interface ChannelContext {
+
+    void channelRegistered() throws IOException;
+
+    int read() throws IOException;
+
+    void sendMessage(ByteBuffer[] buffers, BiConsumer<Void, Throwable> listener);
+
+    void queueWriteOperation(WriteOperation writeOperation);
+
+    void flushChannel() throws IOException;
+
+    boolean hasQueuedWriteOps();
+
+    /**
+     * Schedules a channel to be closed by the selector event loop with which it is registered.
+     * <p>
+     * If the channel is open and the state can be transitioned to closed, the close operation will
+     * be scheduled with the event loop.
+     * <p>
+     * If the channel is already set to closed, it is assumed that it is already scheduled to be closed.
+     * <p>
+     * Depending on the underlying protocol of the channel, a close operation might simply close the socket
+     * channel or may involve reading and writing messages.
+     */
+    void closeChannel();
+
+    /**
+     * This method indicates if a selector should close this channel.
+     *
+     * @return a boolean indicating if the selector should close
+     */
+    boolean selectorShouldClose();
+
+    /**
+     * This method cleans up any context resources that need to be released when a channel is closed. It
+     * should only be called by the selector thread.
+     *
+     * @throws IOException during channel / context close
+     */
+    void closeFromSelector() throws IOException;
+
+    @FunctionalInterface
+    interface ReadConsumer {
+        int consumeReads(InboundChannelBuffer channelBuffer) throws IOException;
+    }
+}

+ 1 - 2
libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/ChannelFactory.java

@@ -88,8 +88,7 @@ public abstract class ChannelFactory<ServerSocket extends NioServerSocketChannel
     private Socket internalCreateChannel(SocketSelector selector, SocketChannel rawChannel) throws IOException {
         try {
             Socket channel = createChannel(selector, rawChannel);
-            assert channel.getReadContext() != null : "read context should have been set on channel";
-            assert channel.getWriteContext() != null : "write context should have been set on channel";
+            assert channel.getContext() != null : "channel context should have been set on channel";
             assert channel.getExceptionContext() != null : "exception handler should have been set on channel";
             return channel;
         } catch (Exception e) {

+ 5 - 0
libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/ESSelector.java

@@ -163,6 +163,11 @@ public abstract class ESSelector implements Closeable {
         return Thread.currentThread() == thread;
     }
 
+    public void assertOnSelectorThread() {
+        assert isOnCurrentThread() : "Must be on selector thread to perform this operation. Currently on thread ["
+            + Thread.currentThread().getName() + "].";
+    }
+
     void wakeup() {
         // TODO: Do we need the wakeup optimizations that some other libraries use?
         selector.wakeup();

+ 0 - 2
libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/NioChannel.java

@@ -32,8 +32,6 @@ public interface NioChannel {
 
     InetSocketAddress getLocalAddress();
 
-    void close();
-
     void closeFromSelector() throws IOException;
 
     void register() throws ClosedChannelException;

+ 26 - 38
libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/NioSocketChannel.java

@@ -19,11 +19,13 @@
 
 package org.elasticsearch.nio;
 
+import org.elasticsearch.nio.utils.ExceptionsHelper;
+
 import java.io.IOException;
 import java.net.InetSocketAddress;
 import java.nio.ByteBuffer;
-import java.nio.channels.ClosedChannelException;
 import java.nio.channels.SocketChannel;
+import java.util.ArrayList;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.function.BiConsumer;
@@ -34,8 +36,7 @@ public class NioSocketChannel extends AbstractNioChannel<SocketChannel> {
     private final CompletableFuture<Void> connectContext = new CompletableFuture<>();
     private final SocketSelector socketSelector;
     private final AtomicBoolean contextsSet = new AtomicBoolean(false);
-    private WriteContext writeContext;
-    private ReadContext readContext;
+    private ChannelContext context;
     private BiConsumer<NioSocketChannel, Exception> exceptionContext;
     private Exception connectException;
 
@@ -47,14 +48,21 @@ public class NioSocketChannel extends AbstractNioChannel<SocketChannel> {
 
     @Override
     public void closeFromSelector() throws IOException {
-        assert socketSelector.isOnCurrentThread() : "Should only call from selector thread";
-        // Even if the channel has already been closed we will clear any pending write operations just in case
-        if (writeContext.hasQueuedWriteOps()) {
-            writeContext.clearQueuedWriteOps(new ClosedChannelException());
+        getSelector().assertOnSelectorThread();
+        if (isOpen()) {
+            ArrayList<IOException> closingExceptions = new ArrayList<>(2);
+            try {
+                super.closeFromSelector();
+            } catch (IOException e) {
+                closingExceptions.add(e);
+            }
+            try {
+                context.closeFromSelector();
+            } catch (IOException e) {
+                closingExceptions.add(e);
+            }
+            ExceptionsHelper.rethrowAndSuppress(closingExceptions);
         }
-        readContext.close();
-
-        super.closeFromSelector();
     }
 
     @Override
@@ -62,6 +70,10 @@ public class NioSocketChannel extends AbstractNioChannel<SocketChannel> {
         return socketSelector;
     }
 
+    public int write(ByteBuffer buffer) throws IOException {
+        return socketChannel.write(buffer);
+    }
+
     public int write(ByteBuffer[] buffers) throws IOException {
         if (buffers.length == 1) {
             return socketChannel.write(buffers[0]);
@@ -82,33 +94,17 @@ public class NioSocketChannel extends AbstractNioChannel<SocketChannel> {
         }
     }
 
-    public int read(InboundChannelBuffer buffer) throws IOException {
-        int bytesRead = (int) socketChannel.read(buffer.sliceBuffersFrom(buffer.getIndex()));
-
-        if (bytesRead == -1) {
-            return bytesRead;
-        }
-
-        buffer.incrementIndex(bytesRead);
-        return bytesRead;
-    }
-
-    public void setContexts(ReadContext readContext, WriteContext writeContext, BiConsumer<NioSocketChannel, Exception> exceptionContext) {
+    public void setContexts(ChannelContext context, BiConsumer<NioSocketChannel, Exception> exceptionContext) {
         if (contextsSet.compareAndSet(false, true)) {
-            this.readContext = readContext;
-            this.writeContext = writeContext;
+            this.context = context;
             this.exceptionContext = exceptionContext;
         } else {
             throw new IllegalStateException("Contexts on this channel were already set. They should only be once.");
         }
     }
 
-    public WriteContext getWriteContext() {
-        return writeContext;
-    }
-
-    public ReadContext getReadContext() {
-        return readContext;
+    public ChannelContext getContext() {
+        return context;
     }
 
     public BiConsumer<NioSocketChannel, Exception> getExceptionContext() {
@@ -123,14 +119,6 @@ public class NioSocketChannel extends AbstractNioChannel<SocketChannel> {
         return isConnectComplete0();
     }
 
-    public boolean isWritable() {
-        return isClosing.get() == false;
-    }
-
-    public boolean isReadable() {
-        return isClosing.get() == false;
-    }
-
     /**
      * This method will attempt to complete the connection process for this channel. It should be called for
      * new channels or for a channel that has produced a OP_CONNECT event. If this method returns true then

+ 0 - 35
libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/ReadContext.java

@@ -1,35 +0,0 @@
-/*
- * Licensed to Elasticsearch under one or more contributor
- * license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright
- * ownership. Elasticsearch licenses this file to you under
- * the Apache License, Version 2.0 (the "License"); you may
- * not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.elasticsearch.nio;
-
-import java.io.IOException;
-
-public interface ReadContext extends AutoCloseable {
-
-    int read() throws IOException;
-
-    @Override
-    void close();
-
-    @FunctionalInterface
-    interface ReadConsumer {
-        int consumeReads(InboundChannelBuffer channelBuffer) throws IOException;
-    }
-}

+ 54 - 1
libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SelectionKeyUtils.java

@@ -26,28 +26,81 @@ public final class SelectionKeyUtils {
 
     private SelectionKeyUtils() {}
 
+    /**
+     * Adds an interest in writes for this channel while maintaining other interests.
+     *
+     * @param channel the channel
+     * @throws CancelledKeyException if the key was already cancelled
+     */
     public static void setWriteInterested(NioChannel channel) throws CancelledKeyException {
         SelectionKey selectionKey = channel.getSelectionKey();
         selectionKey.interestOps(selectionKey.interestOps() | SelectionKey.OP_WRITE);
     }
 
+    /**
+     * Removes an interest in writes for this channel while maintaining other interests.
+     *
+     * @param channel the channel
+     * @throws CancelledKeyException if the key was already cancelled
+     */
     public static void removeWriteInterested(NioChannel channel) throws CancelledKeyException {
         SelectionKey selectionKey = channel.getSelectionKey();
         selectionKey.interestOps(selectionKey.interestOps() & ~SelectionKey.OP_WRITE);
     }
 
+    /**
+     * Removes an interest in connects and reads for this channel while maintaining other interests.
+     *
+     * @param channel the channel
+     * @throws CancelledKeyException if the key was already cancelled
+     */
     public static void setConnectAndReadInterested(NioChannel channel) throws CancelledKeyException {
         SelectionKey selectionKey = channel.getSelectionKey();
         selectionKey.interestOps(selectionKey.interestOps() | SelectionKey.OP_CONNECT | SelectionKey.OP_READ);
     }
 
+    /**
+     * Removes an interest in connects, reads, and writes for this channel while maintaining other interests.
+     *
+     * @param channel the channel
+     * @throws CancelledKeyException if the key was already cancelled
+     */
+    public static void setConnectReadAndWriteInterested(NioChannel channel) throws CancelledKeyException {
+        SelectionKey selectionKey = channel.getSelectionKey();
+        selectionKey.interestOps(selectionKey.interestOps() | SelectionKey.OP_CONNECT | SelectionKey.OP_READ | SelectionKey.OP_WRITE);
+    }
+
+    /**
+     * Removes an interest in connects for this channel while maintaining other interests.
+     *
+     * @param channel the channel
+     * @throws CancelledKeyException if the key was already cancelled
+     */
     public static void removeConnectInterested(NioChannel channel) throws CancelledKeyException {
         SelectionKey selectionKey = channel.getSelectionKey();
         selectionKey.interestOps(selectionKey.interestOps() & ~SelectionKey.OP_CONNECT);
     }
 
-    public static void setAcceptInterested(NioServerSocketChannel channel) {
+    /**
+     * Adds an interest in accepts for this channel while maintaining other interests.
+     *
+     * @param channel the channel
+     * @throws CancelledKeyException if the key was already cancelled
+     */
+    public static void setAcceptInterested(NioServerSocketChannel channel) throws CancelledKeyException {
         SelectionKey selectionKey = channel.getSelectionKey();
         selectionKey.interestOps(selectionKey.interestOps() | SelectionKey.OP_ACCEPT);
     }
+
+
+    /**
+     * Checks for an interest in writes for this channel.
+     *
+     * @param channel the channel
+     * @return a boolean indicating if we are currently interested in writes for this channel
+     * @throws CancelledKeyException if the key was already cancelled
+     */
+    public static boolean isWriteInterested(NioSocketChannel channel) throws CancelledKeyException {
+        return (channel.getSelectionKey().interestOps() & SelectionKey.OP_WRITE) != 0;
+    }
 }

+ 28 - 13
libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SocketEventHandler.java

@@ -43,8 +43,14 @@ public class SocketEventHandler extends EventHandler {
      *
      * @param channel that was registered
      */
-    protected void handleRegistration(NioSocketChannel channel) {
-        SelectionKeyUtils.setConnectAndReadInterested(channel);
+    protected void handleRegistration(NioSocketChannel channel) throws IOException {
+        ChannelContext context = channel.getContext();
+        context.channelRegistered();
+        if (context.hasQueuedWriteOps()) {
+            SelectionKeyUtils.setConnectReadAndWriteInterested(channel);
+        } else {
+            SelectionKeyUtils.setConnectAndReadInterested(channel);
+        }
     }
 
     /**
@@ -86,10 +92,7 @@ public class SocketEventHandler extends EventHandler {
      * @param channel that can be read
      */
     protected void handleRead(NioSocketChannel channel) throws IOException {
-        int bytesRead = channel.getReadContext().read();
-        if (bytesRead == -1) {
-            handleClose(channel);
-        }
+        channel.getContext().read();
     }
 
     /**
@@ -107,16 +110,11 @@ public class SocketEventHandler extends EventHandler {
      * This method is called when a channel signals it is ready to receive writes. All of the write logic
      * should occur in this call.
      *
-     * @param channel that can be read
+     * @param channel that can be written to
      */
     protected void handleWrite(NioSocketChannel channel) throws IOException {
-        WriteContext channelContext = channel.getWriteContext();
+        ChannelContext channelContext = channel.getContext();
         channelContext.flushChannel();
-        if (channelContext.hasQueuedWriteOps()) {
-            SelectionKeyUtils.setWriteInterested(channel);
-        } else {
-            SelectionKeyUtils.removeWriteInterested(channel);
-        }
     }
 
     /**
@@ -153,6 +151,23 @@ public class SocketEventHandler extends EventHandler {
         logger.warn(new ParameterizedMessage("exception while executing listener: {}", listener), exception);
     }
 
+    /**
+     * @param channel that was handled
+     */
+    protected void postHandling(NioSocketChannel channel) {
+        if (channel.getContext().selectorShouldClose()) {
+            handleClose(channel);
+        } else {
+            boolean currentlyWriteInterested = SelectionKeyUtils.isWriteInterested(channel);
+            boolean pendingWrites = channel.getContext().hasQueuedWriteOps();
+            if (currentlyWriteInterested == false && pendingWrites) {
+                SelectionKeyUtils.setWriteInterested(channel);
+            } else if (currentlyWriteInterested && pendingWrites == false) {
+                SelectionKeyUtils.removeWriteInterested(channel);
+            }
+        }
+    }
+
     private void exceptionCaught(NioSocketChannel channel, Exception e) {
         channel.getExceptionContext().accept(channel, e);
     }

+ 8 - 6
libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/SocketSelector.java

@@ -64,6 +64,8 @@ public class SocketSelector extends ESSelector {
                 handleRead(nioSocketChannel);
             }
         }
+
+        eventHandler.postHandling(nioSocketChannel);
     }
 
     @Override
@@ -118,12 +120,12 @@ public class SocketSelector extends ESSelector {
      * @param writeOperation to be queued in a channel's buffer
      */
     public void queueWriteInChannelBuffer(WriteOperation writeOperation) {
-        assert isOnCurrentThread() : "Must be on selector thread";
+        assertOnSelectorThread();
         NioSocketChannel channel = writeOperation.getChannel();
-        WriteContext context = channel.getWriteContext();
+        ChannelContext context = channel.getContext();
         try {
             SelectionKeyUtils.setWriteInterested(channel);
-            context.queueWriteOperations(writeOperation);
+            context.queueWriteOperation(writeOperation);
         } catch (Exception e) {
             executeFailedListener(writeOperation.getListener(), e);
         }
@@ -137,7 +139,7 @@ public class SocketSelector extends ESSelector {
      * @param value to provide to listener
      */
     public <V> void executeListener(BiConsumer<V, Throwable> listener, V value) {
-        assert isOnCurrentThread() : "Must be on selector thread";
+        assertOnSelectorThread();
         try {
             listener.accept(value, null);
         } catch (Exception e) {
@@ -153,7 +155,7 @@ public class SocketSelector extends ESSelector {
      * @param exception to provide to listener
      */
     public <V> void executeFailedListener(BiConsumer<V, Throwable> listener, Exception exception) {
-        assert isOnCurrentThread() : "Must be on selector thread";
+        assertOnSelectorThread();
         try {
             listener.accept(null, exception);
         } catch (Exception e) {
@@ -180,7 +182,7 @@ public class SocketSelector extends ESSelector {
     private void handleQueuedWrites() {
         WriteOperation writeOperation;
         while ((writeOperation = queuedWrites.poll()) != null) {
-            if (writeOperation.getChannel().isWritable()) {
+            if (writeOperation.getChannel().isOpen()) {
                 queueWriteInChannelBuffer(writeOperation);
             } else {
                 executeFailedListener(writeOperation.getListener(), new ClosedChannelException());

+ 0 - 37
libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/WriteContext.java

@@ -1,37 +0,0 @@
-/*
- * Licensed to Elasticsearch under one or more contributor
- * license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright
- * ownership. Elasticsearch licenses this file to you under
- * the Apache License, Version 2.0 (the "License"); you may
- * not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.elasticsearch.nio;
-
-import java.io.IOException;
-import java.util.function.BiConsumer;
-
-public interface WriteContext {
-
-    void sendMessage(Object message, BiConsumer<Void, Throwable> listener);
-
-    void queueWriteOperations(WriteOperation writeOperation);
-
-    void flushChannel() throws IOException;
-
-    boolean hasQueuedWriteOps();
-
-    void clearQueuedWriteOps(Exception e);
-
-}

+ 8 - 66
libs/elasticsearch-nio/src/main/java/org/elasticsearch/nio/WriteOperation.java

@@ -19,74 +19,16 @@
 
 package org.elasticsearch.nio;
 
-import java.io.IOException;
-import java.nio.ByteBuffer;
-import java.util.Arrays;
 import java.util.function.BiConsumer;
 
-public class WriteOperation {
-
-    private final NioSocketChannel channel;
-    private final BiConsumer<Void, Throwable> listener;
-    private final ByteBuffer[] buffers;
-    private final int[] offsets;
-    private final int length;
-    private int internalIndex;
-
-    public WriteOperation(NioSocketChannel channel, ByteBuffer[] buffers, BiConsumer<Void, Throwable> listener) {
-        this.channel = channel;
-        this.listener = listener;
-        this.buffers = buffers;
-        this.offsets = new int[buffers.length];
-        int offset = 0;
-        for (int i = 0; i < buffers.length; i++) {
-            ByteBuffer buffer = buffers[i];
-            offsets[i] = offset;
-            offset += buffer.remaining();
-        }
-        length = offset;
-    }
-
-    public ByteBuffer[] getByteBuffers() {
-        return buffers;
-    }
-
-    public BiConsumer<Void, Throwable> getListener() {
-        return listener;
-    }
-
-    public NioSocketChannel getChannel() {
-        return channel;
-    }
-
-    public boolean isFullyFlushed() {
-        return internalIndex == length;
-    }
-
-    public int flush() throws IOException {
-        int written = channel.write(getBuffersToWrite());
-        internalIndex += written;
-        return written;
-    }
-
-    private ByteBuffer[] getBuffersToWrite() {
-        int offsetIndex = getOffsetIndex(internalIndex);
-
-        ByteBuffer[] postIndexBuffers = new ByteBuffer[buffers.length - offsetIndex];
-
-        ByteBuffer firstBuffer = buffers[offsetIndex].duplicate();
-        firstBuffer.position(internalIndex - offsets[offsetIndex]);
-        postIndexBuffers[0] = firstBuffer;
-        int j = 1;
-        for (int i = (offsetIndex + 1); i < buffers.length; ++i) {
-            postIndexBuffers[j++] = buffers[i].duplicate();
-        }
+/**
+ * This is a basic write operation that can be queued with a channel. The only requirements of a write
+ * operation is that is has a listener and a reference to its channel. The actual conversion of the write
+ * operation implementation to bytes will be performed by the {@link ChannelContext}.
+ */
+public interface WriteOperation {
 
-        return postIndexBuffers;
-    }
+    BiConsumer<Void, Throwable> getListener();
 
-    private int getOffsetIndex(int offset) {
-        final int i = Arrays.binarySearch(offsets, offset);
-        return i < 0 ? (-(i + 1)) - 1 : i;
-    }
+    NioSocketChannel getChannel();
 }

+ 1 - 1
libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/AcceptorEventHandlerTests.java

@@ -80,7 +80,7 @@ public class AcceptorEventHandlerTests extends ESTestCase {
     @SuppressWarnings("unchecked")
     public void testHandleAcceptCallsServerAcceptCallback() throws IOException {
         NioSocketChannel childChannel = new NioSocketChannel(mock(SocketChannel.class), socketSelector);
-        childChannel.setContexts(mock(ReadContext.class), mock(WriteContext.class), mock(BiConsumer.class));
+        childChannel.setContexts(mock(ChannelContext.class), mock(BiConsumer.class));
         when(channelFactory.acceptNioChannel(same(channel), same(socketSelector))).thenReturn(childChannel);
 
         handler.acceptChannel(channel);

+ 337 - 0
libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/BytesChannelContextTests.java

@@ -0,0 +1,337 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.nio;
+
+import org.elasticsearch.common.util.BigArrays;
+import org.elasticsearch.test.ESTestCase;
+import org.junit.Before;
+import org.mockito.ArgumentCaptor;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.channels.ClosedChannelException;
+import java.util.function.BiConsumer;
+import java.util.function.Supplier;
+
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.isNull;
+import static org.mockito.Matchers.same;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+public class BytesChannelContextTests extends ESTestCase {
+
+    private ChannelContext.ReadConsumer readConsumer;
+    private NioSocketChannel channel;
+    private BytesChannelContext context;
+    private InboundChannelBuffer channelBuffer;
+    private SocketSelector selector;
+    private BiConsumer<Void, Throwable> listener;
+    private int messageLength;
+
+    @Before
+    @SuppressWarnings("unchecked")
+    public void init() {
+        readConsumer = mock(ChannelContext.ReadConsumer.class);
+
+        messageLength = randomInt(96) + 20;
+        selector = mock(SocketSelector.class);
+        listener = mock(BiConsumer.class);
+        channel = mock(NioSocketChannel.class);
+        Supplier<InboundChannelBuffer.Page> pageSupplier = () ->
+            new InboundChannelBuffer.Page(ByteBuffer.allocate(BigArrays.BYTE_PAGE_SIZE), () -> {});
+        channelBuffer = new InboundChannelBuffer(pageSupplier);
+        context = new BytesChannelContext(channel, readConsumer, channelBuffer);
+
+        when(channel.getSelector()).thenReturn(selector);
+        when(selector.isOnCurrentThread()).thenReturn(true);
+    }
+
+    public void testSuccessfulRead() throws IOException {
+        byte[] bytes = createMessage(messageLength);
+
+        when(channel.read(any(ByteBuffer[].class))).thenAnswer(invocationOnMock -> {
+            ByteBuffer[] buffers = (ByteBuffer[]) invocationOnMock.getArguments()[0];
+            buffers[0].put(bytes);
+            return bytes.length;
+        });
+
+        when(readConsumer.consumeReads(channelBuffer)).thenReturn(messageLength, 0);
+
+        assertEquals(messageLength, context.read());
+
+        assertEquals(0, channelBuffer.getIndex());
+        assertEquals(BigArrays.BYTE_PAGE_SIZE - bytes.length, channelBuffer.getCapacity());
+        verify(readConsumer, times(1)).consumeReads(channelBuffer);
+    }
+
+    public void testMultipleReadsConsumed() throws IOException {
+        byte[] bytes = createMessage(messageLength * 2);
+
+        when(channel.read(any(ByteBuffer[].class))).thenAnswer(invocationOnMock -> {
+            ByteBuffer[] buffers = (ByteBuffer[]) invocationOnMock.getArguments()[0];
+            buffers[0].put(bytes);
+            return bytes.length;
+        });
+
+        when(readConsumer.consumeReads(channelBuffer)).thenReturn(messageLength, messageLength, 0);
+
+        assertEquals(bytes.length, context.read());
+
+        assertEquals(0, channelBuffer.getIndex());
+        assertEquals(BigArrays.BYTE_PAGE_SIZE - bytes.length, channelBuffer.getCapacity());
+        verify(readConsumer, times(2)).consumeReads(channelBuffer);
+    }
+
+    public void testPartialRead() throws IOException {
+        byte[] bytes = createMessage(messageLength);
+
+        when(channel.read(any(ByteBuffer[].class))).thenAnswer(invocationOnMock -> {
+            ByteBuffer[] buffers = (ByteBuffer[]) invocationOnMock.getArguments()[0];
+            buffers[0].put(bytes);
+            return bytes.length;
+        });
+
+
+        when(readConsumer.consumeReads(channelBuffer)).thenReturn(0);
+
+        assertEquals(messageLength, context.read());
+
+        assertEquals(bytes.length, channelBuffer.getIndex());
+        verify(readConsumer, times(1)).consumeReads(channelBuffer);
+
+        when(readConsumer.consumeReads(channelBuffer)).thenReturn(messageLength * 2, 0);
+
+        assertEquals(messageLength, context.read());
+
+        assertEquals(0, channelBuffer.getIndex());
+        assertEquals(BigArrays.BYTE_PAGE_SIZE - (bytes.length * 2), channelBuffer.getCapacity());
+        verify(readConsumer, times(2)).consumeReads(channelBuffer);
+    }
+
+    public void testReadThrowsIOException() throws IOException {
+        IOException ioException = new IOException();
+        when(channel.read(any(ByteBuffer[].class))).thenThrow(ioException);
+
+        IOException ex = expectThrows(IOException.class, () -> context.read());
+        assertSame(ioException, ex);
+    }
+
+    public void testReadThrowsIOExceptionMeansReadyForClose() throws IOException {
+        when(channel.read(any(ByteBuffer[].class))).thenThrow(new IOException());
+
+        assertFalse(context.selectorShouldClose());
+        expectThrows(IOException.class, () -> context.read());
+        assertTrue(context.selectorShouldClose());
+    }
+
+    public void testReadLessThanZeroMeansReadyForClose() throws IOException {
+        when(channel.read(any(ByteBuffer[].class))).thenReturn(-1);
+
+        assertEquals(0, context.read());
+
+        assertTrue(context.selectorShouldClose());
+    }
+
+    public void testCloseClosesChannelBuffer() throws IOException {
+        Runnable closer = mock(Runnable.class);
+        Supplier<InboundChannelBuffer.Page> pageSupplier = () -> new InboundChannelBuffer.Page(ByteBuffer.allocate(1 << 14), closer);
+        InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier);
+        buffer.ensureCapacity(1);
+        BytesChannelContext context = new BytesChannelContext(channel, readConsumer, buffer);
+        context.closeFromSelector();
+        verify(closer).run();
+    }
+
+    public void testWriteFailsIfClosing() {
+        context.closeChannel();
+
+        ByteBuffer[] buffers = {ByteBuffer.wrap(createMessage(10))};
+        context.sendMessage(buffers, listener);
+
+        verify(listener).accept(isNull(Void.class), any(ClosedChannelException.class));
+    }
+
+    public void testSendMessageFromDifferentThreadIsQueuedWithSelector() throws Exception {
+        ArgumentCaptor<BytesWriteOperation> writeOpCaptor = ArgumentCaptor.forClass(BytesWriteOperation.class);
+
+        when(selector.isOnCurrentThread()).thenReturn(false);
+
+        ByteBuffer[] buffers = {ByteBuffer.wrap(createMessage(10))};
+        context.sendMessage(buffers, listener);
+
+        verify(selector).queueWrite(writeOpCaptor.capture());
+        BytesWriteOperation writeOp = writeOpCaptor.getValue();
+
+        assertSame(listener, writeOp.getListener());
+        assertSame(channel, writeOp.getChannel());
+        assertEquals(buffers[0], writeOp.getBuffersToWrite()[0]);
+    }
+
+    public void testSendMessageFromSameThreadIsQueuedInChannel() {
+        ArgumentCaptor<BytesWriteOperation> writeOpCaptor = ArgumentCaptor.forClass(BytesWriteOperation.class);
+
+        ByteBuffer[] buffers = {ByteBuffer.wrap(createMessage(10))};
+        context.sendMessage(buffers, listener);
+
+        verify(selector).queueWriteInChannelBuffer(writeOpCaptor.capture());
+        BytesWriteOperation writeOp = writeOpCaptor.getValue();
+
+        assertSame(listener, writeOp.getListener());
+        assertSame(channel, writeOp.getChannel());
+        assertEquals(buffers[0], writeOp.getBuffersToWrite()[0]);
+    }
+
+    public void testWriteIsQueuedInChannel() {
+        assertFalse(context.hasQueuedWriteOps());
+
+        ByteBuffer[] buffer = {ByteBuffer.allocate(10)};
+        context.queueWriteOperation(new BytesWriteOperation(channel, buffer, listener));
+
+        assertTrue(context.hasQueuedWriteOps());
+    }
+
+    public void testWriteOpsClearedOnClose() throws Exception {
+        assertFalse(context.hasQueuedWriteOps());
+
+        ByteBuffer[] buffer = {ByteBuffer.allocate(10)};
+        context.queueWriteOperation(new BytesWriteOperation(channel,  buffer, listener));
+
+        assertTrue(context.hasQueuedWriteOps());
+
+        context.closeFromSelector();
+
+        verify(selector).executeFailedListener(same(listener), any(ClosedChannelException.class));
+
+        assertFalse(context.hasQueuedWriteOps());
+    }
+
+    public void testQueuedWriteIsFlushedInFlushCall() throws Exception {
+        assertFalse(context.hasQueuedWriteOps());
+
+        ByteBuffer[] buffers = {ByteBuffer.allocate(10)};
+        BytesWriteOperation writeOperation = mock(BytesWriteOperation.class);
+        context.queueWriteOperation(writeOperation);
+
+        assertTrue(context.hasQueuedWriteOps());
+
+        when(writeOperation.getBuffersToWrite()).thenReturn(buffers);
+        when(writeOperation.isFullyFlushed()).thenReturn(true);
+        when(writeOperation.getListener()).thenReturn(listener);
+        context.flushChannel();
+
+        verify(channel).write(buffers);
+        verify(selector).executeListener(listener, null);
+        assertFalse(context.hasQueuedWriteOps());
+    }
+
+    public void testPartialFlush() throws IOException {
+        assertFalse(context.hasQueuedWriteOps());
+
+        BytesWriteOperation writeOperation = mock(BytesWriteOperation.class);
+        context.queueWriteOperation(writeOperation);
+
+        assertTrue(context.hasQueuedWriteOps());
+
+        when(writeOperation.isFullyFlushed()).thenReturn(false);
+        context.flushChannel();
+
+        verify(listener, times(0)).accept(null, null);
+        assertTrue(context.hasQueuedWriteOps());
+    }
+
+    @SuppressWarnings("unchecked")
+    public void testMultipleWritesPartialFlushes() throws IOException {
+        assertFalse(context.hasQueuedWriteOps());
+
+        BiConsumer<Void, Throwable> listener2 = mock(BiConsumer.class);
+        BytesWriteOperation writeOperation1 = mock(BytesWriteOperation.class);
+        BytesWriteOperation writeOperation2 = mock(BytesWriteOperation.class);
+        when(writeOperation1.getListener()).thenReturn(listener);
+        when(writeOperation2.getListener()).thenReturn(listener2);
+        context.queueWriteOperation(writeOperation1);
+        context.queueWriteOperation(writeOperation2);
+
+        assertTrue(context.hasQueuedWriteOps());
+
+        when(writeOperation1.isFullyFlushed()).thenReturn(true);
+        when(writeOperation2.isFullyFlushed()).thenReturn(false);
+        context.flushChannel();
+
+        verify(selector).executeListener(listener, null);
+        verify(listener2, times(0)).accept(null, null);
+        assertTrue(context.hasQueuedWriteOps());
+
+        when(writeOperation2.isFullyFlushed()).thenReturn(true);
+
+        context.flushChannel();
+
+        verify(selector).executeListener(listener2, null);
+        assertFalse(context.hasQueuedWriteOps());
+    }
+
+    public void testWhenIOExceptionThrownListenerIsCalled() throws IOException {
+        assertFalse(context.hasQueuedWriteOps());
+
+        ByteBuffer[] buffers = {ByteBuffer.allocate(10)};
+        BytesWriteOperation writeOperation = mock(BytesWriteOperation.class);
+        context.queueWriteOperation(writeOperation);
+
+        assertTrue(context.hasQueuedWriteOps());
+
+        IOException exception = new IOException();
+        when(writeOperation.getBuffersToWrite()).thenReturn(buffers);
+        when(channel.write(buffers)).thenThrow(exception);
+        when(writeOperation.getListener()).thenReturn(listener);
+        expectThrows(IOException.class, () -> context.flushChannel());
+
+        verify(selector).executeFailedListener(listener, exception);
+        assertFalse(context.hasQueuedWriteOps());
+    }
+
+    public void testWriteIOExceptionMeansChannelReadyToClose() throws IOException {
+        ByteBuffer[] buffers = {ByteBuffer.allocate(10)};
+        BytesWriteOperation writeOperation = mock(BytesWriteOperation.class);
+        context.queueWriteOperation(writeOperation);
+
+        IOException exception = new IOException();
+        when(writeOperation.getBuffersToWrite()).thenReturn(buffers);
+        when(channel.write(buffers)).thenThrow(exception);
+
+        assertFalse(context.selectorShouldClose());
+        expectThrows(IOException.class, () -> context.flushChannel());
+        assertTrue(context.selectorShouldClose());
+    }
+
+    public void initiateCloseSchedulesCloseWithSelector() {
+        context.closeChannel();
+        verify(selector).queueChannelClose(channel);
+    }
+
+    private static byte[] createMessage(int length) {
+        byte[] bytes = new byte[length];
+        for (int i = 0; i < length; ++i) {
+            bytes[i] = randomByte();
+        }
+        return bytes;
+    }
+}

+ 0 - 142
libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/BytesReadContextTests.java

@@ -1,142 +0,0 @@
-/*
- * Licensed to Elasticsearch under one or more contributor
- * license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright
- * ownership. Elasticsearch licenses this file to you under
- * the Apache License, Version 2.0 (the "License"); you may
- * not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.elasticsearch.nio;
-
-import org.elasticsearch.common.util.BigArrays;
-import org.elasticsearch.test.ESTestCase;
-import org.junit.Before;
-
-import java.io.IOException;
-import java.nio.ByteBuffer;
-import java.util.function.Supplier;
-
-import static org.mockito.Matchers.any;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.times;
-import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.when;
-
-public class BytesReadContextTests extends ESTestCase {
-
-    private ReadContext.ReadConsumer readConsumer;
-    private NioSocketChannel channel;
-    private BytesReadContext readContext;
-    private InboundChannelBuffer channelBuffer;
-    private int messageLength;
-
-    @Before
-    public void init() {
-        readConsumer = mock(ReadContext.ReadConsumer.class);
-
-        messageLength = randomInt(96) + 20;
-        channel = mock(NioSocketChannel.class);
-        Supplier<InboundChannelBuffer.Page> pageSupplier = () ->
-            new InboundChannelBuffer.Page(ByteBuffer.allocate(BigArrays.BYTE_PAGE_SIZE), () -> {});
-        channelBuffer = new InboundChannelBuffer(pageSupplier);
-        readContext = new BytesReadContext(channel, readConsumer, channelBuffer);
-    }
-
-    public void testSuccessfulRead() throws IOException {
-        byte[] bytes = createMessage(messageLength);
-
-        when(channel.read(any(ByteBuffer[].class))).thenAnswer(invocationOnMock -> {
-            ByteBuffer[] buffers = (ByteBuffer[]) invocationOnMock.getArguments()[0];
-            buffers[0].put(bytes);
-            return bytes.length;
-        });
-
-        when(readConsumer.consumeReads(channelBuffer)).thenReturn(messageLength, 0);
-
-        assertEquals(messageLength, readContext.read());
-
-        assertEquals(0, channelBuffer.getIndex());
-        assertEquals(BigArrays.BYTE_PAGE_SIZE - bytes.length, channelBuffer.getCapacity());
-        verify(readConsumer, times(2)).consumeReads(channelBuffer);
-    }
-
-    public void testMultipleReadsConsumed() throws IOException {
-        byte[] bytes = createMessage(messageLength * 2);
-
-        when(channel.read(any(ByteBuffer[].class))).thenAnswer(invocationOnMock -> {
-            ByteBuffer[] buffers = (ByteBuffer[]) invocationOnMock.getArguments()[0];
-            buffers[0].put(bytes);
-            return bytes.length;
-        });
-
-        when(readConsumer.consumeReads(channelBuffer)).thenReturn(messageLength, messageLength, 0);
-
-        assertEquals(bytes.length, readContext.read());
-
-        assertEquals(0, channelBuffer.getIndex());
-        assertEquals(BigArrays.BYTE_PAGE_SIZE - bytes.length, channelBuffer.getCapacity());
-        verify(readConsumer, times(3)).consumeReads(channelBuffer);
-    }
-
-    public void testPartialRead() throws IOException {
-        byte[] bytes = createMessage(messageLength);
-
-        when(channel.read(any(ByteBuffer[].class))).thenAnswer(invocationOnMock -> {
-            ByteBuffer[] buffers = (ByteBuffer[]) invocationOnMock.getArguments()[0];
-            buffers[0].put(bytes);
-            return bytes.length;
-        });
-
-
-        when(readConsumer.consumeReads(channelBuffer)).thenReturn(0, messageLength);
-
-        assertEquals(messageLength, readContext.read());
-
-        assertEquals(bytes.length, channelBuffer.getIndex());
-        verify(readConsumer, times(1)).consumeReads(channelBuffer);
-
-        when(readConsumer.consumeReads(channelBuffer)).thenReturn(messageLength * 2, 0);
-
-        assertEquals(messageLength, readContext.read());
-
-        assertEquals(0, channelBuffer.getIndex());
-        assertEquals(BigArrays.BYTE_PAGE_SIZE - (bytes.length * 2), channelBuffer.getCapacity());
-        verify(readConsumer, times(3)).consumeReads(channelBuffer);
-    }
-
-    public void testReadThrowsIOException() throws IOException {
-        IOException ioException = new IOException();
-        when(channel.read(any(ByteBuffer[].class))).thenThrow(ioException);
-
-        IOException ex = expectThrows(IOException.class, () -> readContext.read());
-        assertSame(ioException, ex);
-    }
-
-    public void closeClosesChannelBuffer() {
-        InboundChannelBuffer buffer = mock(InboundChannelBuffer.class);
-        BytesReadContext readContext = new BytesReadContext(channel, readConsumer, buffer);
-
-        readContext.close();
-
-        verify(buffer).close();
-    }
-
-    private static byte[] createMessage(int length) {
-        byte[] bytes = new byte[length];
-        for (int i = 0; i < length; ++i) {
-            bytes[i] = randomByte();
-        }
-        return bytes;
-    }
-}

+ 0 - 212
libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/BytesWriteContextTests.java

@@ -1,212 +0,0 @@
-/*
- * Licensed to Elasticsearch under one or more contributor
- * license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright
- * ownership. Elasticsearch licenses this file to you under
- * the Apache License, Version 2.0 (the "License"); you may
- * not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.elasticsearch.nio;
-
-import org.elasticsearch.test.ESTestCase;
-import org.junit.Before;
-import org.mockito.ArgumentCaptor;
-
-import java.io.IOException;
-import java.nio.ByteBuffer;
-import java.nio.channels.ClosedChannelException;
-import java.util.function.BiConsumer;
-
-import static org.mockito.Matchers.any;
-import static org.mockito.Matchers.isNull;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.times;
-import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.when;
-
-public class BytesWriteContextTests extends ESTestCase {
-
-    private SocketSelector selector;
-    private BiConsumer<Void, Throwable> listener;
-    private BytesWriteContext writeContext;
-    private NioSocketChannel channel;
-
-    @Before
-    @SuppressWarnings("unchecked")
-    public void setUp() throws Exception {
-        super.setUp();
-        selector = mock(SocketSelector.class);
-        listener = mock(BiConsumer.class);
-        channel = mock(NioSocketChannel.class);
-        writeContext = new BytesWriteContext(channel);
-
-        when(channel.getSelector()).thenReturn(selector);
-        when(selector.isOnCurrentThread()).thenReturn(true);
-    }
-
-    public void testWriteFailsIfChannelNotWritable() throws Exception {
-        when(channel.isWritable()).thenReturn(false);
-
-        ByteBuffer[] buffers = {ByteBuffer.wrap(generateBytes(10))};
-        writeContext.sendMessage(buffers, listener);
-
-        verify(listener).accept(isNull(Void.class), any(ClosedChannelException.class));
-    }
-
-    public void testSendMessageFromDifferentThreadIsQueuedWithSelector() throws Exception {
-        ArgumentCaptor<WriteOperation> writeOpCaptor = ArgumentCaptor.forClass(WriteOperation.class);
-
-        when(selector.isOnCurrentThread()).thenReturn(false);
-        when(channel.isWritable()).thenReturn(true);
-
-        ByteBuffer[] buffers = {ByteBuffer.wrap(generateBytes(10))};
-        writeContext.sendMessage(buffers, listener);
-
-        verify(selector).queueWrite(writeOpCaptor.capture());
-        WriteOperation writeOp = writeOpCaptor.getValue();
-
-        assertSame(listener, writeOp.getListener());
-        assertSame(channel, writeOp.getChannel());
-        assertEquals(buffers[0], writeOp.getByteBuffers()[0]);
-    }
-
-    public void testSendMessageFromSameThreadIsQueuedInChannel() throws Exception {
-        ArgumentCaptor<WriteOperation> writeOpCaptor = ArgumentCaptor.forClass(WriteOperation.class);
-
-        when(channel.isWritable()).thenReturn(true);
-
-        ByteBuffer[] buffers = {ByteBuffer.wrap(generateBytes(10))};
-        writeContext.sendMessage(buffers, listener);
-
-        verify(selector).queueWriteInChannelBuffer(writeOpCaptor.capture());
-        WriteOperation writeOp = writeOpCaptor.getValue();
-
-        assertSame(listener, writeOp.getListener());
-        assertSame(channel, writeOp.getChannel());
-        assertEquals(buffers[0], writeOp.getByteBuffers()[0]);
-    }
-
-    public void testWriteIsQueuedInChannel() throws Exception {
-        assertFalse(writeContext.hasQueuedWriteOps());
-
-        ByteBuffer[] buffer = {ByteBuffer.allocate(10)};
-        writeContext.queueWriteOperations(new WriteOperation(channel, buffer, listener));
-
-        assertTrue(writeContext.hasQueuedWriteOps());
-    }
-
-    public void testWriteOpsCanBeCleared() throws Exception {
-        assertFalse(writeContext.hasQueuedWriteOps());
-
-        ByteBuffer[] buffer = {ByteBuffer.allocate(10)};
-        writeContext.queueWriteOperations(new WriteOperation(channel,  buffer, listener));
-
-        assertTrue(writeContext.hasQueuedWriteOps());
-
-        ClosedChannelException e = new ClosedChannelException();
-        writeContext.clearQueuedWriteOps(e);
-
-        verify(selector).executeFailedListener(listener, e);
-
-        assertFalse(writeContext.hasQueuedWriteOps());
-    }
-
-    public void testQueuedWriteIsFlushedInFlushCall() throws Exception {
-        assertFalse(writeContext.hasQueuedWriteOps());
-
-        WriteOperation writeOperation = mock(WriteOperation.class);
-        writeContext.queueWriteOperations(writeOperation);
-
-        assertTrue(writeContext.hasQueuedWriteOps());
-
-        when(writeOperation.isFullyFlushed()).thenReturn(true);
-        when(writeOperation.getListener()).thenReturn(listener);
-        writeContext.flushChannel();
-
-        verify(writeOperation).flush();
-        verify(selector).executeListener(listener, null);
-        assertFalse(writeContext.hasQueuedWriteOps());
-    }
-
-    public void testPartialFlush() throws IOException {
-        assertFalse(writeContext.hasQueuedWriteOps());
-
-        WriteOperation writeOperation = mock(WriteOperation.class);
-        writeContext.queueWriteOperations(writeOperation);
-
-        assertTrue(writeContext.hasQueuedWriteOps());
-
-        when(writeOperation.isFullyFlushed()).thenReturn(false);
-        writeContext.flushChannel();
-
-        verify(listener, times(0)).accept(null, null);
-        assertTrue(writeContext.hasQueuedWriteOps());
-    }
-
-    @SuppressWarnings("unchecked")
-    public void testMultipleWritesPartialFlushes() throws IOException {
-        assertFalse(writeContext.hasQueuedWriteOps());
-
-        BiConsumer<Void, Throwable> listener2 = mock(BiConsumer.class);
-        WriteOperation writeOperation1 = mock(WriteOperation.class);
-        WriteOperation writeOperation2 = mock(WriteOperation.class);
-        when(writeOperation1.getListener()).thenReturn(listener);
-        when(writeOperation2.getListener()).thenReturn(listener2);
-        writeContext.queueWriteOperations(writeOperation1);
-        writeContext.queueWriteOperations(writeOperation2);
-
-        assertTrue(writeContext.hasQueuedWriteOps());
-
-        when(writeOperation1.isFullyFlushed()).thenReturn(true);
-        when(writeOperation2.isFullyFlushed()).thenReturn(false);
-        writeContext.flushChannel();
-
-        verify(selector).executeListener(listener, null);
-        verify(listener2, times(0)).accept(null, null);
-        assertTrue(writeContext.hasQueuedWriteOps());
-
-        when(writeOperation2.isFullyFlushed()).thenReturn(true);
-
-        writeContext.flushChannel();
-
-        verify(selector).executeListener(listener2, null);
-        assertFalse(writeContext.hasQueuedWriteOps());
-    }
-
-    public void testWhenIOExceptionThrownListenerIsCalled() throws IOException {
-        assertFalse(writeContext.hasQueuedWriteOps());
-
-        WriteOperation writeOperation = mock(WriteOperation.class);
-        writeContext.queueWriteOperations(writeOperation);
-
-        assertTrue(writeContext.hasQueuedWriteOps());
-
-        IOException exception = new IOException();
-        when(writeOperation.flush()).thenThrow(exception);
-        when(writeOperation.getListener()).thenReturn(listener);
-        expectThrows(IOException.class, () -> writeContext.flushChannel());
-
-        verify(selector).executeFailedListener(listener, exception);
-        assertFalse(writeContext.hasQueuedWriteOps());
-    }
-
-    private byte[] generateBytes(int n) {
-        n += 10;
-        byte[] bytes = new byte[n];
-        for (int i = 0; i < n; ++i) {
-            bytes[i] = randomByte();
-        }
-        return bytes;
-    }
-}

+ 1 - 1
libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/ChannelFactoryTests.java

@@ -139,7 +139,7 @@ public class ChannelFactoryTests extends ESTestCase {
         @Override
         public NioSocketChannel createChannel(SocketSelector selector, SocketChannel channel) throws IOException {
             NioSocketChannel nioSocketChannel = new NioSocketChannel(channel, selector);
-            nioSocketChannel.setContexts(mock(ReadContext.class), mock(WriteContext.class), mock(BiConsumer.class));
+            nioSocketChannel.setContexts(mock(ChannelContext.class), mock(BiConsumer.class));
             return nioSocketChannel;
         }
 

+ 1 - 1
libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/NioServerSocketChannelTests.java

@@ -82,7 +82,7 @@ public class NioServerSocketChannelTests extends ESTestCase {
 
         PlainActionFuture<Void> closeFuture = PlainActionFuture.newFuture();
         channel.addCloseListener(ActionListener.toBiConsumer(closeFuture));
-        channel.close();
+        selector.queueChannelClose(channel);
         closeFuture.actionGet();
 
 

+ 43 - 4
libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/NioSocketChannelTests.java

@@ -35,6 +35,7 @@ import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.function.BiConsumer;
 
+import static org.mockito.Mockito.doThrow;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
 
@@ -66,7 +67,7 @@ public class NioSocketChannelTests extends ESTestCase {
         CountDownLatch latch = new CountDownLatch(1);
 
         NioSocketChannel socketChannel = new DoNotCloseChannel(mock(SocketChannel.class), selector);
-        socketChannel.setContexts(mock(ReadContext.class), mock(WriteContext.class), mock(BiConsumer.class));
+        socketChannel.setContexts(mock(ChannelContext.class), mock(BiConsumer.class));
         socketChannel.addCloseListener(ActionListener.toBiConsumer(new ActionListener<Void>() {
             @Override
             public void onResponse(Void o) {
@@ -86,7 +87,45 @@ public class NioSocketChannelTests extends ESTestCase {
 
         PlainActionFuture<Void> closeFuture = PlainActionFuture.newFuture();
         socketChannel.addCloseListener(ActionListener.toBiConsumer(closeFuture));
-        socketChannel.close();
+        selector.queueChannelClose(socketChannel);
+        closeFuture.actionGet();
+
+        assertTrue(closedRawChannel.get());
+        assertFalse(socketChannel.isOpen());
+        latch.await();
+        assertTrue(isClosed.get());
+    }
+
+    @SuppressWarnings("unchecked")
+    public void testCloseContextExceptionDoesNotStopClose() throws Exception {
+        AtomicBoolean isClosed = new AtomicBoolean(false);
+        CountDownLatch latch = new CountDownLatch(1);
+
+        IOException ioException = new IOException();
+        NioSocketChannel socketChannel = new DoNotCloseChannel(mock(SocketChannel.class), selector);
+        ChannelContext context = mock(ChannelContext.class);
+        doThrow(ioException).when(context).closeFromSelector();
+        socketChannel.setContexts(context, mock(BiConsumer.class));
+        socketChannel.addCloseListener(ActionListener.toBiConsumer(new ActionListener<Void>() {
+            @Override
+            public void onResponse(Void o) {
+                isClosed.set(true);
+                latch.countDown();
+            }
+            @Override
+            public void onFailure(Exception e) {
+                isClosed.set(true);
+                latch.countDown();
+            }
+        }));
+
+        assertTrue(socketChannel.isOpen());
+        assertFalse(closedRawChannel.get());
+        assertFalse(isClosed.get());
+
+        PlainActionFuture<Void> closeFuture = PlainActionFuture.newFuture();
+        socketChannel.addCloseListener(ActionListener.toBiConsumer(closeFuture));
+        selector.queueChannelClose(socketChannel);
         closeFuture.actionGet();
 
         assertTrue(closedRawChannel.get());
@@ -100,7 +139,7 @@ public class NioSocketChannelTests extends ESTestCase {
         SocketChannel rawChannel = mock(SocketChannel.class);
         when(rawChannel.finishConnect()).thenReturn(true);
         NioSocketChannel socketChannel = new DoNotCloseChannel(rawChannel, selector);
-        socketChannel.setContexts(mock(ReadContext.class), mock(WriteContext.class), mock(BiConsumer.class));
+        socketChannel.setContexts(mock(ChannelContext.class), mock(BiConsumer.class));
         selector.scheduleForRegistration(socketChannel);
 
         PlainActionFuture<Void> connectFuture = PlainActionFuture.newFuture();
@@ -117,7 +156,7 @@ public class NioSocketChannelTests extends ESTestCase {
         SocketChannel rawChannel = mock(SocketChannel.class);
         when(rawChannel.finishConnect()).thenThrow(new ConnectException());
         NioSocketChannel socketChannel = new DoNotCloseChannel(rawChannel, selector);
-        socketChannel.setContexts(mock(ReadContext.class), mock(WriteContext.class), mock(BiConsumer.class));
+        socketChannel.setContexts(mock(ChannelContext.class), mock(BiConsumer.class));
         selector.scheduleForRegistration(socketChannel);
 
         PlainActionFuture<Void> connectFuture = PlainActionFuture.newFuture();

+ 69 - 50
libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/SocketEventHandlerTests.java

@@ -28,8 +28,10 @@ import java.nio.channels.CancelledKeyException;
 import java.nio.channels.SelectionKey;
 import java.nio.channels.SocketChannel;
 import java.util.function.BiConsumer;
+import java.util.function.Supplier;
 
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
@@ -39,7 +41,6 @@ public class SocketEventHandlerTests extends ESTestCase {
 
     private SocketEventHandler handler;
     private NioSocketChannel channel;
-    private ReadContext readContext;
     private SocketChannel rawChannel;
 
     @Before
@@ -50,21 +51,37 @@ public class SocketEventHandlerTests extends ESTestCase {
         handler = new SocketEventHandler(logger);
         rawChannel = mock(SocketChannel.class);
         channel = new DoNotRegisterChannel(rawChannel, socketSelector);
-        readContext = mock(ReadContext.class);
         when(rawChannel.finishConnect()).thenReturn(true);
 
-        channel.setContexts(readContext, new BytesWriteContext(channel), exceptionHandler);
+        Supplier<InboundChannelBuffer.Page> pageSupplier = () -> new InboundChannelBuffer.Page(ByteBuffer.allocate(1 << 14), () -> {});
+        InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier);
+        channel.setContexts(new BytesChannelContext(channel, mock(ChannelContext.ReadConsumer.class), buffer), exceptionHandler);
         channel.register();
         channel.finishConnect();
 
         when(socketSelector.isOnCurrentThread()).thenReturn(true);
     }
 
+    public void testRegisterCallsContext() throws IOException {
+        NioSocketChannel channel = mock(NioSocketChannel.class);
+        ChannelContext channelContext = mock(ChannelContext.class);
+        when(channel.getContext()).thenReturn(channelContext);
+        when(channel.getSelectionKey()).thenReturn(new TestSelectionKey(0));
+        handler.handleRegistration(channel);
+        verify(channelContext).channelRegistered();
+    }
+
     public void testRegisterAddsOP_CONNECTAndOP_READInterest() throws IOException {
         handler.handleRegistration(channel);
         assertEquals(SelectionKey.OP_READ | SelectionKey.OP_CONNECT, channel.getSelectionKey().interestOps());
     }
 
+    public void testRegisterWithPendingWritesAddsOP_CONNECTAndOP_READAndOP_WRITEInterest() throws IOException {
+        channel.getContext().queueWriteOperation(mock(BytesWriteOperation.class));
+        handler.handleRegistration(channel);
+        assertEquals(SelectionKey.OP_READ | SelectionKey.OP_CONNECT | SelectionKey.OP_WRITE, channel.getSelectionKey().interestOps());
+    }
+
     public void testRegistrationExceptionCallsExceptionHandler() throws IOException {
         CancelledKeyException exception = new CancelledKeyException();
         handler.registrationException(channel, exception);
@@ -83,68 +100,76 @@ public class SocketEventHandlerTests extends ESTestCase {
         verify(exceptionHandler).accept(channel, exception);
     }
 
-    public void testHandleReadDelegatesToReadContext() throws IOException {
-        when(readContext.read()).thenReturn(1);
+    public void testHandleReadDelegatesToContext() throws IOException {
+        NioSocketChannel channel = new DoNotRegisterChannel(rawChannel, mock(SocketSelector.class));
+        ChannelContext context = mock(ChannelContext.class);
+        channel.setContexts(context, exceptionHandler);
 
+        when(context.read()).thenReturn(1);
         handler.handleRead(channel);
-
-        verify(readContext).read();
+        verify(context).read();
     }
 
-    public void testHandleReadMarksChannelForCloseIfPeerClosed() throws IOException {
-        NioSocketChannel nioSocketChannel = mock(NioSocketChannel.class);
-        when(nioSocketChannel.getReadContext()).thenReturn(readContext);
-        when(readContext.read()).thenReturn(-1);
-
-        handler.handleRead(nioSocketChannel);
-
-        verify(nioSocketChannel).closeFromSelector();
-    }
-
-    public void testReadExceptionCallsExceptionHandler() throws IOException {
+    public void testReadExceptionCallsExceptionHandler() {
         IOException exception = new IOException();
         handler.readException(channel, exception);
         verify(exceptionHandler).accept(channel, exception);
     }
 
-    @SuppressWarnings("unchecked")
-    public void testHandleWriteWithCompleteFlushRemovesOP_WRITEInterest() throws IOException {
-        SelectionKey selectionKey = channel.getSelectionKey();
-        setWriteAndRead(channel);
-        assertEquals(SelectionKey.OP_READ | SelectionKey.OP_WRITE, selectionKey.interestOps());
+    public void testWriteExceptionCallsExceptionHandler() {
+        IOException exception = new IOException();
+        handler.writeException(channel, exception);
+        verify(exceptionHandler).accept(channel, exception);
+    }
 
-        ByteBuffer[] buffers = {ByteBuffer.allocate(1)};
-        channel.getWriteContext().queueWriteOperations(new WriteOperation(channel, buffers, mock(BiConsumer.class)));
+    public void testPostHandlingCallWillCloseTheChannelIfReady() throws IOException {
+        NioSocketChannel channel = mock(NioSocketChannel.class);
+        ChannelContext context = mock(ChannelContext.class);
+        when(channel.getSelectionKey()).thenReturn(new TestSelectionKey(0));
 
-        when(rawChannel.write(buffers[0])).thenReturn(1);
-        handler.handleWrite(channel);
+        when(channel.getContext()).thenReturn(context);
+        when(context.selectorShouldClose()).thenReturn(true);
+        handler.postHandling(channel);
 
-        assertEquals(SelectionKey.OP_READ, selectionKey.interestOps());
+        verify(channel).closeFromSelector();
     }
 
-    @SuppressWarnings("unchecked")
-    public void testHandleWriteWithInCompleteFlushLeavesOP_WRITEInterest() throws IOException {
-        SelectionKey selectionKey = channel.getSelectionKey();
-        setWriteAndRead(channel);
-        assertEquals(SelectionKey.OP_READ | SelectionKey.OP_WRITE, selectionKey.interestOps());
-
-        ByteBuffer[] buffers = {ByteBuffer.allocate(1)};
-        channel.getWriteContext().queueWriteOperations(new WriteOperation(channel, buffers, mock(BiConsumer.class)));
+    public void testPostHandlingCallWillNotCloseTheChannelIfNotReady() throws IOException {
+        NioSocketChannel channel = mock(NioSocketChannel.class);
+        ChannelContext context = mock(ChannelContext.class);
+        when(channel.getSelectionKey()).thenReturn(new TestSelectionKey(0));
 
-        when(rawChannel.write(buffers[0])).thenReturn(0);
-        handler.handleWrite(channel);
+        when(channel.getContext()).thenReturn(context);
+        when(context.selectorShouldClose()).thenReturn(false);
+        handler.postHandling(channel);
 
-        assertEquals(SelectionKey.OP_READ | SelectionKey.OP_WRITE, selectionKey.interestOps());
+        verify(channel, times(0)).closeFromSelector();
     }
 
-    public void testHandleWriteWithNoOpsRemovesOP_WRITEInterest() throws IOException {
-        SelectionKey selectionKey = channel.getSelectionKey();
-        setWriteAndRead(channel);
+    public void testPostHandlingWillAddWriteIfNecessary() throws IOException {
+        NioSocketChannel channel = new DoNotRegisterChannel(rawChannel, mock(SocketSelector.class));
+        channel.setSelectionKey(new TestSelectionKey(SelectionKey.OP_READ));
+        ChannelContext context = mock(ChannelContext.class);
+        channel.setContexts(context, null);
+
+        when(context.hasQueuedWriteOps()).thenReturn(true);
+
+        assertEquals(SelectionKey.OP_READ, channel.getSelectionKey().interestOps());
+        handler.postHandling(channel);
         assertEquals(SelectionKey.OP_READ | SelectionKey.OP_WRITE, channel.getSelectionKey().interestOps());
+    }
+
+    public void testPostHandlingWillRemoveWriteIfNecessary() throws IOException {
+        NioSocketChannel channel = new DoNotRegisterChannel(rawChannel, mock(SocketSelector.class));
+        channel.setSelectionKey(new TestSelectionKey(SelectionKey.OP_READ | SelectionKey.OP_WRITE));
+        ChannelContext context = mock(ChannelContext.class);
+        channel.setContexts(context, null);
 
-        handler.handleWrite(channel);
+        when(context.hasQueuedWriteOps()).thenReturn(false);
 
-        assertEquals(SelectionKey.OP_READ, selectionKey.interestOps());
+        assertEquals(SelectionKey.OP_READ | SelectionKey.OP_WRITE, channel.getSelectionKey().interestOps());
+        handler.postHandling(channel);
+        assertEquals(SelectionKey.OP_READ, channel.getSelectionKey().interestOps());
     }
 
     private void setWriteAndRead(NioChannel channel) {
@@ -152,10 +177,4 @@ public class SocketEventHandlerTests extends ESTestCase {
         SelectionKeyUtils.removeConnectInterested(channel);
         SelectionKeyUtils.setWriteInterested(channel);
     }
-
-    public void testWriteExceptionCallsExceptionHandler() throws IOException {
-        IOException exception = new IOException();
-        handler.writeException(channel, exception);
-        verify(exceptionHandler).accept(channel, exception);
-    }
 }

+ 27 - 21
libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/SocketSelectorTests.java

@@ -49,7 +49,7 @@ public class SocketSelectorTests extends ESTestCase {
     private SocketEventHandler eventHandler;
     private NioSocketChannel channel;
     private TestSelectionKey selectionKey;
-    private WriteContext writeContext;
+    private ChannelContext channelContext;
     private BiConsumer<Void, Throwable> listener;
     private ByteBuffer[] buffers = {ByteBuffer.allocate(1)};
     private Selector rawSelector;
@@ -60,7 +60,7 @@ public class SocketSelectorTests extends ESTestCase {
         super.setUp();
         eventHandler = mock(SocketEventHandler.class);
         channel = mock(NioSocketChannel.class);
-        writeContext = mock(WriteContext.class);
+        channelContext = mock(ChannelContext.class);
         listener = mock(BiConsumer.class);
         selectionKey = new TestSelectionKey(0);
         selectionKey.attach(channel);
@@ -71,7 +71,7 @@ public class SocketSelectorTests extends ESTestCase {
 
         when(channel.isOpen()).thenReturn(true);
         when(channel.getSelectionKey()).thenReturn(selectionKey);
-        when(channel.getWriteContext()).thenReturn(writeContext);
+        when(channel.getContext()).thenReturn(channelContext);
         when(channel.isConnectComplete()).thenReturn(true);
         when(channel.getSelector()).thenReturn(socketSelector);
     }
@@ -129,75 +129,71 @@ public class SocketSelectorTests extends ESTestCase {
     public void testQueueWriteWhenNotRunning() throws Exception {
         socketSelector.close();
 
-        socketSelector.queueWrite(new WriteOperation(channel, buffers, listener));
+        socketSelector.queueWrite(new BytesWriteOperation(channel, buffers, listener));
 
         verify(listener).accept(isNull(Void.class), any(ClosedSelectorException.class));
     }
 
-    public void testQueueWriteChannelIsNoLongerWritable() throws Exception {
-        WriteOperation writeOperation = new WriteOperation(channel, buffers, listener);
+    public void testQueueWriteChannelIsClosed() throws Exception {
+        BytesWriteOperation writeOperation = new BytesWriteOperation(channel, buffers, listener);
         socketSelector.queueWrite(writeOperation);
 
-        when(channel.isWritable()).thenReturn(false);
+        when(channel.isOpen()).thenReturn(false);
         socketSelector.preSelect();
 
-        verify(writeContext, times(0)).queueWriteOperations(writeOperation);
+        verify(channelContext, times(0)).queueWriteOperation(writeOperation);
         verify(listener).accept(isNull(Void.class), any(ClosedChannelException.class));
     }
 
     public void testQueueWriteSelectionKeyThrowsException() throws Exception {
         SelectionKey selectionKey = mock(SelectionKey.class);
 
-        WriteOperation writeOperation = new WriteOperation(channel, buffers, listener);
+        BytesWriteOperation writeOperation = new BytesWriteOperation(channel, buffers, listener);
         CancelledKeyException cancelledKeyException = new CancelledKeyException();
         socketSelector.queueWrite(writeOperation);
 
-        when(channel.isWritable()).thenReturn(true);
         when(channel.getSelectionKey()).thenReturn(selectionKey);
         when(selectionKey.interestOps(anyInt())).thenThrow(cancelledKeyException);
         socketSelector.preSelect();
 
-        verify(writeContext, times(0)).queueWriteOperations(writeOperation);
+        verify(channelContext, times(0)).queueWriteOperation(writeOperation);
         verify(listener).accept(null, cancelledKeyException);
     }
 
     public void testQueueWriteSuccessful() throws Exception {
-        WriteOperation writeOperation = new WriteOperation(channel, buffers, listener);
+        BytesWriteOperation writeOperation = new BytesWriteOperation(channel, buffers, listener);
         socketSelector.queueWrite(writeOperation);
 
         assertTrue((selectionKey.interestOps() & SelectionKey.OP_WRITE) == 0);
 
-        when(channel.isWritable()).thenReturn(true);
         socketSelector.preSelect();
 
-        verify(writeContext).queueWriteOperations(writeOperation);
+        verify(channelContext).queueWriteOperation(writeOperation);
         assertTrue((selectionKey.interestOps() & SelectionKey.OP_WRITE) != 0);
     }
 
     public void testQueueDirectlyInChannelBufferSuccessful() throws Exception {
-        WriteOperation writeOperation = new WriteOperation(channel, buffers, listener);
+        BytesWriteOperation writeOperation = new BytesWriteOperation(channel, buffers, listener);
 
         assertTrue((selectionKey.interestOps() & SelectionKey.OP_WRITE) == 0);
 
-        when(channel.isWritable()).thenReturn(true);
         socketSelector.queueWriteInChannelBuffer(writeOperation);
 
-        verify(writeContext).queueWriteOperations(writeOperation);
+        verify(channelContext).queueWriteOperation(writeOperation);
         assertTrue((selectionKey.interestOps() & SelectionKey.OP_WRITE) != 0);
     }
 
     public void testQueueDirectlyInChannelBufferSelectionKeyThrowsException() throws Exception {
         SelectionKey selectionKey = mock(SelectionKey.class);
 
-        WriteOperation writeOperation = new WriteOperation(channel, buffers, listener);
+        BytesWriteOperation writeOperation = new BytesWriteOperation(channel, buffers, listener);
         CancelledKeyException cancelledKeyException = new CancelledKeyException();
 
-        when(channel.isWritable()).thenReturn(true);
         when(channel.getSelectionKey()).thenReturn(selectionKey);
         when(selectionKey.interestOps(anyInt())).thenThrow(cancelledKeyException);
         socketSelector.queueWriteInChannelBuffer(writeOperation);
 
-        verify(writeContext, times(0)).queueWriteOperations(writeOperation);
+        verify(channelContext, times(0)).queueWriteOperation(writeOperation);
         verify(listener).accept(null, cancelledKeyException);
     }
 
@@ -285,6 +281,16 @@ public class SocketSelectorTests extends ESTestCase {
         verify(eventHandler).readException(channel, ioException);
     }
 
+    public void testWillCallPostHandleAfterChannelHandling() throws Exception {
+        selectionKey.setReadyOps(SelectionKey.OP_WRITE | SelectionKey.OP_READ);
+
+        socketSelector.processKey(selectionKey);
+
+        verify(eventHandler).handleWrite(channel);
+        verify(eventHandler).handleRead(channel);
+        verify(eventHandler).postHandling(channel);
+    }
+
     public void testCleanup() throws Exception {
         NioSocketChannel unRegisteredChannel = mock(NioSocketChannel.class);
 
@@ -292,7 +298,7 @@ public class SocketSelectorTests extends ESTestCase {
 
         socketSelector.preSelect();
 
-        socketSelector.queueWrite(new WriteOperation(mock(NioSocketChannel.class), buffers, listener));
+        socketSelector.queueWrite(new BytesWriteOperation(mock(NioSocketChannel.class), buffers, listener));
         socketSelector.scheduleForRegistration(unRegisteredChannel);
 
         TestSelectionKey testSelectionKey = new TestSelectionKey(0);

+ 24 - 37
libs/elasticsearch-nio/src/test/java/org/elasticsearch/nio/WriteOperationTests.java

@@ -45,71 +45,58 @@ public class WriteOperationTests extends ESTestCase {
 
     }
 
-    public void testFlush() throws IOException {
+    public void testFullyFlushedMarker() {
         ByteBuffer[] buffers = {ByteBuffer.allocate(10)};
-        WriteOperation writeOp = new WriteOperation(channel, buffers, listener);
+        BytesWriteOperation writeOp = new BytesWriteOperation(channel, buffers, listener);
 
-
-        when(channel.write(any(ByteBuffer[].class))).thenReturn(10);
-
-        writeOp.flush();
+        writeOp.incrementIndex(10);
 
         assertTrue(writeOp.isFullyFlushed());
     }
 
-    public void testPartialFlush() throws IOException {
+    public void testPartiallyFlushedMarker() {
         ByteBuffer[] buffers = {ByteBuffer.allocate(10)};
-        WriteOperation writeOp = new WriteOperation(channel, buffers, listener);
-
-        when(channel.write(any(ByteBuffer[].class))).thenReturn(5);
+        BytesWriteOperation writeOp = new BytesWriteOperation(channel, buffers, listener);
 
-        writeOp.flush();
+        writeOp.incrementIndex(5);
 
         assertFalse(writeOp.isFullyFlushed());
     }
 
     public void testMultipleFlushesWithCompositeBuffer() throws IOException {
         ByteBuffer[] buffers = {ByteBuffer.allocate(10), ByteBuffer.allocate(15), ByteBuffer.allocate(3)};
-        WriteOperation writeOp = new WriteOperation(channel, buffers, listener);
+        BytesWriteOperation writeOp = new BytesWriteOperation(channel, buffers, listener);
 
         ArgumentCaptor<ByteBuffer[]> buffersCaptor = ArgumentCaptor.forClass(ByteBuffer[].class);
 
-        when(channel.write(buffersCaptor.capture())).thenReturn(5)
-            .thenReturn(5)
-            .thenReturn(2)
-            .thenReturn(15)
-            .thenReturn(1);
-
-        writeOp.flush();
-        assertFalse(writeOp.isFullyFlushed());
-        writeOp.flush();
+        writeOp.incrementIndex(5);
         assertFalse(writeOp.isFullyFlushed());
-        writeOp.flush();
-        assertFalse(writeOp.isFullyFlushed());
-        writeOp.flush();
-        assertFalse(writeOp.isFullyFlushed());
-        writeOp.flush();
-        assertTrue(writeOp.isFullyFlushed());
-
-        List<ByteBuffer[]> values = buffersCaptor.getAllValues();
-        ByteBuffer[] byteBuffers = values.get(0);
-        assertEquals(3, byteBuffers.length);
-        assertEquals(10, byteBuffers[0].remaining());
-
-        byteBuffers = values.get(1);
+        ByteBuffer[] byteBuffers = writeOp.getBuffersToWrite();
         assertEquals(3, byteBuffers.length);
         assertEquals(5, byteBuffers[0].remaining());
 
-        byteBuffers = values.get(2);
+        writeOp.incrementIndex(5);
+        assertFalse(writeOp.isFullyFlushed());
+        byteBuffers = writeOp.getBuffersToWrite();
         assertEquals(2, byteBuffers.length);
         assertEquals(15, byteBuffers[0].remaining());
 
-        byteBuffers = values.get(3);
+        writeOp.incrementIndex(2);
+        assertFalse(writeOp.isFullyFlushed());
+        byteBuffers = writeOp.getBuffersToWrite();
         assertEquals(2, byteBuffers.length);
         assertEquals(13, byteBuffers[0].remaining());
 
-        byteBuffers = values.get(4);
+        writeOp.incrementIndex(15);
+        assertFalse(writeOp.isFullyFlushed());
+        byteBuffers = writeOp.getBuffersToWrite();
         assertEquals(1, byteBuffers.length);
         assertEquals(1, byteBuffers[0].remaining());
+
+        writeOp.incrementIndex(1);
+        assertTrue(writeOp.isFullyFlushed());
+        byteBuffers = writeOp.getBuffersToWrite();
+        assertEquals(1, byteBuffers.length);
+        assertEquals(0, byteBuffers[0].remaining());
     }
 }

+ 24 - 14
plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTransport.java

@@ -33,13 +33,12 @@ import org.elasticsearch.common.util.concurrent.EsExecutors;
 import org.elasticsearch.indices.breaker.CircuitBreakerService;
 import org.elasticsearch.nio.AcceptingSelector;
 import org.elasticsearch.nio.AcceptorEventHandler;
-import org.elasticsearch.nio.BytesReadContext;
-import org.elasticsearch.nio.BytesWriteContext;
+import org.elasticsearch.nio.BytesChannelContext;
+import org.elasticsearch.nio.ChannelContext;
 import org.elasticsearch.nio.ChannelFactory;
 import org.elasticsearch.nio.InboundChannelBuffer;
 import org.elasticsearch.nio.NioGroup;
 import org.elasticsearch.nio.NioSocketChannel;
-import org.elasticsearch.nio.ReadContext;
 import org.elasticsearch.nio.SocketEventHandler;
 import org.elasticsearch.nio.SocketSelector;
 import org.elasticsearch.threadpool.ThreadPool;
@@ -72,12 +71,12 @@ public class NioTransport extends TcpTransport {
     public static final Setting<Integer> NIO_ACCEPTOR_COUNT =
         intSetting("transport.nio.acceptor_count", 1, 1, Setting.Property.NodeScope);
 
-    private final PageCacheRecycler pageCacheRecycler;
+    protected final PageCacheRecycler pageCacheRecycler;
     private final ConcurrentMap<String, TcpChannelFactory> profileToChannelFactory = newConcurrentMap();
     private volatile NioGroup nioGroup;
     private volatile TcpChannelFactory clientChannelFactory;
 
-    NioTransport(Settings settings, ThreadPool threadPool, NetworkService networkService, BigArrays bigArrays,
+    protected NioTransport(Settings settings, ThreadPool threadPool, NetworkService networkService, BigArrays bigArrays,
                  PageCacheRecycler pageCacheRecycler, NamedWriteableRegistry namedWriteableRegistry,
                  CircuitBreakerService circuitBreakerService) {
         super("nio", settings, threadPool, bigArrays, circuitBreakerService, namedWriteableRegistry, networkService);
@@ -111,13 +110,13 @@ public class NioTransport extends TcpTransport {
                 NioTransport.NIO_WORKER_COUNT.get(settings), SocketEventHandler::new);
 
             ProfileSettings clientProfileSettings = new ProfileSettings(settings, "default");
-            clientChannelFactory = new TcpChannelFactory(clientProfileSettings);
+            clientChannelFactory = channelFactory(clientProfileSettings, true);
 
             if (useNetworkServer) {
                 // loop through all profiles and start them up, special handling for default one
                 for (ProfileSettings profileSettings : profileSettings) {
                     String profileName = profileSettings.profileName;
-                    TcpChannelFactory factory = new TcpChannelFactory(profileSettings);
+                    TcpChannelFactory factory = channelFactory(profileSettings, false);
                     profileToChannelFactory.putIfAbsent(profileName, factory);
                     bindServer(profileSettings);
                 }
@@ -144,19 +143,30 @@ public class NioTransport extends TcpTransport {
         profileToChannelFactory.clear();
     }
 
-    private void exceptionCaught(NioSocketChannel channel, Exception exception) {
+    protected void exceptionCaught(NioSocketChannel channel, Exception exception) {
         onException((TcpChannel) channel, exception);
     }
 
-    private void acceptChannel(NioSocketChannel channel) {
+    protected void acceptChannel(NioSocketChannel channel) {
         serverAcceptedChannel((TcpNioSocketChannel) channel);
     }
 
-    private class TcpChannelFactory extends ChannelFactory<TcpNioServerSocketChannel, TcpNioSocketChannel> {
+    protected TcpChannelFactory channelFactory(ProfileSettings settings, boolean isClient) {
+        return new TcpChannelFactoryImpl(settings);
+    }
+
+    protected abstract class TcpChannelFactory extends ChannelFactory<TcpNioServerSocketChannel, TcpNioSocketChannel> {
+
+        protected TcpChannelFactory(RawChannelFactory rawChannelFactory) {
+            super(rawChannelFactory);
+        }
+    }
+
+    private class TcpChannelFactoryImpl extends TcpChannelFactory {
 
         private final String profileName;
 
-        TcpChannelFactory(TcpTransport.ProfileSettings profileSettings) {
+        private TcpChannelFactoryImpl(ProfileSettings profileSettings) {
             super(new RawChannelFactory(profileSettings.tcpNoDelay,
                 profileSettings.tcpKeepAlive,
                 profileSettings.reuseAddress,
@@ -172,10 +182,10 @@ public class NioTransport extends TcpTransport {
                 Recycler.V<byte[]> bytes = pageCacheRecycler.bytePage(false);
                 return new InboundChannelBuffer.Page(ByteBuffer.wrap(bytes.v()), bytes::close);
             };
-            ReadContext.ReadConsumer nioReadConsumer = channelBuffer ->
+            ChannelContext.ReadConsumer nioReadConsumer = channelBuffer ->
                 consumeNetworkReads(nioChannel, BytesReference.fromByteBuffers(channelBuffer.sliceBuffersTo(channelBuffer.getIndex())));
-            BytesReadContext readContext = new BytesReadContext(nioChannel, nioReadConsumer, new InboundChannelBuffer(pageSupplier));
-            nioChannel.setContexts(readContext, new BytesWriteContext(nioChannel), NioTransport.this::exceptionCaught);
+            BytesChannelContext context = new BytesChannelContext(nioChannel, nioReadConsumer, new InboundChannelBuffer(pageSupplier));
+            nioChannel.setContexts(context, NioTransport.this::exceptionCaught);
             return nioChannel;
         }
 

+ 6 - 1
plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/TcpNioServerSocketChannel.java

@@ -38,7 +38,7 @@ public class TcpNioServerSocketChannel extends NioServerSocketChannel implements
 
     private final String profile;
 
-    TcpNioServerSocketChannel(String profile, ServerSocketChannel socketChannel,
+    public TcpNioServerSocketChannel(String profile, ServerSocketChannel socketChannel,
                               ChannelFactory<TcpNioServerSocketChannel, TcpNioSocketChannel> channelFactory,
                               AcceptingSelector selector) throws IOException {
         super(socketChannel, channelFactory, selector);
@@ -60,6 +60,11 @@ public class TcpNioServerSocketChannel extends NioServerSocketChannel implements
         return null;
     }
 
+    @Override
+    public void close() {
+        getSelector().queueChannelClose(this);
+    }
+
     @Override
     public String getProfile() {
         return profile;

+ 7 - 2
plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/TcpNioSocketChannel.java

@@ -33,13 +33,13 @@ public class TcpNioSocketChannel extends NioSocketChannel implements TcpChannel
 
     private final String profile;
 
-    TcpNioSocketChannel(String profile, SocketChannel socketChannel, SocketSelector selector) throws IOException {
+    public TcpNioSocketChannel(String profile, SocketChannel socketChannel, SocketSelector selector) throws IOException {
         super(socketChannel, selector);
         this.profile = profile;
     }
 
     public void sendMessage(BytesReference reference, ActionListener<Void> listener) {
-        getWriteContext().sendMessage(BytesReference.toByteBuffers(reference), ActionListener.toBiConsumer(listener));
+        getContext().sendMessage(BytesReference.toByteBuffers(reference), ActionListener.toBiConsumer(listener));
     }
 
     @Override
@@ -59,6 +59,11 @@ public class TcpNioSocketChannel extends NioSocketChannel implements TcpChannel
         addCloseListener(ActionListener.toBiConsumer(listener));
     }
 
+    @Override
+    public void close() {
+        getContext().closeChannel();
+    }
+
     @Override
     public String toString() {
         return "TcpNioSocketChannel{" +

+ 16 - 8
test/framework/src/main/java/org/elasticsearch/transport/nio/MockNioTransport.java

@@ -31,14 +31,13 @@ import org.elasticsearch.common.util.PageCacheRecycler;
 import org.elasticsearch.indices.breaker.CircuitBreakerService;
 import org.elasticsearch.nio.AcceptingSelector;
 import org.elasticsearch.nio.AcceptorEventHandler;
-import org.elasticsearch.nio.BytesReadContext;
-import org.elasticsearch.nio.BytesWriteContext;
+import org.elasticsearch.nio.BytesChannelContext;
+import org.elasticsearch.nio.ChannelContext;
 import org.elasticsearch.nio.ChannelFactory;
 import org.elasticsearch.nio.InboundChannelBuffer;
 import org.elasticsearch.nio.NioGroup;
 import org.elasticsearch.nio.NioServerSocketChannel;
 import org.elasticsearch.nio.NioSocketChannel;
-import org.elasticsearch.nio.ReadContext;
 import org.elasticsearch.nio.SocketSelector;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.TcpChannel;
@@ -162,11 +161,10 @@ public class MockNioTransport extends TcpTransport {
                 Recycler.V<byte[]> bytes = pageCacheRecycler.bytePage(false);
                 return new InboundChannelBuffer.Page(ByteBuffer.wrap(bytes.v()), bytes::close);
             };
-            ReadContext.ReadConsumer nioReadConsumer = channelBuffer ->
+            ChannelContext.ReadConsumer nioReadConsumer = channelBuffer ->
                 consumeNetworkReads(nioChannel, BytesReference.fromByteBuffers(channelBuffer.sliceBuffersTo(channelBuffer.getIndex())));
-            BytesReadContext readContext = new BytesReadContext(nioChannel, nioReadConsumer, new InboundChannelBuffer(pageSupplier));
-            BytesWriteContext writeContext = new BytesWriteContext(nioChannel);
-            nioChannel.setContexts(readContext, writeContext, MockNioTransport.this::exceptionCaught);
+            BytesChannelContext context = new BytesChannelContext(nioChannel, nioReadConsumer, new InboundChannelBuffer(pageSupplier));
+            nioChannel.setContexts(context, MockNioTransport.this::exceptionCaught);
             return nioChannel;
         }
 
@@ -188,6 +186,11 @@ public class MockNioTransport extends TcpTransport {
             this.profile = profile;
         }
 
+        @Override
+        public void close() {
+            getSelector().queueChannelClose(this);
+        }
+
         @Override
         public String getProfile() {
             return profile;
@@ -224,6 +227,11 @@ public class MockNioTransport extends TcpTransport {
             this.profile = profile;
         }
 
+        @Override
+        public void close() {
+            getContext().closeChannel();
+        }
+
         @Override
         public String getProfile() {
             return profile;
@@ -243,7 +251,7 @@ public class MockNioTransport extends TcpTransport {
 
         @Override
         public void sendMessage(BytesReference reference, ActionListener<Void> listener) {
-            getWriteContext().sendMessage(BytesReference.toByteBuffers(reference), ActionListener.toBiConsumer(listener));
+            getContext().sendMessage(BytesReference.toByteBuffers(reference), ActionListener.toBiConsumer(listener));
         }
     }
 }