Browse Source

Fix handling of bad requests (#29249)

Today we have a few problems with how we handle bad requests:
 - handling requests with bad encoding
 - handling requests with invalid value for filter_path/pretty/human
 - handling requests with a garbage Content-Type header

There are two problems:
 - in every case, we give an empty response to the client
 - in most cases, we leak the byte buffer backing the request!

These problems are caused by a broader problem: poor handling preparing
the request for handling, or the channel to write to when the response
is ready. This commit addresses these issues by taking a unified
approach to all of them that ensures that:
 - we respond to the client with the exception that blew us up
 - we do not leak the byte buffer backing the request
Jason Tedor 7 years ago
parent
commit
4ef3de40bc

+ 37 - 1
modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequest.java

@@ -23,7 +23,6 @@ import io.netty.channel.Channel;
 import io.netty.handler.codec.http.FullHttpRequest;
 import io.netty.handler.codec.http.HttpHeaders;
 import io.netty.handler.codec.http.HttpMethod;
-
 import org.elasticsearch.common.bytes.BytesArray;
 import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.common.xcontent.NamedXContentRegistry;
@@ -45,6 +44,15 @@ public class Netty4HttpRequest extends RestRequest {
     private final Channel channel;
     private final BytesReference content;
 
+    /**
+     * Construct a new request.
+     *
+     * @param xContentRegistry the content registry
+     * @param request          the underlying request
+     * @param channel          the channel for the request
+     * @throws BadParameterException      if the parameters can not be decoded
+     * @throws ContentTypeHeaderException if the Content-Type header can not be parsed
+     */
     Netty4HttpRequest(NamedXContentRegistry xContentRegistry, FullHttpRequest request, Channel channel) {
         super(xContentRegistry, request.uri(), new HttpHeadersMap(request.headers()));
         this.request = request;
@@ -56,6 +64,34 @@ public class Netty4HttpRequest extends RestRequest {
         }
     }
 
+    /**
+     * Construct a new request. In contrast to
+     * {@link Netty4HttpRequest#Netty4HttpRequest(NamedXContentRegistry, Map, String, FullHttpRequest, Channel)}, the URI is not decoded so
+     * this constructor will not throw a {@link BadParameterException}.
+     *
+     * @param xContentRegistry the content registry
+     * @param params           the parameters for the request
+     * @param uri              the path for the request
+     * @param request          the underlying request
+     * @param channel          the channel for the request
+     * @throws ContentTypeHeaderException if the Content-Type header can not be parsed
+     */
+    Netty4HttpRequest(
+            final NamedXContentRegistry xContentRegistry,
+            final Map<String, String> params,
+            final String uri,
+            final FullHttpRequest request,
+            final Channel channel) {
+        super(xContentRegistry, params, uri, new HttpHeadersMap(request.headers()));
+        this.request = request;
+        this.channel = channel;
+        if (request.content().isReadable()) {
+            this.content = Netty4Utils.toBytesReference(request.content());
+        } else {
+            this.content = BytesArray.EMPTY;
+        }
+    }
+
     public FullHttpRequest request() {
         return this.request;
     }

+ 106 - 19
modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequestHandler.java

@@ -20,15 +20,21 @@
 package org.elasticsearch.http.netty4;
 
 import io.netty.buffer.Unpooled;
+import io.netty.channel.Channel;
 import io.netty.channel.ChannelHandler;
 import io.netty.channel.ChannelHandlerContext;
 import io.netty.channel.SimpleChannelInboundHandler;
 import io.netty.handler.codec.http.DefaultFullHttpRequest;
+import io.netty.handler.codec.http.DefaultHttpHeaders;
 import io.netty.handler.codec.http.FullHttpRequest;
+import io.netty.handler.codec.http.HttpHeaders;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
 import org.elasticsearch.http.netty4.pipelining.HttpPipelinedRequest;
+import org.elasticsearch.rest.RestRequest;
 import org.elasticsearch.transport.netty4.Netty4Utils;
 
+import java.util.Collections;
+
 @ChannelHandler.Sharable
 class Netty4HttpRequestHandler extends SimpleChannelInboundHandler<Object> {
 
@@ -56,32 +62,113 @@ class Netty4HttpRequestHandler extends SimpleChannelInboundHandler<Object> {
             request = (FullHttpRequest) msg;
         }
 
-        final FullHttpRequest copy =
+        boolean success = false;
+        try {
+
+            final FullHttpRequest copy =
+                    new DefaultFullHttpRequest(
+                            request.protocolVersion(),
+                            request.method(),
+                            request.uri(),
+                            Unpooled.copiedBuffer(request.content()),
+                            request.headers(),
+                            request.trailingHeaders());
+
+            Exception badRequestCause = null;
+
+            /*
+             * We want to create a REST request from the incoming request from Netty. However, creating this request could fail if there
+             * are incorrectly encoded parameters, or the Content-Type header is invalid. If one of these specific failures occurs, we
+             * attempt to create a REST request again without the input that caused the exception (e.g., we remove the Content-Type header,
+             * or skip decoding the parameters). Once we have a request in hand, we then dispatch the request as a bad request with the
+             * underlying exception that caused us to treat the request as bad.
+             */
+            final Netty4HttpRequest httpRequest;
+            {
+                Netty4HttpRequest innerHttpRequest;
+                try {
+                    innerHttpRequest = new Netty4HttpRequest(serverTransport.xContentRegistry, copy, ctx.channel());
+                } catch (final RestRequest.ContentTypeHeaderException e) {
+                    badRequestCause = e;
+                    innerHttpRequest = requestWithoutContentTypeHeader(copy, ctx.channel(), badRequestCause);
+                } catch (final RestRequest.BadParameterException e) {
+                    badRequestCause = e;
+                    innerHttpRequest = requestWithoutParameters(copy, ctx.channel());
+                }
+                httpRequest = innerHttpRequest;
+            }
+
+            /*
+             * We now want to create a channel used to send the response on. However, creating this channel can fail if there are invalid
+             * parameter values for any of the filter_path, human, or pretty parameters. We detect these specific failures via an
+             * IllegalArgumentException from the channel constructor and then attempt to create a new channel that bypasses parsing of these
+             * parameter values.
+             */
+            final Netty4HttpChannel channel;
+            {
+                Netty4HttpChannel innerChannel;
+                try {
+                    innerChannel =
+                            new Netty4HttpChannel(serverTransport, httpRequest, pipelinedRequest, detailedErrorsEnabled, threadContext);
+                } catch (final IllegalArgumentException e) {
+                    if (badRequestCause == null) {
+                        badRequestCause = e;
+                    } else {
+                        badRequestCause.addSuppressed(e);
+                    }
+                    final Netty4HttpRequest innerRequest =
+                            new Netty4HttpRequest(
+                                    serverTransport.xContentRegistry,
+                                    Collections.emptyMap(), // we are going to dispatch the request as a bad request, drop all parameters
+                                    copy.uri(),
+                                    copy,
+                                    ctx.channel());
+                    innerChannel =
+                            new Netty4HttpChannel(serverTransport, innerRequest, pipelinedRequest, detailedErrorsEnabled, threadContext);
+                }
+                channel = innerChannel;
+            }
+
+            if (request.decoderResult().isFailure()) {
+                serverTransport.dispatchBadRequest(httpRequest, channel, request.decoderResult().cause());
+            } else if (badRequestCause != null) {
+                serverTransport.dispatchBadRequest(httpRequest, channel, badRequestCause);
+            } else {
+                serverTransport.dispatchRequest(httpRequest, channel);
+            }
+            success = true;
+        } finally {
+            // the request is otherwise released in case of dispatch
+            if (success == false && pipelinedRequest != null) {
+                pipelinedRequest.release();
+            }
+        }
+    }
+
+    private Netty4HttpRequest requestWithoutContentTypeHeader(
+            final FullHttpRequest request, final Channel channel, final Exception badRequestCause) {
+        final HttpHeaders headersWithoutContentTypeHeader = new DefaultHttpHeaders();
+        headersWithoutContentTypeHeader.add(request.headers());
+        headersWithoutContentTypeHeader.remove("Content-Type");
+        final FullHttpRequest requestWithoutContentTypeHeader =
                 new DefaultFullHttpRequest(
                         request.protocolVersion(),
                         request.method(),
                         request.uri(),
-                        Unpooled.copiedBuffer(request.content()),
-                        request.headers(),
-                        request.trailingHeaders());
-        final Netty4HttpRequest httpRequest;
+                        request.content(),
+                        headersWithoutContentTypeHeader, // remove the Content-Type header so as to not parse it again
+                        request.trailingHeaders()); // Content-Type can not be a trailing header
         try {
-            httpRequest = new Netty4HttpRequest(serverTransport.xContentRegistry, copy, ctx.channel());
-        } catch (Exception ex) {
-            if (pipelinedRequest != null) {
-                pipelinedRequest.release();
-            }
-            throw ex;
+            return new Netty4HttpRequest(serverTransport.xContentRegistry, requestWithoutContentTypeHeader, channel);
+        } catch (final RestRequest.BadParameterException e) {
+            badRequestCause.addSuppressed(e);
+            return requestWithoutParameters(requestWithoutContentTypeHeader, channel);
         }
-        final Netty4HttpChannel channel =
-                new Netty4HttpChannel(serverTransport, httpRequest, pipelinedRequest, detailedErrorsEnabled, threadContext);
+    }
 
-        if (request.decoderResult().isSuccess()) {
-            serverTransport.dispatchRequest(httpRequest, channel);
-        } else {
-            assert request.decoderResult().isFailure();
-            serverTransport.dispatchBadRequest(httpRequest, channel, request.decoderResult().cause());
-        }
+    private Netty4HttpRequest requestWithoutParameters(final FullHttpRequest request, final Channel channel) {
+        // remove all parameters as at least one is incorrectly encoded
+        return new Netty4HttpRequest(serverTransport.xContentRegistry, Collections.emptyMap(), request.uri(), request, channel);
     }
 
     @Override

+ 108 - 0
modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4BadRequestTests.java

@@ -0,0 +1,108 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.http.netty4;
+
+import io.netty.handler.codec.http.FullHttpResponse;
+import org.elasticsearch.ElasticsearchException;
+import org.elasticsearch.common.network.NetworkService;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.transport.TransportAddress;
+import org.elasticsearch.common.util.MockBigArrays;
+import org.elasticsearch.common.util.MockPageCacheRecycler;
+import org.elasticsearch.common.util.concurrent.ThreadContext;
+import org.elasticsearch.http.HttpServerTransport;
+import org.elasticsearch.indices.breaker.NoneCircuitBreakerService;
+import org.elasticsearch.rest.BytesRestResponse;
+import org.elasticsearch.rest.RestChannel;
+import org.elasticsearch.rest.RestRequest;
+import org.elasticsearch.rest.RestStatus;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.threadpool.TestThreadPool;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.junit.After;
+import org.junit.Before;
+
+import java.io.IOException;
+import java.io.UncheckedIOException;
+import java.util.Collection;
+import java.util.Collections;
+
+import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.hasSize;
+
+public class Netty4BadRequestTests extends ESTestCase {
+
+    private NetworkService networkService;
+    private MockBigArrays bigArrays;
+    private ThreadPool threadPool;
+
+    @Before
+    public void setup() throws Exception {
+        networkService = new NetworkService(Collections.emptyList());
+        bigArrays = new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService());
+        threadPool = new TestThreadPool("test");
+    }
+
+    @After
+    public void shutdown() throws Exception {
+        terminate(threadPool);
+    }
+
+    public void testBadParameterEncoding() throws Exception {
+        final HttpServerTransport.Dispatcher dispatcher = new HttpServerTransport.Dispatcher() {
+            @Override
+            public void dispatchRequest(RestRequest request, RestChannel channel, ThreadContext threadContext) {
+                fail();
+            }
+
+            @Override
+            public void dispatchBadRequest(RestRequest request, RestChannel channel, ThreadContext threadContext, Throwable cause) {
+                try {
+                    final Exception e = cause instanceof Exception ? (Exception) cause : new ElasticsearchException(cause);
+                    channel.sendResponse(new BytesRestResponse(channel, RestStatus.BAD_REQUEST, e));
+                } catch (final IOException e) {
+                    throw new UncheckedIOException(e);
+                }
+            }
+        };
+
+        try (HttpServerTransport httpServerTransport =
+                     new Netty4HttpServerTransport(Settings.EMPTY, networkService, bigArrays, threadPool, xContentRegistry(), dispatcher)) {
+            httpServerTransport.start();
+            final TransportAddress transportAddress = randomFrom(httpServerTransport.boundAddress().boundAddresses());
+
+            try (Netty4HttpClient nettyHttpClient = new Netty4HttpClient()) {
+                final Collection<FullHttpResponse> responses =
+                        nettyHttpClient.get(transportAddress.address(), "/_cluster/settings?pretty=%");
+                assertThat(responses, hasSize(1));
+                assertThat(responses.iterator().next().status().code(), equalTo(400));
+                final Collection<String> responseBodies = Netty4HttpClient.returnHttpResponseBodies(responses);
+                assertThat(responseBodies, hasSize(1));
+                assertThat(responseBodies.iterator().next(), containsString("\"type\":\"bad_parameter_exception\""));
+                assertThat(
+                        responseBodies.iterator().next(),
+                        containsString(
+                                "\"reason\":\"java.lang.IllegalArgumentException: unterminated escape sequence at end of string: %\""));
+            }
+        }
+    }
+
+}

+ 2 - 1
modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpChannelTests.java

@@ -330,7 +330,8 @@ public class Netty4HttpChannelTests extends ESTestCase {
             }
             httpRequest.headers().add(HttpHeaderNames.HOST, host);
             final WriteCapturingChannel writeCapturingChannel = new WriteCapturingChannel();
-            final Netty4HttpRequest request = new Netty4HttpRequest(xContentRegistry(), httpRequest, writeCapturingChannel);
+            final Netty4HttpRequest request =
+                    new Netty4HttpRequest(xContentRegistry(), httpRequest, writeCapturingChannel);
 
             Netty4HttpChannel channel =
                     new Netty4HttpChannel(httpServerTransport, request, null, randomBoolean(), threadPool.getThreadContext());

+ 26 - 0
modules/transport-netty4/src/test/java/org/elasticsearch/rest/Netty4BadRequestIT.java

@@ -19,6 +19,7 @@
 
 package org.elasticsearch.rest;
 
+import org.apache.http.message.BasicHeader;
 import org.elasticsearch.client.Response;
 import org.elasticsearch.client.ResponseException;
 import org.elasticsearch.common.settings.Setting;
@@ -74,4 +75,29 @@ public class Netty4BadRequestIT extends ESRestTestCase {
         assertThat(e, hasToString(containsString("too_long_frame_exception")));
         assertThat(e, hasToString(matches("An HTTP line is larger than \\d+ bytes")));
     }
+
+    public void testInvalidParameterValue() throws IOException {
+        final ResponseException e = expectThrows(
+                ResponseException.class,
+                () -> client().performRequest("GET", "/_cluster/settings", Collections.singletonMap("pretty", "neither-true-nor-false")));
+        final Response response = e.getResponse();
+        assertThat(response.getStatusLine().getStatusCode(), equalTo(400));
+        final ObjectPath objectPath = ObjectPath.createFromResponse(response);
+        final Map<String, Object> map = objectPath.evaluate("error");
+        assertThat(map.get("type"), equalTo("illegal_argument_exception"));
+        assertThat(map.get("reason"), equalTo("Failed to parse value [neither-true-nor-false] as only [true] or [false] are allowed."));
+    }
+
+    public void testInvalidHeaderValue() throws IOException {
+        final BasicHeader header = new BasicHeader("Content-Type", "\t");
+        final ResponseException e =
+                expectThrows(ResponseException.class, () -> client().performRequest("GET", "/_cluster/settings", header));
+        final Response response = e.getResponse();
+        assertThat(response.getStatusLine().getStatusCode(), equalTo(400));
+        final ObjectPath objectPath = ObjectPath.createFromResponse(response);
+        final Map<String, Object> map = objectPath.evaluate("error");
+        assertThat(map.get("type"), equalTo("content_type_header_exception"));
+        assertThat(map.get("reason"), equalTo("java.lang.IllegalArgumentException: invalid Content-Type header []"));
+    }
+
 }

+ 7 - 0
server/src/main/java/org/elasticsearch/rest/AbstractRestChannel.java

@@ -48,6 +48,13 @@ public abstract class AbstractRestChannel implements RestChannel {
 
     private BytesStreamOutput bytesOut;
 
+    /**
+     * Construct a channel for handling the request.
+     *
+     * @param request               the request
+     * @param detailedErrorsEnabled if detailed errors should be reported to the channel
+     * @throws IllegalArgumentException if parsing the pretty or human parameters fails
+     */
     protected AbstractRestChannel(RestRequest request, boolean detailedErrorsEnabled) {
         this.request = request;
         this.detailedErrorsEnabled = detailedErrorsEnabled;

+ 68 - 32
server/src/main/java/org/elasticsearch/rest/RestRequest.java

@@ -64,49 +64,69 @@ public abstract class RestRequest implements ToXContent.Params {
     private final SetOnce<XContentType> xContentType = new SetOnce<>();
 
     /**
-     * Creates a new RestRequest
-     * @param xContentRegistry the xContentRegistry to use when parsing XContent
-     * @param uri the URI of the request that potentially contains request parameters
-     * @param headers a map of the headers. This map should implement a Case-Insensitive hashing for keys as HTTP header names are case
-     *                insensitive
+     * Creates a new REST request.
+     *
+     * @param xContentRegistry the content registry
+     * @param uri              the raw URI that will be parsed into the path and the parameters
+     * @param headers          a map of the header; this map should implement a case-insensitive lookup
+     * @throws BadParameterException      if the parameters can not be decoded
+     * @throws ContentTypeHeaderException if the Content-Type header can not be parsed
      */
-    public RestRequest(NamedXContentRegistry xContentRegistry, String uri, Map<String, List<String>> headers) {
-        this.xContentRegistry = xContentRegistry;
+    public RestRequest(final NamedXContentRegistry xContentRegistry, final String uri, final Map<String, List<String>> headers) {
+        this(xContentRegistry, params(uri), path(uri), headers);
+    }
+
+    private static Map<String, String> params(final String uri) {
         final Map<String, String> params = new HashMap<>();
-        int pathEndPos = uri.indexOf('?');
-        if (pathEndPos < 0) {
-            this.rawPath = uri;
-        } else {
-            this.rawPath = uri.substring(0, pathEndPos);
-            RestUtils.decodeQueryString(uri, pathEndPos + 1, params);
+        int index = uri.indexOf('?');
+        if (index >= 0) {
+            try {
+                RestUtils.decodeQueryString(uri, index + 1, params);
+            } catch (final IllegalArgumentException e) {
+                throw new BadParameterException(e);
+            }
         }
-        this.params = params;
-        this.headers = Collections.unmodifiableMap(headers);
-        final List<String> contentType = getAllHeaderValues("Content-Type");
-        final XContentType xContentType = parseContentType(contentType);
-        if (xContentType != null) {
-            this.xContentType.set(xContentType);
+        return params;
+    }
+
+    private static String path(final String uri) {
+        final int index = uri.indexOf('?');
+        if (index >= 0) {
+            return uri.substring(0, index);
+        } else {
+            return uri;
         }
     }
 
     /**
-     * Creates a new RestRequest
-     * @param xContentRegistry the xContentRegistry to use when parsing XContent
-     * @param params the parameters of the request
-     * @param path the path of the request. This should not contain request parameters
-     * @param headers a map of the headers. This map should implement a Case-Insensitive hashing for keys as HTTP header names are case
-     *                insensitive
+     * Creates a new REST request. In contrast to
+     * {@link RestRequest#RestRequest(NamedXContentRegistry, Map, String, Map)}, the path is not decoded so this constructor will not throw
+     * a {@link BadParameterException}.
+     *
+     * @param xContentRegistry the content registry
+     * @param params           the request parameters
+     * @param path             the raw path (which is not parsed)
+     * @param headers          a map of the header; this map should implement a case-insensitive lookup
+     * @throws ContentTypeHeaderException if the Content-Type header can not be parsed
      */
-    public RestRequest(NamedXContentRegistry xContentRegistry, Map<String, String> params, String path, Map<String, List<String>> headers) {
+    public RestRequest(
+            final NamedXContentRegistry xContentRegistry,
+            final Map<String, String> params,
+            final String path,
+            final Map<String, List<String>> headers) {
+        final XContentType xContentType;
+        try {
+            xContentType = parseContentType(headers.get("Content-Type"));
+        } catch (final IllegalArgumentException e) {
+            throw new ContentTypeHeaderException(e);
+        }
+        if (xContentType != null) {
+            this.xContentType.set(xContentType);
+        }
         this.xContentRegistry = xContentRegistry;
         this.params = params;
         this.rawPath = path;
         this.headers = Collections.unmodifiableMap(headers);
-        final List<String> contentType = getAllHeaderValues("Content-Type");
-        final XContentType xContentType = parseContentType(contentType);
-        if (xContentType != null) {
-            this.xContentType.set(xContentType);
-        }
     }
 
     public enum Method {
@@ -423,7 +443,7 @@ public abstract class RestRequest implements ToXContent.Params {
      * Parses the given content type string for the media type. This method currently ignores parameters.
      */
     // TODO stop ignoring parameters such as charset...
-    private static XContentType parseContentType(List<String> header) {
+    public static XContentType parseContentType(List<String> header) {
         if (header == null || header.isEmpty()) {
             return null;
         } else if (header.size() > 1) {
@@ -444,4 +464,20 @@ public abstract class RestRequest implements ToXContent.Params {
         throw new IllegalArgumentException("empty Content-Type header");
     }
 
+    public static class ContentTypeHeaderException extends RuntimeException {
+
+        ContentTypeHeaderException(final IllegalArgumentException cause) {
+            super(cause);
+        }
+
+    }
+
+    public static class BadParameterException extends RuntimeException {
+
+        BadParameterException(final IllegalArgumentException cause) {
+            super(cause);
+        }
+
+    }
+
 }

+ 22 - 21
server/src/test/java/org/elasticsearch/rest/BytesRestResponseTests.java

@@ -165,27 +165,28 @@ public class BytesRestResponseTests extends ESTestCase {
 
     public void testResponseWhenPathContainsEncodingError() throws IOException {
         final String path = "%a";
-        final RestRequest request = new RestRequest(NamedXContentRegistry.EMPTY, Collections.emptyMap(), path, Collections.emptyMap()) {
-            @Override
-            public Method method() {
-                return null;
-            }
-
-            @Override
-            public String uri() {
-                return null;
-            }
-
-            @Override
-            public boolean hasContent() {
-                return false;
-            }
-
-            @Override
-            public BytesReference content() {
-                return null;
-            }
-        };
+        final RestRequest request =
+                new RestRequest(NamedXContentRegistry.EMPTY, Collections.emptyMap(), path, Collections.emptyMap()) {
+                    @Override
+                    public Method method() {
+                        return null;
+                    }
+
+                    @Override
+                    public String uri() {
+                        return null;
+                    }
+
+                    @Override
+                    public boolean hasContent() {
+                        return false;
+                    }
+
+                    @Override
+                    public BytesReference content() {
+                        return null;
+                    }
+                };
         final IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> RestUtils.decodeComponent(request.rawPath()));
         final RestChannel channel = new DetailedExceptionRestChannel(request);
         // if we try to decode the path, this will throw an IllegalArgumentException again

+ 3 - 2
server/src/test/java/org/elasticsearch/rest/RestControllerTests.java

@@ -367,9 +367,10 @@ public class RestControllerTests extends ESTestCase {
     public void testDispatchWithContentStream() {
         final String mimeType = randomFrom("application/json", "application/smile");
         String content = randomAlphaOfLengthBetween(1, BREAKER_LIMIT.bytesAsInt());
+        final List<String> contentTypeHeader = Collections.singletonList(mimeType);
         FakeRestRequest fakeRestRequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY)
-            .withContent(new BytesArray(content), null).withPath("/foo")
-            .withHeaders(Collections.singletonMap("Content-Type", Collections.singletonList(mimeType))).build();
+            .withContent(new BytesArray(content), RestRequest.parseContentType(contentTypeHeader)).withPath("/foo")
+            .withHeaders(Collections.singletonMap("Content-Type", contentTypeHeader)).build();
         AssertingChannel channel = new AssertingChannel(fakeRestRequest, true, RestStatus.OK);
         restController.registerHandler(RestRequest.Method.GET, "/foo", new RestHandler() {
             @Override

+ 17 - 6
server/src/test/java/org/elasticsearch/rest/RestRequestTests.java

@@ -38,6 +38,8 @@ import java.util.concurrent.atomic.AtomicReference;
 
 import static java.util.Collections.emptyMap;
 import static java.util.Collections.singletonMap;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.instanceOf;
 
 public class RestRequestTests extends ESTestCase {
     public void testContentParser() throws IOException {
@@ -130,9 +132,15 @@ public class RestRequestTests extends ESTestCase {
 
     public void testMalformedContentTypeHeader() {
         final String type = randomFrom("text", "text/:ain; charset=utf-8", "text/plain\";charset=utf-8", ":", "/", "t:/plain");
-        IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> new ContentRestRequest("", Collections.emptyMap(),
-            Collections.singletonMap("Content-Type", Collections.singletonList(type))));
-        assertEquals("invalid Content-Type header [" + type + "]", e.getMessage());
+        final RestRequest.ContentTypeHeaderException e = expectThrows(
+                RestRequest.ContentTypeHeaderException.class,
+                () -> {
+                    final Map<String, List<String>> headers = Collections.singletonMap("Content-Type", Collections.singletonList(type));
+                    new ContentRestRequest("", Collections.emptyMap(), headers);
+                });
+        assertNotNull(e.getCause());
+        assertThat(e.getCause(), instanceOf(IllegalArgumentException.class));
+        assertThat(e.getMessage(), equalTo("java.lang.IllegalArgumentException: invalid Content-Type header [" + type + "]"));
     }
 
     public void testNoContentTypeHeader() {
@@ -142,9 +150,12 @@ public class RestRequestTests extends ESTestCase {
 
     public void testMultipleContentTypeHeaders() {
         List<String> headers = new ArrayList<>(randomUnique(() -> randomAlphaOfLengthBetween(1, 16), randomIntBetween(2, 10)));
-        IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> new ContentRestRequest("", Collections.emptyMap(),
-            Collections.singletonMap("Content-Type", headers)));
-        assertEquals("only one Content-Type header should be provided", e.getMessage());
+        final RestRequest.ContentTypeHeaderException e = expectThrows(
+                RestRequest.ContentTypeHeaderException.class,
+                () -> new ContentRestRequest("", Collections.emptyMap(), Collections.singletonMap("Content-Type", headers)));
+        assertNotNull(e.getCause());
+        assertThat(e.getCause(), instanceOf((IllegalArgumentException.class)));
+        assertThat(e.getMessage(), equalTo("java.lang.IllegalArgumentException: only one Content-Type header should be provided"));
     }
 
     public void testRequiredContent() {

+ 2 - 2
test/framework/src/main/java/org/elasticsearch/test/rest/FakeRestRequest.java

@@ -40,8 +40,8 @@ public class FakeRestRequest extends RestRequest {
         this(NamedXContentRegistry.EMPTY, new HashMap<>(), new HashMap<>(), null, Method.GET, "/", null);
     }
 
-    private FakeRestRequest(NamedXContentRegistry xContentRegistry, Map<String, List<String>> headers, Map<String, String> params,
-                            BytesReference content, Method method, String path, SocketAddress remoteAddress) {
+    private FakeRestRequest(NamedXContentRegistry xContentRegistry, Map<String, List<String>> headers,
+                            Map<String, String> params, BytesReference content, Method method, String path, SocketAddress remoteAddress) {
         super(xContentRegistry, params, path, headers);
         this.content = content;
         this.method = method;