Pārlūkot izejas kodu

Read multiple TLS packets in one read call (#41725)

This is related to #27260. Currently we have a single read buffer that
is no larger than a single TLS packet. This prevents us from reading
multiple TLS packets in a single socket read call. This commit modifies
our TLS work to support reading similar to the plaintext case. The data
will be copied to a (potentially) recycled TLS packet-sized buffer for
interaction with the SSLEngine.
Tim Brooks 6 gadi atpakaļ
vecāks
revīzija
cb2bd0bb6b
15 mainītis faili ar 406 papildinājumiem un 375 dzēšanām
  1. 8 11
      libs/nio/src/main/java/org/elasticsearch/nio/InboundChannelBuffer.java
  2. 3 52
      libs/nio/src/main/java/org/elasticsearch/nio/SocketChannelContext.java
  3. 63 0
      libs/nio/src/main/java/org/elasticsearch/nio/utils/ByteBufferUtils.java
  4. 59 57
      libs/nio/src/test/java/org/elasticsearch/nio/InboundChannelBufferTests.java
  5. 7 65
      libs/nio/src/test/java/org/elasticsearch/nio/SocketChannelContextTests.java
  6. 4 10
      plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpServerTransport.java
  7. 3 9
      plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTransport.java
  8. 48 0
      plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/PageAllocator.java
  9. 9 5
      test/framework/src/main/java/org/elasticsearch/transport/nio/MockNioTransport.java
  10. 9 6
      x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContext.java
  11. 53 55
      x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLDriver.java
  12. 6 12
      x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioHttpServerTransport.java
  13. 6 12
      x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioTransport.java
  14. 22 12
      x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContextTests.java
  15. 106 69
      x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLDriverTests.java

+ 8 - 11
libs/nio/src/main/java/org/elasticsearch/nio/InboundChannelBuffer.java

@@ -27,7 +27,7 @@ import java.util.ArrayList;
 import java.util.Iterator;
 import java.util.List;
 import java.util.concurrent.atomic.AtomicBoolean;
-import java.util.function.Supplier;
+import java.util.function.IntFunction;
 
 /**
  * This is a channel byte buffer composed internally of 16kb pages. When an entire message has been read
@@ -37,15 +37,14 @@ import java.util.function.Supplier;
  */
 public final class InboundChannelBuffer implements AutoCloseable {
 
-    private static final int PAGE_SIZE = 1 << 14;
+    public static final int PAGE_SIZE = 1 << 14;
     private static final int PAGE_MASK = PAGE_SIZE - 1;
     private static final int PAGE_SHIFT = Integer.numberOfTrailingZeros(PAGE_SIZE);
     private static final ByteBuffer[] EMPTY_BYTE_BUFFER_ARRAY = new ByteBuffer[0];
     private static final Page[] EMPTY_BYTE_PAGE_ARRAY = new Page[0];
 
-
-    private final ArrayDeque<Page> pages;
-    private final Supplier<Page> pageSupplier;
+    private final IntFunction<Page> pageAllocator;
+    private final ArrayDeque<Page> pages = new ArrayDeque<>();
     private final AtomicBoolean isClosed = new AtomicBoolean(false);
 
     private long capacity = 0;
@@ -53,14 +52,12 @@ public final class InboundChannelBuffer implements AutoCloseable {
     // The offset is an int as it is the offset of where the bytes begin in the first buffer
     private int offset = 0;
 
-    public InboundChannelBuffer(Supplier<Page> pageSupplier) {
-        this.pageSupplier = pageSupplier;
-        this.pages = new ArrayDeque<>();
-        this.capacity = PAGE_SIZE * pages.size();
+    public InboundChannelBuffer(IntFunction<Page> pageAllocator) {
+        this.pageAllocator = pageAllocator;
     }
 
     public static InboundChannelBuffer allocatingInstance() {
-        return new InboundChannelBuffer(() -> new Page(ByteBuffer.allocate(PAGE_SIZE), () -> {}));
+        return new InboundChannelBuffer((n) -> new Page(ByteBuffer.allocate(n), () -> {}));
     }
 
     @Override
@@ -87,7 +84,7 @@ public final class InboundChannelBuffer implements AutoCloseable {
             int numPages = numPages(requiredCapacity + offset);
             int pagesToAdd = numPages - pages.size();
             for (int i = 0; i < pagesToAdd; i++) {
-                Page page = pageSupplier.get();
+                Page page = pageAllocator.apply(PAGE_SIZE);
                 pages.addLast(page);
             }
             capacity += pagesToAdd * PAGE_SIZE;

+ 3 - 52
libs/nio/src/main/java/org/elasticsearch/nio/SocketChannelContext.java

@@ -20,6 +20,7 @@
 package org.elasticsearch.nio;
 
 import org.elasticsearch.common.concurrent.CompletableContext;
+import org.elasticsearch.nio.utils.ByteBufferUtils;
 import org.elasticsearch.nio.utils.ExceptionsHelper;
 
 import java.io.IOException;
@@ -249,26 +250,6 @@ public abstract class SocketChannelContext extends ChannelContext<SocketChannel>
     // data that is copied to the buffer for a write, but not successfully flushed immediately, must be
     // copied again on the next call.
 
-    protected int readFromChannel(ByteBuffer buffer) throws IOException {
-        ByteBuffer ioBuffer = getSelector().getIoBuffer();
-        ioBuffer.limit(Math.min(buffer.remaining(), ioBuffer.limit()));
-        int bytesRead;
-        try {
-            bytesRead = rawChannel.read(ioBuffer);
-        } catch (IOException e) {
-            closeNow = true;
-            throw e;
-        }
-        if (bytesRead < 0) {
-            closeNow = true;
-            return 0;
-        } else {
-            ioBuffer.flip();
-            buffer.put(ioBuffer);
-            return bytesRead;
-        }
-    }
-
     protected int readFromChannel(InboundChannelBuffer channelBuffer) throws IOException {
         ByteBuffer ioBuffer = getSelector().getIoBuffer();
         int bytesRead;
@@ -288,7 +269,7 @@ public abstract class SocketChannelContext extends ChannelContext<SocketChannel>
             int j = 0;
             while (j < buffers.length && ioBuffer.remaining() > 0) {
                 ByteBuffer buffer = buffers[j++];
-                copyBytes(ioBuffer, buffer);
+                ByteBufferUtils.copyBytes(ioBuffer, buffer);
             }
             channelBuffer.incrementIndex(bytesRead);
             return bytesRead;
@@ -299,24 +280,6 @@ public abstract class SocketChannelContext extends ChannelContext<SocketChannel>
     // copying.
     private final int WRITE_LIMIT = 1 << 16;
 
-    protected int flushToChannel(ByteBuffer buffer) throws IOException {
-        int initialPosition = buffer.position();
-        ByteBuffer ioBuffer = getSelector().getIoBuffer();
-        ioBuffer.limit(Math.min(WRITE_LIMIT, ioBuffer.limit()));
-        copyBytes(buffer, ioBuffer);
-        ioBuffer.flip();
-        int bytesWritten;
-        try {
-            bytesWritten = rawChannel.write(ioBuffer);
-        } catch (IOException e) {
-            closeNow = true;
-            buffer.position(initialPosition);
-            throw e;
-        }
-        buffer.position(initialPosition + bytesWritten);
-        return bytesWritten;
-    }
-
     protected int flushToChannel(FlushOperation flushOperation) throws IOException {
         ByteBuffer ioBuffer = getSelector().getIoBuffer();
 
@@ -325,12 +288,8 @@ public abstract class SocketChannelContext extends ChannelContext<SocketChannel>
         while (continueFlush) {
             ioBuffer.clear();
             ioBuffer.limit(Math.min(WRITE_LIMIT, ioBuffer.limit()));
-            int j = 0;
             ByteBuffer[] buffers = flushOperation.getBuffersToWrite(WRITE_LIMIT);
-            while (j < buffers.length && ioBuffer.remaining() > 0) {
-                ByteBuffer buffer = buffers[j++];
-                copyBytes(buffer, ioBuffer);
-            }
+            ByteBufferUtils.copyBytes(buffers, ioBuffer);
             ioBuffer.flip();
             int bytesFlushed;
             try {
@@ -345,12 +304,4 @@ public abstract class SocketChannelContext extends ChannelContext<SocketChannel>
         }
         return totalBytesFlushed;
     }
-
-    private void copyBytes(ByteBuffer from, ByteBuffer to) {
-        int nBytesToCopy = Math.min(to.remaining(), from.remaining());
-        int initialLimit = from.limit();
-        from.limit(from.position() + nBytesToCopy);
-        to.put(from);
-        from.limit(initialLimit);
-    }
 }

+ 63 - 0
libs/nio/src/main/java/org/elasticsearch/nio/utils/ByteBufferUtils.java

@@ -0,0 +1,63 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.nio.utils;
+
+import java.nio.ByteBuffer;
+
+public final class ByteBufferUtils {
+
+    private ByteBufferUtils() {}
+
+    /**
+     * Copies bytes from the array of byte buffers into the destination buffer. The number of bytes copied is
+     * limited by the bytes available to copy and the space remaining in the destination byte buffer.
+     *
+     * @param source byte buffers to copy from
+     * @param destination byte buffer to copy to
+     *
+     * @return number of bytes copied
+     */
+    public static long copyBytes(ByteBuffer[] source, ByteBuffer destination) {
+        long bytesCopied = 0;
+        for (int i = 0; i < source.length && destination.hasRemaining(); i++) {
+            ByteBuffer buffer = source[i];
+            bytesCopied += copyBytes(buffer, destination);
+        }
+        return bytesCopied;
+    }
+
+    /**
+     * Copies bytes from source byte buffer into the destination buffer. The number of bytes copied is
+     * limited by the bytes available to copy and the space remaining in the destination byte buffer.
+     *
+     * @param source byte buffer to copy from
+     * @param destination byte buffer to copy to
+     *
+     * @return number of bytes copied
+     */
+    public static int copyBytes(ByteBuffer source, ByteBuffer destination) {
+        int nBytesToCopy = Math.min(destination.remaining(), source.remaining());
+        int initialLimit = source.limit();
+        source.limit(source.position() + nBytesToCopy);
+        destination.put(source);
+        source.limit(initialLimit);
+        return nBytesToCopy;
+    }
+}

+ 59 - 57
libs/nio/src/test/java/org/elasticsearch/nio/InboundChannelBufferTests.java

@@ -19,23 +19,25 @@
 
 package org.elasticsearch.nio;
 
-import org.elasticsearch.common.util.PageCacheRecycler;
 import org.elasticsearch.test.ESTestCase;
 
 import java.nio.ByteBuffer;
 import java.util.concurrent.ConcurrentLinkedQueue;
 import java.util.concurrent.atomic.AtomicBoolean;
-import java.util.function.Supplier;
+import java.util.function.IntFunction;
 
 public class InboundChannelBufferTests extends ESTestCase {
 
-    private static final int PAGE_SIZE = PageCacheRecycler.PAGE_SIZE_IN_BYTES;
-    private final Supplier<Page> defaultPageSupplier = () ->
-        new Page(ByteBuffer.allocate(PageCacheRecycler.BYTE_PAGE_SIZE), () -> {
-        });
+    private IntFunction<Page> defaultPageAllocator;
+
+    @Override
+    public void setUp() throws Exception {
+        super.setUp();
+        defaultPageAllocator = (n) -> new Page(ByteBuffer.allocate(n), () -> {});
+    }
 
     public void testNewBufferNoPages() {
-        InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageSupplier);
+        InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageAllocator);
 
         assertEquals(0, channelBuffer.getCapacity());
         assertEquals(0, channelBuffer.getRemaining());
@@ -43,107 +45,107 @@ public class InboundChannelBufferTests extends ESTestCase {
     }
 
     public void testExpandCapacity() {
-        InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageSupplier);
+        InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageAllocator);
         assertEquals(0, channelBuffer.getCapacity());
         assertEquals(0, channelBuffer.getRemaining());
 
-        channelBuffer.ensureCapacity(PAGE_SIZE);
+        channelBuffer.ensureCapacity(InboundChannelBuffer.PAGE_SIZE);
 
-        assertEquals(PAGE_SIZE, channelBuffer.getCapacity());
-        assertEquals(PAGE_SIZE, channelBuffer.getRemaining());
+        assertEquals(InboundChannelBuffer.PAGE_SIZE, channelBuffer.getCapacity());
+        assertEquals(InboundChannelBuffer.PAGE_SIZE, channelBuffer.getRemaining());
 
-        channelBuffer.ensureCapacity(PAGE_SIZE + 1);
+        channelBuffer.ensureCapacity(InboundChannelBuffer.PAGE_SIZE + 1);
 
-        assertEquals(PAGE_SIZE * 2, channelBuffer.getCapacity());
-        assertEquals(PAGE_SIZE * 2, channelBuffer.getRemaining());
+        assertEquals(InboundChannelBuffer.PAGE_SIZE * 2, channelBuffer.getCapacity());
+        assertEquals(InboundChannelBuffer.PAGE_SIZE * 2, channelBuffer.getRemaining());
     }
 
     public void testExpandCapacityMultiplePages() {
-        InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageSupplier);
-        channelBuffer.ensureCapacity(PAGE_SIZE);
+        InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageAllocator);
+        channelBuffer.ensureCapacity(InboundChannelBuffer.PAGE_SIZE);
 
-        assertEquals(PAGE_SIZE, channelBuffer.getCapacity());
+        assertEquals(InboundChannelBuffer.PAGE_SIZE, channelBuffer.getCapacity());
 
         int multiple = randomInt(80);
-        channelBuffer.ensureCapacity(PAGE_SIZE + ((multiple * PAGE_SIZE) - randomInt(500)));
+        channelBuffer.ensureCapacity(InboundChannelBuffer.PAGE_SIZE + ((multiple * InboundChannelBuffer.PAGE_SIZE) - randomInt(500)));
 
-        assertEquals(PAGE_SIZE * (multiple + 1), channelBuffer.getCapacity());
-        assertEquals(PAGE_SIZE * (multiple + 1), channelBuffer.getRemaining());
+        assertEquals(InboundChannelBuffer.PAGE_SIZE * (multiple + 1), channelBuffer.getCapacity());
+        assertEquals(InboundChannelBuffer.PAGE_SIZE * (multiple + 1), channelBuffer.getRemaining());
     }
 
     public void testExpandCapacityRespectsOffset() {
-        InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageSupplier);
-        channelBuffer.ensureCapacity(PAGE_SIZE);
+        InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageAllocator);
+        channelBuffer.ensureCapacity(InboundChannelBuffer.PAGE_SIZE);
 
-        assertEquals(PAGE_SIZE, channelBuffer.getCapacity());
-        assertEquals(PAGE_SIZE, channelBuffer.getRemaining());
+        assertEquals(InboundChannelBuffer.PAGE_SIZE, channelBuffer.getCapacity());
+        assertEquals(InboundChannelBuffer.PAGE_SIZE, channelBuffer.getRemaining());
 
         int offset = randomInt(300);
 
         channelBuffer.release(offset);
 
-        assertEquals(PAGE_SIZE - offset, channelBuffer.getCapacity());
-        assertEquals(PAGE_SIZE - offset, channelBuffer.getRemaining());
+        assertEquals(InboundChannelBuffer.PAGE_SIZE - offset, channelBuffer.getCapacity());
+        assertEquals(InboundChannelBuffer.PAGE_SIZE - offset, channelBuffer.getRemaining());
 
-        channelBuffer.ensureCapacity(PAGE_SIZE + 1);
+        channelBuffer.ensureCapacity(InboundChannelBuffer.PAGE_SIZE + 1);
 
-        assertEquals(PAGE_SIZE * 2 - offset, channelBuffer.getCapacity());
-        assertEquals(PAGE_SIZE * 2 - offset, channelBuffer.getRemaining());
+        assertEquals(InboundChannelBuffer.PAGE_SIZE * 2 - offset, channelBuffer.getCapacity());
+        assertEquals(InboundChannelBuffer.PAGE_SIZE * 2 - offset, channelBuffer.getRemaining());
     }
 
     public void testIncrementIndex() {
-        InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageSupplier);
-        channelBuffer.ensureCapacity(PAGE_SIZE);
+        InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageAllocator);
+        channelBuffer.ensureCapacity(InboundChannelBuffer.PAGE_SIZE);
 
         assertEquals(0, channelBuffer.getIndex());
-        assertEquals(PAGE_SIZE, channelBuffer.getRemaining());
+        assertEquals(InboundChannelBuffer.PAGE_SIZE, channelBuffer.getRemaining());
 
         channelBuffer.incrementIndex(10);
 
         assertEquals(10, channelBuffer.getIndex());
-        assertEquals(PAGE_SIZE - 10, channelBuffer.getRemaining());
+        assertEquals(InboundChannelBuffer.PAGE_SIZE - 10, channelBuffer.getRemaining());
     }
 
     public void testIncrementIndexWithOffset() {
-        InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageSupplier);
-        channelBuffer.ensureCapacity(PAGE_SIZE);
+        InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageAllocator);
+        channelBuffer.ensureCapacity(InboundChannelBuffer.PAGE_SIZE);
 
         assertEquals(0, channelBuffer.getIndex());
-        assertEquals(PAGE_SIZE, channelBuffer.getRemaining());
+        assertEquals(InboundChannelBuffer.PAGE_SIZE, channelBuffer.getRemaining());
 
         channelBuffer.release(10);
-        assertEquals(PAGE_SIZE - 10, channelBuffer.getRemaining());
+        assertEquals(InboundChannelBuffer.PAGE_SIZE - 10, channelBuffer.getRemaining());
 
         channelBuffer.incrementIndex(10);
 
         assertEquals(10, channelBuffer.getIndex());
-        assertEquals(PAGE_SIZE - 20, channelBuffer.getRemaining());
+        assertEquals(InboundChannelBuffer.PAGE_SIZE - 20, channelBuffer.getRemaining());
 
         channelBuffer.release(2);
         assertEquals(8, channelBuffer.getIndex());
-        assertEquals(PAGE_SIZE - 20, channelBuffer.getRemaining());
+        assertEquals(InboundChannelBuffer.PAGE_SIZE - 20, channelBuffer.getRemaining());
     }
 
     public void testReleaseClosesPages() {
         ConcurrentLinkedQueue<AtomicBoolean> queue = new ConcurrentLinkedQueue<>();
-        Supplier<Page> supplier = () -> {
+        IntFunction<Page> allocator = (n) -> {
             AtomicBoolean atomicBoolean = new AtomicBoolean();
             queue.add(atomicBoolean);
-            return new Page(ByteBuffer.allocate(PAGE_SIZE), () -> atomicBoolean.set(true));
+            return new Page(ByteBuffer.allocate(n), () -> atomicBoolean.set(true));
         };
-        InboundChannelBuffer channelBuffer = new InboundChannelBuffer(supplier);
-        channelBuffer.ensureCapacity(PAGE_SIZE * 4);
+        InboundChannelBuffer channelBuffer = new InboundChannelBuffer(allocator);
+        channelBuffer.ensureCapacity(InboundChannelBuffer.PAGE_SIZE * 4);
 
-        assertEquals(PAGE_SIZE * 4, channelBuffer.getCapacity());
+        assertEquals(InboundChannelBuffer.PAGE_SIZE * 4, channelBuffer.getCapacity());
         assertEquals(4, queue.size());
 
         for (AtomicBoolean closedRef : queue) {
             assertFalse(closedRef.get());
         }
 
-        channelBuffer.release(2 * PAGE_SIZE);
+        channelBuffer.release(2 * InboundChannelBuffer.PAGE_SIZE);
 
-        assertEquals(PAGE_SIZE * 2, channelBuffer.getCapacity());
+        assertEquals(InboundChannelBuffer.PAGE_SIZE * 2, channelBuffer.getCapacity());
 
         assertTrue(queue.poll().get());
         assertTrue(queue.poll().get());
@@ -153,13 +155,13 @@ public class InboundChannelBufferTests extends ESTestCase {
 
     public void testClose() {
         ConcurrentLinkedQueue<AtomicBoolean> queue = new ConcurrentLinkedQueue<>();
-        Supplier<Page> supplier = () -> {
+        IntFunction<Page> allocator = (n) -> {
             AtomicBoolean atomicBoolean = new AtomicBoolean();
             queue.add(atomicBoolean);
-            return new Page(ByteBuffer.allocate(PAGE_SIZE), () -> atomicBoolean.set(true));
+            return new Page(ByteBuffer.allocate(n), () -> atomicBoolean.set(true));
         };
-        InboundChannelBuffer channelBuffer = new InboundChannelBuffer(supplier);
-        channelBuffer.ensureCapacity(PAGE_SIZE * 4);
+        InboundChannelBuffer channelBuffer = new InboundChannelBuffer(allocator);
+        channelBuffer.ensureCapacity(InboundChannelBuffer.PAGE_SIZE * 4);
 
         assertEquals(4, queue.size());
 
@@ -178,13 +180,13 @@ public class InboundChannelBufferTests extends ESTestCase {
 
     public void testCloseRetainedPages() {
         ConcurrentLinkedQueue<AtomicBoolean> queue = new ConcurrentLinkedQueue<>();
-        Supplier<Page> supplier = () -> {
+        IntFunction<Page> allocator = (n) -> {
             AtomicBoolean atomicBoolean = new AtomicBoolean();
             queue.add(atomicBoolean);
-            return new Page(ByteBuffer.allocate(PAGE_SIZE), () -> atomicBoolean.set(true));
+            return new Page(ByteBuffer.allocate(n), () -> atomicBoolean.set(true));
         };
-        InboundChannelBuffer channelBuffer = new InboundChannelBuffer(supplier);
-        channelBuffer.ensureCapacity(PAGE_SIZE * 4);
+        InboundChannelBuffer channelBuffer = new InboundChannelBuffer(allocator);
+        channelBuffer.ensureCapacity(InboundChannelBuffer.PAGE_SIZE * 4);
 
         assertEquals(4, queue.size());
 
@@ -192,7 +194,7 @@ public class InboundChannelBufferTests extends ESTestCase {
             assertFalse(closedRef.get());
         }
 
-        Page[] pages = channelBuffer.sliceAndRetainPagesTo(PAGE_SIZE * 2);
+        Page[] pages = channelBuffer.sliceAndRetainPagesTo(InboundChannelBuffer.PAGE_SIZE * 2);
 
         pages[1].close();
 
@@ -220,10 +222,10 @@ public class InboundChannelBufferTests extends ESTestCase {
     }
 
     public void testAccessByteBuffers() {
-        InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageSupplier);
+        InboundChannelBuffer channelBuffer = new InboundChannelBuffer(defaultPageAllocator);
 
         int pages = randomInt(50) + 5;
-        channelBuffer.ensureCapacity(pages * PAGE_SIZE);
+        channelBuffer.ensureCapacity(pages * InboundChannelBuffer.PAGE_SIZE);
 
         long capacity = channelBuffer.getCapacity();
 

+ 7 - 65
libs/nio/src/test/java/org/elasticsearch/nio/SocketChannelContextTests.java

@@ -34,8 +34,8 @@ import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.BiConsumer;
 import java.util.function.Consumer;
+import java.util.function.IntFunction;
 import java.util.function.Predicate;
-import java.util.function.Supplier;
 
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.anyInt;
@@ -285,8 +285,8 @@ public class SocketChannelContextTests extends ESTestCase {
             when(channel.getRawChannel()).thenReturn(realChannel);
             when(channel.isOpen()).thenReturn(true);
             Runnable closer = mock(Runnable.class);
-            Supplier<Page> pageSupplier = () -> new Page(ByteBuffer.allocate(1 << 14), closer);
-            InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier);
+            IntFunction<Page> pageAllocator = (n) -> new Page(ByteBuffer.allocate(n), closer);
+            InboundChannelBuffer buffer = new InboundChannelBuffer(pageAllocator);
             buffer.ensureCapacity(1);
             TestSocketChannelContext context = new TestSocketChannelContext(channel, selector, exceptionHandler, readWriteHandler, buffer);
             context.closeFromSelector();
@@ -294,29 +294,6 @@ public class SocketChannelContextTests extends ESTestCase {
         }
     }
 
-    public void testReadToBufferLimitsToPassedBuffer() throws IOException {
-        ByteBuffer buffer = ByteBuffer.allocate(10);
-        when(rawChannel.read(any(ByteBuffer.class))).thenAnswer(completelyFillBufferAnswer());
-
-        int bytesRead = context.readFromChannel(buffer);
-        assertEquals(bytesRead, 10);
-        assertEquals(0, buffer.remaining());
-    }
-
-    public void testReadToBufferHandlesIOException() throws IOException {
-        when(rawChannel.read(any(ByteBuffer.class))).thenThrow(new IOException());
-
-        expectThrows(IOException.class, () -> context.readFromChannel(ByteBuffer.allocate(10)));
-        assertTrue(context.closeNow());
-    }
-
-    public void testReadToBufferHandlesEOF() throws IOException {
-        when(rawChannel.read(any(ByteBuffer.class))).thenReturn(-1);
-
-        context.readFromChannel(ByteBuffer.allocate(10));
-        assertTrue(context.closeNow());
-    }
-
     public void testReadToChannelBufferWillReadAsMuchAsIOBufferAllows() throws IOException {
         when(rawChannel.read(any(ByteBuffer.class))).thenAnswer(completelyFillBufferAnswer());
 
@@ -344,33 +321,6 @@ public class SocketChannelContextTests extends ESTestCase {
         assertEquals(0, channelBuffer.getIndex());
     }
 
-    public void testFlushBufferHandlesPartialFlush() throws IOException {
-        int bytesToConsume = 3;
-        when(rawChannel.write(any(ByteBuffer.class))).thenAnswer(consumeBufferAnswer(bytesToConsume));
-
-        ByteBuffer buffer = ByteBuffer.allocate(10);
-        context.flushToChannel(buffer);
-        assertEquals(10 - bytesToConsume, buffer.remaining());
-    }
-
-    public void testFlushBufferHandlesFullFlush() throws IOException {
-        int bytesToConsume = 10;
-        when(rawChannel.write(any(ByteBuffer.class))).thenAnswer(consumeBufferAnswer(bytesToConsume));
-
-        ByteBuffer buffer = ByteBuffer.allocate(10);
-        context.flushToChannel(buffer);
-        assertEquals(0, buffer.remaining());
-    }
-
-    public void testFlushBufferHandlesIOException() throws IOException {
-        when(rawChannel.write(any(ByteBuffer.class))).thenThrow(new IOException());
-
-        ByteBuffer buffer = ByteBuffer.allocate(10);
-        expectThrows(IOException.class, () -> context.flushToChannel(buffer));
-        assertTrue(context.closeNow());
-        assertEquals(10, buffer.remaining());
-    }
-
     public void testFlushBuffersHandlesZeroFlush() throws IOException {
         when(rawChannel.write(any(ByteBuffer.class))).thenAnswer(consumeBufferAnswer(0));
 
@@ -456,22 +406,14 @@ public class SocketChannelContextTests extends ESTestCase {
 
         @Override
         public int read() throws IOException {
-            if (randomBoolean()) {
-                InboundChannelBuffer channelBuffer = InboundChannelBuffer.allocatingInstance();
-                return readFromChannel(channelBuffer);
-            } else {
-                return readFromChannel(ByteBuffer.allocate(10));
-            }
+            InboundChannelBuffer channelBuffer = InboundChannelBuffer.allocatingInstance();
+            return readFromChannel(channelBuffer);
         }
 
         @Override
         public void flushChannel() throws IOException {
-            if (randomBoolean()) {
-                ByteBuffer[] byteBuffers = {ByteBuffer.allocate(10)};
-                flushToChannel(new FlushOperation(byteBuffers, (v, e) -> {}));
-            } else {
-                flushToChannel(ByteBuffer.allocate(10));
-            }
+            ByteBuffer[] byteBuffers = {ByteBuffer.allocate(10)};
+            flushToChannel(new FlushOperation(byteBuffers, (v, e) -> {}));
         }
 
         @Override

+ 4 - 10
plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpServerTransport.java

@@ -25,7 +25,6 @@ import org.apache.logging.log4j.Logger;
 import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.network.NetworkService;
-import org.elasticsearch.common.recycler.Recycler;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.settings.SettingsException;
 import org.elasticsearch.common.unit.ByteSizeValue;
@@ -43,16 +42,15 @@ import org.elasticsearch.nio.InboundChannelBuffer;
 import org.elasticsearch.nio.NioGroup;
 import org.elasticsearch.nio.NioSelector;
 import org.elasticsearch.nio.NioSocketChannel;
-import org.elasticsearch.nio.Page;
 import org.elasticsearch.nio.ServerChannelContext;
 import org.elasticsearch.nio.SocketChannelContext;
 import org.elasticsearch.rest.RestUtils;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.nio.NioGroupFactory;
+import org.elasticsearch.transport.nio.PageAllocator;
 
 import java.io.IOException;
 import java.net.InetSocketAddress;
-import java.nio.ByteBuffer;
 import java.nio.channels.ServerSocketChannel;
 import java.nio.channels.SocketChannel;
 import java.util.Arrays;
@@ -80,8 +78,8 @@ import static org.elasticsearch.http.nio.cors.NioCorsHandler.ANY_ORIGIN;
 public class NioHttpServerTransport extends AbstractHttpServerTransport {
     private static final Logger logger = LogManager.getLogger(NioHttpServerTransport.class);
 
-    protected final PageCacheRecycler pageCacheRecycler;
     protected final NioCorsConfig corsConfig;
+    protected final PageAllocator pageAllocator;
     private final NioGroupFactory nioGroupFactory;
 
     protected final boolean tcpNoDelay;
@@ -97,7 +95,7 @@ public class NioHttpServerTransport extends AbstractHttpServerTransport {
                                   PageCacheRecycler pageCacheRecycler, ThreadPool threadPool, NamedXContentRegistry xContentRegistry,
                                   Dispatcher dispatcher, NioGroupFactory nioGroupFactory) {
         super(settings, networkService, bigArrays, threadPool, xContentRegistry, dispatcher);
-        this.pageCacheRecycler = pageCacheRecycler;
+        this.pageAllocator = new PageAllocator(pageCacheRecycler);
         this.nioGroupFactory = nioGroupFactory;
 
         ByteSizeValue maxChunkSize = SETTING_HTTP_MAX_CHUNK_SIZE.get(settings);
@@ -206,15 +204,11 @@ public class NioHttpServerTransport extends AbstractHttpServerTransport {
         @Override
         public NioHttpChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException {
             NioHttpChannel httpChannel = new NioHttpChannel(channel);
-            java.util.function.Supplier<Page> pageSupplier = () -> {
-                Recycler.V<byte[]> bytes = pageCacheRecycler.bytePage(false);
-                return new Page(ByteBuffer.wrap(bytes.v()), bytes::close);
-            };
             HttpReadWriteHandler httpReadWritePipeline = new HttpReadWriteHandler(httpChannel,NioHttpServerTransport.this,
                 handlingSettings, corsConfig, selector.getTaskScheduler(), threadPool::relativeTimeInMillis);
             Consumer<Exception> exceptionHandler = (e) -> onException(httpChannel, e);
             SocketChannelContext context = new BytesChannelContext(httpChannel, selector, exceptionHandler, httpReadWritePipeline,
-                new InboundChannelBuffer(pageSupplier));
+                new InboundChannelBuffer(pageAllocator));
             httpChannel.setContext(context);
             return httpChannel;
         }

+ 3 - 9
plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/NioTransport.java

@@ -26,7 +26,6 @@ import org.elasticsearch.Version;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.network.NetworkService;
-import org.elasticsearch.common.recycler.Recycler;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.util.PageCacheRecycler;
 import org.elasticsearch.indices.breaker.CircuitBreakerService;
@@ -36,20 +35,17 @@ import org.elasticsearch.nio.InboundChannelBuffer;
 import org.elasticsearch.nio.NioGroup;
 import org.elasticsearch.nio.NioSelector;
 import org.elasticsearch.nio.NioSocketChannel;
-import org.elasticsearch.nio.Page;
 import org.elasticsearch.nio.ServerChannelContext;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.TcpTransport;
 
 import java.io.IOException;
 import java.net.InetSocketAddress;
-import java.nio.ByteBuffer;
 import java.nio.channels.ServerSocketChannel;
 import java.nio.channels.SocketChannel;
 import java.util.concurrent.ConcurrentMap;
 import java.util.function.Consumer;
 import java.util.function.Function;
-import java.util.function.Supplier;
 
 import static org.elasticsearch.common.util.concurrent.ConcurrentCollections.newConcurrentMap;
 
@@ -57,6 +53,7 @@ public class NioTransport extends TcpTransport {
 
     private static final Logger logger = LogManager.getLogger(NioTransport.class);
 
+    protected final PageAllocator pageAllocator;
     private final ConcurrentMap<String, TcpChannelFactory> profileToChannelFactory = newConcurrentMap();
     private final NioGroupFactory groupFactory;
     private volatile NioGroup nioGroup;
@@ -66,6 +63,7 @@ public class NioTransport extends TcpTransport {
                            PageCacheRecycler pageCacheRecycler, NamedWriteableRegistry namedWriteableRegistry,
                            CircuitBreakerService circuitBreakerService, NioGroupFactory groupFactory) {
         super(settings, version, threadPool, pageCacheRecycler, circuitBreakerService, namedWriteableRegistry, networkService);
+        this.pageAllocator = new PageAllocator(pageCacheRecycler);
         this.groupFactory = groupFactory;
     }
 
@@ -158,14 +156,10 @@ public class NioTransport extends TcpTransport {
         @Override
         public NioTcpChannel createChannel(NioSelector selector, SocketChannel channel) {
             NioTcpChannel nioChannel = new NioTcpChannel(isClient == false, profileName, channel);
-            Supplier<Page> pageSupplier = () -> {
-                Recycler.V<byte[]> bytes = pageCacheRecycler.bytePage(false);
-                return new Page(ByteBuffer.wrap(bytes.v()), bytes::close);
-            };
             TcpReadWriteHandler readWriteHandler = new TcpReadWriteHandler(nioChannel, NioTransport.this);
             Consumer<Exception> exceptionHandler = (e) -> onException(nioChannel, e);
             BytesChannelContext context = new BytesChannelContext(nioChannel, selector, exceptionHandler, readWriteHandler,
-                new InboundChannelBuffer(pageSupplier));
+                new InboundChannelBuffer(pageAllocator));
             nioChannel.setContext(context);
             return nioChannel;
         }

+ 48 - 0
plugins/transport-nio/src/main/java/org/elasticsearch/transport/nio/PageAllocator.java

@@ -0,0 +1,48 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.transport.nio;
+
+import org.elasticsearch.common.recycler.Recycler;
+import org.elasticsearch.common.util.PageCacheRecycler;
+import org.elasticsearch.nio.Page;
+
+import java.nio.ByteBuffer;
+import java.util.function.IntFunction;
+
+public class PageAllocator implements IntFunction<Page> {
+
+    private static final int RECYCLE_LOWER_THRESHOLD = PageCacheRecycler.BYTE_PAGE_SIZE / 2;
+
+    private final PageCacheRecycler recycler;
+
+    public PageAllocator(PageCacheRecycler recycler) {
+        this.recycler = recycler;
+    }
+
+    @Override
+    public Page apply(int length) {
+        if (length >= RECYCLE_LOWER_THRESHOLD && length <= PageCacheRecycler.BYTE_PAGE_SIZE){
+            Recycler.V<byte[]> bytePage = recycler.bytePage(false);
+            return new Page(ByteBuffer.wrap(bytePage.v(), 0, length), bytePage::close);
+        } else {
+            return new Page(ByteBuffer.allocate(length), () -> {});
+        }
+    }
+}

+ 9 - 5
test/framework/src/main/java/org/elasticsearch/transport/nio/MockNioTransport.java

@@ -37,8 +37,8 @@ import org.elasticsearch.nio.BytesChannelContext;
 import org.elasticsearch.nio.BytesWriteHandler;
 import org.elasticsearch.nio.ChannelFactory;
 import org.elasticsearch.nio.InboundChannelBuffer;
-import org.elasticsearch.nio.NioSelectorGroup;
 import org.elasticsearch.nio.NioSelector;
+import org.elasticsearch.nio.NioSelectorGroup;
 import org.elasticsearch.nio.NioServerSocketChannel;
 import org.elasticsearch.nio.NioSocketChannel;
 import org.elasticsearch.nio.Page;
@@ -61,7 +61,7 @@ import java.util.HashSet;
 import java.util.Set;
 import java.util.concurrent.ConcurrentMap;
 import java.util.function.Consumer;
-import java.util.function.Supplier;
+import java.util.function.IntFunction;
 
 import static org.elasticsearch.common.util.concurrent.ConcurrentCollections.newConcurrentMap;
 import static org.elasticsearch.common.util.concurrent.EsExecutors.daemonThreadFactory;
@@ -192,9 +192,13 @@ public class MockNioTransport extends TcpTransport {
         @Override
         public MockSocketChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException {
             MockSocketChannel nioChannel = new MockSocketChannel(isClient == false, profileName, channel);
-            Supplier<Page> pageSupplier = () -> {
-                Recycler.V<byte[]> bytes = pageCacheRecycler.bytePage(false);
-                return new Page(ByteBuffer.wrap(bytes.v()), bytes::close);
+            IntFunction<Page> pageSupplier = (length) -> {
+                if (length > PageCacheRecycler.BYTE_PAGE_SIZE) {
+                    return new Page(ByteBuffer.allocate(length), () -> {});
+                } else {
+                    Recycler.V<byte[]> bytes = pageCacheRecycler.bytePage(false);
+                    return new Page(ByteBuffer.wrap(bytes.v(), 0, length), bytes::close);
+                }
             };
             MockTcpReadWriteHandler readWriteHandler = new MockTcpReadWriteHandler(nioChannel, MockNioTransport.this);
             BytesChannelContext context = new BytesChannelContext(nioChannel, selector, (e) -> exceptionCaught(nioChannel, e),

+ 9 - 6
x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContext.java

@@ -36,19 +36,22 @@ public final class SSLChannelContext extends SocketChannelContext {
     private static final Runnable DEFAULT_TIMEOUT_CANCELLER = () -> {};
 
     private final SSLDriver sslDriver;
+    private final InboundChannelBuffer networkReadBuffer;
     private final LinkedList<FlushOperation> encryptedFlushes = new LinkedList<>();
     private Runnable closeTimeoutCanceller = DEFAULT_TIMEOUT_CANCELLER;
 
     SSLChannelContext(NioSocketChannel channel, NioSelector selector, Consumer<Exception> exceptionHandler, SSLDriver sslDriver,
-                      ReadWriteHandler readWriteHandler, InboundChannelBuffer channelBuffer) {
-        this(channel, selector, exceptionHandler, sslDriver, readWriteHandler, channelBuffer, ALWAYS_ALLOW_CHANNEL);
+                      ReadWriteHandler readWriteHandler, InboundChannelBuffer applicationBuffer) {
+        this(channel, selector, exceptionHandler, sslDriver, readWriteHandler, InboundChannelBuffer.allocatingInstance(),
+            applicationBuffer, ALWAYS_ALLOW_CHANNEL);
     }
 
     SSLChannelContext(NioSocketChannel channel, NioSelector selector, Consumer<Exception> exceptionHandler, SSLDriver sslDriver,
-                      ReadWriteHandler readWriteHandler, InboundChannelBuffer channelBuffer,
+                      ReadWriteHandler readWriteHandler, InboundChannelBuffer networkReadBuffer, InboundChannelBuffer channelBuffer,
                       Predicate<NioSocketChannel> allowChannelPredicate) {
         super(channel, selector, exceptionHandler, readWriteHandler, channelBuffer, allowChannelPredicate);
         this.sslDriver = sslDriver;
+        this.networkReadBuffer = networkReadBuffer;
     }
 
     @Override
@@ -157,12 +160,12 @@ public final class SSLChannelContext extends SocketChannelContext {
         if (closeNow()) {
             return bytesRead;
         }
-        bytesRead = readFromChannel(sslDriver.getNetworkReadBuffer());
+        bytesRead = readFromChannel(networkReadBuffer);
         if (bytesRead == 0) {
             return bytesRead;
         }
 
-        sslDriver.read(channelBuffer);
+        sslDriver.read(networkReadBuffer, channelBuffer);
 
         handleReadBytes();
         // It is possible that a read call produced non-application bytes to flush
@@ -201,7 +204,7 @@ public final class SSLChannelContext extends SocketChannelContext {
                 getSelector().executeFailedListener(encryptedFlush.getListener(), new ClosedChannelException());
             }
             encryptedFlushes.clear();
-            IOUtils.close(super::closeFromSelector, sslDriver::close);
+            IOUtils.close(super::closeFromSelector, networkReadBuffer::close, sslDriver::close);
         }
     }
 

+ 53 - 55
x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLDriver.java

@@ -8,6 +8,7 @@ package org.elasticsearch.xpack.security.transport.nio;
 import org.elasticsearch.nio.FlushOperation;
 import org.elasticsearch.nio.InboundChannelBuffer;
 import org.elasticsearch.nio.Page;
+import org.elasticsearch.nio.utils.ByteBufferUtils;
 import org.elasticsearch.nio.utils.ExceptionsHelper;
 
 import javax.net.ssl.SSLEngine;
@@ -16,6 +17,7 @@ import javax.net.ssl.SSLException;
 import javax.net.ssl.SSLSession;
 import java.nio.ByteBuffer;
 import java.util.ArrayList;
+import java.util.function.IntFunction;
 
 /**
  * SSLDriver is a class that wraps the {@link SSLEngine} and attempts to simplify the API. The basic usage is
@@ -27,9 +29,9 @@ import java.util.ArrayList;
  * application to be written to the wire.
  *
  * Handling reads from a channel with this class is very simple. When data has been read, call
- * {@link #read(InboundChannelBuffer)}. If the data is application data, it will be decrypted and placed into
- * the buffer passed as an argument. Otherwise, it will be consumed internally and advance the SSL/TLS close
- * or handshake process.
+ * {@link #read(InboundChannelBuffer, InboundChannelBuffer)}. If the data is application data, it will be
+ * decrypted and placed into the application buffer passed as an argument. Otherwise, it will be consumed
+ * internally and advance the SSL/TLS close or handshake process.
  *
  * Producing writes for a channel is more complicated. The method {@link #needsNonApplicationWrite()} can be
  * called to determine if this driver needs to produce more data to advance the handshake or close process.
@@ -54,21 +56,22 @@ public class SSLDriver implements AutoCloseable {
     private static final FlushOperation EMPTY_FLUSH_OPERATION = new FlushOperation(EMPTY_BUFFERS, (r, t) -> {});
 
     private final SSLEngine engine;
-    // TODO: When the bytes are actually recycled, we need to test that they are released on driver close
-    private final SSLOutboundBuffer outboundBuffer = new SSLOutboundBuffer((n) -> new Page(ByteBuffer.allocate(n)));
+    private final IntFunction<Page> pageAllocator;
+    private final SSLOutboundBuffer outboundBuffer;
+    private Page networkReadPage;
     private final boolean isClientMode;
     // This should only be accessed by the network thread associated with this channel, so nothing needs to
     // be volatile.
     private Mode currentMode = new HandshakeMode();
-    private ByteBuffer networkReadBuffer;
     private int packetSize;
 
-    public SSLDriver(SSLEngine engine, boolean isClientMode) {
+    public SSLDriver(SSLEngine engine, IntFunction<Page> pageAllocator, boolean isClientMode) {
         this.engine = engine;
+        this.pageAllocator = pageAllocator;
+        this.outboundBuffer = new SSLOutboundBuffer(pageAllocator);
         this.isClientMode = isClientMode;
         SSLSession session = engine.getSession();
         packetSize = session.getPacketBufferSize();
-        this.networkReadBuffer = ByteBuffer.allocate(packetSize);
     }
 
     public void init() throws SSLException {
@@ -106,22 +109,25 @@ public class SSLDriver implements AutoCloseable {
         return currentMode.isHandshake();
     }
 
-    public ByteBuffer getNetworkReadBuffer() {
-        return networkReadBuffer;
-    }
-
     public SSLOutboundBuffer getOutboundBuffer() {
         return outboundBuffer;
     }
 
-    public void read(InboundChannelBuffer buffer) throws SSLException {
-        Mode modePriorToRead;
-        do {
-            modePriorToRead = currentMode;
-            currentMode.read(buffer);
-            // If we switched modes we want to read again as there might be unhandled bytes that need to be
-            // handled by the new mode.
-        } while (modePriorToRead != currentMode);
+    public void read(InboundChannelBuffer encryptedBuffer, InboundChannelBuffer applicationBuffer) throws SSLException {
+        networkReadPage = pageAllocator.apply(packetSize);
+        try {
+            Mode modePriorToRead;
+            do {
+                modePriorToRead = currentMode;
+                currentMode.read(encryptedBuffer, applicationBuffer);
+                // It is possible that we received multiple SSL packets from the network since the last read.
+                // If one of those packets causes us to change modes (such as finished handshaking), we need
+                // to call read in the new mode to handle the remaining packets.
+            } while (modePriorToRead != currentMode);
+        } finally {
+            networkReadPage.close();
+            networkReadPage = null;
+        }
     }
 
     public boolean readyForApplicationWrites() {
@@ -171,27 +177,34 @@ public class SSLDriver implements AutoCloseable {
         ExceptionsHelper.rethrowAndSuppress(closingExceptions);
     }
 
-    private SSLEngineResult unwrap(InboundChannelBuffer buffer) throws SSLException {
+    private SSLEngineResult unwrap(InboundChannelBuffer networkBuffer, InboundChannelBuffer applicationBuffer) throws SSLException {
         while (true) {
-            SSLEngineResult result = engine.unwrap(networkReadBuffer, buffer.sliceBuffersFrom(buffer.getIndex()));
-            buffer.incrementIndex(result.bytesProduced());
+            ensureApplicationBufferSize(applicationBuffer);
+            ByteBuffer networkReadBuffer = networkReadPage.byteBuffer();
+            networkReadBuffer.clear();
+            ByteBufferUtils.copyBytes(networkBuffer.sliceBuffersTo(Math.min(networkBuffer.getIndex(), packetSize)), networkReadBuffer);
+            networkReadBuffer.flip();
+            SSLEngineResult result = engine.unwrap(networkReadBuffer, applicationBuffer.sliceBuffersFrom(applicationBuffer.getIndex()));
+            networkBuffer.release(result.bytesConsumed());
+            applicationBuffer.incrementIndex(result.bytesProduced());
             switch (result.getStatus()) {
                 case OK:
-                    networkReadBuffer.compact();
                     return result;
                 case BUFFER_UNDERFLOW:
                     // There is not enough space in the network buffer for an entire SSL packet. Compact the
                     // current data and expand the buffer if necessary.
-                    int currentCapacity = networkReadBuffer.capacity();
-                    ensureNetworkReadBufferSize();
-                    if (currentCapacity == networkReadBuffer.capacity()) {
-                        networkReadBuffer.compact();
+                    packetSize = engine.getSession().getPacketBufferSize();
+                    if (networkReadPage.byteBuffer().capacity() < packetSize) {
+                        networkReadPage.close();
+                        networkReadPage = pageAllocator.apply(packetSize);
+                    } else {
+                        return result;
                     }
-                    return result;
+                    break;
                 case BUFFER_OVERFLOW:
                     // There is not enough space in the application buffer for the decrypted message. Expand
                     // the application buffer to ensure that it has enough space.
-                    ensureApplicationBufferSize(buffer);
+                    ensureApplicationBufferSize(applicationBuffer);
                     break;
                 case CLOSED:
                     assert engine.isInboundDone() : "We received close_notify so read should be done";
@@ -254,15 +267,6 @@ public class SSLDriver implements AutoCloseable {
         }
     }
 
-    private void ensureNetworkReadBufferSize() {
-        packetSize = engine.getSession().getPacketBufferSize();
-        if (networkReadBuffer.capacity() < packetSize) {
-            ByteBuffer newBuffer = ByteBuffer.allocate(packetSize);
-            networkReadBuffer.flip();
-            newBuffer.put(networkReadBuffer);
-        }
-    }
-
     // There are three potential modes for the driver to be in - HANDSHAKE, APPLICATION, or CLOSE. HANDSHAKE
     // is the initial mode. During this mode data that is read and written will be related to the TLS
     // handshake process. Application related data cannot be encrypted until the handshake is complete. From
@@ -282,7 +286,7 @@ public class SSLDriver implements AutoCloseable {
 
     private interface Mode {
 
-        void read(InboundChannelBuffer buffer) throws SSLException;
+        void read(InboundChannelBuffer encryptedBuffer, InboundChannelBuffer applicationBuffer) throws SSLException;
 
         int write(FlushOperation applicationBytes) throws SSLException;
 
@@ -342,13 +346,11 @@ public class SSLDriver implements AutoCloseable {
         }
 
         @Override
-        public void read(InboundChannelBuffer buffer) throws SSLException {
-            ensureApplicationBufferSize(buffer);
+        public void read(InboundChannelBuffer encryptedBuffer, InboundChannelBuffer applicationBuffer) throws SSLException {
             boolean continueUnwrap = true;
-            while (continueUnwrap && networkReadBuffer.position() > 0) {
-                networkReadBuffer.flip();
+            while (continueUnwrap && encryptedBuffer.getIndex() > 0) {
                 try {
-                    SSLEngineResult result = unwrap(buffer);
+                    SSLEngineResult result = unwrap(encryptedBuffer, applicationBuffer);
                     handshakeStatus = result.getHandshakeStatus();
                     handshake();
                     // If we are done handshaking we should exit the handshake read
@@ -430,12 +432,10 @@ public class SSLDriver implements AutoCloseable {
     private class ApplicationMode implements Mode {
 
         @Override
-        public void read(InboundChannelBuffer buffer) throws SSLException {
-            ensureApplicationBufferSize(buffer);
+        public void read(InboundChannelBuffer encryptedBuffer, InboundChannelBuffer applicationBuffer) throws SSLException {
             boolean continueUnwrap = true;
-            while (continueUnwrap && networkReadBuffer.position() > 0) {
-                networkReadBuffer.flip();
-                SSLEngineResult result = unwrap(buffer);
+            while (continueUnwrap && encryptedBuffer.getIndex() > 0) {
+                SSLEngineResult result = unwrap(encryptedBuffer, applicationBuffer);
                 boolean renegotiationRequested = result.getStatus() != SSLEngineResult.Status.CLOSED
                     && maybeRenegotiation(result.getHandshakeStatus());
                 continueUnwrap = result.bytesConsumed() > 0 && renegotiationRequested == false;
@@ -515,7 +515,7 @@ public class SSLDriver implements AutoCloseable {
         }
 
         @Override
-        public void read(InboundChannelBuffer buffer) throws SSLException {
+        public void read(InboundChannelBuffer encryptedBuffer, InboundChannelBuffer applicationBuffer) throws SSLException {
             if (needToReceiveClose == false) {
                 // There is an issue where receiving handshake messages after initiating the close process
                 // can place the SSLEngine back into handshaking mode. In order to handle this, if we
@@ -524,11 +524,9 @@ public class SSLDriver implements AutoCloseable {
                 return;
             }
 
-            ensureApplicationBufferSize(buffer);
             boolean continueUnwrap = true;
-            while (continueUnwrap && networkReadBuffer.position() > 0) {
-                networkReadBuffer.flip();
-                SSLEngineResult result = unwrap(buffer);
+            while (continueUnwrap && encryptedBuffer.getIndex() > 0) {
+                SSLEngineResult result = unwrap(encryptedBuffer, applicationBuffer);
                 continueUnwrap = result.bytesProduced() > 0 || result.bytesConsumed() > 0;
             }
             if (engine.isInboundDone()) {

+ 6 - 12
x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioHttpServerTransport.java

@@ -8,7 +8,6 @@ package org.elasticsearch.xpack.security.transport.nio;
 import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 import org.elasticsearch.common.network.NetworkService;
-import org.elasticsearch.common.recycler.Recycler;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.util.BigArrays;
 import org.elasticsearch.common.util.PageCacheRecycler;
@@ -22,7 +21,6 @@ import org.elasticsearch.nio.ChannelFactory;
 import org.elasticsearch.nio.InboundChannelBuffer;
 import org.elasticsearch.nio.NioSelector;
 import org.elasticsearch.nio.NioSocketChannel;
-import org.elasticsearch.nio.Page;
 import org.elasticsearch.nio.ServerChannelContext;
 import org.elasticsearch.nio.SocketChannelContext;
 import org.elasticsearch.threadpool.ThreadPool;
@@ -35,11 +33,9 @@ import org.elasticsearch.xpack.security.transport.filter.IPFilter;
 import javax.net.ssl.SSLEngine;
 import java.io.IOException;
 import java.net.InetSocketAddress;
-import java.nio.ByteBuffer;
 import java.nio.channels.ServerSocketChannel;
 import java.nio.channels.SocketChannel;
 import java.util.function.Consumer;
-import java.util.function.Supplier;
 
 import static org.elasticsearch.xpack.core.XPackSettings.HTTP_SSL_ENABLED;
 
@@ -93,13 +89,9 @@ public class SecurityNioHttpServerTransport extends NioHttpServerTransport {
         @Override
         public NioHttpChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException {
             NioHttpChannel httpChannel = new NioHttpChannel(channel);
-            Supplier<Page> pageSupplier = () -> {
-                Recycler.V<byte[]> bytes = pageCacheRecycler.bytePage(false);
-                return new Page(ByteBuffer.wrap(bytes.v()), bytes::close);
-            };
             HttpReadWriteHandler httpHandler = new HttpReadWriteHandler(httpChannel,SecurityNioHttpServerTransport.this,
                 handlingSettings, corsConfig, selector.getTaskScheduler(), threadPool::relativeTimeInNanos);
-            InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier);
+            InboundChannelBuffer networkBuffer = new InboundChannelBuffer(pageAllocator);
             Consumer<Exception> exceptionHandler = (e) -> securityExceptionHandler.accept(httpChannel, e);
 
             SocketChannelContext context;
@@ -113,10 +105,12 @@ public class SecurityNioHttpServerTransport extends NioHttpServerTransport {
                 } else {
                     sslEngine = sslService.createSSLEngine(sslConfiguration, null, -1);
                 }
-                SSLDriver sslDriver = new SSLDriver(sslEngine, false);
-                context = new SSLChannelContext(httpChannel, selector, exceptionHandler, sslDriver, httpHandler, buffer, nioIpFilter);
+                SSLDriver sslDriver = new SSLDriver(sslEngine, pageAllocator, false);
+                InboundChannelBuffer applicationBuffer = new InboundChannelBuffer(pageAllocator);
+                context = new SSLChannelContext(httpChannel, selector, exceptionHandler, sslDriver, httpHandler, networkBuffer,
+                    applicationBuffer, nioIpFilter);
             } else {
-                context = new BytesChannelContext(httpChannel, selector, exceptionHandler, httpHandler, buffer, nioIpFilter);
+                context = new BytesChannelContext(httpChannel, selector, exceptionHandler, httpHandler, networkBuffer, nioIpFilter);
             }
             httpChannel.setContext(context);
 

+ 6 - 12
x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioTransport.java

@@ -12,7 +12,6 @@ import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.network.NetworkService;
-import org.elasticsearch.common.recycler.Recycler;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.util.PageCacheRecycler;
 import org.elasticsearch.indices.breaker.CircuitBreakerService;
@@ -21,7 +20,6 @@ import org.elasticsearch.nio.ChannelFactory;
 import org.elasticsearch.nio.InboundChannelBuffer;
 import org.elasticsearch.nio.NioSelector;
 import org.elasticsearch.nio.NioSocketChannel;
-import org.elasticsearch.nio.Page;
 import org.elasticsearch.nio.ServerChannelContext;
 import org.elasticsearch.nio.SocketChannelContext;
 import org.elasticsearch.threadpool.ThreadPool;
@@ -45,14 +43,12 @@ import javax.net.ssl.SSLEngine;
 import javax.net.ssl.SSLParameters;
 import java.io.IOException;
 import java.net.InetSocketAddress;
-import java.nio.ByteBuffer;
 import java.nio.channels.ServerSocketChannel;
 import java.nio.channels.SocketChannel;
 import java.util.Collections;
 import java.util.Map;
 import java.util.function.Consumer;
 import java.util.function.Function;
-import java.util.function.Supplier;
 
 import static org.elasticsearch.xpack.core.security.SecurityField.setting;
 
@@ -156,20 +152,18 @@ public class SecurityNioTransport extends NioTransport {
         @Override
         public NioTcpChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException {
             NioTcpChannel nioChannel = new NioTcpChannel(isClient == false, profileName, channel);
-            Supplier<Page> pageSupplier = () -> {
-                Recycler.V<byte[]> bytes = pageCacheRecycler.bytePage(false);
-                return new Page(ByteBuffer.wrap(bytes.v()), bytes::close);
-            };
             TcpReadWriteHandler readWriteHandler = new TcpReadWriteHandler(nioChannel, SecurityNioTransport.this);
-            InboundChannelBuffer buffer = new InboundChannelBuffer(pageSupplier);
+            InboundChannelBuffer networkBuffer = new InboundChannelBuffer(pageAllocator);
             Consumer<Exception> exceptionHandler = (e) -> onException(nioChannel, e);
 
             SocketChannelContext context;
             if (sslEnabled) {
-                SSLDriver sslDriver = new SSLDriver(createSSLEngine(channel), isClient);
-                context = new SSLChannelContext(nioChannel, selector, exceptionHandler, sslDriver, readWriteHandler, buffer, ipFilter);
+                SSLDriver sslDriver = new SSLDriver(createSSLEngine(channel), pageAllocator, isClient);
+                InboundChannelBuffer applicationBuffer = new InboundChannelBuffer(pageAllocator);
+                context = new SSLChannelContext(nioChannel, selector, exceptionHandler, sslDriver, readWriteHandler, networkBuffer,
+                    applicationBuffer, ipFilter);
             } else {
-                context = new BytesChannelContext(nioChannel, selector, exceptionHandler, readWriteHandler, buffer, ipFilter);
+                context = new BytesChannelContext(nioChannel, selector, exceptionHandler, readWriteHandler, networkBuffer, ipFilter);
             }
             nioChannel.setContext(context);
 

+ 22 - 12
x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContextTests.java

@@ -52,7 +52,6 @@ public class SSLChannelContextTests extends ESTestCase {
     private BiConsumer<Void, Exception> listener;
     private Consumer exceptionHandler;
     private SSLDriver sslDriver;
-    private ByteBuffer readBuffer = ByteBuffer.allocate(1 << 14);
     private int messageLength;
 
     @Before
@@ -76,7 +75,6 @@ public class SSLChannelContextTests extends ESTestCase {
 
         when(selector.isOnCurrentThread()).thenReturn(true);
         when(selector.getTaskScheduler()).thenReturn(nioTimer);
-        when(sslDriver.getNetworkReadBuffer()).thenReturn(readBuffer);
         when(sslDriver.getOutboundBuffer()).thenReturn(outboundBuffer);
         ByteBuffer buffer = ByteBuffer.allocate(1 << 14);
         when(selector.getIoBuffer()).thenAnswer(invocationOnMock -> {
@@ -88,8 +86,12 @@ public class SSLChannelContextTests extends ESTestCase {
     public void testSuccessfulRead() throws IOException {
         byte[] bytes = createMessage(messageLength);
 
-        when(rawChannel.read(any(ByteBuffer.class))).thenReturn(bytes.length);
-        doAnswer(getReadAnswerForBytes(bytes)).when(sslDriver).read(channelBuffer);
+        when(rawChannel.read(any(ByteBuffer.class))).thenAnswer(invocationOnMock -> {
+            ByteBuffer buffer = (ByteBuffer) invocationOnMock.getArguments()[0];
+            buffer.put(bytes);
+            return bytes.length;
+        });
+        doAnswer(getReadAnswerForBytes(bytes)).when(sslDriver).read(any(InboundChannelBuffer.class), eq(channelBuffer));
 
         when(readConsumer.apply(channelBuffer)).thenReturn(messageLength, 0);
 
@@ -103,8 +105,12 @@ public class SSLChannelContextTests extends ESTestCase {
     public void testMultipleReadsConsumed() throws IOException {
         byte[] bytes = createMessage(messageLength * 2);
 
-        when(rawChannel.read(any(ByteBuffer.class))).thenReturn(bytes.length);
-        doAnswer(getReadAnswerForBytes(bytes)).when(sslDriver).read(channelBuffer);
+        when(rawChannel.read(any(ByteBuffer.class))).thenAnswer(invocationOnMock -> {
+            ByteBuffer buffer = (ByteBuffer) invocationOnMock.getArguments()[0];
+            buffer.put(bytes);
+            return bytes.length;
+        });
+        doAnswer(getReadAnswerForBytes(bytes)).when(sslDriver).read(any(InboundChannelBuffer.class), eq(channelBuffer));
 
         when(readConsumer.apply(channelBuffer)).thenReturn(messageLength, messageLength, 0);
 
@@ -118,8 +124,12 @@ public class SSLChannelContextTests extends ESTestCase {
     public void testPartialRead() throws IOException {
         byte[] bytes = createMessage(messageLength);
 
-        when(rawChannel.read(any(ByteBuffer.class))).thenReturn(bytes.length);
-        doAnswer(getReadAnswerForBytes(bytes)).when(sslDriver).read(channelBuffer);
+        when(rawChannel.read(any(ByteBuffer.class))).thenAnswer(invocationOnMock -> {
+            ByteBuffer buffer = (ByteBuffer) invocationOnMock.getArguments()[0];
+            buffer.put(bytes);
+            return bytes.length;
+        });
+        doAnswer(getReadAnswerForBytes(bytes)).when(sslDriver).read(any(InboundChannelBuffer.class), eq(channelBuffer));
 
 
         when(readConsumer.apply(channelBuffer)).thenReturn(0);
@@ -424,12 +434,12 @@ public class SSLChannelContextTests extends ESTestCase {
 
     private Answer getReadAnswerForBytes(byte[] bytes) {
         return invocationOnMock -> {
-            InboundChannelBuffer buffer = (InboundChannelBuffer) invocationOnMock.getArguments()[0];
-            buffer.ensureCapacity(buffer.getIndex() + bytes.length);
-            ByteBuffer[] buffers = buffer.sliceBuffersFrom(buffer.getIndex());
+            InboundChannelBuffer appBuffer = (InboundChannelBuffer) invocationOnMock.getArguments()[1];
+            appBuffer.ensureCapacity(appBuffer.getIndex() + bytes.length);
+            ByteBuffer[] buffers = appBuffer.sliceBuffersFrom(appBuffer.getIndex());
             assert buffers[0].remaining() > bytes.length;
             buffers[0].put(bytes);
-            buffer.incrementIndex(bytes.length);
+            appBuffer.incrementIndex(bytes.length);
             return bytes.length;
         };
     }

+ 106 - 69
x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/SSLDriverTests.java

@@ -26,14 +26,16 @@ import java.security.SecureRandom;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
-import java.util.function.Supplier;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.IntFunction;
 
 public class SSLDriverTests extends ESTestCase {
 
-    private final Supplier<Page> pageSupplier = () -> new Page(ByteBuffer.allocate(1 << 14), () -> {});
-    private InboundChannelBuffer serverBuffer = new InboundChannelBuffer(pageSupplier);
-    private InboundChannelBuffer clientBuffer = new InboundChannelBuffer(pageSupplier);
-    private InboundChannelBuffer genericBuffer = new InboundChannelBuffer(pageSupplier);
+    private final IntFunction<Page> pageAllocator = (n) -> new Page(ByteBuffer.allocate(n), () -> {});
+
+    private final InboundChannelBuffer networkReadBuffer = new InboundChannelBuffer(pageAllocator);
+    private final InboundChannelBuffer applicationBuffer = new InboundChannelBuffer(pageAllocator);
+    private final AtomicInteger openPages = new AtomicInteger(0);
 
     public void testPingPongAndClose() throws Exception {
         SSLContext sslContext = getSSLContext();
@@ -44,19 +46,36 @@ public class SSLDriverTests extends ESTestCase {
         handshake(clientDriver, serverDriver);
 
         ByteBuffer[] buffers = {ByteBuffer.wrap("ping".getBytes(StandardCharsets.UTF_8))};
-        sendAppData(clientDriver, serverDriver, buffers);
-        serverDriver.read(serverBuffer);
-        assertEquals(ByteBuffer.wrap("ping".getBytes(StandardCharsets.UTF_8)), serverBuffer.sliceBuffersTo(4)[0]);
+        sendAppData(clientDriver, buffers);
+        serverDriver.read(networkReadBuffer, applicationBuffer);
+        assertEquals(ByteBuffer.wrap("ping".getBytes(StandardCharsets.UTF_8)), applicationBuffer.sliceBuffersTo(4)[0]);
+        applicationBuffer.release(4);
 
         ByteBuffer[] buffers2 = {ByteBuffer.wrap("pong".getBytes(StandardCharsets.UTF_8))};
-        sendAppData(serverDriver, clientDriver, buffers2);
-        clientDriver.read(clientBuffer);
-        assertEquals(ByteBuffer.wrap("pong".getBytes(StandardCharsets.UTF_8)), clientBuffer.sliceBuffersTo(4)[0]);
+        sendAppData(serverDriver, buffers2);
+        clientDriver.read(networkReadBuffer, applicationBuffer);
+        assertEquals(ByteBuffer.wrap("pong".getBytes(StandardCharsets.UTF_8)), applicationBuffer.sliceBuffersTo(4)[0]);
+        applicationBuffer.release(4);
 
         assertFalse(clientDriver.needsNonApplicationWrite());
         normalClose(clientDriver, serverDriver);
     }
 
+    public void testDataStoredInOutboundBufferIsClosed() throws Exception {
+        SSLContext sslContext = getSSLContext();
+
+        SSLDriver clientDriver = getDriver(sslContext.createSSLEngine(), true);
+        SSLDriver serverDriver = getDriver(sslContext.createSSLEngine(), false);
+
+        handshake(clientDriver, serverDriver);
+
+        ByteBuffer[] buffers = {ByteBuffer.wrap("ping".getBytes(StandardCharsets.UTF_8))};
+        serverDriver.write(new FlushOperation(buffers, (v, e) -> {}));
+
+        expectThrows(SSLException.class, serverDriver::close);
+        assertEquals(0, openPages.get());
+    }
+
     public void testRenegotiate() throws Exception {
         SSLContext sslContext = getSSLContext();
 
@@ -73,9 +92,10 @@ public class SSLDriverTests extends ESTestCase {
         handshake(clientDriver, serverDriver);
 
         ByteBuffer[] buffers = {ByteBuffer.wrap("ping".getBytes(StandardCharsets.UTF_8))};
-        sendAppData(clientDriver, serverDriver, buffers);
-        serverDriver.read(serverBuffer);
-        assertEquals(ByteBuffer.wrap("ping".getBytes(StandardCharsets.UTF_8)), serverBuffer.sliceBuffersTo(4)[0]);
+        sendAppData(clientDriver, buffers);
+        serverDriver.read(networkReadBuffer, applicationBuffer);
+        assertEquals(ByteBuffer.wrap("ping".getBytes(StandardCharsets.UTF_8)), applicationBuffer.sliceBuffersTo(4)[0]);
+        applicationBuffer.release(4);
 
         clientDriver.renegotiate();
         assertTrue(clientDriver.isHandshaking());
@@ -83,17 +103,20 @@ public class SSLDriverTests extends ESTestCase {
 
         // This tests that the client driver can still receive data based on the prior handshake
         ByteBuffer[] buffers2 = {ByteBuffer.wrap("pong".getBytes(StandardCharsets.UTF_8))};
-        sendAppData(serverDriver, clientDriver, buffers2);
-        clientDriver.read(clientBuffer);
-        assertEquals(ByteBuffer.wrap("pong".getBytes(StandardCharsets.UTF_8)), clientBuffer.sliceBuffersTo(4)[0]);
+        sendAppData(serverDriver, buffers2);
+        clientDriver.read(networkReadBuffer, applicationBuffer);
+        assertEquals(ByteBuffer.wrap("pong".getBytes(StandardCharsets.UTF_8)), applicationBuffer.sliceBuffersTo(4)[0]);
+        applicationBuffer.release(4);
 
         handshake(clientDriver, serverDriver, true);
-        sendAppData(clientDriver, serverDriver, buffers);
-        serverDriver.read(serverBuffer);
-        assertEquals(ByteBuffer.wrap("ping".getBytes(StandardCharsets.UTF_8)), serverBuffer.sliceBuffersTo(4)[0]);
-        sendAppData(serverDriver, clientDriver, buffers2);
-        clientDriver.read(clientBuffer);
-        assertEquals(ByteBuffer.wrap("pong".getBytes(StandardCharsets.UTF_8)), clientBuffer.sliceBuffersTo(4)[0]);
+        sendAppData(clientDriver, buffers);
+        serverDriver.read(networkReadBuffer, applicationBuffer);
+        assertEquals(ByteBuffer.wrap("ping".getBytes(StandardCharsets.UTF_8)), applicationBuffer.sliceBuffersTo(4)[0]);
+        applicationBuffer.release(4);
+        sendAppData(serverDriver, buffers2);
+        clientDriver.read(networkReadBuffer, applicationBuffer);
+        assertEquals(ByteBuffer.wrap("pong".getBytes(StandardCharsets.UTF_8)), applicationBuffer.sliceBuffersTo(4)[0]);
+        applicationBuffer.release(4);
 
         normalClose(clientDriver, serverDriver);
     }
@@ -108,18 +131,22 @@ public class SSLDriverTests extends ESTestCase {
 
         ByteBuffer buffer = ByteBuffer.allocate(1 << 15);
         for (int i = 0; i < (1 << 15); ++i) {
-            buffer.put((byte) i);
+            buffer.put((byte) (i % 127));
         }
+        buffer.flip();
         ByteBuffer[] buffers = {buffer};
-        sendAppData(clientDriver, serverDriver, buffers);
-        serverDriver.read(serverBuffer);
-        assertEquals(16384, serverBuffer.sliceBuffersFrom(0)[0].limit());
-        assertEquals(16384, serverBuffer.sliceBuffersFrom(0)[1].limit());
+        sendAppData(clientDriver, buffers);
+        serverDriver.read(networkReadBuffer, applicationBuffer);
+        ByteBuffer[] buffers1 = applicationBuffer.sliceBuffersFrom(0);
+        assertEquals((byte) (16383 % 127), buffers1[0].get(16383));
+        assertEquals((byte) (32767 % 127), buffers1[1].get(16383));
+        applicationBuffer.release(1 << 15);
 
         ByteBuffer[] buffers2 = {ByteBuffer.wrap("pong".getBytes(StandardCharsets.UTF_8))};
-        sendAppData(serverDriver, clientDriver, buffers2);
-        clientDriver.read(clientBuffer);
-        assertEquals(ByteBuffer.wrap("pong".getBytes(StandardCharsets.UTF_8)), clientBuffer.sliceBuffersTo(4)[0]);
+        sendAppData(serverDriver, buffers2);
+        clientDriver.read(networkReadBuffer, applicationBuffer);
+        assertEquals(ByteBuffer.wrap("pong".getBytes(StandardCharsets.UTF_8)), applicationBuffer.sliceBuffersTo(4)[0]);
+        applicationBuffer.release(4);
 
         assertFalse(clientDriver.needsNonApplicationWrite());
         normalClose(clientDriver, serverDriver);
@@ -193,16 +220,16 @@ public class SSLDriverTests extends ESTestCase {
         serverDriver.initiateClose();
         assertTrue(serverDriver.needsNonApplicationWrite());
         assertFalse(serverDriver.isClosed());
-        sendNonApplicationWrites(serverDriver, clientDriver);
+        sendNonApplicationWrites(serverDriver);
         // We are immediately fully closed due to SSLEngine inconsistency
         assertTrue(serverDriver.isClosed());
-        // This should not throw exception yet as the SSLEngine will not UNWRAP data while attempting to WRAP
-        clientDriver.read(clientBuffer);
-        sendNonApplicationWrites(clientDriver, serverDriver);
-        clientDriver.read(clientBuffer);
-        sendNonApplicationWrites(clientDriver, serverDriver);
-        serverDriver.read(serverBuffer);
+
+        SSLException sslException = expectThrows(SSLException.class, () -> clientDriver.read(networkReadBuffer, applicationBuffer));
+        assertEquals("Received close_notify during handshake", sslException.getMessage());
+        sendNonApplicationWrites(clientDriver);
         assertTrue(clientDriver.isClosed());
+
+        serverDriver.read(networkReadBuffer, applicationBuffer);
     }
 
     public void testCloseDuringHandshakePreJDK11() throws Exception {
@@ -226,17 +253,17 @@ public class SSLDriverTests extends ESTestCase {
         serverDriver.initiateClose();
         assertTrue(serverDriver.needsNonApplicationWrite());
         assertFalse(serverDriver.isClosed());
-        sendNonApplicationWrites(serverDriver, clientDriver);
+        sendNonApplicationWrites(serverDriver);
         // We are immediately fully closed due to SSLEngine inconsistency
         assertTrue(serverDriver.isClosed());
         // This should not throw exception yet as the SSLEngine will not UNWRAP data while attempting to WRAP
-        clientDriver.read(clientBuffer);
-        sendNonApplicationWrites(clientDriver, serverDriver);
-        SSLException sslException = expectThrows(SSLException.class, () -> clientDriver.read(clientBuffer));
+        clientDriver.read(networkReadBuffer, applicationBuffer);
+        sendNonApplicationWrites(clientDriver);
+        SSLException sslException = expectThrows(SSLException.class, () -> clientDriver.read(networkReadBuffer, applicationBuffer));
         assertEquals("Received close_notify during handshake", sslException.getMessage());
         assertTrue(clientDriver.needsNonApplicationWrite());
-        sendNonApplicationWrites(clientDriver, serverDriver);
-        serverDriver.read(serverBuffer);
+        sendNonApplicationWrites(clientDriver);
+        serverDriver.read(networkReadBuffer, applicationBuffer);
         assertTrue(clientDriver.isClosed());
     }
 
@@ -244,11 +271,11 @@ public class SSLDriverTests extends ESTestCase {
         assertTrue(sendDriver.needsNonApplicationWrite());
         assertFalse(sendDriver.isClosed());
 
-        sendNonApplicationWrites(sendDriver, receiveDriver);
+        sendNonApplicationWrites(sendDriver);
         assertTrue(sendDriver.isClosed());
         sendDriver.close();
 
-        SSLException sslException = expectThrows(SSLException.class, () -> receiveDriver.read(genericBuffer));
+        SSLException sslException = expectThrows(SSLException.class, () -> receiveDriver.read(networkReadBuffer, applicationBuffer));
         assertTrue("Expected one of the following exception messages: " + messages + ". Found: " + sslException.getMessage(),
             messages.stream().anyMatch(m -> sslException.getMessage().equals(m)));
         if (receiveDriver.needsNonApplicationWrite() == false) {
@@ -277,29 +304,30 @@ public class SSLDriverTests extends ESTestCase {
         sendDriver.initiateClose();
         assertFalse(sendDriver.readyForApplicationWrites());
         assertTrue(sendDriver.needsNonApplicationWrite());
-        sendNonApplicationWrites(sendDriver, receiveDriver);
+        sendNonApplicationWrites(sendDriver);
         assertFalse(sendDriver.isClosed());
 
-        receiveDriver.read(genericBuffer);
+        receiveDriver.read(networkReadBuffer, applicationBuffer);
         assertFalse(receiveDriver.isClosed());
 
         assertFalse(receiveDriver.readyForApplicationWrites());
         assertTrue(receiveDriver.needsNonApplicationWrite());
-        sendNonApplicationWrites(receiveDriver, sendDriver);
+        sendNonApplicationWrites(receiveDriver);
         assertTrue(receiveDriver.isClosed());
 
-        sendDriver.read(genericBuffer);
+        sendDriver.read(networkReadBuffer, applicationBuffer);
         assertTrue(sendDriver.isClosed());
 
         sendDriver.close();
         receiveDriver.close();
+        assertEquals(0, openPages.get());
     }
 
-    private void sendNonApplicationWrites(SSLDriver sendDriver, SSLDriver receiveDriver) throws SSLException {
+    private void sendNonApplicationWrites(SSLDriver sendDriver) throws SSLException {
         SSLOutboundBuffer outboundBuffer = sendDriver.getOutboundBuffer();
         while (sendDriver.needsNonApplicationWrite() || outboundBuffer.hasEncryptedBytesToFlush()) {
             if (outboundBuffer.hasEncryptedBytesToFlush()) {
-                sendData(outboundBuffer.buildNetworkFlushOperation(), receiveDriver);
+                sendData(outboundBuffer.buildNetworkFlushOperation());
             } else {
                 sendDriver.nonApplicationWrite();
             }
@@ -345,8 +373,8 @@ public class SSLDriverTests extends ESTestCase {
 
         while (sendDriver.needsNonApplicationWrite() || outboundBuffer.hasEncryptedBytesToFlush()) {
             if (outboundBuffer.hasEncryptedBytesToFlush()) {
-                sendData(outboundBuffer.buildNetworkFlushOperation(), receiveDriver);
-                receiveDriver.read(genericBuffer);
+                sendData(outboundBuffer.buildNetworkFlushOperation());
+                receiveDriver.read(networkReadBuffer, applicationBuffer);
             } else {
                 sendDriver.nonApplicationWrite();
             }
@@ -356,37 +384,46 @@ public class SSLDriverTests extends ESTestCase {
         }
     }
 
-    private void sendAppData(SSLDriver sendDriver, SSLDriver receiveDriver, ByteBuffer[] message) throws IOException {
+    private void sendAppData(SSLDriver sendDriver, ByteBuffer[] message) throws IOException {
         assertFalse(sendDriver.needsNonApplicationWrite());
 
-        int bytesToEncrypt = Arrays.stream(message).mapToInt(Buffer::remaining).sum();
-        SSLOutboundBuffer outboundBuffer = sendDriver.getOutboundBuffer();
         FlushOperation flushOperation = new FlushOperation(message, (r, l) -> {});
 
-        int bytesEncrypted = 0;
-        while (bytesToEncrypt > bytesEncrypted) {
-            bytesEncrypted += sendDriver.write(flushOperation);
-            sendData(outboundBuffer.buildNetworkFlushOperation(), receiveDriver);
+        while (flushOperation.isFullyFlushed() == false) {
+            sendDriver.write(flushOperation);
         }
+        sendData(sendDriver.getOutboundBuffer().buildNetworkFlushOperation());
     }
 
-    private void sendData(FlushOperation flushOperation, SSLDriver receiveDriver) {
-        ByteBuffer readBuffer = receiveDriver.getNetworkReadBuffer();
+    private void sendData(FlushOperation flushOperation) {
         ByteBuffer[] writeBuffers = flushOperation.getBuffersToWrite();
-        int bytesToEncrypt = Arrays.stream(writeBuffers).mapToInt(Buffer::remaining).sum();
-        assert bytesToEncrypt < readBuffer.capacity() : "Flush operation must be less that read buffer";
+        int bytesToCopy = Arrays.stream(writeBuffers).mapToInt(Buffer::remaining).sum();
+        networkReadBuffer.ensureCapacity(bytesToCopy + networkReadBuffer.getIndex());
+        ByteBuffer[] byteBuffers = networkReadBuffer.sliceBuffersFrom(0);
         assert  writeBuffers.length > 0 : "No write buffers";
 
-        for (ByteBuffer writeBuffer : writeBuffers) {
-            int written = writeBuffer.remaining();
+        int r = 0;
+        while (flushOperation.isFullyFlushed() == false) {
+            ByteBuffer readBuffer = byteBuffers[r];
+            ByteBuffer writeBuffer = flushOperation.getBuffersToWrite()[0];
+            int toWrite = Math.min(writeBuffer.remaining(), readBuffer.remaining());
+            writeBuffer.limit(writeBuffer.position() + toWrite);
             readBuffer.put(writeBuffer);
-            flushOperation.incrementIndex(written);
+            flushOperation.incrementIndex(toWrite);
+            if (readBuffer.remaining() == 0) {
+                r++;
+            }
         }
+        networkReadBuffer.incrementIndex(bytesToCopy);
 
         assertTrue(flushOperation.isFullyFlushed());
+        flushOperation.getListener().accept(null, null);
     }
 
     private SSLDriver getDriver(SSLEngine engine, boolean isClient) {
-        return new SSLDriver(engine, isClient);
+        return new SSLDriver(engine, (n) -> {
+            openPages.incrementAndGet();
+            return new Page(ByteBuffer.allocate(n), openPages::decrementAndGet);
+        }, isClient);
     }
 }