|  | @@ -19,6 +19,8 @@ import org.elasticsearch.common.io.stream.StreamInput;
 | 
	
		
			
				|  |  |  import org.elasticsearch.common.network.HandlingTimeTracker;
 | 
	
		
			
				|  |  |  import org.elasticsearch.common.util.concurrent.AbstractRunnable;
 | 
	
		
			
				|  |  |  import org.elasticsearch.common.util.concurrent.ThreadContext;
 | 
	
		
			
				|  |  | +import org.elasticsearch.core.Releasable;
 | 
	
		
			
				|  |  | +import org.elasticsearch.core.Releasables;
 | 
	
		
			
				|  |  |  import org.elasticsearch.core.TimeValue;
 | 
	
		
			
				|  |  |  import org.elasticsearch.threadpool.ThreadPool;
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -133,34 +135,17 @@ public class InboundHandler {
 | 
	
		
			
				|  |  |                  }
 | 
	
		
			
				|  |  |                  // ignore if its null, the service logs it
 | 
	
		
			
				|  |  |                  if (responseHandler != null) {
 | 
	
		
			
				|  |  | -                    final StreamInput streamInput;
 | 
	
		
			
				|  |  |                      if (message.getContentLength() > 0 || header.getVersion().equals(Version.CURRENT) == false) {
 | 
	
		
			
				|  |  | -                        streamInput = namedWriteableStream(message.openOrGetStreamInput());
 | 
	
		
			
				|  |  | +                        final StreamInput streamInput = namedWriteableStream(message.openOrGetStreamInput());
 | 
	
		
			
				|  |  |                          assertRemoteVersion(streamInput, header.getVersion());
 | 
	
		
			
				|  |  |                          if (header.isError()) {
 | 
	
		
			
				|  |  | -                            handlerResponseError(streamInput, responseHandler);
 | 
	
		
			
				|  |  | +                            handlerResponseError(streamInput, message, responseHandler);
 | 
	
		
			
				|  |  |                          } else {
 | 
	
		
			
				|  |  | -                            handleResponse(remoteAddress, streamInput, responseHandler);
 | 
	
		
			
				|  |  | -                        }
 | 
	
		
			
				|  |  | -                        // Check the entire message has been read
 | 
	
		
			
				|  |  | -                        final int nextByte = streamInput.read();
 | 
	
		
			
				|  |  | -                        // calling read() is useful to make sure the message is fully read, even if there is an EOS marker
 | 
	
		
			
				|  |  | -                        if (nextByte != -1) {
 | 
	
		
			
				|  |  | -                            final IllegalStateException exception = new IllegalStateException(
 | 
	
		
			
				|  |  | -                                "Message not fully read (response) for requestId ["
 | 
	
		
			
				|  |  | -                                    + requestId
 | 
	
		
			
				|  |  | -                                    + "], handler ["
 | 
	
		
			
				|  |  | -                                    + responseHandler
 | 
	
		
			
				|  |  | -                                    + "], error ["
 | 
	
		
			
				|  |  | -                                    + header.isError()
 | 
	
		
			
				|  |  | -                                    + "]; resetting"
 | 
	
		
			
				|  |  | -                            );
 | 
	
		
			
				|  |  | -                            assert ignoreDeserializationErrors : exception;
 | 
	
		
			
				|  |  | -                            throw exception;
 | 
	
		
			
				|  |  | +                            handleResponse(remoteAddress, streamInput, responseHandler, message);
 | 
	
		
			
				|  |  |                          }
 | 
	
		
			
				|  |  |                      } else {
 | 
	
		
			
				|  |  |                          assert header.isError() == false;
 | 
	
		
			
				|  |  | -                        handleResponse(remoteAddress, EMPTY_STREAM_INPUT, responseHandler);
 | 
	
		
			
				|  |  | +                        handleResponse(remoteAddress, EMPTY_STREAM_INPUT, responseHandler, message);
 | 
	
		
			
				|  |  |                      }
 | 
	
		
			
				|  |  |                  }
 | 
	
		
			
				|  |  |              }
 | 
	
	
		
			
				|  | @@ -189,6 +174,26 @@ public class InboundHandler {
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +    private void verifyResponseReadFully(Header header, TransportResponseHandler<?> responseHandler, StreamInput streamInput)
 | 
	
		
			
				|  |  | +        throws IOException {
 | 
	
		
			
				|  |  | +        // Check the entire message has been read
 | 
	
		
			
				|  |  | +        final int nextByte = streamInput.read();
 | 
	
		
			
				|  |  | +        // calling read() is useful to make sure the message is fully read, even if there is an EOS marker
 | 
	
		
			
				|  |  | +        if (nextByte != -1) {
 | 
	
		
			
				|  |  | +            final IllegalStateException exception = new IllegalStateException(
 | 
	
		
			
				|  |  | +                "Message not fully read (response) for requestId ["
 | 
	
		
			
				|  |  | +                    + header.getRequestId()
 | 
	
		
			
				|  |  | +                    + "], handler ["
 | 
	
		
			
				|  |  | +                    + responseHandler
 | 
	
		
			
				|  |  | +                    + "], error ["
 | 
	
		
			
				|  |  | +                    + header.isError()
 | 
	
		
			
				|  |  | +                    + "]; resetting"
 | 
	
		
			
				|  |  | +            );
 | 
	
		
			
				|  |  | +            assert ignoreDeserializationErrors : exception;
 | 
	
		
			
				|  |  | +            throw exception;
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      private <T extends TransportRequest> void handleRequest(TcpChannel channel, Header header, InboundMessage message) throws IOException {
 | 
	
		
			
				|  |  |          final String action = header.getActionName();
 | 
	
		
			
				|  |  |          final long requestId = header.getRequestId();
 | 
	
	
		
			
				|  | @@ -335,10 +340,49 @@ public class InboundHandler {
 | 
	
		
			
				|  |  |      private <T extends TransportResponse> void handleResponse(
 | 
	
		
			
				|  |  |          InetSocketAddress remoteAddress,
 | 
	
		
			
				|  |  |          final StreamInput stream,
 | 
	
		
			
				|  |  | -        final TransportResponseHandler<T> handler
 | 
	
		
			
				|  |  | +        final TransportResponseHandler<T> handler,
 | 
	
		
			
				|  |  | +        final InboundMessage inboundMessage
 | 
	
		
			
				|  |  | +    ) {
 | 
	
		
			
				|  |  | +        final String executor = handler.executor();
 | 
	
		
			
				|  |  | +        if (ThreadPool.Names.SAME.equals(executor)) {
 | 
	
		
			
				|  |  | +            // no need to provide a buffer release here, we never escape the buffer when handling directly
 | 
	
		
			
				|  |  | +            doHandleResponse(handler, remoteAddress, stream, inboundMessage.getHeader(), () -> {});
 | 
	
		
			
				|  |  | +        } else {
 | 
	
		
			
				|  |  | +            inboundMessage.incRef();
 | 
	
		
			
				|  |  | +            // release buffer once we deserialize the message, but have a fail-safe in #onAfter below in case that didn't work out
 | 
	
		
			
				|  |  | +            final Releasable releaseBuffer = Releasables.releaseOnce(inboundMessage::decRef);
 | 
	
		
			
				|  |  | +            threadPool.executor(executor).execute(new ForkingResponseHandlerRunnable(handler, null) {
 | 
	
		
			
				|  |  | +                @Override
 | 
	
		
			
				|  |  | +                protected void doRun() {
 | 
	
		
			
				|  |  | +                    doHandleResponse(handler, remoteAddress, stream, inboundMessage.getHeader(), releaseBuffer);
 | 
	
		
			
				|  |  | +                }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                @Override
 | 
	
		
			
				|  |  | +                public void onAfter() {
 | 
	
		
			
				|  |  | +                    Releasables.closeExpectNoException(releaseBuffer);
 | 
	
		
			
				|  |  | +                }
 | 
	
		
			
				|  |  | +            });
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /**
 | 
	
		
			
				|  |  | +     *
 | 
	
		
			
				|  |  | +     * @param handler response handler
 | 
	
		
			
				|  |  | +     * @param remoteAddress remote address that the message was sent from
 | 
	
		
			
				|  |  | +     * @param stream bytes stream for reading the message
 | 
	
		
			
				|  |  | +     * @param header message header
 | 
	
		
			
				|  |  | +     * @param releaseResponseBuffer releasable that will be released once the message has been read from the {@code stream}
 | 
	
		
			
				|  |  | +     * @param <T> response message type
 | 
	
		
			
				|  |  | +     */
 | 
	
		
			
				|  |  | +    private <T extends TransportResponse> void doHandleResponse(
 | 
	
		
			
				|  |  | +        TransportResponseHandler<T> handler,
 | 
	
		
			
				|  |  | +        InetSocketAddress remoteAddress,
 | 
	
		
			
				|  |  | +        final StreamInput stream,
 | 
	
		
			
				|  |  | +        final Header header,
 | 
	
		
			
				|  |  | +        Releasable releaseResponseBuffer
 | 
	
		
			
				|  |  |      ) {
 | 
	
		
			
				|  |  |          final T response;
 | 
	
		
			
				|  |  | -        try {
 | 
	
		
			
				|  |  | +        try (releaseResponseBuffer) {
 | 
	
		
			
				|  |  |              response = handler.read(stream);
 | 
	
		
			
				|  |  |              response.remoteAddress(remoteAddress);
 | 
	
		
			
				|  |  |          } catch (Exception e) {
 | 
	
	
		
			
				|  | @@ -348,24 +392,11 @@ public class InboundHandler {
 | 
	
		
			
				|  |  |              );
 | 
	
		
			
				|  |  |              logger.warn(() -> "Failed to deserialize response from [" + remoteAddress + "]", serializationException);
 | 
	
		
			
				|  |  |              assert ignoreDeserializationErrors : e;
 | 
	
		
			
				|  |  | -            handleException(handler, serializationException);
 | 
	
		
			
				|  |  | +            doHandleException(handler, serializationException);
 | 
	
		
			
				|  |  |              return;
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  | -        final String executor = handler.executor();
 | 
	
		
			
				|  |  | -        if (ThreadPool.Names.SAME.equals(executor)) {
 | 
	
		
			
				|  |  | -            doHandleResponse(handler, response);
 | 
	
		
			
				|  |  | -        } else {
 | 
	
		
			
				|  |  | -            threadPool.executor(executor).execute(new ForkingResponseHandlerRunnable(handler, null) {
 | 
	
		
			
				|  |  | -                @Override
 | 
	
		
			
				|  |  | -                protected void doRun() {
 | 
	
		
			
				|  |  | -                    doHandleResponse(handler, response);
 | 
	
		
			
				|  |  | -                }
 | 
	
		
			
				|  |  | -            });
 | 
	
		
			
				|  |  | -        }
 | 
	
		
			
				|  |  | -    }
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -    private static <T extends TransportResponse> void doHandleResponse(TransportResponseHandler<T> handler, T response) {
 | 
	
		
			
				|  |  |          try {
 | 
	
		
			
				|  |  | +            verifyResponseReadFully(header, handler, stream);
 | 
	
		
			
				|  |  |              handler.handleResponse(response);
 | 
	
		
			
				|  |  |          } catch (Exception e) {
 | 
	
		
			
				|  |  |              doHandleException(handler, new ResponseHandlerFailureTransportException(e));
 | 
	
	
		
			
				|  | @@ -374,10 +405,11 @@ public class InboundHandler {
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    private void handlerResponseError(StreamInput stream, final TransportResponseHandler<?> handler) {
 | 
	
		
			
				|  |  | +    private void handlerResponseError(StreamInput stream, InboundMessage message, final TransportResponseHandler<?> handler) {
 | 
	
		
			
				|  |  |          Exception error;
 | 
	
		
			
				|  |  |          try {
 | 
	
		
			
				|  |  |              error = stream.readException();
 | 
	
		
			
				|  |  | +            verifyResponseReadFully(message.getHeader(), handler, stream);
 | 
	
		
			
				|  |  |          } catch (Exception e) {
 | 
	
		
			
				|  |  |              error = new TransportSerializationException(
 | 
	
		
			
				|  |  |                  "Failed to deserialize exception response from stream for handler [" + handler + "]",
 |