Explorar el Código

Move nio ip filter rule to be a channel handler (#43507)

Currently nio implements ip filtering at the channel context level. This
is kind of a hack as the application logic should be implemented at the
handler level. This commit moves the ip filtering into a channel
handler. This requires adding an indicator to the channel handler to
show when a channel should be closed.
Tim Brooks hace 6 años
padre
commit
893785a758

+ 2 - 9
libs/nio/src/main/java/org/elasticsearch/nio/BytesChannelContext.java

@@ -21,19 +21,12 @@ package org.elasticsearch.nio;
 
 import java.io.IOException;
 import java.util.function.Consumer;
-import java.util.function.Predicate;
 
 public class BytesChannelContext extends SocketChannelContext {
 
     public BytesChannelContext(NioSocketChannel channel, NioSelector selector, Consumer<Exception> exceptionHandler,
-                               ReadWriteHandler handler, InboundChannelBuffer channelBuffer) {
-        this(channel, selector, exceptionHandler, handler, channelBuffer, ALWAYS_ALLOW_CHANNEL);
-    }
-
-    public BytesChannelContext(NioSocketChannel channel, NioSelector selector, Consumer<Exception> exceptionHandler,
-                               ReadWriteHandler handler, InboundChannelBuffer channelBuffer,
-                               Predicate<NioSocketChannel> allowChannelPredicate) {
-        super(channel, selector, exceptionHandler, handler, channelBuffer, allowChannelPredicate);
+                               NioChannelHandler handler, InboundChannelBuffer channelBuffer) {
+        super(channel, selector, exceptionHandler, handler, channelBuffer);
     }
 
     @Override

+ 6 - 1
libs/nio/src/main/java/org/elasticsearch/nio/BytesWriteHandler.java

@@ -24,7 +24,7 @@ import java.util.Collections;
 import java.util.List;
 import java.util.function.BiConsumer;
 
-public abstract class BytesWriteHandler implements ReadWriteHandler {
+public abstract class BytesWriteHandler implements NioChannelHandler {
 
     private static final List<FlushOperation> EMPTY_LIST = Collections.emptyList();
 
@@ -48,6 +48,11 @@ public abstract class BytesWriteHandler implements ReadWriteHandler {
         return EMPTY_LIST;
     }
 
+    @Override
+    public boolean closeNow() {
+        return false;
+    }
+
     @Override
     public void close() {}
 }

+ 68 - 0
libs/nio/src/main/java/org/elasticsearch/nio/DelegatingHandler.java

@@ -0,0 +1,68 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.nio;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.function.BiConsumer;
+
+public abstract class DelegatingHandler implements NioChannelHandler {
+
+    private NioChannelHandler delegate;
+
+    public DelegatingHandler(NioChannelHandler delegate) {
+        this.delegate = delegate;
+    }
+
+    @Override
+    public void channelRegistered() {
+        this.delegate.channelRegistered();
+    }
+
+    @Override
+    public WriteOperation createWriteOperation(SocketChannelContext context, Object message, BiConsumer<Void, Exception> listener) {
+        return delegate.createWriteOperation(context, message, listener);
+    }
+
+    @Override
+    public List<FlushOperation> writeToBytes(WriteOperation writeOperation) {
+        return delegate.writeToBytes(writeOperation);
+    }
+
+    @Override
+    public List<FlushOperation> pollFlushOperations() {
+        return delegate.pollFlushOperations();
+    }
+
+    @Override
+    public int consumeReads(InboundChannelBuffer channelBuffer) throws IOException {
+        return delegate.consumeReads(channelBuffer);
+    }
+
+    @Override
+    public boolean closeNow() {
+        return delegate.closeNow();
+    }
+
+    @Override
+    public void close() throws IOException {
+        delegate.close();
+    }
+}

+ 9 - 2
libs/nio/src/main/java/org/elasticsearch/nio/ReadWriteHandler.java → libs/nio/src/main/java/org/elasticsearch/nio/NioChannelHandler.java

@@ -24,9 +24,9 @@ import java.util.List;
 import java.util.function.BiConsumer;
 
 /**
- * Implements the application specific logic for handling inbound and outbound messages for a channel.
+ * Implements the application specific logic for handling channel operations.
  */
-public interface ReadWriteHandler {
+public interface NioChannelHandler {
 
     /**
      * This method is called when the channel is registered with its selector.
@@ -72,5 +72,12 @@ public interface ReadWriteHandler {
      */
     int consumeReads(InboundChannelBuffer channelBuffer) throws IOException;
 
+    /**
+     * This method indicates if the underlying channel should be closed.
+     *
+     * @return if the channel should be closed
+     */
+    boolean closeNow();
+
     void close() throws IOException;
 }

+ 3 - 12
libs/nio/src/main/java/org/elasticsearch/nio/SocketChannelContext.java

@@ -32,7 +32,6 @@ import java.util.LinkedList;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.function.BiConsumer;
 import java.util.function.Consumer;
-import java.util.function.Predicate;
 
 /**
  * This context should implement the specific logic for a channel. When a channel receives a notification
@@ -45,13 +44,10 @@ import java.util.function.Predicate;
  */
 public abstract class SocketChannelContext extends ChannelContext<SocketChannel> {
 
-    protected static final Predicate<NioSocketChannel> ALWAYS_ALLOW_CHANNEL = (c) -> true;
-
     protected final NioSocketChannel channel;
     protected final InboundChannelBuffer channelBuffer;
     protected final AtomicBoolean isClosing = new AtomicBoolean(false);
-    private final ReadWriteHandler readWriteHandler;
-    private final Predicate<NioSocketChannel> allowChannelPredicate;
+    private final NioChannelHandler readWriteHandler;
     private final NioSelector selector;
     private final CompletableContext<Void> connectContext = new CompletableContext<>();
     private final LinkedList<FlushOperation> pendingFlushes = new LinkedList<>();
@@ -59,14 +55,12 @@ public abstract class SocketChannelContext extends ChannelContext<SocketChannel>
     private Exception connectException;
 
     protected SocketChannelContext(NioSocketChannel channel, NioSelector selector, Consumer<Exception> exceptionHandler,
-                                   ReadWriteHandler readWriteHandler, InboundChannelBuffer channelBuffer,
-                                   Predicate<NioSocketChannel> allowChannelPredicate) {
+                                   NioChannelHandler readWriteHandler, InboundChannelBuffer channelBuffer) {
         super(channel.getRawChannel(), exceptionHandler);
         this.selector = selector;
         this.channel = channel;
         this.readWriteHandler = readWriteHandler;
         this.channelBuffer = channelBuffer;
-        this.allowChannelPredicate = allowChannelPredicate;
     }
 
     @Override
@@ -171,9 +165,6 @@ public abstract class SocketChannelContext extends ChannelContext<SocketChannel>
     protected void register() throws IOException {
         super.register();
         readWriteHandler.channelRegistered();
-        if (allowChannelPredicate.test(channel) == false) {
-            closeNow = true;
-        }
     }
 
     @Override
@@ -233,7 +224,7 @@ public abstract class SocketChannelContext extends ChannelContext<SocketChannel>
     public abstract boolean selectorShouldClose();
 
     protected boolean closeNow() {
-        return closeNow;
+        return closeNow || readWriteHandler.closeNow();
     }
 
     protected void setCloseNow() {

+ 1 - 1
libs/nio/src/main/java/org/elasticsearch/nio/WriteOperation.java

@@ -23,7 +23,7 @@ import java.util.function.BiConsumer;
 /**
  * This is a basic write operation that can be queued with a channel. The only requirements of a write
  * operation is that is has a listener and a reference to its channel. The actual conversion of the write
- * operation implementation to bytes will be performed by the {@link ReadWriteHandler}.
+ * operation implementation to bytes will be performed by the {@link NioChannelHandler}.
  */
 public interface WriteOperation {
 

+ 3 - 3
libs/nio/src/test/java/org/elasticsearch/nio/EventHandlerTests.java

@@ -44,7 +44,7 @@ public class EventHandlerTests extends ESTestCase {
     private Consumer<Exception> channelExceptionHandler;
     private Consumer<Exception> genericExceptionHandler;
 
-    private ReadWriteHandler readWriteHandler;
+    private NioChannelHandler readWriteHandler;
     private EventHandler handler;
     private DoNotRegisterSocketContext context;
     private DoNotRegisterServerContext serverContext;
@@ -56,7 +56,7 @@ public class EventHandlerTests extends ESTestCase {
     public void setUpHandler() throws IOException {
         channelExceptionHandler = mock(Consumer.class);
         genericExceptionHandler = mock(Consumer.class);
-        readWriteHandler = mock(ReadWriteHandler.class);
+        readWriteHandler = mock(NioChannelHandler.class);
         channelFactory = mock(ChannelFactory.class);
         NioSelector selector = mock(NioSelector.class);
         ArrayList<NioSelector> selectors = new ArrayList<>();
@@ -260,7 +260,7 @@ public class EventHandlerTests extends ESTestCase {
 
 
         DoNotRegisterSocketContext(NioSocketChannel channel, NioSelector selector, Consumer<Exception> exceptionHandler,
-                                   ReadWriteHandler handler) {
+                                   NioChannelHandler handler) {
             super(channel, selector, exceptionHandler, handler, InboundChannelBuffer.allocatingInstance());
         }
 

+ 4 - 27
libs/nio/src/test/java/org/elasticsearch/nio/SocketChannelContextTests.java

@@ -35,7 +35,6 @@ import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.BiConsumer;
 import java.util.function.Consumer;
 import java.util.function.IntFunction;
-import java.util.function.Predicate;
 
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.anyInt;
@@ -54,7 +53,7 @@ public class SocketChannelContextTests extends ESTestCase {
     private NioSocketChannel channel;
     private BiConsumer<Void, Exception> listener;
     private NioSelector selector;
-    private ReadWriteHandler readWriteHandler;
+    private NioChannelHandler readWriteHandler;
     private ByteBuffer ioBuffer = ByteBuffer.allocate(1024);
 
     @SuppressWarnings("unchecked")
@@ -68,7 +67,7 @@ public class SocketChannelContextTests extends ESTestCase {
         when(channel.getRawChannel()).thenReturn(rawChannel);
         exceptionHandler = mock(Consumer.class);
         selector = mock(NioSelector.class);
-        readWriteHandler = mock(ReadWriteHandler.class);
+        readWriteHandler = mock(NioChannelHandler.class);
         InboundChannelBuffer channelBuffer = InboundChannelBuffer.allocatingInstance();
         context = new TestSocketChannelContext(channel, selector, exceptionHandler, readWriteHandler, channelBuffer);
 
@@ -102,22 +101,6 @@ public class SocketChannelContextTests extends ESTestCase {
         assertTrue(context.closeNow());
     }
 
-    public void testValidateInRegisterCanSucceed() throws IOException {
-        InboundChannelBuffer channelBuffer = InboundChannelBuffer.allocatingInstance();
-        context = new TestSocketChannelContext(channel, selector, exceptionHandler, readWriteHandler, channelBuffer, (c) -> true);
-        assertFalse(context.closeNow());
-        context.register();
-        assertFalse(context.closeNow());
-    }
-
-    public void testValidateInRegisterCanFail() throws IOException {
-        InboundChannelBuffer channelBuffer = InboundChannelBuffer.allocatingInstance();
-        context = new TestSocketChannelContext(channel, selector, exceptionHandler, readWriteHandler, channelBuffer, (c) -> false);
-        assertFalse(context.closeNow());
-        context.register();
-        assertTrue(context.closeNow());
-    }
-
     public void testConnectSucceeds() throws IOException {
         AtomicBoolean listenerCalled = new AtomicBoolean(false);
         when(rawChannel.finishConnect()).thenReturn(false, true);
@@ -394,14 +377,8 @@ public class SocketChannelContextTests extends ESTestCase {
     private static class TestSocketChannelContext extends SocketChannelContext {
 
         private TestSocketChannelContext(NioSocketChannel channel, NioSelector selector, Consumer<Exception> exceptionHandler,
-                                         ReadWriteHandler readWriteHandler, InboundChannelBuffer channelBuffer) {
-            this(channel, selector, exceptionHandler, readWriteHandler, channelBuffer, ALWAYS_ALLOW_CHANNEL);
-        }
-
-        private TestSocketChannelContext(NioSocketChannel channel, NioSelector selector, Consumer<Exception> exceptionHandler,
-                                         ReadWriteHandler readWriteHandler, InboundChannelBuffer channelBuffer,
-                                         Predicate<NioSocketChannel> allowChannelPredicate) {
-            super(channel, selector, exceptionHandler, readWriteHandler, channelBuffer, allowChannelPredicate);
+                                         NioChannelHandler readWriteHandler, InboundChannelBuffer channelBuffer) {
+            super(channel, selector, exceptionHandler, readWriteHandler, channelBuffer);
         }
 
         @Override

+ 7 - 2
plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/HttpReadWriteHandler.java

@@ -38,7 +38,7 @@ import org.elasticsearch.http.nio.cors.NioCorsConfig;
 import org.elasticsearch.http.nio.cors.NioCorsHandler;
 import org.elasticsearch.nio.FlushOperation;
 import org.elasticsearch.nio.InboundChannelBuffer;
-import org.elasticsearch.nio.ReadWriteHandler;
+import org.elasticsearch.nio.NioChannelHandler;
 import org.elasticsearch.nio.SocketChannelContext;
 import org.elasticsearch.nio.TaskScheduler;
 import org.elasticsearch.nio.WriteOperation;
@@ -50,7 +50,7 @@ import java.util.concurrent.TimeUnit;
 import java.util.function.BiConsumer;
 import java.util.function.LongSupplier;
 
-public class HttpReadWriteHandler implements ReadWriteHandler {
+public class HttpReadWriteHandler implements NioChannelHandler {
 
     private final NettyAdaptor adaptor;
     private final NioHttpChannel nioHttpChannel;
@@ -140,6 +140,11 @@ public class HttpReadWriteHandler implements ReadWriteHandler {
         return copiedOperations;
     }
 
+    @Override
+    public boolean closeNow() {
+        return false;
+    }
+
     @Override
     public void close() throws IOException {
         try {

+ 7 - 2
plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpClient.java

@@ -49,7 +49,7 @@ import org.elasticsearch.nio.NioSelectorGroup;
 import org.elasticsearch.nio.NioSelector;
 import org.elasticsearch.nio.NioServerSocketChannel;
 import org.elasticsearch.nio.NioSocketChannel;
-import org.elasticsearch.nio.ReadWriteHandler;
+import org.elasticsearch.nio.NioChannelHandler;
 import org.elasticsearch.nio.SocketChannelContext;
 import org.elasticsearch.nio.WriteOperation;
 import org.elasticsearch.tasks.Task;
@@ -207,7 +207,7 @@ class NioHttpClient implements Closeable {
         }
     }
 
-    private static class HttpClientHandler implements ReadWriteHandler {
+    private static class HttpClientHandler implements NioChannelHandler {
 
         private final NettyAdaptor adaptor;
         private final CountDownLatch latch;
@@ -277,6 +277,11 @@ class NioHttpClient implements Closeable {
             return bytesConsumed;
         }
 
+        @Override
+        public boolean closeNow() {
+            return false;
+        }
+
         @Override
         public void close() throws IOException {
             try {

+ 30 - 9
x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/NioIPFilter.java

@@ -5,28 +5,49 @@
  */
 package org.elasticsearch.xpack.security.transport.nio;
 
-import org.elasticsearch.common.Nullable;
-import org.elasticsearch.nio.NioSocketChannel;
+import org.elasticsearch.nio.DelegatingHandler;
+import org.elasticsearch.nio.InboundChannelBuffer;
+import org.elasticsearch.nio.NioChannelHandler;
 import org.elasticsearch.xpack.security.transport.filter.IPFilter;
 
-import java.util.function.Predicate;
+import java.io.IOException;
+import java.net.InetSocketAddress;
 
-public final class NioIPFilter implements Predicate<NioSocketChannel> {
+public final class NioIPFilter extends DelegatingHandler {
 
+    private final InetSocketAddress remoteAddress;
     private final IPFilter filter;
     private final String profile;
+    private boolean denied = false;
 
-    NioIPFilter(@Nullable IPFilter filter, String profile) {
+    NioIPFilter(NioChannelHandler delegate, InetSocketAddress remoteAddress, IPFilter filter, String profile) {
+        super(delegate);
+        this.remoteAddress = remoteAddress;
         this.filter = filter;
         this.profile = profile;
     }
 
     @Override
-    public boolean test(NioSocketChannel nioChannel) {
-        if (filter != null) {
-            return filter.accept(profile, nioChannel.getRemoteAddress());
+    public void channelRegistered() {
+        if (filter.accept(profile, remoteAddress)) {
+            super.channelRegistered();
         } else {
-            return true;
+            denied = true;
         }
     }
+
+    @Override
+    public int consumeReads(InboundChannelBuffer channelBuffer) throws IOException {
+        if (denied) {
+            // Do not consume any reads if channel is disallowed
+            return 0;
+        } else {
+            return super.consumeReads(channelBuffer);
+        }
+    }
+
+    @Override
+    public boolean closeNow() {
+        return denied;
+    }
 }

+ 6 - 8
x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContext.java

@@ -9,9 +9,9 @@ import org.elasticsearch.common.unit.TimeValue;
 import org.elasticsearch.core.internal.io.IOUtils;
 import org.elasticsearch.nio.FlushOperation;
 import org.elasticsearch.nio.InboundChannelBuffer;
+import org.elasticsearch.nio.NioChannelHandler;
 import org.elasticsearch.nio.NioSelector;
 import org.elasticsearch.nio.NioSocketChannel;
-import org.elasticsearch.nio.ReadWriteHandler;
 import org.elasticsearch.nio.SocketChannelContext;
 import org.elasticsearch.nio.WriteOperation;
 
@@ -23,12 +23,11 @@ import java.util.LinkedList;
 import java.util.concurrent.TimeUnit;
 import java.util.function.BiConsumer;
 import java.util.function.Consumer;
-import java.util.function.Predicate;
 
 /**
  * Provides a TLS/SSL read/write layer over a channel. This context will use a {@link SSLDriver} to handshake
  * with the peer channel. Once the handshake is complete, any data from the peer channel will be decrypted
- * before being passed to the {@link ReadWriteHandler}. Outbound data will be encrypted before being flushed
+ * before being passed to the {@link NioChannelHandler}. Outbound data will be encrypted before being flushed
  * to the channel.
  */
 public final class SSLChannelContext extends SocketChannelContext {
@@ -43,15 +42,14 @@ public final class SSLChannelContext extends SocketChannelContext {
     private Runnable closeTimeoutCanceller = DEFAULT_TIMEOUT_CANCELLER;
 
     SSLChannelContext(NioSocketChannel channel, NioSelector selector, Consumer<Exception> exceptionHandler, SSLDriver sslDriver,
-                      ReadWriteHandler readWriteHandler, InboundChannelBuffer applicationBuffer) {
+                      NioChannelHandler readWriteHandler, InboundChannelBuffer applicationBuffer) {
         this(channel, selector, exceptionHandler, sslDriver, readWriteHandler, InboundChannelBuffer.allocatingInstance(),
-            applicationBuffer, ALWAYS_ALLOW_CHANNEL);
+            applicationBuffer);
     }
 
     SSLChannelContext(NioSocketChannel channel, NioSelector selector, Consumer<Exception> exceptionHandler, SSLDriver sslDriver,
-                      ReadWriteHandler readWriteHandler, InboundChannelBuffer networkReadBuffer, InboundChannelBuffer channelBuffer,
-                      Predicate<NioSocketChannel> allowChannelPredicate) {
-        super(channel, selector, exceptionHandler, readWriteHandler, channelBuffer, allowChannelPredicate);
+                      NioChannelHandler readWriteHandler, InboundChannelBuffer networkReadBuffer, InboundChannelBuffer channelBuffer) {
+        super(channel, selector, exceptionHandler, readWriteHandler, channelBuffer);
         this.sslDriver = sslDriver;
         this.networkReadBuffer = networkReadBuffer;
     }

+ 11 - 5
x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioHttpServerTransport.java

@@ -19,6 +19,7 @@ import org.elasticsearch.http.nio.NioHttpServerTransport;
 import org.elasticsearch.nio.BytesChannelContext;
 import org.elasticsearch.nio.ChannelFactory;
 import org.elasticsearch.nio.InboundChannelBuffer;
+import org.elasticsearch.nio.NioChannelHandler;
 import org.elasticsearch.nio.NioSelector;
 import org.elasticsearch.nio.NioSocketChannel;
 import org.elasticsearch.nio.ServerChannelContext;
@@ -44,7 +45,6 @@ public class SecurityNioHttpServerTransport extends NioHttpServerTransport {
 
     private final SecurityHttpExceptionHandler securityExceptionHandler;
     private final IPFilter ipFilter;
-    private final NioIPFilter nioIpFilter;
     private final SSLService sslService;
     private final SSLConfiguration sslConfiguration;
     private final boolean sslEnabled;
@@ -56,7 +56,6 @@ public class SecurityNioHttpServerTransport extends NioHttpServerTransport {
         super(settings, networkService, bigArrays, pageCacheRecycler, threadPool, xContentRegistry, dispatcher, nioGroupFactory);
         this.securityExceptionHandler = new SecurityHttpExceptionHandler(logger, lifecycle, (c, e) -> super.onException(c, e));
         this.ipFilter = ipFilter;
-        this.nioIpFilter = new NioIPFilter(ipFilter, IPFilter.HTTP_PROFILE_NAME);
         this.sslEnabled = HTTP_SSL_ENABLED.get(settings);
         this.sslService = sslService;
         if (sslEnabled) {
@@ -91,6 +90,13 @@ public class SecurityNioHttpServerTransport extends NioHttpServerTransport {
             NioHttpChannel httpChannel = new NioHttpChannel(channel);
             HttpReadWriteHandler httpHandler = new HttpReadWriteHandler(httpChannel,SecurityNioHttpServerTransport.this,
                 handlingSettings, corsConfig, selector.getTaskScheduler(), threadPool::relativeTimeInNanos);
+            final NioChannelHandler handler;
+            if (ipFilter != null) {
+                handler = new NioIPFilter(httpHandler, httpChannel.getRemoteAddress(), ipFilter, IPFilter.HTTP_PROFILE_NAME);
+            } else {
+                handler = httpHandler;
+            }
+
             InboundChannelBuffer networkBuffer = new InboundChannelBuffer(pageAllocator);
             Consumer<Exception> exceptionHandler = (e) -> securityExceptionHandler.accept(httpChannel, e);
 
@@ -107,10 +113,10 @@ public class SecurityNioHttpServerTransport extends NioHttpServerTransport {
                 }
                 SSLDriver sslDriver = new SSLDriver(sslEngine, pageAllocator, false);
                 InboundChannelBuffer applicationBuffer = new InboundChannelBuffer(pageAllocator);
-                context = new SSLChannelContext(httpChannel, selector, exceptionHandler, sslDriver, httpHandler, networkBuffer,
-                    applicationBuffer, nioIpFilter);
+                context = new SSLChannelContext(httpChannel, selector, exceptionHandler, sslDriver, handler, networkBuffer,
+                    applicationBuffer);
             } else {
-                context = new BytesChannelContext(httpChannel, selector, exceptionHandler, httpHandler, networkBuffer, nioIpFilter);
+                context = new BytesChannelContext(httpChannel, selector, exceptionHandler, handler, networkBuffer);
             }
             httpChannel.setContext(context);
 

+ 15 - 10
x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioTransport.java

@@ -18,6 +18,7 @@ import org.elasticsearch.indices.breaker.CircuitBreakerService;
 import org.elasticsearch.nio.BytesChannelContext;
 import org.elasticsearch.nio.ChannelFactory;
 import org.elasticsearch.nio.InboundChannelBuffer;
+import org.elasticsearch.nio.NioChannelHandler;
 import org.elasticsearch.nio.NioSelector;
 import org.elasticsearch.nio.NioSocketChannel;
 import org.elasticsearch.nio.ServerChannelContext;
@@ -65,19 +66,19 @@ public class SecurityNioTransport extends NioTransport {
     private static final Logger logger = LogManager.getLogger(SecurityNioTransport.class);
 
     private final SecurityTransportExceptionHandler exceptionHandler;
-    private final IPFilter authenticator;
+    private final IPFilter ipFilter;
     private final SSLService sslService;
     private final Map<String, SSLConfiguration> profileConfiguration;
     private final boolean sslEnabled;
 
     public SecurityNioTransport(Settings settings, Version version, ThreadPool threadPool, NetworkService networkService,
                                 PageCacheRecycler pageCacheRecycler, NamedWriteableRegistry namedWriteableRegistry,
-                                CircuitBreakerService circuitBreakerService, @Nullable final IPFilter authenticator,
+                                CircuitBreakerService circuitBreakerService, @Nullable final IPFilter ipFilter,
                                 SSLService sslService, NioGroupFactory groupFactory) {
         super(settings, version, threadPool, networkService, pageCacheRecycler, namedWriteableRegistry, circuitBreakerService,
             groupFactory);
         this.exceptionHandler = new SecurityTransportExceptionHandler(logger, lifecycle, (c, e) -> super.onException(c, e));
-        this.authenticator = authenticator;
+        this.ipFilter = ipFilter;
         this.sslService = sslService;
         this.sslEnabled = XPackSettings.TRANSPORT_SSL_ENABLED.get(settings);
         if (sslEnabled) {
@@ -92,8 +93,8 @@ public class SecurityNioTransport extends NioTransport {
     @Override
     protected void doStart() {
         super.doStart();
-        if (authenticator != null) {
-            authenticator.setBoundTransportAddress(boundAddress(), profileBoundAddresses());
+        if (ipFilter != null) {
+            ipFilter.setBoundTransportAddress(boundAddress(), profileBoundAddresses());
         }
     }
 
@@ -132,7 +133,6 @@ public class SecurityNioTransport extends NioTransport {
 
         private final String profileName;
         private final boolean isClient;
-        private final NioIPFilter ipFilter;
 
         private SecurityTcpChannelFactory(ProfileSettings profileSettings, boolean isClient) {
             this(new RawChannelFactory(profileSettings.tcpNoDelay,
@@ -146,13 +146,18 @@ public class SecurityNioTransport extends NioTransport {
             super(rawChannelFactory);
             this.profileName = profileName;
             this.isClient = isClient;
-            this.ipFilter = new NioIPFilter(authenticator, profileName);
         }
 
         @Override
         public NioTcpChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException {
             NioTcpChannel nioChannel = new NioTcpChannel(isClient == false, profileName, channel);
             TcpReadWriteHandler readWriteHandler = new TcpReadWriteHandler(nioChannel, SecurityNioTransport.this);
+            final NioChannelHandler handler;
+            if (ipFilter != null) {
+                handler = new NioIPFilter(readWriteHandler, nioChannel.getRemoteAddress(), ipFilter, profileName);
+            } else {
+                handler = readWriteHandler;
+            }
             InboundChannelBuffer networkBuffer = new InboundChannelBuffer(pageAllocator);
             Consumer<Exception> exceptionHandler = (e) -> onException(nioChannel, e);
 
@@ -160,10 +165,10 @@ public class SecurityNioTransport extends NioTransport {
             if (sslEnabled) {
                 SSLDriver sslDriver = new SSLDriver(createSSLEngine(channel), pageAllocator, isClient);
                 InboundChannelBuffer applicationBuffer = new InboundChannelBuffer(pageAllocator);
-                context = new SSLChannelContext(nioChannel, selector, exceptionHandler, sslDriver, readWriteHandler, networkBuffer,
-                    applicationBuffer, ipFilter);
+                context = new SSLChannelContext(nioChannel, selector, exceptionHandler, sslDriver, handler, networkBuffer,
+                    applicationBuffer);
             } else {
-                context = new BytesChannelContext(nioChannel, selector, exceptionHandler, readWriteHandler, networkBuffer, ipFilter);
+                context = new BytesChannelContext(nioChannel, selector, exceptionHandler, handler, networkBuffer);
             }
             nioChannel.setContext(context);
 

+ 22 - 14
x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/NioIPFilterTests.java

@@ -13,7 +13,7 @@ import org.elasticsearch.common.transport.BoundTransportAddress;
 import org.elasticsearch.common.transport.TransportAddress;
 import org.elasticsearch.http.HttpServerTransport;
 import org.elasticsearch.license.XPackLicenseState;
-import org.elasticsearch.nio.NioSocketChannel;
+import org.elasticsearch.nio.NioChannelHandler;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.transport.Transport;
 import org.elasticsearch.xpack.security.audit.AuditTrailService;
@@ -26,13 +26,15 @@ import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashSet;
 
-import static org.hamcrest.Matchers.is;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
 public class NioIPFilterTests extends ESTestCase {
 
-    private NioIPFilter nioIPFilter;
+    private IPFilter ipFilter;
+    private String profile;
 
     @Before
     public void init() throws Exception {
@@ -59,7 +61,7 @@ public class NioIPFilterTests extends ESTestCase {
         XPackLicenseState licenseState = mock(XPackLicenseState.class);
         when(licenseState.isIpFilteringAllowed()).thenReturn(true);
         AuditTrailService auditTrailService = new AuditTrailService(Collections.emptyList(), licenseState);
-        IPFilter ipFilter = new IPFilter(settings, auditTrailService, clusterSettings, licenseState);
+        ipFilter = new IPFilter(settings, auditTrailService, clusterSettings, licenseState);
         ipFilter.setBoundTransportAddress(transport.boundAddress(), transport.profileBoundAddresses());
         if (isHttpEnabled) {
             HttpServerTransport httpTransport = mock(HttpServerTransport.class);
@@ -70,21 +72,27 @@ public class NioIPFilterTests extends ESTestCase {
         }
 
         if (isHttpEnabled) {
-            nioIPFilter = new NioIPFilter(ipFilter, IPFilter.HTTP_PROFILE_NAME);
+            profile = IPFilter.HTTP_PROFILE_NAME;
         } else {
-            nioIPFilter = new NioIPFilter(ipFilter, "default");
+            profile = "default";
         }
     }
 
-    public void testThatFilteringWorksByIp() throws Exception {
+    public void testThatFilterCanPass() throws Exception {
         InetSocketAddress localhostAddr = new InetSocketAddress(InetAddresses.forString("127.0.0.1"), 12345);
-        NioSocketChannel channel1 = mock(NioSocketChannel.class);
-        when(channel1.getRemoteAddress()).thenReturn(localhostAddr);
-        assertThat(nioIPFilter.test(channel1), is(true));
+        NioChannelHandler delegate = mock(NioChannelHandler.class);
+        NioIPFilter nioIPFilter = new NioIPFilter(delegate, localhostAddr, ipFilter, profile);
+        nioIPFilter.channelRegistered();
+        verify(delegate).channelRegistered();
+        assertFalse(nioIPFilter.closeNow());
+    }
 
-        InetSocketAddress remoteAddr = new InetSocketAddress(InetAddresses.forString("10.0.0.8"), 12345);
-        NioSocketChannel channel2 = mock(NioSocketChannel.class);
-        when(channel2.getRemoteAddress()).thenReturn(remoteAddr);
-        assertThat(nioIPFilter.test(channel2), is(false));
+    public void testThatFilterCanFail() throws Exception {
+        InetSocketAddress localhostAddr = new InetSocketAddress(InetAddresses.forString("10.0.0.8"), 12345);
+        NioChannelHandler delegate = mock(NioChannelHandler.class);
+        NioIPFilter nioIPFilter = new NioIPFilter(delegate, localhostAddr, ipFilter, profile);
+        nioIPFilter.channelRegistered();
+        verify(delegate, times(0)).channelRegistered();
+        assertTrue(nioIPFilter.closeNow());
     }
 }