|  | @@ -22,6 +22,7 @@ package org.elasticsearch.nio;
 | 
	
		
			
				|  |  |  import org.elasticsearch.test.ESTestCase;
 | 
	
		
			
				|  |  |  import org.junit.Before;
 | 
	
		
			
				|  |  |  import org.mockito.ArgumentCaptor;
 | 
	
		
			
				|  |  | +import org.mockito.stubbing.Answer;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  import java.io.IOException;
 | 
	
		
			
				|  |  |  import java.nio.ByteBuffer;
 | 
	
	
		
			
				|  | @@ -54,6 +55,7 @@ public class SocketChannelContextTests extends ESTestCase {
 | 
	
		
			
				|  |  |      private BiConsumer<Void, Exception> listener;
 | 
	
		
			
				|  |  |      private NioSelector selector;
 | 
	
		
			
				|  |  |      private ReadWriteHandler readWriteHandler;
 | 
	
		
			
				|  |  | +    private ByteBuffer ioBuffer = ByteBuffer.allocate(1024);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      @SuppressWarnings("unchecked")
 | 
	
		
			
				|  |  |      @Before
 | 
	
	
		
			
				|  | @@ -71,6 +73,10 @@ public class SocketChannelContextTests extends ESTestCase {
 | 
	
		
			
				|  |  |          context = new TestSocketChannelContext(channel, selector, exceptionHandler, readWriteHandler, channelBuffer);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          when(selector.isOnCurrentThread()).thenReturn(true);
 | 
	
		
			
				|  |  | +        when(selector.getIoBuffer()).thenAnswer(invocationOnMock -> {
 | 
	
		
			
				|  |  | +            ioBuffer.clear();
 | 
	
		
			
				|  |  | +            return ioBuffer;
 | 
	
		
			
				|  |  | +        });
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      public void testIOExceptionSetIfEncountered() throws IOException {
 | 
	
	
		
			
				|  | @@ -90,7 +96,6 @@ public class SocketChannelContextTests extends ESTestCase {
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      public void testSignalWhenPeerClosed() throws IOException {
 | 
	
		
			
				|  |  | -        when(rawChannel.read(any(ByteBuffer[].class), anyInt(), anyInt())).thenReturn(-1L);
 | 
	
		
			
				|  |  |          when(rawChannel.read(any(ByteBuffer.class))).thenReturn(-1);
 | 
	
		
			
				|  |  |          assertFalse(context.closeNow());
 | 
	
		
			
				|  |  |          context.read();
 | 
	
	
		
			
				|  | @@ -289,6 +294,153 @@ public class SocketChannelContextTests extends ESTestCase {
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +    public void testReadToBufferLimitsToPassedBuffer() throws IOException {
 | 
	
		
			
				|  |  | +        ByteBuffer buffer = ByteBuffer.allocate(10);
 | 
	
		
			
				|  |  | +        when(rawChannel.read(any(ByteBuffer.class))).thenAnswer(completelyFillBufferAnswer());
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        int bytesRead = context.readFromChannel(buffer);
 | 
	
		
			
				|  |  | +        assertEquals(bytesRead, 10);
 | 
	
		
			
				|  |  | +        assertEquals(0, buffer.remaining());
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    public void testReadToBufferHandlesIOException() throws IOException {
 | 
	
		
			
				|  |  | +        when(rawChannel.read(any(ByteBuffer.class))).thenThrow(new IOException());
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        expectThrows(IOException.class, () -> context.readFromChannel(ByteBuffer.allocate(10)));
 | 
	
		
			
				|  |  | +        assertTrue(context.closeNow());
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    public void testReadToBufferHandlesEOF() throws IOException {
 | 
	
		
			
				|  |  | +        when(rawChannel.read(any(ByteBuffer.class))).thenReturn(-1);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        context.readFromChannel(ByteBuffer.allocate(10));
 | 
	
		
			
				|  |  | +        assertTrue(context.closeNow());
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    public void testReadToChannelBufferWillReadAsMuchAsIOBufferAllows() throws IOException {
 | 
	
		
			
				|  |  | +        when(rawChannel.read(any(ByteBuffer.class))).thenAnswer(completelyFillBufferAnswer());
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        InboundChannelBuffer channelBuffer = InboundChannelBuffer.allocatingInstance();
 | 
	
		
			
				|  |  | +        int bytesRead = context.readFromChannel(channelBuffer);
 | 
	
		
			
				|  |  | +        assertEquals(ioBuffer.capacity(), bytesRead);
 | 
	
		
			
				|  |  | +        assertEquals(ioBuffer.capacity(), channelBuffer.getIndex());
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    public void testReadToChannelBufferHandlesIOException() throws IOException  {
 | 
	
		
			
				|  |  | +        when(rawChannel.read(any(ByteBuffer.class))).thenThrow(new IOException());
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        InboundChannelBuffer channelBuffer = InboundChannelBuffer.allocatingInstance();
 | 
	
		
			
				|  |  | +        expectThrows(IOException.class, () -> context.readFromChannel(channelBuffer));
 | 
	
		
			
				|  |  | +        assertTrue(context.closeNow());
 | 
	
		
			
				|  |  | +        assertEquals(0, channelBuffer.getIndex());
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    public void testReadToChannelBufferHandlesEOF() throws IOException {
 | 
	
		
			
				|  |  | +        when(rawChannel.read(any(ByteBuffer.class))).thenReturn(-1);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        InboundChannelBuffer channelBuffer = InboundChannelBuffer.allocatingInstance();
 | 
	
		
			
				|  |  | +        context.readFromChannel(channelBuffer);
 | 
	
		
			
				|  |  | +        assertTrue(context.closeNow());
 | 
	
		
			
				|  |  | +        assertEquals(0, channelBuffer.getIndex());
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    public void testFlushBufferHandlesPartialFlush() throws IOException {
 | 
	
		
			
				|  |  | +        int bytesToConsume = 3;
 | 
	
		
			
				|  |  | +        when(rawChannel.write(any(ByteBuffer.class))).thenAnswer(consumeBufferAnswer(bytesToConsume));
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        ByteBuffer buffer = ByteBuffer.allocate(10);
 | 
	
		
			
				|  |  | +        context.flushToChannel(buffer);
 | 
	
		
			
				|  |  | +        assertEquals(10 - bytesToConsume, buffer.remaining());
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    public void testFlushBufferHandlesFullFlush() throws IOException {
 | 
	
		
			
				|  |  | +        int bytesToConsume = 10;
 | 
	
		
			
				|  |  | +        when(rawChannel.write(any(ByteBuffer.class))).thenAnswer(consumeBufferAnswer(bytesToConsume));
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        ByteBuffer buffer = ByteBuffer.allocate(10);
 | 
	
		
			
				|  |  | +        context.flushToChannel(buffer);
 | 
	
		
			
				|  |  | +        assertEquals(0, buffer.remaining());
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    public void testFlushBufferHandlesIOException() throws IOException {
 | 
	
		
			
				|  |  | +        when(rawChannel.write(any(ByteBuffer.class))).thenThrow(new IOException());
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        ByteBuffer buffer = ByteBuffer.allocate(10);
 | 
	
		
			
				|  |  | +        expectThrows(IOException.class, () -> context.flushToChannel(buffer));
 | 
	
		
			
				|  |  | +        assertTrue(context.closeNow());
 | 
	
		
			
				|  |  | +        assertEquals(10, buffer.remaining());
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    public void testFlushBuffersHandlesZeroFlush() throws IOException {
 | 
	
		
			
				|  |  | +        when(rawChannel.write(any(ByteBuffer.class))).thenAnswer(consumeBufferAnswer(0));
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        ByteBuffer[] buffers = {ByteBuffer.allocate(1023), ByteBuffer.allocate(1023)};
 | 
	
		
			
				|  |  | +        FlushOperation flushOperation = new FlushOperation(buffers, listener);
 | 
	
		
			
				|  |  | +        context.flushToChannel(flushOperation);
 | 
	
		
			
				|  |  | +        assertEquals(2, flushOperation.getBuffersToWrite().length);
 | 
	
		
			
				|  |  | +        assertEquals(0, flushOperation.getBuffersToWrite()[0].position());
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    public void testFlushBuffersHandlesPartialFlush() throws IOException {
 | 
	
		
			
				|  |  | +        AtomicBoolean first = new AtomicBoolean(true);
 | 
	
		
			
				|  |  | +        when(rawChannel.write(any(ByteBuffer.class))).thenAnswer(invocationOnMock -> {
 | 
	
		
			
				|  |  | +            if (first.compareAndSet(true, false)) {
 | 
	
		
			
				|  |  | +                return consumeBufferAnswer(1024).answer(invocationOnMock);
 | 
	
		
			
				|  |  | +            } else {
 | 
	
		
			
				|  |  | +                return consumeBufferAnswer(3).answer(invocationOnMock);
 | 
	
		
			
				|  |  | +            }
 | 
	
		
			
				|  |  | +        });
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        ByteBuffer[] buffers = {ByteBuffer.allocate(1023), ByteBuffer.allocate(1023)};
 | 
	
		
			
				|  |  | +        FlushOperation flushOperation = new FlushOperation(buffers, listener);
 | 
	
		
			
				|  |  | +        context.flushToChannel(flushOperation);
 | 
	
		
			
				|  |  | +        assertEquals(1, flushOperation.getBuffersToWrite().length);
 | 
	
		
			
				|  |  | +        assertEquals(4, flushOperation.getBuffersToWrite()[0].position());
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    public void testFlushBuffersHandlesFullFlush() throws IOException {
 | 
	
		
			
				|  |  | +        AtomicBoolean first = new AtomicBoolean(true);
 | 
	
		
			
				|  |  | +        when(rawChannel.write(any(ByteBuffer.class))).thenAnswer(invocationOnMock -> {
 | 
	
		
			
				|  |  | +            if (first.compareAndSet(true, false)) {
 | 
	
		
			
				|  |  | +                return consumeBufferAnswer(1024).answer(invocationOnMock);
 | 
	
		
			
				|  |  | +            } else {
 | 
	
		
			
				|  |  | +                return consumeBufferAnswer(1022).answer(invocationOnMock);
 | 
	
		
			
				|  |  | +            }
 | 
	
		
			
				|  |  | +        });
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        ByteBuffer[] buffers = {ByteBuffer.allocate(1023), ByteBuffer.allocate(1023)};
 | 
	
		
			
				|  |  | +        FlushOperation flushOperation = new FlushOperation(buffers, listener);
 | 
	
		
			
				|  |  | +        context.flushToChannel(flushOperation);
 | 
	
		
			
				|  |  | +        assertTrue(flushOperation.isFullyFlushed());
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    public void testFlushBuffersHandlesIOException() throws IOException {
 | 
	
		
			
				|  |  | +        when(rawChannel.write(any(ByteBuffer.class))).thenThrow(new IOException());
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        ByteBuffer[] buffers = {ByteBuffer.allocate(10), ByteBuffer.allocate(10)};
 | 
	
		
			
				|  |  | +        FlushOperation flushOperation = new FlushOperation(buffers, listener);
 | 
	
		
			
				|  |  | +        expectThrows(IOException.class, () -> context.flushToChannel(flushOperation));
 | 
	
		
			
				|  |  | +        assertTrue(context.closeNow());
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    public void testFlushBuffersHandlesIOExceptionSecondTimeThroughLoop() throws IOException {
 | 
	
		
			
				|  |  | +        AtomicBoolean first = new AtomicBoolean(true);
 | 
	
		
			
				|  |  | +        when(rawChannel.write(any(ByteBuffer.class))).thenAnswer(invocationOnMock -> {
 | 
	
		
			
				|  |  | +            if (first.compareAndSet(true, false)) {
 | 
	
		
			
				|  |  | +                return consumeBufferAnswer(1024).answer(invocationOnMock);
 | 
	
		
			
				|  |  | +            } else {
 | 
	
		
			
				|  |  | +                throw new IOException();
 | 
	
		
			
				|  |  | +            }
 | 
	
		
			
				|  |  | +        });
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        ByteBuffer[] buffers = {ByteBuffer.allocate(1023), ByteBuffer.allocate(1023)};
 | 
	
		
			
				|  |  | +        FlushOperation flushOperation = new FlushOperation(buffers, listener);
 | 
	
		
			
				|  |  | +        expectThrows(IOException.class, () -> context.flushToChannel(flushOperation));
 | 
	
		
			
				|  |  | +        assertTrue(context.closeNow());
 | 
	
		
			
				|  |  | +        assertEquals(1, flushOperation.getBuffersToWrite().length);
 | 
	
		
			
				|  |  | +        assertEquals(1, flushOperation.getBuffersToWrite()[0].position());
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      private static class TestSocketChannelContext extends SocketChannelContext {
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          private TestSocketChannelContext(NioSocketChannel channel, NioSelector selector, Consumer<Exception> exceptionHandler,
 | 
	
	
		
			
				|  | @@ -305,8 +457,8 @@ public class SocketChannelContextTests extends ESTestCase {
 | 
	
		
			
				|  |  |          @Override
 | 
	
		
			
				|  |  |          public int read() throws IOException {
 | 
	
		
			
				|  |  |              if (randomBoolean()) {
 | 
	
		
			
				|  |  | -                ByteBuffer[] byteBuffers = {ByteBuffer.allocate(10)};
 | 
	
		
			
				|  |  | -                return readFromChannel(byteBuffers);
 | 
	
		
			
				|  |  | +                InboundChannelBuffer channelBuffer = InboundChannelBuffer.allocatingInstance();
 | 
	
		
			
				|  |  | +                return readFromChannel(channelBuffer);
 | 
	
		
			
				|  |  |              } else {
 | 
	
		
			
				|  |  |                  return readFromChannel(ByteBuffer.allocate(10));
 | 
	
		
			
				|  |  |              }
 | 
	
	
		
			
				|  | @@ -316,7 +468,7 @@ public class SocketChannelContextTests extends ESTestCase {
 | 
	
		
			
				|  |  |          public void flushChannel() throws IOException {
 | 
	
		
			
				|  |  |              if (randomBoolean()) {
 | 
	
		
			
				|  |  |                  ByteBuffer[] byteBuffers = {ByteBuffer.allocate(10)};
 | 
	
		
			
				|  |  | -                flushToChannel(byteBuffers);
 | 
	
		
			
				|  |  | +                flushToChannel(new FlushOperation(byteBuffers, (v, e) -> {}));
 | 
	
		
			
				|  |  |              } else {
 | 
	
		
			
				|  |  |                  flushToChannel(ByteBuffer.allocate(10));
 | 
	
		
			
				|  |  |              }
 | 
	
	
		
			
				|  | @@ -345,4 +497,23 @@ public class SocketChannelContextTests extends ESTestCase {
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |          return bytes;
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    private Answer<Integer> completelyFillBufferAnswer() {
 | 
	
		
			
				|  |  | +        return invocationOnMock -> {
 | 
	
		
			
				|  |  | +            ByteBuffer b = (ByteBuffer) invocationOnMock.getArguments()[0];
 | 
	
		
			
				|  |  | +            int bytesRead = b.remaining();
 | 
	
		
			
				|  |  | +            while (b.hasRemaining()) {
 | 
	
		
			
				|  |  | +                b.put((byte) 1);
 | 
	
		
			
				|  |  | +            }
 | 
	
		
			
				|  |  | +            return bytesRead;
 | 
	
		
			
				|  |  | +        };
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    private Answer<Object> consumeBufferAnswer(int bytesToConsume) {
 | 
	
		
			
				|  |  | +        return invocationOnMock -> {
 | 
	
		
			
				|  |  | +            ByteBuffer b = (ByteBuffer) invocationOnMock.getArguments()[0];
 | 
	
		
			
				|  |  | +            b.position(b.position() + bytesToConsume);
 | 
	
		
			
				|  |  | +            return bytesToConsume;
 | 
	
		
			
				|  |  | +        };
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  |  }
 |