瀏覽代碼

Add Cors integration tests (#44361)

This commit adds integration tests to ensure that the basic cors
functionality works for the netty and nio transports.
Tim Brooks 6 年之前
父節點
當前提交
c9607061ae

+ 63 - 1
modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerTransportTests.java

@@ -50,6 +50,7 @@ import org.elasticsearch.common.util.MockBigArrays;
 import org.elasticsearch.common.util.MockPageCacheRecycler;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
 import org.elasticsearch.http.BindHttpException;
+import org.elasticsearch.http.CorsHandler;
 import org.elasticsearch.http.HttpServerTransport;
 import org.elasticsearch.http.HttpTransportSettings;
 import org.elasticsearch.http.NullDispatcher;
@@ -71,6 +72,8 @@ import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicReference;
 
+import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_ORIGIN;
+import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ENABLED;
 import static org.elasticsearch.rest.RestStatus.BAD_REQUEST;
 import static org.elasticsearch.rest.RestStatus.OK;
 import static org.hamcrest.Matchers.containsString;
@@ -193,7 +196,7 @@ public class Netty4HttpServerTransportTests extends ESTestCase {
             Settings settings = Settings.builder().put("http.port", remoteAddress.getPort()).build();
             try (Netty4HttpServerTransport otherTransport = new Netty4HttpServerTransport(settings, networkService, bigArrays, threadPool,
                     xContentRegistry(), new NullDispatcher())) {
-                BindHttpException bindHttpException = expectThrows(BindHttpException.class, () -> otherTransport.start());
+                BindHttpException bindHttpException = expectThrows(BindHttpException.class, otherTransport::start);
                 assertEquals("Failed to bind to [" + remoteAddress.getPort() + "]", bindHttpException.getMessage());
             }
         }
@@ -260,6 +263,65 @@ public class Netty4HttpServerTransportTests extends ESTestCase {
         assertThat(causeReference.get(), instanceOf(TooLongFrameException.class));
     }
 
+    public void testCorsRequest() throws InterruptedException {
+        final HttpServerTransport.Dispatcher dispatcher = new HttpServerTransport.Dispatcher() {
+
+            @Override
+            public void dispatchRequest(final RestRequest request, final RestChannel channel, final ThreadContext threadContext) {
+                throw new AssertionError();
+            }
+
+            @Override
+            public void dispatchBadRequest(final RestRequest request,
+                                           final RestChannel channel,
+                                           final ThreadContext threadContext,
+                                           final Throwable cause) {
+                throw new AssertionError();
+            }
+
+        };
+
+        final Settings settings = Settings.builder()
+            .put(SETTING_CORS_ENABLED.getKey(), true)
+            .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), "elastic.co").build();
+
+        try (Netty4HttpServerTransport transport =
+                 new Netty4HttpServerTransport(settings, networkService, bigArrays, threadPool, xContentRegistry(), dispatcher)) {
+            transport.start();
+            final TransportAddress remoteAddress = randomFrom(transport.boundAddress().boundAddresses());
+
+            // Test pre-flight request
+            try (Netty4HttpClient client = new Netty4HttpClient()) {
+                final FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.OPTIONS, "/");
+                request.headers().add(CorsHandler.ORIGIN, "elastic.co");
+                request.headers().add(CorsHandler.ACCESS_CONTROL_REQUEST_METHOD, "POST");
+
+                final FullHttpResponse response = client.post(remoteAddress.address(), request);
+                try {
+                    assertThat(response.status(), equalTo(HttpResponseStatus.OK));
+                    assertThat(response.headers().get(CorsHandler.ACCESS_CONTROL_ALLOW_ORIGIN), equalTo("elastic.co"));
+                    assertThat(response.headers().get(CorsHandler.VARY), equalTo(CorsHandler.ORIGIN));
+                    assertTrue(response.headers().contains(CorsHandler.DATE));
+                } finally {
+                    response.release();
+                }
+            }
+
+            // Test short-circuited request
+            try (Netty4HttpClient client = new Netty4HttpClient()) {
+                final FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/");
+                request.headers().add(CorsHandler.ORIGIN, "elastic2.co");
+
+                final FullHttpResponse response = client.post(remoteAddress.address(), request);
+                try {
+                    assertThat(response.status(), equalTo(HttpResponseStatus.FORBIDDEN));
+                } finally {
+                    response.release();
+                }
+            }
+        }
+    }
+
     public void testReadTimeout() throws Exception {
         final HttpServerTransport.Dispatcher dispatcher = new HttpServerTransport.Dispatcher() {
 

+ 9 - 3
plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpClient.java

@@ -110,7 +110,7 @@ class NioHttpClient implements Closeable {
         return sendRequests(remoteAddress, requests);
     }
 
-    public final FullHttpResponse post(InetSocketAddress remoteAddress, FullHttpRequest httpRequest) throws InterruptedException {
+    public final FullHttpResponse send(InetSocketAddress remoteAddress, FullHttpRequest httpRequest) throws InterruptedException {
         Collection<FullHttpResponse> responses = sendRequests(remoteAddress, Collections.singleton(httpRequest));
         assert responses.size() == 1 : "expected 1 and only 1 http response";
         return responses.iterator().next();
@@ -271,7 +271,7 @@ class NioHttpClient implements Closeable {
             int bytesConsumed = adaptor.read(channelBuffer.sliceAndRetainPagesTo(channelBuffer.getIndex()));
             Object message;
             while ((message = adaptor.pollInboundMessage()) != null) {
-                handleRequest(message);
+                handleResponse(message);
             }
 
             return bytesConsumed;
@@ -286,12 +286,18 @@ class NioHttpClient implements Closeable {
         public void close() throws IOException {
             try {
                 adaptor.close();
+                // After closing the pipeline, we must poll to see if any new messages are available. This
+                // is because HTTP supports a channel being closed as an end of content marker.
+                Object message;
+                while ((message = adaptor.pollInboundMessage()) != null) {
+                    handleResponse(message);
+                }
             } catch (Exception e) {
                 throw new IOException(e);
             }
         }
 
-        private void handleRequest(Object message) {
+        private void handleResponse(Object message) {
             final FullHttpResponse response = (FullHttpResponse) message;
             DefaultFullHttpResponse newResponse = new DefaultFullHttpResponse(response.protocolVersion(),
                 response.status(),

+ 65 - 3
plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpServerTransportTests.java

@@ -43,6 +43,7 @@ import org.elasticsearch.common.util.MockBigArrays;
 import org.elasticsearch.common.util.MockPageCacheRecycler;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
 import org.elasticsearch.http.BindHttpException;
+import org.elasticsearch.http.CorsHandler;
 import org.elasticsearch.http.HttpServerTransport;
 import org.elasticsearch.http.HttpTransportSettings;
 import org.elasticsearch.http.NullDispatcher;
@@ -66,6 +67,8 @@ import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicReference;
 
+import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_ORIGIN;
+import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ENABLED;
 import static org.elasticsearch.rest.RestStatus.BAD_REQUEST;
 import static org.elasticsearch.rest.RestStatus.OK;
 import static org.hamcrest.Matchers.containsString;
@@ -159,13 +162,13 @@ public class NioHttpServerTransportTests extends ESTestCase {
                 request.headers().set(HttpHeaderNames.EXPECT, expectation);
                 HttpUtil.setContentLength(request, contentLength);
 
-                final FullHttpResponse response = client.post(remoteAddress.address(), request);
+                final FullHttpResponse response = client.send(remoteAddress.address(), request);
                 try {
                     assertThat(response.status(), equalTo(expectedStatus));
                     if (expectedStatus.equals(HttpResponseStatus.CONTINUE)) {
                         final FullHttpRequest continuationRequest =
                             new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/", Unpooled.EMPTY_BUFFER);
-                        final FullHttpResponse continuationResponse = client.post(remoteAddress.address(), continuationRequest);
+                        final FullHttpResponse continuationResponse = client.send(remoteAddress.address(), continuationRequest);
                         try {
                             assertThat(continuationResponse.status(), is(HttpResponseStatus.OK));
                             assertThat(
@@ -196,6 +199,65 @@ public class NioHttpServerTransportTests extends ESTestCase {
         }
     }
 
+    public void testCorsRequest() throws InterruptedException {
+        final HttpServerTransport.Dispatcher dispatcher = new HttpServerTransport.Dispatcher() {
+
+            @Override
+            public void dispatchRequest(final RestRequest request, final RestChannel channel, final ThreadContext threadContext) {
+                throw new AssertionError();
+            }
+
+            @Override
+            public void dispatchBadRequest(final RestRequest request,
+                                           final RestChannel channel,
+                                           final ThreadContext threadContext,
+                                           final Throwable cause) {
+                throw new AssertionError();
+            }
+
+        };
+
+        final Settings settings = Settings.builder()
+            .put(SETTING_CORS_ENABLED.getKey(), true)
+            .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), "elastic.co").build();
+
+        try (NioHttpServerTransport transport = new NioHttpServerTransport(settings, networkService, bigArrays, pageRecycler,
+            threadPool, xContentRegistry(), dispatcher, new NioGroupFactory(settings, logger))) {
+            transport.start();
+            final TransportAddress remoteAddress = randomFrom(transport.boundAddress().boundAddresses());
+
+            // Test pre-flight request
+            try (NioHttpClient client = new NioHttpClient()) {
+                final FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.OPTIONS, "/");
+                request.headers().add(CorsHandler.ORIGIN, "elastic.co");
+                request.headers().add(CorsHandler.ACCESS_CONTROL_REQUEST_METHOD, "POST");
+
+                final FullHttpResponse response = client.send(remoteAddress.address(), request);
+                try {
+                    assertThat(response.status(), equalTo(HttpResponseStatus.OK));
+                    assertThat(response.headers().get(CorsHandler.ACCESS_CONTROL_ALLOW_ORIGIN), equalTo("elastic.co"));
+                    assertThat(response.headers().get(CorsHandler.VARY), equalTo(CorsHandler.ORIGIN));
+                    assertTrue(response.headers().contains(CorsHandler.DATE));
+                } finally {
+                    response.release();
+                }
+            }
+
+            // Test short-circuited request
+            try (NioHttpClient client = new NioHttpClient()) {
+                final FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/");
+                request.headers().add(CorsHandler.ORIGIN, "elastic2.co");
+
+                final FullHttpResponse response = client.send(remoteAddress.address(), request);
+                try {
+                    assertThat(response.status(), equalTo(HttpResponseStatus.FORBIDDEN));
+                } finally {
+                    response.release();
+                }
+            }
+        }
+    }
+
     public void testBadRequest() throws InterruptedException {
         final AtomicReference<Throwable> causeReference = new AtomicReference<>();
         final HttpServerTransport.Dispatcher dispatcher = new HttpServerTransport.Dispatcher() {
@@ -241,7 +303,7 @@ public class NioHttpServerTransportTests extends ESTestCase {
                 final String url = "/" + new String(new byte[maxInitialLineLength], Charset.forName("UTF-8"));
                 final FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, url);
 
-                final FullHttpResponse response = client.post(remoteAddress.address(), request);
+                final FullHttpResponse response = client.send(remoteAddress.address(), request);
                 try {
                     assertThat(response.status(), equalTo(HttpResponseStatus.BAD_REQUEST));
                     assertThat(

+ 5 - 0
server/src/main/java/org/elasticsearch/http/CorsHandler.java

@@ -68,6 +68,11 @@ import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_MAX_AGE;
 public class CorsHandler {
 
     public static final String ANY_ORIGIN = "*";
+    public static final String ORIGIN = "origin";
+    public static final String DATE = "date";
+    public static final String VARY = "vary";
+    public static final String ACCESS_CONTROL_REQUEST_METHOD = "access-control-request-method";
+    public static final String ACCESS_CONTROL_ALLOW_ORIGIN = "access-control-allow-origin";
 
     private CorsHandler() {
     }