Browse Source

Move CorsHandler to server (#62007)

Currently we duplicate our specialized cors logic in all transport
plugins. This is unnecessary as it could be implemented in a single
place. This commit moves the logic to server. Additionally it fixes a
but where we are incorrectly closing http channels on early Cors
responses.
Tim Brooks 5 years ago
parent
commit
d5f9e4ecb0
18 changed files with 773 additions and 1094 deletions
  1. 1 5
      modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpServerTransport.java
  2. 0 253
      modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/cors/Netty4CorsHandler.java
  3. 0 149
      modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4CorsTests.java
  4. 1 7
      plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/HttpReadWriteHandler.java
  5. 1 1
      plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpServerTransport.java
  6. 0 254
      plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/cors/NioCorsHandler.java
  7. 3 158
      plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/HttpReadWriteHandlerTests.java
  8. 24 4
      server/src/main/java/org/elasticsearch/http/AbstractHttpServerTransport.java
  9. 181 6
      server/src/main/java/org/elasticsearch/http/CorsHandler.java
  10. 6 17
      server/src/main/java/org/elasticsearch/http/DefaultRestChannel.java
  11. 17 0
      server/src/main/java/org/elasticsearch/http/HttpRequest.java
  12. 39 0
      server/src/main/java/org/elasticsearch/http/HttpUtils.java
  13. 1 1
      server/src/main/java/org/elasticsearch/rest/RestUtils.java
  14. 242 3
      server/src/test/java/org/elasticsearch/http/CorsHandlerTests.java
  15. 85 235
      server/src/test/java/org/elasticsearch/http/DefaultRestChannelTests.java
  16. 103 0
      server/src/test/java/org/elasticsearch/http/TestHttpRequest.java
  17. 68 0
      server/src/test/java/org/elasticsearch/http/TestHttpResponse.java
  18. 1 1
      x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioHttpServerTransport.java

+ 1 - 5
modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpServerTransport.java

@@ -59,10 +59,9 @@ import org.elasticsearch.http.HttpChannel;
 import org.elasticsearch.http.HttpHandlingSettings;
 import org.elasticsearch.http.HttpReadTimeoutException;
 import org.elasticsearch.http.HttpServerChannel;
-import org.elasticsearch.http.netty4.cors.Netty4CorsHandler;
 import org.elasticsearch.threadpool.ThreadPool;
-import org.elasticsearch.transport.SharedGroupFactory;
 import org.elasticsearch.transport.NettyAllocator;
+import org.elasticsearch.transport.SharedGroupFactory;
 import org.elasticsearch.transport.netty4.Netty4Utils;
 
 import java.net.InetSocketAddress;
@@ -314,9 +313,6 @@ public class Netty4HttpServerTransport extends AbstractHttpServerTransport {
                 ch.pipeline().addLast("encoder_compress", new HttpContentCompressor(handlingSettings.getCompressionLevel()));
             }
             ch.pipeline().addLast("request_creator", requestCreator);
-            if (handlingSettings.isCorsEnabled()) {
-                ch.pipeline().addLast("cors", new Netty4CorsHandler(transport.corsConfig));
-            }
             ch.pipeline().addLast("pipelining", new Netty4HttpPipeliningHandler(logger, transport.pipeliningMaxEvents));
             ch.pipeline().addLast("handler", requestHandler);
             transport.serverAcceptedChannel(nettyHttpChannel);

+ 0 - 253
modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/cors/Netty4CorsHandler.java

@@ -1,253 +0,0 @@
-/*
- * Licensed to Elasticsearch under one or more contributor
- * license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright
- * ownership. Elasticsearch licenses this file to you under
- * the Apache License, Version 2.0 (the "License"); you may
- * not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.elasticsearch.http.netty4.cors;
-
-import io.netty.channel.ChannelDuplexHandler;
-import io.netty.channel.ChannelFutureListener;
-import io.netty.channel.ChannelHandlerContext;
-import io.netty.channel.ChannelPromise;
-import io.netty.handler.codec.http.DefaultFullHttpResponse;
-import io.netty.handler.codec.http.HttpHeaderNames;
-import io.netty.handler.codec.http.HttpHeaders;
-import io.netty.handler.codec.http.HttpMethod;
-import io.netty.handler.codec.http.HttpRequest;
-import io.netty.handler.codec.http.HttpResponse;
-import io.netty.handler.codec.http.HttpResponseStatus;
-import org.elasticsearch.common.Strings;
-import org.elasticsearch.http.CorsHandler;
-import org.elasticsearch.http.netty4.Netty4HttpRequest;
-import org.elasticsearch.http.netty4.Netty4HttpResponse;
-
-import java.util.Date;
-import java.util.regex.Pattern;
-import java.util.stream.Collectors;
-
-/**
- * Handles <a href="http://www.w3.org/TR/cors/">Cross Origin Resource Sharing</a> (CORS) requests.
- * <p>
- * This handler can be configured using a {@link CorsHandler.Config}, please
- * refer to this class for details about the configuration options available.
- *
- */
-public class Netty4CorsHandler extends ChannelDuplexHandler {
-
-    public static final String ANY_ORIGIN = "*";
-    private static Pattern SCHEME_PATTERN = Pattern.compile("^https?://");
-
-    private final CorsHandler.Config config;
-    private Netty4HttpRequest request;
-
-    /**
-     * Creates a new instance with the specified {@link CorsHandler.Config}.
-     */
-    public Netty4CorsHandler(final CorsHandler.Config config) {
-        if (config == null) {
-            throw new NullPointerException();
-        }
-        this.config = config;
-    }
-
-    @Override
-    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
-        assert msg instanceof Netty4HttpRequest : "Invalid message type: " + msg.getClass();
-        if (config.isCorsSupportEnabled()) {
-            request = (Netty4HttpRequest) msg;
-            if (isPreflightRequest(request.nettyRequest())) {
-                try {
-                    handlePreflight(ctx, request.nettyRequest());
-                    return;
-                } finally {
-                    releaseRequest();
-                }
-            }
-            if (!validateOrigin()) {
-                try {
-                    forbidden(ctx, request.nettyRequest());
-                    return;
-                } finally {
-                    releaseRequest();
-                }
-            }
-        }
-        ctx.fireChannelRead(msg);
-    }
-
-    @Override
-    public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
-        assert msg instanceof Netty4HttpResponse : "Invalid message type: " + msg.getClass();
-        Netty4HttpResponse response = (Netty4HttpResponse) msg;
-        setCorsResponseHeaders(response.requestHeaders(), response, config);
-        ctx.write(response, promise);
-    }
-
-    public static void setCorsResponseHeaders(HttpHeaders headers, HttpResponse resp, CorsHandler.Config config) {
-        if (!config.isCorsSupportEnabled()) {
-            return;
-        }
-        String originHeader = headers.get(HttpHeaderNames.ORIGIN);
-        if (!Strings.isNullOrEmpty(originHeader)) {
-            final String originHeaderVal;
-            if (config.isAnyOriginSupported()) {
-                originHeaderVal = ANY_ORIGIN;
-            } else if (config.isOriginAllowed(originHeader) || isSameOrigin(originHeader, headers.get(HttpHeaderNames.HOST))) {
-                originHeaderVal = originHeader;
-            } else {
-                originHeaderVal = null;
-            }
-            if (originHeaderVal != null) {
-                resp.headers().add(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN, originHeaderVal);
-            }
-        }
-        if (config.isCredentialsAllowed()) {
-            resp.headers().add(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS, "true");
-        }
-    }
-
-    private void handlePreflight(final ChannelHandlerContext ctx, final HttpRequest request) {
-        final HttpResponse response = new DefaultFullHttpResponse(request.protocolVersion(), HttpResponseStatus.OK, true, true);
-        if (setOrigin(response)) {
-            setAllowMethods(response);
-            setAllowHeaders(response);
-            setAllowCredentials(response);
-            setMaxAge(response);
-            setPreflightHeaders(response);
-            ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE);
-        } else {
-            forbidden(ctx, request);
-        }
-    }
-
-    private void releaseRequest() {
-        request.release();
-        request = null;
-    }
-
-    private static void forbidden(final ChannelHandlerContext ctx, final HttpRequest request) {
-        ctx.writeAndFlush(new DefaultFullHttpResponse(request.protocolVersion(), HttpResponseStatus.FORBIDDEN))
-            .addListener(ChannelFutureListener.CLOSE);
-    }
-
-    private static boolean isSameOrigin(final String origin, final String host) {
-        if (Strings.isNullOrEmpty(host) == false) {
-            // strip protocol from origin
-            final String originDomain = SCHEME_PATTERN.matcher(origin).replaceFirst("");
-            if (host.equals(originDomain)) {
-                return true;
-            }
-        }
-        return false;
-    }
-
-    /**
-     * This is a non CORS specification feature which enables the setting of preflight
-     * response headers that might be required by intermediaries.
-     *
-     * @param response the HttpResponse to which the preflight response headers should be added.
-     */
-    private void setPreflightHeaders(final HttpResponse response) {
-        response.headers().add("date", new Date());
-        response.headers().add("content-length", "0");
-    }
-
-    private boolean setOrigin(final HttpResponse response) {
-        final String origin = request.nettyRequest().headers().get(HttpHeaderNames.ORIGIN);
-        if (!Strings.isNullOrEmpty(origin)) {
-            if (config.isAnyOriginSupported()) {
-                if (config.isCredentialsAllowed()) {
-                    echoRequestOrigin(response);
-                    setVaryHeader(response);
-                } else {
-                    setAnyOrigin(response);
-                }
-                return true;
-            }
-            if (config.isOriginAllowed(origin)) {
-                setOrigin(response, origin);
-                setVaryHeader(response);
-                return true;
-            }
-        }
-        return false;
-    }
-
-    private boolean validateOrigin() {
-        if (config.isAnyOriginSupported()) {
-            return true;
-        }
-
-        final String origin = request.nettyRequest().headers().get(HttpHeaderNames.ORIGIN);
-        if (Strings.isNullOrEmpty(origin)) {
-            // Not a CORS request so we cannot validate it. It may be a non CORS request.
-            return true;
-        }
-
-        // if the origin is the same as the host of the request, then allow
-        if (isSameOrigin(origin, request.nettyRequest().headers().get(HttpHeaderNames.HOST))) {
-            return true;
-        }
-
-        return config.isOriginAllowed(origin);
-    }
-
-    private void echoRequestOrigin(final HttpResponse response) {
-        setOrigin(response, request.nettyRequest().headers().get(HttpHeaderNames.ORIGIN));
-    }
-
-    private static void setVaryHeader(final HttpResponse response) {
-        response.headers().set(HttpHeaderNames.VARY, HttpHeaderNames.ORIGIN);
-    }
-
-    private static void setAnyOrigin(final HttpResponse response) {
-        setOrigin(response, ANY_ORIGIN);
-    }
-
-    private static void setOrigin(final HttpResponse response, final String origin) {
-        response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN, origin);
-    }
-
-    private void setAllowCredentials(final HttpResponse response) {
-        if (config.isCredentialsAllowed()
-            && !response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN).equals(ANY_ORIGIN)) {
-            response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS, "true");
-        }
-    }
-
-    private static boolean isPreflightRequest(final HttpRequest request) {
-        final HttpHeaders headers = request.headers();
-        return request.method().equals(HttpMethod.OPTIONS) &&
-            headers.contains(HttpHeaderNames.ORIGIN) &&
-            headers.contains(HttpHeaderNames.ACCESS_CONTROL_REQUEST_METHOD);
-    }
-
-    private void setAllowMethods(final HttpResponse response) {
-        response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_METHODS, config.allowedRequestMethods().stream()
-            .map(m -> m.name().trim())
-            .collect(Collectors.toList()));
-    }
-
-    private void setAllowHeaders(final HttpResponse response) {
-        response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_HEADERS, config.allowedRequestHeaders());
-    }
-
-    private void setMaxAge(final HttpResponse response) {
-        response.headers().set(HttpHeaderNames.ACCESS_CONTROL_MAX_AGE, config.maxAge());
-    }
-
-}

+ 0 - 149
modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4CorsTests.java

@@ -1,149 +0,0 @@
-/*
- * Licensed to Elasticsearch under one or more contributor
- * license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright
- * ownership. Elasticsearch licenses this file to you under
- * the Apache License, Version 2.0 (the "License"); you may
- * not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.elasticsearch.http.netty4;
-
-import io.netty.channel.embedded.EmbeddedChannel;
-import io.netty.handler.codec.http.DefaultFullHttpRequest;
-import io.netty.handler.codec.http.FullHttpRequest;
-import io.netty.handler.codec.http.FullHttpResponse;
-import io.netty.handler.codec.http.HttpHeaderNames;
-import io.netty.handler.codec.http.HttpMethod;
-import io.netty.handler.codec.http.HttpResponse;
-import io.netty.handler.codec.http.HttpVersion;
-import org.elasticsearch.common.bytes.BytesArray;
-import org.elasticsearch.common.settings.Settings;
-import org.elasticsearch.http.CorsHandler;
-import org.elasticsearch.http.HttpTransportSettings;
-import org.elasticsearch.http.netty4.cors.Netty4CorsHandler;
-import org.elasticsearch.rest.RestStatus;
-import org.elasticsearch.test.ESTestCase;
-
-import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_CREDENTIALS;
-import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_METHODS;
-import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_ORIGIN;
-import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ENABLED;
-import static org.hamcrest.Matchers.equalTo;
-import static org.hamcrest.Matchers.is;
-import static org.hamcrest.Matchers.notNullValue;
-import static org.hamcrest.Matchers.nullValue;
-
-public class Netty4CorsTests extends ESTestCase {
-
-    public void testCorsEnabledWithoutAllowOrigins() {
-        // Set up an HTTP transport with only the CORS enabled setting
-        Settings settings = Settings.builder()
-            .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true)
-            .build();
-        HttpResponse response = executeRequest(settings, "remote-host", "request-host");
-        // inspect response and validate
-        assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), nullValue());
-    }
-
-    public void testCorsEnabledWithAllowOrigins() {
-        final String originValue = "remote-host";
-        // create an HTTP transport with CORS enabled and allow origin configured
-        Settings settings = Settings.builder()
-            .put(SETTING_CORS_ENABLED.getKey(), true)
-            .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue)
-            .build();
-        HttpResponse response = executeRequest(settings, originValue, "request-host");
-        // inspect response and validate
-        assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
-        String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
-        assertThat(allowedOrigins, is(originValue));
-    }
-
-    public void testCorsAllowOriginWithSameHost() {
-        String originValue = "remote-host";
-        String host = "remote-host";
-        // create an HTTP transport with CORS enabled
-        Settings settings = Settings.builder()
-            .put(SETTING_CORS_ENABLED.getKey(), true)
-            .build();
-        HttpResponse response = executeRequest(settings, originValue, host);
-        // inspect response and validate
-        assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
-        String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
-        assertThat(allowedOrigins, is(originValue));
-
-        originValue = "http://" + originValue;
-        response = executeRequest(settings, originValue, host);
-        assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
-        allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
-        assertThat(allowedOrigins, is(originValue));
-
-        originValue = originValue + ":5555";
-        host = host + ":5555";
-        response = executeRequest(settings, originValue, host);
-        assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
-        allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
-        assertThat(allowedOrigins, is(originValue));
-
-        originValue = originValue.replace("http", "https");
-        response = executeRequest(settings, originValue, host);
-        assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
-        allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
-        assertThat(allowedOrigins, is(originValue));
-    }
-
-    public void testThatStringLiteralWorksOnMatch() {
-        final String originValue = "remote-host";
-        Settings settings = Settings.builder()
-            .put(SETTING_CORS_ENABLED.getKey(), true)
-            .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue)
-            .put(SETTING_CORS_ALLOW_METHODS.getKey(), "get, options, post")
-            .put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true)
-            .build();
-        HttpResponse response = executeRequest(settings, originValue, "request-host");
-        // inspect response and validate
-        assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
-        String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
-        assertThat(allowedOrigins, is(originValue));
-        assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true"));
-    }
-
-    public void testThatAnyOriginWorks() {
-        final String originValue = Netty4CorsHandler.ANY_ORIGIN;
-        Settings settings = Settings.builder()
-            .put(SETTING_CORS_ENABLED.getKey(), true)
-            .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue)
-            .build();
-        HttpResponse response = executeRequest(settings, originValue, "request-host");
-        // inspect response and validate
-        assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
-        String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
-        assertThat(allowedOrigins, is(originValue));
-        assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), nullValue());
-    }
-
-    private FullHttpResponse executeRequest(final Settings settings, final String originValue, final String host) {
-        // construct request and send it over the transport layer
-        final FullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/");
-        if (originValue != null) {
-            httpRequest.headers().add(HttpHeaderNames.ORIGIN, originValue);
-        }
-        httpRequest.headers().add(HttpHeaderNames.HOST, host);
-        EmbeddedChannel embeddedChannel = new EmbeddedChannel();
-        embeddedChannel.pipeline().addLast(new Netty4CorsHandler(CorsHandler.fromSettings(settings)));
-        Netty4HttpRequest nettyRequest = new Netty4HttpRequest(httpRequest);
-        embeddedChannel.writeOutbound(nettyRequest.createResponse(RestStatus.OK, new BytesArray("content")));
-        return embeddedChannel.readOutbound();
-    }
-}

+ 1 - 7
plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/HttpReadWriteHandler.java

@@ -27,12 +27,10 @@ import io.netty.handler.codec.http.HttpObjectAggregator;
 import io.netty.handler.codec.http.HttpRequestDecoder;
 import io.netty.handler.codec.http.HttpResponseEncoder;
 import org.elasticsearch.common.unit.TimeValue;
-import org.elasticsearch.http.CorsHandler;
 import org.elasticsearch.http.HttpHandlingSettings;
 import org.elasticsearch.http.HttpPipelinedRequest;
 import org.elasticsearch.http.HttpPipelinedResponse;
 import org.elasticsearch.http.HttpReadTimeoutException;
-import org.elasticsearch.http.nio.cors.NioCorsHandler;
 import org.elasticsearch.nio.FlushOperation;
 import org.elasticsearch.nio.InboundChannelBuffer;
 import org.elasticsearch.nio.NioChannelHandler;
@@ -60,7 +58,7 @@ public class HttpReadWriteHandler implements NioChannelHandler {
     private int inFlightRequests = 0;
 
     public HttpReadWriteHandler(NioHttpChannel nioHttpChannel, NioHttpServerTransport transport, HttpHandlingSettings settings,
-                                CorsHandler.Config corsConfig, TaskScheduler taskScheduler, LongSupplier nanoClock) {
+                                TaskScheduler taskScheduler, LongSupplier nanoClock) {
         this.nioHttpChannel = nioHttpChannel;
         this.transport = transport;
         this.taskScheduler = taskScheduler;
@@ -79,9 +77,6 @@ public class HttpReadWriteHandler implements NioChannelHandler {
             handlers.add(new HttpContentCompressor(settings.getCompressionLevel()));
         }
         handlers.add(new NioHttpRequestCreator());
-        if (settings.isCorsEnabled()) {
-            handlers.add(new NioCorsHandler(corsConfig));
-        }
         handlers.add(new NioHttpPipeliningHandler(transport.getLogger(), settings.getPipeliningMaxEvents()));
 
         adaptor = new NettyAdaptor(handlers.toArray(new ChannelHandler[0]));
@@ -150,7 +145,6 @@ public class HttpReadWriteHandler implements NioChannelHandler {
         }
     }
 
-    @SuppressWarnings("unchecked")
     private void handleRequest(Object msg) {
         final HttpPipelinedRequest pipelinedRequest = (HttpPipelinedRequest) msg;
         boolean success = false;

+ 1 - 1
plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpServerTransport.java

@@ -169,7 +169,7 @@ public class NioHttpServerTransport extends AbstractHttpServerTransport {
         public NioHttpChannel createChannel(NioSelector selector, SocketChannel channel, Config.Socket socketConfig) {
             NioHttpChannel httpChannel = new NioHttpChannel(channel);
             HttpReadWriteHandler handler = new HttpReadWriteHandler(httpChannel,NioHttpServerTransport.this,
-                handlingSettings, corsConfig, selector.getTaskScheduler(), threadPool::relativeTimeInMillis);
+                handlingSettings, selector.getTaskScheduler(), threadPool::relativeTimeInMillis);
             Consumer<Exception> exceptionHandler = (e) -> onException(httpChannel, e);
             SocketChannelContext context = new BytesChannelContext(httpChannel, selector, socketConfig, exceptionHandler, handler,
                 new InboundChannelBuffer(pageAllocator));

+ 0 - 254
plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/cors/NioCorsHandler.java

@@ -1,254 +0,0 @@
-/*
- * Licensed to Elasticsearch under one or more contributor
- * license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright
- * ownership. Elasticsearch licenses this file to you under
- * the Apache License, Version 2.0 (the "License"); you may
- * not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.elasticsearch.http.nio.cors;
-
-import io.netty.channel.ChannelDuplexHandler;
-import io.netty.channel.ChannelFutureListener;
-import io.netty.channel.ChannelHandlerContext;
-import io.netty.channel.ChannelPromise;
-import io.netty.handler.codec.http.DefaultFullHttpResponse;
-import io.netty.handler.codec.http.HttpHeaderNames;
-import io.netty.handler.codec.http.HttpHeaders;
-import io.netty.handler.codec.http.HttpMethod;
-import io.netty.handler.codec.http.HttpRequest;
-import io.netty.handler.codec.http.HttpResponse;
-import io.netty.handler.codec.http.HttpResponseStatus;
-import org.elasticsearch.common.Strings;
-import org.elasticsearch.http.CorsHandler;
-import org.elasticsearch.http.nio.NioHttpRequest;
-import org.elasticsearch.http.nio.NioHttpResponse;
-
-import java.util.Date;
-import java.util.regex.Pattern;
-import java.util.stream.Collectors;
-
-/**
- * Handles <a href="http://www.w3.org/TR/cors/">Cross Origin Resource Sharing</a> (CORS) requests.
- * <p>
- * This handler can be configured using a {@link CorsHandler.Config}, please
- * refer to this class for details about the configuration options available.
- *
- * This code was borrowed from Netty 4 and refactored to work for Elasticsearch's Netty 3 setup.
- */
-public class NioCorsHandler extends ChannelDuplexHandler {
-
-    public static final String ANY_ORIGIN = "*";
-    private static final Pattern SCHEME_PATTERN = Pattern.compile("^https?://");
-
-    private final CorsHandler.Config config;
-    private NioHttpRequest request;
-
-    /**
-     * Creates a new instance with the specified {@link CorsHandler.Config}.
-     */
-    public NioCorsHandler(final CorsHandler.Config config) {
-        if (config == null) {
-            throw new NullPointerException();
-        }
-        this.config = config;
-    }
-
-    @Override
-    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
-        assert msg instanceof NioHttpRequest : "Invalid message type: " + msg.getClass();
-        if (config.isCorsSupportEnabled()) {
-            request = (NioHttpRequest) msg;
-            if (isPreflightRequest(request.nettyRequest())) {
-                try {
-                    handlePreflight(ctx, request.nettyRequest());
-                    return;
-                } finally {
-                    releaseRequest();
-                }
-            }
-            if (!validateOrigin()) {
-                try {
-                    forbidden(ctx, request.nettyRequest());
-                    return;
-                } finally {
-                    releaseRequest();
-                }
-            }
-        }
-        ctx.fireChannelRead(msg);
-    }
-
-    @Override
-    public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
-        assert msg instanceof NioHttpResponse : "Invalid message type: " + msg.getClass();
-        NioHttpResponse response = (NioHttpResponse) msg;
-        setCorsResponseHeaders(response.requestHeaders(), response, config);
-        ctx.write(response, promise);
-    }
-
-    public static void setCorsResponseHeaders(HttpHeaders headers, HttpResponse resp, CorsHandler.Config config) {
-        if (!config.isCorsSupportEnabled()) {
-            return;
-        }
-        String originHeader = headers.get(HttpHeaderNames.ORIGIN);
-        if (!Strings.isNullOrEmpty(originHeader)) {
-            final String originHeaderVal;
-            if (config.isAnyOriginSupported()) {
-                originHeaderVal = ANY_ORIGIN;
-            } else if (config.isOriginAllowed(originHeader) || isSameOrigin(originHeader, headers.get(HttpHeaderNames.HOST))) {
-                originHeaderVal = originHeader;
-            } else {
-                originHeaderVal = null;
-            }
-            if (originHeaderVal != null) {
-                resp.headers().add(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN, originHeaderVal);
-            }
-        }
-        if (config.isCredentialsAllowed()) {
-            resp.headers().add(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS, "true");
-        }
-    }
-
-    private void releaseRequest() {
-        request.release();
-        request = null;
-    }
-
-    private void handlePreflight(final ChannelHandlerContext ctx, final HttpRequest request) {
-        final HttpResponse response = new DefaultFullHttpResponse(request.protocolVersion(), HttpResponseStatus.OK, true, true);
-        if (setOrigin(response)) {
-            setAllowMethods(response);
-            setAllowHeaders(response);
-            setAllowCredentials(response);
-            setMaxAge(response);
-            setPreflightHeaders(response);
-            ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE);
-        } else {
-            forbidden(ctx, request);
-        }
-    }
-
-    private static void forbidden(final ChannelHandlerContext ctx, final HttpRequest request) {
-        ctx.writeAndFlush(new DefaultFullHttpResponse(request.protocolVersion(), HttpResponseStatus.FORBIDDEN))
-            .addListener(ChannelFutureListener.CLOSE);
-    }
-
-    private static boolean isSameOrigin(final String origin, final String host) {
-        if (Strings.isNullOrEmpty(host) == false) {
-            // strip protocol from origin
-            final String originDomain = SCHEME_PATTERN.matcher(origin).replaceFirst("");
-            if (host.equals(originDomain)) {
-                return true;
-            }
-        }
-        return false;
-    }
-
-    /**
-     * This is a non CORS specification feature which enables the setting of preflight
-     * response headers that might be required by intermediaries.
-     *
-     * @param response the HttpResponse to which the preflight response headers should be added.
-     */
-    private void setPreflightHeaders(final HttpResponse response) {
-        response.headers().add("date", new Date());
-        response.headers().add("content-length", "0");
-    }
-
-    private boolean setOrigin(final HttpResponse response) {
-        final String origin = request.nettyRequest().headers().get(HttpHeaderNames.ORIGIN);
-        if (!Strings.isNullOrEmpty(origin)) {
-            if (config.isAnyOriginSupported()) {
-                if (config.isCredentialsAllowed()) {
-                    echoRequestOrigin(response);
-                    setVaryHeader(response);
-                } else {
-                    setAnyOrigin(response);
-                }
-                return true;
-            }
-            if (config.isOriginAllowed(origin)) {
-                setOrigin(response, origin);
-                setVaryHeader(response);
-                return true;
-            }
-        }
-        return false;
-    }
-
-    private boolean validateOrigin() {
-        if (config.isAnyOriginSupported()) {
-            return true;
-        }
-
-        final String origin = request.nettyRequest().headers().get(HttpHeaderNames.ORIGIN);
-        if (Strings.isNullOrEmpty(origin)) {
-            // Not a CORS request so we cannot validate it. It may be a non CORS request.
-            return true;
-        }
-
-        // if the origin is the same as the host of the request, then allow
-        if (isSameOrigin(origin, request.nettyRequest().headers().get(HttpHeaderNames.HOST))) {
-            return true;
-        }
-
-        return config.isOriginAllowed(origin);
-    }
-
-    private void echoRequestOrigin(final HttpResponse response) {
-        setOrigin(response, request.nettyRequest().headers().get(HttpHeaderNames.ORIGIN));
-    }
-
-    private static void setVaryHeader(final HttpResponse response) {
-        response.headers().set(HttpHeaderNames.VARY, HttpHeaderNames.ORIGIN);
-    }
-
-    private static void setAnyOrigin(final HttpResponse response) {
-        setOrigin(response, ANY_ORIGIN);
-    }
-
-    private static void setOrigin(final HttpResponse response, final String origin) {
-        response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN, origin);
-    }
-
-    private void setAllowCredentials(final HttpResponse response) {
-        if (config.isCredentialsAllowed()
-            && !response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN).equals(ANY_ORIGIN)) {
-            response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS, "true");
-        }
-    }
-
-    private static boolean isPreflightRequest(final HttpRequest request) {
-        final HttpHeaders headers = request.headers();
-        return request.method().equals(HttpMethod.OPTIONS) &&
-            headers.contains(HttpHeaderNames.ORIGIN) &&
-            headers.contains(HttpHeaderNames.ACCESS_CONTROL_REQUEST_METHOD);
-    }
-
-    private void setAllowMethods(final HttpResponse response) {
-        response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_METHODS, config.allowedRequestMethods().stream()
-            .map(m -> m.name().trim())
-            .collect(Collectors.toList()));
-    }
-
-    private void setAllowHeaders(final HttpResponse response) {
-        response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_HEADERS, config.allowedRequestHeaders());
-    }
-
-    private void setMaxAge(final HttpResponse response) {
-        response.headers().set(HttpHeaderNames.ACCESS_CONTROL_MAX_AGE, config.maxAge());
-    }
-
-}

+ 3 - 158
plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/HttpReadWriteHandlerTests.java

@@ -44,9 +44,6 @@ import org.elasticsearch.http.HttpPipelinedRequest;
 import org.elasticsearch.http.HttpPipelinedResponse;
 import org.elasticsearch.http.HttpReadTimeoutException;
 import org.elasticsearch.http.HttpRequest;
-import org.elasticsearch.http.HttpResponse;
-import org.elasticsearch.http.HttpTransportSettings;
-import org.elasticsearch.http.nio.cors.NioCorsHandler;
 import org.elasticsearch.nio.FlushOperation;
 import org.elasticsearch.nio.InboundChannelBuffer;
 import org.elasticsearch.nio.SocketChannelContext;
@@ -64,16 +61,8 @@ import java.util.Iterator;
 import java.util.List;
 import java.util.function.BiConsumer;
 
-import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_CREDENTIALS;
-import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_METHODS;
-import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_ORIGIN;
-import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ENABLED;
 import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_MAX_CONTENT_LENGTH;
 import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_READ_TIMEOUT;
-import static org.hamcrest.Matchers.equalTo;
-import static org.hamcrest.Matchers.is;
-import static org.hamcrest.Matchers.notNullValue;
-import static org.hamcrest.Matchers.nullValue;
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.eq;
 import static org.mockito.Mockito.atLeastOnce;
@@ -104,8 +93,7 @@ public class HttpReadWriteHandlerTests extends ESTestCase {
         channel = mock(NioHttpChannel.class);
         taskScheduler = mock(TaskScheduler.class);
 
-        CorsHandler.Config corsConfig = CorsHandler.disabled();
-        handler = new HttpReadWriteHandler(channel, transport, httpHandlingSettings, corsConfig, taskScheduler, System::nanoTime);
+        handler = new HttpReadWriteHandler(channel, transport, httpHandlingSettings, taskScheduler, System::nanoTime);
         handler.channelActive();
     }
 
@@ -211,135 +199,17 @@ public class HttpReadWriteHandlerTests extends ESTestCase {
         }
     }
 
-    public void testCorsEnabledWithoutAllowOrigins() throws IOException {
-        // Set up an HTTP transport with only the CORS enabled setting
-        Settings settings = Settings.builder()
-            .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true)
-            .build();
-        FullHttpResponse response = executeCorsRequest(settings, "remote-host", "request-host");
-        try {
-            // inspect response and validate
-            assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), nullValue());
-        } finally {
-            response.release();
-        }
-    }
-
-    public void testCorsEnabledWithAllowOrigins() throws IOException {
-        final String originValue = "remote-host";
-        // create an HTTP transport with CORS enabled and allow origin configured
-        Settings settings = Settings.builder()
-            .put(SETTING_CORS_ENABLED.getKey(), true)
-            .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue)
-            .build();
-        FullHttpResponse response = executeCorsRequest(settings, originValue, "request-host");
-        try {
-            // inspect response and validate
-            assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
-            String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
-            assertThat(allowedOrigins, is(originValue));
-        } finally {
-            response.release();
-        }
-    }
-
-    public void testCorsAllowOriginWithSameHost() throws IOException {
-        String originValue = "remote-host";
-        String host = "remote-host";
-        // create an HTTP transport with CORS enabled
-        Settings settings = Settings.builder()
-            .put(SETTING_CORS_ENABLED.getKey(), true)
-            .build();
-        FullHttpResponse response = executeCorsRequest(settings, originValue, host);
-        String allowedOrigins;
-        try {
-            // inspect response and validate
-            assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
-            allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
-            assertThat(allowedOrigins, is(originValue));
-        } finally {
-            response.release();
-        }
-        originValue = "http://" + originValue;
-        response = executeCorsRequest(settings, originValue, host);
-        try {
-            assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
-            allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
-            assertThat(allowedOrigins, is(originValue));
-        } finally {
-            response.release();
-        }
-
-        originValue = originValue + ":5555";
-        host = host + ":5555";
-        response = executeCorsRequest(settings, originValue, host);
-        try {
-            assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
-            allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
-            assertThat(allowedOrigins, is(originValue));
-        } finally {
-            response.release();
-        }
-        originValue = originValue.replace("http", "https");
-        response = executeCorsRequest(settings, originValue, host);
-        try {
-            assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
-            allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
-            assertThat(allowedOrigins, is(originValue));
-        } finally {
-            response.release();
-        }
-    }
-
-    public void testThatStringLiteralWorksOnMatch() throws IOException {
-        final String originValue = "remote-host";
-        Settings settings = Settings.builder()
-            .put(SETTING_CORS_ENABLED.getKey(), true)
-            .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue)
-            .put(SETTING_CORS_ALLOW_METHODS.getKey(), "get, options, post")
-            .put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true)
-            .build();
-        FullHttpResponse response = executeCorsRequest(settings, originValue, "request-host");
-        try {
-            // inspect response and validate
-            assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
-            String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
-            assertThat(allowedOrigins, is(originValue));
-            assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true"));
-        } finally {
-            response.release();
-        }
-    }
-
-    public void testThatAnyOriginWorks() throws IOException {
-        final String originValue = NioCorsHandler.ANY_ORIGIN;
-        Settings settings = Settings.builder()
-            .put(SETTING_CORS_ENABLED.getKey(), true)
-            .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue)
-            .build();
-        FullHttpResponse response = executeCorsRequest(settings, originValue, "request-host");
-        try {
-            // inspect response and validate
-            assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
-            String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
-            assertThat(allowedOrigins, is(originValue));
-            assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), nullValue());
-        } finally {
-            response.release();
-        }
-    }
-
     @SuppressWarnings("unchecked")
     public void testReadTimeout() throws IOException {
         TimeValue timeValue = TimeValue.timeValueMillis(500);
         Settings settings = Settings.builder().put(SETTING_HTTP_READ_TIMEOUT.getKey(), timeValue).build();
         HttpHandlingSettings httpHandlingSettings = HttpHandlingSettings.fromSettings(settings);
 
-        CorsHandler.Config corsConfig = CorsHandler.disabled();
+        CorsHandler corsHandler = CorsHandler.disabled();
         TaskScheduler taskScheduler = new TaskScheduler();
 
         Iterator<Integer> timeValues = Arrays.asList(0, 2, 4, 6, 8).iterator();
-        handler = new HttpReadWriteHandler(channel, transport, httpHandlingSettings, corsConfig, taskScheduler, timeValues::next);
+        handler = new HttpReadWriteHandler(channel, transport, httpHandlingSettings, taskScheduler, timeValues::next);
         handler.channelActive();
 
         prepareHandlerForResponse(handler);
@@ -382,31 +252,6 @@ public class HttpReadWriteHandlerTests extends ESTestCase {
         return httpResponse;
     }
 
-    private FullHttpResponse executeCorsRequest(final Settings settings, final String originValue, final String host) throws IOException {
-        HttpHandlingSettings httpSettings = HttpHandlingSettings.fromSettings(settings);
-        CorsHandler.Config corsConfig = CorsHandler.fromSettings(settings);
-        HttpReadWriteHandler handler = new HttpReadWriteHandler(channel, transport, httpSettings, corsConfig, taskScheduler,
-            System::nanoTime);
-        handler.channelActive();
-        prepareHandlerForResponse(handler);
-        DefaultFullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/");
-        if (originValue != null) {
-            httpRequest.headers().add(HttpHeaderNames.ORIGIN, originValue);
-        }
-        httpRequest.headers().add(HttpHeaderNames.HOST, host);
-        HttpPipelinedRequest pipelinedRequest = new HttpPipelinedRequest(0, new NioHttpRequest(httpRequest));
-        BytesArray content = new BytesArray("content");
-        HttpResponse response = pipelinedRequest.createResponse(RestStatus.OK, content);
-        response.addHeader("Content-Length", Integer.toString(content.length()));
-
-        SocketChannelContext context = mock(SocketChannelContext.class);
-        List<FlushOperation> flushOperations = handler.writeToBytes(handler.createWriteOperation(context, response, (v, e) -> {}));
-        handler.close();
-        FlushOperation flushOperation = flushOperations.get(0);
-        ((ChannelPromise) flushOperation.getListener()).setSuccess();
-        return responseDecoder.decode(Unpooled.wrappedBuffer(flushOperation.getBuffersToWrite()));
-    }
-
 
 
     private void prepareHandlerForResponse(HttpReadWriteHandler handler) throws IOException {

+ 24 - 4
server/src/main/java/org/elasticsearch/http/AbstractHttpServerTransport.java

@@ -67,6 +67,7 @@ import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_PUBLISH_
 
 public abstract class AbstractHttpServerTransport extends AbstractLifecycleComponent implements HttpServerTransport {
     private static final Logger logger = LogManager.getLogger(AbstractHttpServerTransport.class);
+    private static final ActionListener<Void> NO_OP = ActionListener.wrap(() -> {});
 
     protected final Settings settings;
     public final HttpHandlingSettings handlingSettings;
@@ -74,7 +75,7 @@ public abstract class AbstractHttpServerTransport extends AbstractLifecycleCompo
     protected final BigArrays bigArrays;
     protected final ThreadPool threadPool;
     protected final Dispatcher dispatcher;
-    protected final CorsHandler.Config corsConfig;
+    protected final CorsHandler corsHandler;
     private final NamedXContentRegistry xContentRegistry;
 
     protected final PortsRange port;
@@ -98,7 +99,7 @@ public abstract class AbstractHttpServerTransport extends AbstractLifecycleCompo
         this.xContentRegistry = xContentRegistry;
         this.dispatcher = dispatcher;
         this.handlingSettings = HttpHandlingSettings.fromSettings(settings);
-        this.corsConfig = CorsHandler.fromSettings(settings);
+        this.corsHandler = CorsHandler.fromSettings(settings);
 
         // we can't make the network.bind_host a fallback since we already fall back to http.host hence the extra conditional here
         List<String> httpBindHost = SETTING_HTTP_BIND_HOST.get(settings);
@@ -321,6 +322,15 @@ public abstract class AbstractHttpServerTransport extends AbstractLifecycleCompo
     }
 
     private void handleIncomingRequest(final HttpRequest httpRequest, final HttpChannel httpChannel, final Exception exception) {
+        if (exception == null) {
+            HttpResponse earlyResponse = corsHandler.handleInbound(httpRequest);
+            if (earlyResponse != null) {
+                httpChannel.sendResponse(earlyResponse, earlyResponseListener(httpRequest, httpChannel));
+                httpRequest.release();
+                return;
+            }
+        }
+
         Exception badRequestCause = exception;
 
         /*
@@ -359,12 +369,14 @@ public abstract class AbstractHttpServerTransport extends AbstractLifecycleCompo
             ThreadContext threadContext = threadPool.getThreadContext();
             try {
                 innerChannel =
-                    new DefaultRestChannel(httpChannel, httpRequest, restRequest, bigArrays, handlingSettings, threadContext, trace);
+                    new DefaultRestChannel(httpChannel, httpRequest, restRequest, bigArrays, handlingSettings, threadContext, corsHandler,
+                        trace);
             } catch (final IllegalArgumentException e) {
                 badRequestCause = ExceptionsHelper.useOrSuppress(badRequestCause, e);
                 final RestRequest innerRequest = RestRequest.requestWithoutParameters(xContentRegistry, httpRequest, httpChannel);
                 innerChannel =
-                    new DefaultRestChannel(httpChannel, httpRequest, innerRequest, bigArrays, handlingSettings, threadContext, trace);
+                    new DefaultRestChannel(httpChannel, httpRequest, innerRequest, bigArrays, handlingSettings, threadContext, corsHandler,
+                        trace);
             }
             channel = innerChannel;
         }
@@ -381,4 +393,12 @@ public abstract class AbstractHttpServerTransport extends AbstractLifecycleCompo
             return RestRequest.requestWithoutParameters(xContentRegistry, httpRequestWithoutContentType, httpChannel);
         }
     }
+
+    private static ActionListener<Void> earlyResponseListener(HttpRequest request, HttpChannel httpChannel) {
+        if (HttpUtils.shouldCloseConnection(request)) {
+            return ActionListener.wrap(() -> CloseableChannel.closeChannel(httpChannel));
+        } else {
+            return NO_OP;
+        }
+    }
 }

+ 181 - 6
server/src/main/java/org/elasticsearch/http/CorsHandler.java

@@ -35,16 +35,23 @@
 package org.elasticsearch.http;
 
 import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.bytes.BytesArray;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.settings.SettingsException;
 import org.elasticsearch.rest.RestRequest;
+import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.rest.RestUtils;
 
+import java.time.ZoneOffset;
+import java.time.ZonedDateTime;
+import java.time.format.DateTimeFormatter;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashSet;
 import java.util.LinkedHashSet;
+import java.util.List;
 import java.util.Locale;
+import java.util.Map;
 import java.util.Optional;
 import java.util.Set;
 import java.util.regex.Pattern;
@@ -62,7 +69,7 @@ import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_MAX_AGE;
  * files: io.netty.handler.codec.http.cors.CorsHandler, io.netty.handler.codec.http.cors.CorsConfig, and
  * io.netty.handler.codec.http.cors.CorsConfigBuilder.
  *
- * It modifies the original netty code to operation on Elasticsearch http request/response abstractions.
+ * It modifies the original netty code to operate on Elasticsearch http request/response abstractions.
  * Additionally, it removes CORS features that are not used by Elasticsearch.
  */
 public class CorsHandler {
@@ -71,10 +78,172 @@ public class CorsHandler {
     public static final String ORIGIN = "origin";
     public static final String DATE = "date";
     public static final String VARY = "vary";
+    public static final String HOST = "host";
     public static final String ACCESS_CONTROL_REQUEST_METHOD = "access-control-request-method";
+    public static final String ACCESS_CONTROL_ALLOW_HEADERS = "access-control-allow-headers";
+    public static final String ACCESS_CONTROL_ALLOW_CREDENTIALS = "access-control-allow-credentials";
+    public static final String ACCESS_CONTROL_ALLOW_METHODS = "access-control-allow-methods";
     public static final String ACCESS_CONTROL_ALLOW_ORIGIN = "access-control-allow-origin";
+    public static final String ACCESS_CONTROL_MAX_AGE = "access-control-max-age";
 
-    private CorsHandler() {
+    private static final Pattern SCHEME_PATTERN = Pattern.compile("^https?://");
+    private static final DateTimeFormatter dateTimeFormatter = DateTimeFormatter.ofPattern("EEE, dd MMM yyyy HH:mm:ss O", Locale.ENGLISH);
+    private final Config config;
+
+    public CorsHandler(Config config) {
+        this.config = config;
+    }
+
+    public HttpResponse handleInbound(HttpRequest request) {
+        if (config.isCorsSupportEnabled()) {
+            if (isPreflightRequest(request)) {
+                return handlePreflight(request);
+            }
+
+            if (validateOrigin(request) == false) {
+                return forbidden(request);
+            }
+        }
+        return null;
+    }
+
+    public void setCorsResponseHeaders(final HttpRequest httpRequest, final HttpResponse httpResponse) {
+        if (!config.isCorsSupportEnabled()) {
+            return;
+        }
+        if (setOrigin(httpRequest, httpResponse)) {
+            setAllowCredentials(httpResponse);
+        }
+    }
+
+    private HttpResponse handlePreflight(final HttpRequest request) {
+        final HttpResponse response = request.createResponse(RestStatus.OK, BytesArray.EMPTY);
+        if (setOrigin(request, response)) {
+            setAllowMethods(response);
+            setAllowHeaders(response);
+            setAllowCredentials(response);
+            setMaxAge(response);
+            setPreflightHeaders(response);
+            return response;
+        } else {
+            return forbidden(request);
+        }
+    }
+
+    private static HttpResponse forbidden(final HttpRequest request) {
+        HttpResponse response = request.createResponse(RestStatus.FORBIDDEN, BytesArray.EMPTY);
+        response.addHeader("content-length", "0");
+        return response;
+    }
+
+    private static boolean isSameOrigin(final String origin, final String host) {
+        if (Strings.isNullOrEmpty(host) == false) {
+            // strip protocol from origin
+            final String originDomain = SCHEME_PATTERN.matcher(origin).replaceFirst("");
+            if (host.equals(originDomain)) {
+                return true;
+            }
+        }
+        return false;
+    }
+
+    private void setPreflightHeaders(final HttpResponse response) {
+        response.addHeader(CorsHandler.DATE, dateTimeFormatter.format(ZonedDateTime.now(ZoneOffset.UTC)));
+        response.addHeader("content-length", "0");
+    }
+
+    private boolean setOrigin(final HttpRequest request, final HttpResponse response) {
+        String origin = getOrigin(request);
+        if (!Strings.isNullOrEmpty(origin)) {
+            if (config.isAnyOriginSupported()) {
+                if (config.isCredentialsAllowed()) {
+                    setAllowOrigin(response, origin);
+                    setVaryHeader(response);
+                } else {
+                    setAllowOrigin(response, ANY_ORIGIN);
+                }
+                return true;
+            } else if (config.isOriginAllowed(origin) || isSameOrigin(origin, getHost(request))) {
+                setAllowOrigin(response, origin);
+                setVaryHeader(response);
+                return true;
+            }
+        }
+        return false;
+    }
+
+    private boolean validateOrigin(final HttpRequest request) {
+        if (config.isAnyOriginSupported()) {
+            return true;
+        }
+
+        final String origin = getOrigin(request);
+        if (Strings.isNullOrEmpty(origin)) {
+            // Not a CORS request so we cannot validate it. It may be a non CORS request.
+            return true;
+        }
+
+        // if the origin is the same as the host of the request, then allow
+        if (isSameOrigin(origin, getHost(request))) {
+            return true;
+        }
+
+        return config.isOriginAllowed(origin);
+    }
+
+    private static String getOrigin(HttpRequest request) {
+        List<String> headers = request.getHeaders().get(ORIGIN);
+        if (headers == null || headers.isEmpty()) {
+            return null;
+        } else {
+            return headers.get(0);
+        }
+    }
+
+    private static String getHost(HttpRequest request) {
+        List<String> headers = request.getHeaders().get(HOST);
+        if (headers == null || headers.isEmpty()) {
+            return null;
+        } else {
+            return headers.get(0);
+        }
+    }
+
+    private static boolean isPreflightRequest(final HttpRequest request) {
+        final Map<String, List<String>> headers = request.getHeaders();
+        return request.method().equals(RestRequest.Method.OPTIONS) &&
+            headers.containsKey(ORIGIN) &&
+            headers.containsKey(ACCESS_CONTROL_REQUEST_METHOD);
+    }
+
+    private static void setVaryHeader(final HttpResponse response) {
+        response.addHeader(VARY, ORIGIN);
+    }
+
+    private static void setAllowOrigin(final HttpResponse response, final String origin) {
+        response.addHeader(ACCESS_CONTROL_ALLOW_ORIGIN, origin);
+    }
+
+    private void setAllowMethods(final HttpResponse response) {
+        for (RestRequest.Method method : config.allowedRequestMethods()) {
+            response.addHeader(ACCESS_CONTROL_ALLOW_METHODS, method.name().trim());
+        }
+    }
+
+    private void setAllowHeaders(final HttpResponse response) {
+        for (String header : config.allowedRequestHeaders) {
+            response.addHeader(ACCESS_CONTROL_ALLOW_HEADERS, header);
+        }
+    }
+
+    private void setAllowCredentials(final HttpResponse response) {
+        if (config.isCredentialsAllowed()) {
+            response.addHeader(ACCESS_CONTROL_ALLOW_CREDENTIALS, "true");
+        }
+    }
+
+    private void setMaxAge(final HttpResponse response) {
+        response.addHeader(ACCESS_CONTROL_MAX_AGE, Long.toString(config.maxAge));
     }
 
     public static class Config {
@@ -218,15 +387,17 @@ public class CorsHandler {
         }
     }
 
-    public static Config disabled() {
+    public static CorsHandler disabled() {
         Config.Builder builder = new Config.Builder();
         builder.enabled = false;
-        return new Config(builder);
+        return new CorsHandler(new Config(builder));
     }
 
-    public static Config fromSettings(Settings settings) {
+    public static Config buildConfig(Settings settings) {
         if (SETTING_CORS_ENABLED.get(settings) == false) {
-            return disabled();
+            Config.Builder builder = new Config.Builder();
+            builder.enabled = false;
+            return new Config(builder);
         }
         String origin = SETTING_CORS_ALLOW_ORIGIN.get(settings);
         final CorsHandler.Config.Builder builder;
@@ -260,4 +431,8 @@ public class CorsHandler {
             .build();
         return config;
     }
+
+    public static CorsHandler fromSettings(Settings settings) {
+        return new CorsHandler(buildConfig(settings));
+    }
 }

+ 6 - 17
server/src/main/java/org/elasticsearch/http/DefaultRestChannel.java

@@ -60,19 +60,21 @@ public class DefaultRestChannel extends AbstractRestChannel implements RestChann
     private final HttpHandlingSettings settings;
     private final ThreadContext threadContext;
     private final HttpChannel httpChannel;
+    private final CorsHandler corsHandler;
 
     @Nullable
     private final HttpTracer tracerLog;
 
     DefaultRestChannel(HttpChannel httpChannel, HttpRequest httpRequest, RestRequest request, BigArrays bigArrays,
-                       HttpHandlingSettings settings, ThreadContext threadContext, @Nullable HttpTracer tracerLog) {
+                       HttpHandlingSettings settings, ThreadContext threadContext, CorsHandler corsHandler,
+                       @Nullable HttpTracer tracerLog) {
         super(request, settings.getDetailedErrorsEnabled());
         this.httpChannel = httpChannel;
-        // TODO: Fix
         this.httpRequest = httpRequest;
         this.bigArrays = bigArrays;
         this.settings = settings;
         this.threadContext = threadContext;
+        this.corsHandler = corsHandler;
         this.tracerLog = tracerLog;
     }
 
@@ -87,7 +89,7 @@ public class DefaultRestChannel extends AbstractRestChannel implements RestChann
         Releasables.closeWhileHandlingException(httpRequest::release);
 
         final ArrayList<Releasable> toClose = new ArrayList<>(3);
-        if (isCloseConnection()) {
+        if (HttpUtils.shouldCloseConnection(httpRequest)) {
             toClose.add(() -> CloseableChannel.closeChannel(httpChannel));
         }
 
@@ -112,8 +114,7 @@ public class DefaultRestChannel extends AbstractRestChannel implements RestChann
 
             final HttpResponse httpResponse = httpRequest.createResponse(restResponse.status(), finalContent);
 
-            // TODO: Ideally we should move the setting of Cors headers into :server
-            // NioCorsHandler.setCorsResponseHeaders(nettyRequest, resp, corsConfig);
+            corsHandler.setCorsResponseHeaders(httpRequest, httpResponse);
 
             opaque = request.header(X_OPAQUE_ID);
             if (opaque != null) {
@@ -180,16 +181,4 @@ public class DefaultRestChannel extends AbstractRestChannel implements RestChann
             }
         }
     }
-
-    // Determine if the request connection should be closed on completion.
-    private boolean isCloseConnection() {
-        try {
-            final boolean http10 = request.getHttpRequest().protocolVersion() == HttpRequest.HttpVersion.HTTP_1_0;
-            return CLOSE.equalsIgnoreCase(request.header(CONNECTION))
-                || (http10 && !KEEP_ALIVE.equalsIgnoreCase(request.header(CONNECTION)));
-        } catch (Exception e) {
-            // In case we fail to parse the http protocol version out of the request we always close the connection
-            return true;
-        }
-    }
 }

+ 17 - 0
server/src/main/java/org/elasticsearch/http/HttpRequest.java

@@ -24,6 +24,7 @@ import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.rest.RestRequest;
 import org.elasticsearch.rest.RestStatus;
 
+import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 
@@ -58,6 +59,22 @@ public interface HttpRequest {
      */
     Map<String, List<String>> getHeaders();
 
+    default String header(String name) {
+        List<String> values = getHeaders().get(name);
+        if (values != null && values.isEmpty() == false) {
+            return values.get(0);
+        }
+        return null;
+    }
+
+    default List<String> allHeaders(String name) {
+        List<String> values = getHeaders().get(name);
+        if (values != null) {
+            return Collections.unmodifiableList(values);
+        }
+        return null;
+    }
+
     List<String> strictCookies();
 
     HttpVersion protocolVersion();

+ 39 - 0
server/src/main/java/org/elasticsearch/http/HttpUtils.java

@@ -0,0 +1,39 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.http;
+
+public class HttpUtils {
+
+    static final String CLOSE = "close";
+    static final String CONNECTION = "connection";
+    static final String KEEP_ALIVE = "keep-alive";
+
+    // Determine if the request connection should be closed on completion.
+    public static boolean shouldCloseConnection(HttpRequest httpRequest) {
+        try {
+            final boolean http10 = httpRequest.protocolVersion() == HttpRequest.HttpVersion.HTTP_1_0;
+            return CLOSE.equalsIgnoreCase(httpRequest.header(CONNECTION))
+                || (http10 && !KEEP_ALIVE.equalsIgnoreCase(httpRequest.header(CONNECTION)));
+        } catch (Exception e) {
+            // In case we fail to parse the http protocol version out of the request we always close the connection
+            return true;
+        }
+    }
+}

+ 1 - 1
server/src/main/java/org/elasticsearch/rest/RestUtils.java

@@ -233,7 +233,7 @@ public class RestUtils {
             return null;
         }
         int len = corsSetting.length();
-        boolean isRegex = len > 2 &&  corsSetting.startsWith("/") && corsSetting.endsWith("/");
+        boolean isRegex = len > 2 && corsSetting.startsWith("/") && corsSetting.endsWith("/");
 
         if (isRegex) {
             return Pattern.compile(corsSetting.substring(1, corsSetting.length()-1));

+ 242 - 3
server/src/test/java/org/elasticsearch/http/CorsHandlerTests.java

@@ -20,15 +20,19 @@
 package org.elasticsearch.http;
 
 import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.bytes.BytesArray;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.settings.SettingsException;
 import org.elasticsearch.rest.RestRequest;
+import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.test.ESTestCase;
 
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashSet;
+import java.util.List;
 import java.util.Locale;
+import java.util.Map;
 import java.util.Set;
 import java.util.regex.PatternSyntaxException;
 import java.util.stream.Collectors;
@@ -40,8 +44,11 @@ import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_ME
 import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_ORIGIN;
 import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ENABLED;
 import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_MAX_AGE;
+import static org.hamcrest.Matchers.containsInAnyOrder;
 import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.instanceOf;
+import static org.hamcrest.Matchers.nullValue;
 
 public class CorsHandlerTests extends ESTestCase {
 
@@ -51,7 +58,7 @@ public class CorsHandlerTests extends ESTestCase {
             .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), "/[*/")
             .put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true)
             .build();
-        SettingsException e = expectThrows(SettingsException.class, () -> CorsHandler.fromSettings(settings));
+        SettingsException e = expectThrows(SettingsException.class, () -> CorsHandler.buildConfig(settings));
         assertThat(e.getMessage(), containsString("Bad regex in [http.cors.allow-origin]: [/[*/]"));
         assertThat(e.getCause(), instanceOf(PatternSyntaxException.class));
     }
@@ -67,7 +74,7 @@ public class CorsHandlerTests extends ESTestCase {
             .put(SETTING_CORS_ALLOW_HEADERS.getKey(), collectionToDelimitedString(headers, ",", prefix, ""))
             .put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true)
             .build();
-        final CorsHandler.Config corsConfig = CorsHandler.fromSettings(settings);
+        final CorsHandler.Config corsConfig = CorsHandler.buildConfig(settings);
         assertTrue(corsConfig.isAnyOriginSupported());
         assertEquals(headers, corsConfig.allowedRequestHeaders());
         assertEquals(methods.stream().map(s -> s.toUpperCase(Locale.ENGLISH)).collect(Collectors.toSet()),
@@ -79,7 +86,7 @@ public class CorsHandlerTests extends ESTestCase {
         final Set<String> headers = Strings.commaDelimitedListToSet(SETTING_CORS_ALLOW_HEADERS.getDefault(Settings.EMPTY));
         final long maxAge = SETTING_CORS_MAX_AGE.getDefault(Settings.EMPTY);
         final Settings settings = Settings.builder().put(SETTING_CORS_ENABLED.getKey(), true).build();
-        final CorsHandler.Config corsConfig = CorsHandler.fromSettings(settings);
+        final CorsHandler.Config corsConfig = CorsHandler.buildConfig(settings);
         assertFalse(corsConfig.isAnyOriginSupported());
         assertEquals(Collections.emptySet(), corsConfig.origins().get());
         assertEquals(headers, corsConfig.allowedRequestHeaders());
@@ -87,4 +94,236 @@ public class CorsHandlerTests extends ESTestCase {
         assertEquals(maxAge, corsConfig.maxAge());
         assertFalse(corsConfig.isCredentialsAllowed());
     }
+
+    public void testHandleInboundNonCorsRequest() {
+        Settings settings = Settings.builder()
+            .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true)
+            .build();
+        CorsHandler corsHandler = CorsHandler.fromSettings(settings);
+        TestHttpRequest request = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/");
+        HttpResponse httpResponse = corsHandler.handleInbound(request);
+        // Since this is not a Cors request, there is not an early response
+        assertThat(httpResponse, nullValue());
+    }
+
+    public void testHandleInboundValidCorsRequest() {
+        final String validOriginLiteral = "valid-origin";
+        final String originSetting;
+        if (randomBoolean()) {
+            originSetting = validOriginLiteral;
+        } else {
+            if (randomBoolean()) {
+                originSetting = "/valid-.+/";
+            } else {
+                originSetting = "*";
+            }
+        }
+        Settings settings = Settings.builder()
+            .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true)
+            .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originSetting)
+            .build();
+        CorsHandler corsHandler = CorsHandler.fromSettings(settings);
+        TestHttpRequest request = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.POST, "/");
+        request.getHeaders().put(CorsHandler.ORIGIN, Collections.singletonList(validOriginLiteral));
+        HttpResponse httpResponse = corsHandler.handleInbound(request);
+        // Since is a Cors enabled request. However, it is not forbidden because the origin is allowed.
+        assertThat(httpResponse, nullValue());
+    }
+
+    public void testHandleInboundForbidden() {
+        final String validOriginLiteral = "valid-origin";
+        final String originSetting;
+        if (randomBoolean()) {
+            originSetting = validOriginLiteral;
+        } else {
+            originSetting = "/valid-.+/";
+        }
+        Settings settings = Settings.builder()
+            .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true)
+            .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originSetting)
+            .build();
+        CorsHandler corsHandler = CorsHandler.fromSettings(settings);
+        TestHttpRequest request = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.POST, "/");
+        request.getHeaders().put(CorsHandler.ORIGIN, Collections.singletonList("invalid-origin"));
+        TestHttpResponse httpResponse = (TestHttpResponse) corsHandler.handleInbound(request);
+        // Forbidden
+        assertThat(httpResponse.status(), equalTo(RestStatus.FORBIDDEN));
+    }
+
+    public void testHandleInboundAllowsSameOrigin() {
+        final String validOriginLiteral = "valid-origin";
+        final String originSetting;
+        if (randomBoolean()) {
+            originSetting = validOriginLiteral;
+        } else {
+            originSetting = "/valid-.+/";
+        }
+        Settings settings = Settings.builder()
+            .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true)
+            .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originSetting)
+            .build();
+        CorsHandler corsHandler = CorsHandler.fromSettings(settings);
+        TestHttpRequest request = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.POST, "/");
+        request.getHeaders().put(CorsHandler.ORIGIN, Collections.singletonList("https://same-host"));
+        request.getHeaders().put(CorsHandler.HOST, Collections.singletonList("same-host"));
+        TestHttpResponse httpResponse = (TestHttpResponse) corsHandler.handleInbound(request);
+        // Since is a Cors enabled request. However, it is not forbidden because the origin is the same as the host.
+        assertThat(httpResponse, nullValue());
+    }
+
+    public void testHandleInboundPreflightWithWildcardNoCredentials() {
+        Settings settings = Settings.builder()
+            .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true)
+            .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), "*")
+            .put(SETTING_CORS_ALLOW_METHODS.getKey(), "OPTIONS,HEAD,GET,DELETE")
+            .put(SETTING_CORS_ALLOW_HEADERS.getKey(), "Content-Type,Content-Length")
+            .build();
+        CorsHandler corsHandler = CorsHandler.fromSettings(settings);
+        TestHttpRequest request = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.OPTIONS, "/");
+        request.getHeaders().put(CorsHandler.ORIGIN, Collections.singletonList("valid-origin"));
+        request.getHeaders().put(CorsHandler.ACCESS_CONTROL_REQUEST_METHOD, Collections.singletonList("POST"));
+        TestHttpResponse httpResponse = (TestHttpResponse) corsHandler.handleInbound(request);
+
+        assertThat(httpResponse.status(), equalTo(RestStatus.OK));
+        Map<String, List<String>> headers = httpResponse.headers();
+        assertThat(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_ORIGIN), containsInAnyOrder("*"));
+        assertThat(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_METHODS),
+            containsInAnyOrder("HEAD", "OPTIONS", "GET", "DELETE"));
+        assertThat(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_HEADERS),
+            containsInAnyOrder("Content-Type", "Content-Length"));
+        assertNull(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_CREDENTIALS));
+        assertThat(headers.get(CorsHandler.ACCESS_CONTROL_MAX_AGE), containsInAnyOrder("1728000"));
+        assertNotNull(headers.get(CorsHandler.DATE));
+    }
+
+    public void testHandleInboundPreflightWithWildcardAllowCredentials() {
+        Settings settings = Settings.builder()
+            .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true)
+            .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), "*")
+            .put(SETTING_CORS_ALLOW_METHODS.getKey(), "OPTIONS,HEAD,GET,DELETE,POST")
+            .put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true)
+            .build();
+        CorsHandler corsHandler = CorsHandler.fromSettings(settings);
+        TestHttpRequest request = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.OPTIONS, "/");
+        request.getHeaders().put(CorsHandler.ORIGIN, Collections.singletonList("valid-origin"));
+        request.getHeaders().put(CorsHandler.ACCESS_CONTROL_REQUEST_METHOD, Collections.singletonList("POST"));
+        TestHttpResponse httpResponse = (TestHttpResponse) corsHandler.handleInbound(request);
+
+        assertThat(httpResponse.status(), equalTo(RestStatus.OK));
+        Map<String, List<String>> headers = httpResponse.headers();
+        // Since credentials are allowed, we echo the origin
+        assertThat(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_ORIGIN), containsInAnyOrder("valid-origin"));
+        assertThat(headers.get(CorsHandler.VARY), containsInAnyOrder(CorsHandler.ORIGIN));
+        assertThat(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_METHODS),
+            containsInAnyOrder("HEAD", "OPTIONS", "GET", "DELETE", "POST"));
+        assertThat(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_HEADERS),
+            containsInAnyOrder("X-Requested-With", "Content-Type", "Content-Length"));
+        assertThat(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_CREDENTIALS), containsInAnyOrder("true"));
+        assertThat(headers.get(CorsHandler.ACCESS_CONTROL_MAX_AGE), containsInAnyOrder("1728000"));
+        assertNotNull(headers.get(CorsHandler.DATE));
+    }
+
+    public void testHandleInboundPreflightWithValidOriginAllowCredentials() {
+        Settings settings = Settings.builder()
+            .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true)
+            .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), "valid-origin")
+            .put(SETTING_CORS_ALLOW_METHODS.getKey(), "OPTIONS,HEAD,GET,DELETE,POST")
+            .put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true)
+            .build();
+        CorsHandler corsHandler = CorsHandler.fromSettings(settings);
+        TestHttpRequest request = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.OPTIONS, "/");
+        request.getHeaders().put(CorsHandler.ORIGIN, Collections.singletonList("valid-origin"));
+        request.getHeaders().put(CorsHandler.ACCESS_CONTROL_REQUEST_METHOD, Collections.singletonList("POST"));
+        TestHttpResponse httpResponse = (TestHttpResponse) corsHandler.handleInbound(request);
+
+        assertThat(httpResponse.status(), equalTo(RestStatus.OK));
+        Map<String, List<String>> headers = httpResponse.headers();
+        // Since credentials are allowed, we echo the origin
+        assertThat(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_ORIGIN), containsInAnyOrder("valid-origin"));
+        assertThat(headers.get(CorsHandler.VARY), containsInAnyOrder(CorsHandler.ORIGIN));
+        assertThat(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_METHODS),
+            containsInAnyOrder("HEAD", "OPTIONS", "GET", "DELETE", "POST"));
+        assertThat(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_HEADERS),
+            containsInAnyOrder("X-Requested-With", "Content-Type", "Content-Length"));
+        assertThat(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_CREDENTIALS), containsInAnyOrder("true"));
+        assertThat(headers.get(CorsHandler.ACCESS_CONTROL_MAX_AGE), containsInAnyOrder("1728000"));
+        assertNotNull(headers.get(CorsHandler.DATE));
+    }
+
+    public void testSetResponseNonCorsRequest() {
+        Settings settings = Settings.builder()
+            .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true)
+            .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), "*")
+            .put(SETTING_CORS_ALLOW_METHODS.getKey(), "OPTIONS,HEAD,GET,DELETE")
+            .put(SETTING_CORS_ALLOW_HEADERS.getKey(), "Content-Type,Content-Length")
+            .build();
+        CorsHandler corsHandler = CorsHandler.fromSettings(settings);
+
+        TestHttpRequest request = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/");
+        TestHttpResponse response = new TestHttpResponse(RestStatus.OK, BytesArray.EMPTY);
+        corsHandler.setCorsResponseHeaders(request, response);
+
+        Map<String, List<String>> headers = response.headers();
+        assertNull(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_ORIGIN));
+    }
+
+    public void testSetResponseHeadersWithWildcardOrigin() {
+        Settings settings = Settings.builder()
+            .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true)
+            .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), "*")
+            .build();
+        CorsHandler corsHandler = CorsHandler.fromSettings(settings);
+
+        TestHttpRequest request = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/");
+        request.getHeaders().put(CorsHandler.ORIGIN, Collections.singletonList("valid-origin"));
+        TestHttpResponse response = new TestHttpResponse(RestStatus.OK, BytesArray.EMPTY);
+        corsHandler.setCorsResponseHeaders(request, response);
+
+        Map<String, List<String>> headers = response.headers();
+        assertThat(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_ORIGIN), containsInAnyOrder("*"));
+        assertNull(headers.get(CorsHandler.VARY));
+    }
+
+    public void testSetResponseHeadersWithCredentialsWithWildcard() {
+        Settings settings = Settings.builder()
+            .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true)
+            .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), "*")
+            .put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true)
+            .build();
+        CorsHandler corsHandler = CorsHandler.fromSettings(settings);
+
+        TestHttpRequest request = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/");
+        request.getHeaders().put(CorsHandler.ORIGIN, Collections.singletonList("valid-origin"));
+        TestHttpResponse response = new TestHttpResponse(RestStatus.OK, BytesArray.EMPTY);
+        corsHandler.setCorsResponseHeaders(request, response);
+
+        Map<String, List<String>> headers = response.headers();
+        assertThat(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_ORIGIN), containsInAnyOrder("valid-origin"));
+        assertThat(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_CREDENTIALS), containsInAnyOrder("true"));
+        assertThat(headers.get(CorsHandler.VARY), containsInAnyOrder(CorsHandler.ORIGIN));
+    }
+
+    public void testSetResponseHeadersWithNonWildcardOrigin() {
+        boolean allowCredentials = randomBoolean();
+        Settings settings = Settings.builder()
+            .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true)
+            .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), "valid-origin")
+            .put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), allowCredentials)
+            .build();
+        CorsHandler corsHandler = CorsHandler.fromSettings(settings);
+
+        TestHttpRequest request = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/");
+        request.getHeaders().put(CorsHandler.ORIGIN, Collections.singletonList("valid-origin"));
+        TestHttpResponse response = new TestHttpResponse(RestStatus.OK, BytesArray.EMPTY);
+        corsHandler.setCorsResponseHeaders(request, response);
+
+        Map<String, List<String>> headers = response.headers();
+        assertThat(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_ORIGIN), containsInAnyOrder("valid-origin"));
+        assertThat(headers.get(CorsHandler.VARY), containsInAnyOrder(CorsHandler.ORIGIN));
+        if (allowCredentials) {
+            assertThat(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_CREDENTIALS), containsInAnyOrder("true"));
+        } else {
+            assertNull(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_CREDENTIALS));
+        }
+    }
 }

+ 85 - 235
server/src/test/java/org/elasticsearch/http/DefaultRestChannelTests.java

@@ -50,19 +50,17 @@ import org.mockito.ArgumentCaptor;
 import java.io.IOException;
 import java.nio.channels.ClosedChannelException;
 import java.nio.charset.StandardCharsets;
-import java.util.ArrayList;
-import java.util.Arrays;
 import java.util.Collections;
-import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
-import java.util.function.Supplier;
 
+import static org.hamcrest.Matchers.containsInAnyOrder;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.hasItem;
 import static org.hamcrest.Matchers.instanceOf;
 import static org.hamcrest.Matchers.not;
+import static org.hamcrest.Matchers.nullValue;
 import static org.mockito.Matchers.any;
 import static org.mockito.Mockito.atLeastOnce;
 import static org.mockito.Mockito.mock;
@@ -90,109 +88,72 @@ public class DefaultRestChannelTests extends ESTestCase {
     }
 
     public void testResponse() {
-        final TestResponse response = executeRequest(Settings.EMPTY, "request-host");
+        final TestHttpResponse response = executeRequest(Settings.EMPTY, "request-host");
         assertThat(response.content(), equalTo(new TestRestResponse().content()));
     }
 
-    // TODO: Enable these Cors tests when the Cors logic lives in :server
-
-//    public void testCorsEnabledWithoutAllowOrigins() {
-//        // Set up an HTTP transport with only the CORS enabled setting
-//        Settings settings = Settings.builder()
-//            .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true)
-//            .build();
-//        HttpResponse response = executeRequest(settings, "remote-host", "request-host");
-//        // inspect response and validate
-//        assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), nullValue());
-//    }
-//
-//    public void testCorsEnabledWithAllowOrigins() {
-//        final String originValue = "remote-host";
-//        // create an HTTP transport with CORS enabled and allow origin configured
-//        Settings settings = Settings.builder()
-//            .put(SETTING_CORS_ENABLED.getKey(), true)
-//            .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue)
-//            .build();
-//        HttpResponse response = executeRequest(settings, originValue, "request-host");
-//        // inspect response and validate
-//        assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
-//        String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
-//        assertThat(allowedOrigins, is(originValue));
-//    }
-//
-//    public void testCorsAllowOriginWithSameHost() {
-//        String originValue = "remote-host";
-//        String host = "remote-host";
-//        // create an HTTP transport with CORS enabled
-//        Settings settings = Settings.builder()
-//            .put(SETTING_CORS_ENABLED.getKey(), true)
-//            .build();
-//        HttpResponse response = executeRequest(settings, originValue, host);
-//        // inspect response and validate
-//        assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
-//        String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
-//        assertThat(allowedOrigins, is(originValue));
-//
-//        originValue = "http://" + originValue;
-//        response = executeRequest(settings, originValue, host);
-//        assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
-//        allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
-//        assertThat(allowedOrigins, is(originValue));
-//
-//        originValue = originValue + ":5555";
-//        host = host + ":5555";
-//        response = executeRequest(settings, originValue, host);
-//        assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
-//        allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
-//        assertThat(allowedOrigins, is(originValue));
-//
-//        originValue = originValue.replace("http", "https");
-//        response = executeRequest(settings, originValue, host);
-//        assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
-//        allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
-//        assertThat(allowedOrigins, is(originValue));
-//    }
-//
-//    public void testThatStringLiteralWorksOnMatch() {
-//        final String originValue = "remote-host";
-//        Settings settings = Settings.builder()
-//            .put(SETTING_CORS_ENABLED.getKey(), true)
-//            .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue)
-//            .put(SETTING_CORS_ALLOW_METHODS.getKey(), "get, options, post")
-//            .put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true)
-//            .build();
-//        HttpResponse response = executeRequest(settings, originValue, "request-host");
-//        // inspect response and validate
-//        assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
-//        String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
-//        assertThat(allowedOrigins, is(originValue));
-//        assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true"));
-//    }
-//
-//    public void testThatAnyOriginWorks() {
-//        final String originValue = NioCorsHandler.ANY_ORIGIN;
-//        Settings settings = Settings.builder()
-//            .put(SETTING_CORS_ENABLED.getKey(), true)
-//            .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue)
-//            .build();
-//        HttpResponse response = executeRequest(settings, originValue, "request-host");
-//        // inspect response and validate
-//        assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
-//        String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
-//        assertThat(allowedOrigins, is(originValue));
-//        assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), nullValue());
-//    }
+    public void testCorsEnabledWithoutAllowOrigins() {
+        // Set up an HTTP transport with only the CORS enabled setting
+        Settings settings = Settings.builder()
+            .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true)
+            .build();
+        TestHttpResponse response = executeRequest(settings, "request-host");
+        assertThat(response.headers().get(CorsHandler.ACCESS_CONTROL_ALLOW_ORIGIN), nullValue());
+    }
+
+    public void testCorsEnabledWithAllowOrigins() {
+        final String originValue = "remote-host";
+        final String pattern;
+        if (randomBoolean()) {
+            pattern = originValue;
+        } else {
+            pattern = "/remote-hos.+/";
+        }
+        // create an HTTP transport with CORS enabled and allow origin configured
+        Settings settings = Settings.builder()
+            .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true)
+            .put(HttpTransportSettings.SETTING_CORS_ALLOW_ORIGIN.getKey(), pattern)
+            .build();
+        TestHttpResponse response = executeRequest(settings, originValue, "https://127.0.0.1");
+        assertEquals(originValue, response.headers().get(CorsHandler.ACCESS_CONTROL_ALLOW_ORIGIN).get(0));
+        assertThat(response.headers().get(CorsHandler.VARY), containsInAnyOrder(CorsHandler.ORIGIN));
+    }
+
+    public void testCorsEnabledWithAllowOriginsAndAllowCredentials() {
+        final String originValue = "remote-host";
+        // create an HTTP transport with CORS enabled and allow origin configured
+        Settings settings = Settings.builder()
+            .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true)
+            .put(HttpTransportSettings.SETTING_CORS_ALLOW_ORIGIN.getKey(), CorsHandler.ANY_ORIGIN)
+            .put(HttpTransportSettings.SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true)
+            .build();
+        TestHttpResponse response = executeRequest(settings, originValue, "https://127.0.0.1");
+        assertEquals(originValue, response.headers().get(CorsHandler.ACCESS_CONTROL_ALLOW_ORIGIN).get(0));
+        assertEquals(CorsHandler.ORIGIN, response.headers().get(CorsHandler.VARY).get(0));
+        assertEquals("true", response.headers().get(CorsHandler.ACCESS_CONTROL_ALLOW_CREDENTIALS).get(0));
+    }
+
+    public void testThatAnyOriginWorks() {
+        final String originValue = CorsHandler.ANY_ORIGIN;
+        Settings settings = Settings.builder()
+            .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true)
+            .put(HttpTransportSettings.SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue)
+            .build();
+        TestHttpResponse response = executeRequest(settings, originValue, "https://127.0.0.1");
+        assertEquals(originValue, response.headers().get(CorsHandler.ACCESS_CONTROL_ALLOW_ORIGIN).get(0));
+        assertNull(response.headers().get(CorsHandler.VARY));
+    }
 
     public void testHeadersSet() {
         Settings settings = Settings.builder().build();
-        final TestRequest httpRequest = new TestRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/");
+        final TestHttpRequest httpRequest = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/");
         httpRequest.getHeaders().put(Task.X_OPAQUE_ID, Collections.singletonList("abc"));
         final RestRequest request = RestRequest.request(xContentRegistry(), httpRequest, httpChannel);
         HttpHandlingSettings handlingSettings = HttpHandlingSettings.fromSettings(settings);
 
         // send a response
         DefaultRestChannel channel = new DefaultRestChannel(httpChannel, httpRequest, request, bigArrays, handlingSettings,
-            threadPool.getThreadContext(), null);
+            threadPool.getThreadContext(), CorsHandler.fromSettings(settings), null);
         TestRestResponse resp = new TestRestResponse();
         final String customHeader = "custom-header";
         final String customHeaderValue = "xyz";
@@ -200,10 +161,10 @@ public class DefaultRestChannelTests extends ESTestCase {
         channel.sendResponse(resp);
 
         // inspect what was written
-        ArgumentCaptor<TestResponse> responseCaptor = ArgumentCaptor.forClass(TestResponse.class);
+        ArgumentCaptor<TestHttpResponse> responseCaptor = ArgumentCaptor.forClass(TestHttpResponse.class);
         verify(httpChannel).sendResponse(responseCaptor.capture(), any());
-        TestResponse httpResponse = responseCaptor.getValue();
-        Map<String, List<String>> headers = httpResponse.headers;
+        TestHttpResponse httpResponse = responseCaptor.getValue();
+        Map<String, List<String>> headers = httpResponse.headers();
         assertNull(headers.get("non-existent-header"));
         assertEquals(customHeaderValue, headers.get(customHeader).get(0));
         assertEquals("abc", headers.get(Task.X_OPAQUE_ID).get(0));
@@ -213,21 +174,21 @@ public class DefaultRestChannelTests extends ESTestCase {
 
     public void testCookiesSet() {
         Settings settings = Settings.builder().put(HttpTransportSettings.SETTING_HTTP_RESET_COOKIES.getKey(), true).build();
-        final TestRequest httpRequest = new TestRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/");
+        final TestHttpRequest httpRequest = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/");
         httpRequest.getHeaders().put(Task.X_OPAQUE_ID, Collections.singletonList("abc"));
         final RestRequest request = RestRequest.request(xContentRegistry(), httpRequest, httpChannel);
         HttpHandlingSettings handlingSettings = HttpHandlingSettings.fromSettings(settings);
 
         // send a response
         DefaultRestChannel channel = new DefaultRestChannel(httpChannel, httpRequest, request, bigArrays, handlingSettings,
-            threadPool.getThreadContext(), null);
+            threadPool.getThreadContext(), CorsHandler.fromSettings(settings), null);
         channel.sendResponse(new TestRestResponse());
 
         // inspect what was written
-        ArgumentCaptor<TestResponse> responseCaptor = ArgumentCaptor.forClass(TestResponse.class);
+        ArgumentCaptor<TestHttpResponse> responseCaptor = ArgumentCaptor.forClass(TestHttpResponse.class);
         verify(httpChannel).sendResponse(responseCaptor.capture(), any());
-        TestResponse nioResponse = responseCaptor.getValue();
-        Map<String, List<String>> headers = nioResponse.headers;
+        TestHttpResponse nioResponse = responseCaptor.getValue();
+        Map<String, List<String>> headers = nioResponse.headers();
         assertThat(headers.get(DefaultRestChannel.SET_COOKIE), hasItem("cookie"));
         assertThat(headers.get(DefaultRestChannel.SET_COOKIE), hasItem("cookie2"));
     }
@@ -235,12 +196,12 @@ public class DefaultRestChannelTests extends ESTestCase {
     @SuppressWarnings("unchecked")
     public void testReleaseInListener() throws IOException {
         final Settings settings = Settings.builder().build();
-        final TestRequest httpRequest = new TestRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/");
+        final TestHttpRequest httpRequest = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/");
         final RestRequest request = RestRequest.request(xContentRegistry(), httpRequest, httpChannel);
         HttpHandlingSettings handlingSettings = HttpHandlingSettings.fromSettings(settings);
 
         DefaultRestChannel channel = new DefaultRestChannel(httpChannel, httpRequest, request, bigArrays, handlingSettings,
-            threadPool.getThreadContext(), null);
+            threadPool.getThreadContext(), CorsHandler.fromSettings(settings), null);
         final BytesRestResponse response = new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR,
             JsonXContent.contentBuilder().startObject().endObject());
         assertThat(response.content(), not(instanceOf(Releasable.class)));
@@ -276,16 +237,16 @@ public class DefaultRestChannelTests extends ESTestCase {
         final boolean brokenRequest = randomBoolean();
         final boolean close = brokenRequest || randomBoolean();
         if (brokenRequest) {
-            httpRequest = new TestRequest(() -> {
+            httpRequest = new TestHttpRequest(() -> {
                 throw new IllegalArgumentException("Can't parse HTTP version");
             }, RestRequest.Method.GET, "/");
         } else if (randomBoolean()) {
-            httpRequest = new TestRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/");
+            httpRequest = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/");
             if (close) {
                 httpRequest.getHeaders().put(DefaultRestChannel.CONNECTION, Collections.singletonList(DefaultRestChannel.CLOSE));
             }
         } else {
-            httpRequest = new TestRequest(HttpRequest.HttpVersion.HTTP_1_0, RestRequest.Method.GET, "/");
+            httpRequest = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_0, RestRequest.Method.GET, "/");
             if (!close) {
                 httpRequest.getHeaders().put(DefaultRestChannel.CONNECTION, Collections.singletonList(DefaultRestChannel.KEEP_ALIVE));
             }
@@ -295,7 +256,7 @@ public class DefaultRestChannelTests extends ESTestCase {
         HttpHandlingSettings handlingSettings = HttpHandlingSettings.fromSettings(settings);
 
         DefaultRestChannel channel = new DefaultRestChannel(httpChannel, httpRequest, request, bigArrays, handlingSettings,
-            threadPool.getThreadContext(), null);
+            threadPool.getThreadContext(), CorsHandler.fromSettings(settings), null);
         channel.sendResponse(new TestRestResponse());
         Class<ActionListener<Void>> listenerClass = (Class<ActionListener<Void>>) (Class) ActionListener.class;
         ArgumentCaptor<ActionListener<Void>> listenerCaptor = ArgumentCaptor.forClass(listenerClass);
@@ -317,7 +278,7 @@ public class DefaultRestChannelTests extends ESTestCase {
         final boolean close = randomBoolean();
         final HttpRequest.HttpVersion httpVersion = close ? HttpRequest.HttpVersion.HTTP_1_0 : HttpRequest.HttpVersion.HTTP_1_1;
         final String httpConnectionHeaderValue = close ? DefaultRestChannel.CLOSE : DefaultRestChannel.KEEP_ALIVE;
-        final RestRequest request = RestRequest.request(xContentRegistry(), new TestRequest(httpVersion, null, "/") {
+        final RestRequest request = RestRequest.request(xContentRegistry(), new TestHttpRequest(httpVersion, null, "/") {
             @Override
             public RestRequest.Method method() {
                 throw new IllegalArgumentException("test");
@@ -326,7 +287,8 @@ public class DefaultRestChannelTests extends ESTestCase {
         request.getHttpRequest().getHeaders().put(DefaultRestChannel.CONNECTION, Collections.singletonList(httpConnectionHeaderValue));
 
         DefaultRestChannel channel = new DefaultRestChannel(httpChannel, request.getHttpRequest(), request, bigArrays,
-            HttpHandlingSettings.fromSettings(Settings.EMPTY), threadPool.getThreadContext(), null);
+            HttpHandlingSettings.fromSettings(Settings.EMPTY), threadPool.getThreadContext(), CorsHandler.fromSettings(Settings.EMPTY),
+            null);
 
         // ESTestCase#after will invoke ensureAllArraysAreReleased which will fail if the response content was not released
         final BigArrays bigArrays = new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService());
@@ -354,7 +316,7 @@ public class DefaultRestChannelTests extends ESTestCase {
         final boolean close = randomBoolean();
         final HttpRequest.HttpVersion httpVersion = close ? HttpRequest.HttpVersion.HTTP_1_0 : HttpRequest.HttpVersion.HTTP_1_1;
         final String httpConnectionHeaderValue = close ? DefaultRestChannel.CLOSE : DefaultRestChannel.KEEP_ALIVE;
-        final RestRequest request = RestRequest.request(xContentRegistry(), new TestRequest(httpVersion, null, "/") {
+        final RestRequest request = RestRequest.request(xContentRegistry(), new TestHttpRequest(httpVersion, null, "/") {
             @Override
             public HttpResponse createResponse(RestStatus status, BytesReference content) {
                 throw new IllegalArgumentException("test");
@@ -363,7 +325,8 @@ public class DefaultRestChannelTests extends ESTestCase {
         request.getHttpRequest().getHeaders().put(DefaultRestChannel.CONNECTION, Collections.singletonList(httpConnectionHeaderValue));
 
         DefaultRestChannel channel = new DefaultRestChannel(httpChannel, request.getHttpRequest(), request, bigArrays,
-            HttpHandlingSettings.fromSettings(Settings.EMPTY), threadPool.getThreadContext(), null);
+            HttpHandlingSettings.fromSettings(Settings.EMPTY), threadPool.getThreadContext(), CorsHandler.fromSettings(Settings.EMPTY),
+            null);
 
         // ESTestCase#after will invoke ensureAllArraysAreReleased which will fail if the response content was not released
         final BigArrays bigArrays = new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService());
@@ -379,142 +342,29 @@ public class DefaultRestChannelTests extends ESTestCase {
         }
     }
 
-    private TestResponse executeRequest(final Settings settings, final String host) {
+    private TestHttpResponse executeRequest(final Settings settings, final String host) {
         return executeRequest(settings, null, host);
     }
 
-    private TestResponse executeRequest(final Settings settings, final String originValue, final String host) {
-        HttpRequest httpRequest = new TestRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/");
-        // TODO: These exist for the Cors tests
-//        if (originValue != null) {
-//            httpRequest.headers().add(HttpHeaderNames.ORIGIN, originValue);
-//        }
-//        httpRequest.headers().add(HttpHeaderNames.HOST, host);
+    private TestHttpResponse executeRequest(final Settings settings, final String originValue, final String host) {
+        HttpRequest httpRequest = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/");
+        if (originValue != null) {
+            httpRequest.getHeaders().put(CorsHandler.ORIGIN, Collections.singletonList(originValue));
+        }
+        httpRequest.getHeaders().put(CorsHandler.HOST, Collections.singletonList(host));
         final RestRequest request = RestRequest.request(xContentRegistry(), httpRequest, httpChannel);
 
         HttpHandlingSettings httpHandlingSettings = HttpHandlingSettings.fromSettings(settings);
         RestChannel channel = new DefaultRestChannel(httpChannel, httpRequest, request, bigArrays, httpHandlingSettings,
-            threadPool.getThreadContext(), null);
+            threadPool.getThreadContext(), new CorsHandler(CorsHandler.buildConfig(settings)), null);
         channel.sendResponse(new TestRestResponse());
 
         // get the response
-        ArgumentCaptor<TestResponse> responseCaptor = ArgumentCaptor.forClass(TestResponse.class);
+        ArgumentCaptor<TestHttpResponse> responseCaptor = ArgumentCaptor.forClass(TestHttpResponse.class);
         verify(httpChannel, atLeastOnce()).sendResponse(responseCaptor.capture(), any());
         return responseCaptor.getValue();
     }
 
-    private static class TestRequest implements HttpRequest {
-
-        private final Supplier<HttpVersion> version;
-        private final RestRequest.Method method;
-        private final String uri;
-        private HashMap<String, List<String>> headers = new HashMap<>();
-
-        private TestRequest(Supplier<HttpVersion> versionSupplier, RestRequest.Method method, String uri) {
-            this.version = versionSupplier;
-            this.method = method;
-            this.uri = uri;
-        }
-
-        private TestRequest(HttpVersion version, RestRequest.Method method, String uri) {
-            this(() -> version, method, uri);
-        }
-
-        @Override
-        public RestRequest.Method method() {
-            return method;
-        }
-
-        @Override
-        public String uri() {
-            return uri;
-        }
-
-        @Override
-        public BytesReference content() {
-            return BytesArray.EMPTY;
-        }
-
-        @Override
-        public Map<String, List<String>> getHeaders() {
-            return headers;
-        }
-
-        @Override
-        public List<String> strictCookies() {
-            return Arrays.asList("cookie", "cookie2");
-        }
-
-        @Override
-        public HttpVersion protocolVersion() {
-            return version.get();
-        }
-
-        @Override
-        public HttpRequest removeHeader(String header) {
-            throw new UnsupportedOperationException("Do not support removing header on test request.");
-        }
-
-        @Override
-        public HttpResponse createResponse(RestStatus status, BytesReference content) {
-            return new TestResponse(status, content);
-        }
-
-        @Override
-        public void release() {
-        }
-
-        @Override
-        public HttpRequest releaseAndCopy() {
-            return this;
-        }
-
-        @Override
-        public Exception getInboundException() {
-            return null;
-        }
-    }
-
-    private static class TestResponse implements HttpResponse {
-
-        private final RestStatus status;
-        private final BytesReference content;
-        private final Map<String, List<String>> headers = new HashMap<>();
-
-        TestResponse(RestStatus status, BytesReference content) {
-            this.status = status;
-            this.content = content;
-        }
-
-        public String contentType() {
-            return "text";
-        }
-
-        public BytesReference content() {
-            return content;
-        }
-
-        public RestStatus status() {
-            return status;
-        }
-
-        @Override
-        public void addHeader(String name, String value) {
-            if (headers.containsKey(name) == false) {
-                ArrayList<String> values = new ArrayList<>();
-                values.add(value);
-                headers.put(name, values);
-            } else {
-                headers.get(name).add(value);
-            }
-        }
-
-        @Override
-        public boolean containsHeader(String name) {
-            return headers.containsKey(name);
-        }
-    }
-
     private static class TestRestResponse extends RestResponse {
 
         private final RestStatus status;

+ 103 - 0
server/src/test/java/org/elasticsearch/http/TestHttpRequest.java

@@ -0,0 +1,103 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.http;
+
+import org.elasticsearch.common.bytes.BytesArray;
+import org.elasticsearch.common.bytes.BytesReference;
+import org.elasticsearch.rest.RestRequest;
+import org.elasticsearch.rest.RestStatus;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Supplier;
+
+class TestHttpRequest implements HttpRequest {
+
+    private final Supplier<HttpVersion> version;
+    private final RestRequest.Method method;
+    private final String uri;
+    private final HashMap<String, List<String>> headers = new HashMap<>();
+
+    TestHttpRequest(Supplier<HttpVersion> versionSupplier, RestRequest.Method method, String uri) {
+        this.version = versionSupplier;
+        this.method = method;
+        this.uri = uri;
+    }
+
+    TestHttpRequest(HttpVersion version, RestRequest.Method method, String uri) {
+        this(() -> version, method, uri);
+    }
+
+    @Override
+    public RestRequest.Method method() {
+        return method;
+    }
+
+    @Override
+    public String uri() {
+        return uri;
+    }
+
+    @Override
+    public BytesReference content() {
+        return BytesArray.EMPTY;
+    }
+
+    @Override
+    public Map<String, List<String>> getHeaders() {
+        return headers;
+    }
+
+    @Override
+    public List<String> strictCookies() {
+        return Arrays.asList("cookie", "cookie2");
+    }
+
+    @Override
+    public HttpVersion protocolVersion() {
+        return version.get();
+    }
+
+    @Override
+    public HttpRequest removeHeader(String header) {
+        throw new UnsupportedOperationException("Do not support removing header on test request.");
+    }
+
+    @Override
+    public HttpResponse createResponse(RestStatus status, BytesReference content) {
+        return new TestHttpResponse(status, content);
+    }
+
+    @Override
+    public void release() {
+    }
+
+    @Override
+    public HttpRequest releaseAndCopy() {
+        return this;
+    }
+
+    @Override
+    public Exception getInboundException() {
+        return null;
+    }
+}

+ 68 - 0
server/src/test/java/org/elasticsearch/http/TestHttpResponse.java

@@ -0,0 +1,68 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.http;
+
+import org.elasticsearch.common.bytes.BytesReference;
+import org.elasticsearch.rest.RestStatus;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+class TestHttpResponse implements HttpResponse {
+
+    private final RestStatus status;
+    private final BytesReference content;
+    private final Map<String, List<String>> headers = new HashMap<>();
+
+    TestHttpResponse(RestStatus status, BytesReference content) {
+        this.status = status;
+        this.content = content;
+    }
+
+    public BytesReference content() {
+        return content;
+    }
+
+    public RestStatus status() {
+        return status;
+    }
+
+    public Map<String, List<String>> headers() {
+        return headers;
+    }
+
+    @Override
+    public void addHeader(String name, String value) {
+        if (headers.containsKey(name) == false) {
+            ArrayList<String> values = new ArrayList<>();
+            values.add(value);
+            headers.put(name, values);
+        } else {
+            headers.get(name).add(value);
+        }
+    }
+
+    @Override
+    public boolean containsHeader(String name) {
+        return headers.containsKey(name);
+    }
+}

+ 1 - 1
x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioHttpServerTransport.java

@@ -94,7 +94,7 @@ public class SecurityNioHttpServerTransport extends NioHttpServerTransport {
         public NioHttpChannel createChannel(NioSelector selector, SocketChannel channel, Config.Socket socketConfig) throws IOException {
             NioHttpChannel httpChannel = new NioHttpChannel(channel);
             HttpReadWriteHandler httpHandler = new HttpReadWriteHandler(httpChannel,SecurityNioHttpServerTransport.this,
-                handlingSettings, corsConfig, selector.getTaskScheduler(), threadPool::relativeTimeInNanos);
+                handlingSettings, selector.getTaskScheduler(), threadPool::relativeTimeInNanos);
             final NioChannelHandler handler;
             if (ipFilter != null) {
                 handler = new NioIPFilter(httpHandler, socketConfig.getRemoteAddress(), ipFilter, IPFilter.HTTP_PROFILE_NAME);