Browse Source

Move outbound message handling to OutboundHandler (#40336)

Currently there are some components of message serializer and sending
that still occur in TcpTransport. This commit makes it possible to
send a message without the TcpTransport by moving all of the remaining
application logic to the OutboundHandler. Additionally, it adds unit
tests to ensure that this logic works as expected.
Tim Brooks 6 năm trước cách đây
mục cha
commit
9e11dfc7f0

+ 1 - 1
modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4TransportIT.java

@@ -111,7 +111,7 @@ public class Netty4TransportIT extends ESNetty4IntegTestCase {
         }
 
         @Override
-        protected void handleRequest(TcpChannel channel, InboundMessage.RequestMessage request, int messageLengthBytes) throws IOException {
+        protected void handleRequest(TcpChannel channel, InboundMessage.Request request, int messageLengthBytes) throws IOException {
             super.handleRequest(channel, request, messageLengthBytes);
             channelProfileName = TransportSettings.DEFAULT_PROFILE;
         }

+ 1 - 1
plugins/transport-nio/src/test/java/org/elasticsearch/transport/nio/NioTransportIT.java

@@ -113,7 +113,7 @@ public class NioTransportIT extends NioIntegTestCase {
         }
 
         @Override
-        protected void handleRequest(TcpChannel channel, InboundMessage.RequestMessage request, int messageLengthBytes) throws IOException {
+        protected void handleRequest(TcpChannel channel, InboundMessage.Request request, int messageLengthBytes) throws IOException {
             super.handleRequest(channel, request, messageLengthBytes);
             channelProfileName = TransportSettings.DEFAULT_PROFILE;
         }

+ 7 - 7
server/src/main/java/org/elasticsearch/transport/InboundMessage.java

@@ -101,9 +101,9 @@ public abstract class InboundMessage extends NetworkMessage implements Closeable
                 if (TransportStatus.isRequest(status)) {
                     final Set<String> features = Collections.unmodifiableSet(new TreeSet<>(Arrays.asList(streamInput.readStringArray())));
                     final String action = streamInput.readString();
-                    message = new RequestMessage(threadContext, remoteVersion, status, requestId, action, features, streamInput);
+                    message = new Request(threadContext, remoteVersion, status, requestId, action, features, streamInput);
                 } else {
-                    message = new ResponseMessage(threadContext, remoteVersion, status, requestId, streamInput);
+                    message = new Response(threadContext, remoteVersion, status, requestId, streamInput);
                 }
                 success = true;
                 return message;
@@ -133,13 +133,13 @@ public abstract class InboundMessage extends NetworkMessage implements Closeable
         }
     }
 
-    public static class RequestMessage extends InboundMessage {
+    public static class Request extends InboundMessage {
 
         private final String actionName;
         private final Set<String> features;
 
-        RequestMessage(ThreadContext threadContext, Version version, byte status, long requestId, String actionName, Set<String> features,
-                       StreamInput streamInput) {
+        Request(ThreadContext threadContext, Version version, byte status, long requestId, String actionName, Set<String> features,
+                StreamInput streamInput) {
             super(threadContext, version, status, requestId, streamInput);
             this.actionName = actionName;
             this.features = features;
@@ -154,9 +154,9 @@ public abstract class InboundMessage extends NetworkMessage implements Closeable
         }
     }
 
-    public static class ResponseMessage extends InboundMessage {
+    public static class Response extends InboundMessage {
 
-        ResponseMessage(ThreadContext threadContext, Version version, byte status, long requestId, StreamInput streamInput) {
+        Response(ThreadContext threadContext, Version version, byte status, long requestId, StreamInput streamInput) {
             super(threadContext, version, status, requestId, streamInput);
         }
     }

+ 71 - 10
server/src/main/java/org/elasticsearch/transport/OutboundHandler.java

@@ -22,8 +22,10 @@ package org.elasticsearch.transport;
 import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 import org.apache.logging.log4j.message.ParameterizedMessage;
+import org.elasticsearch.Version;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.NotifyOnceListener;
+import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.common.CheckedSupplier;
 import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.common.io.stream.ReleasableBytesStreamOutput;
@@ -32,49 +34,100 @@ import org.elasticsearch.common.lease.Releasables;
 import org.elasticsearch.common.metrics.MeanMetric;
 import org.elasticsearch.common.network.CloseableChannel;
 import org.elasticsearch.common.transport.NetworkExceptionHelper;
+import org.elasticsearch.common.transport.TransportAddress;
 import org.elasticsearch.common.util.BigArrays;
 import org.elasticsearch.core.internal.io.IOUtils;
 import org.elasticsearch.threadpool.ThreadPool;
 
 import java.io.IOException;
+import java.util.Set;
 
 final class OutboundHandler {
 
     private static final Logger logger = LogManager.getLogger(OutboundHandler.class);
 
     private final MeanMetric transmittedBytesMetric = new MeanMetric();
+
+    private final String nodeName;
+    private final Version version;
+    private final String[] features;
     private final ThreadPool threadPool;
     private final BigArrays bigArrays;
     private final TransportLogger transportLogger;
+    private volatile TransportMessageListener messageListener = TransportMessageListener.NOOP_LISTENER;
 
-    OutboundHandler(ThreadPool threadPool, BigArrays bigArrays, TransportLogger transportLogger) {
+    OutboundHandler(String nodeName, Version version, String[] features, ThreadPool threadPool, BigArrays bigArrays,
+                    TransportLogger transportLogger) {
+        this.nodeName = nodeName;
+        this.version = version;
+        this.features = features;
         this.threadPool = threadPool;
         this.bigArrays = bigArrays;
         this.transportLogger = transportLogger;
     }
 
     void sendBytes(TcpChannel channel, BytesReference bytes, ActionListener<Void> listener) {
-        channel.getChannelStats().markAccessed(threadPool.relativeTimeInMillis());
         SendContext sendContext = new SendContext(channel, () -> bytes, listener);
         try {
-            internalSendMessage(channel, sendContext);
+            internalSend(channel, sendContext);
         } catch (IOException e) {
             // This should not happen as the bytes are already serialized
             throw new AssertionError(e);
         }
     }
 
-    void sendMessage(TcpChannel channel, OutboundMessage networkMessage, ActionListener<Void> listener) throws IOException {
-        channel.getChannelStats().markAccessed(threadPool.relativeTimeInMillis());
-        MessageSerializer serializer = new MessageSerializer(networkMessage, bigArrays);
-        SendContext sendContext = new SendContext(channel, serializer, listener, serializer);
-        internalSendMessage(channel, sendContext);
+    /**
+     * Sends the request to the given channel. This method should be used to send {@link TransportRequest}
+     * objects back to the caller.
+     */
+    void sendRequest(final DiscoveryNode node, final TcpChannel channel, final long requestId, final String action,
+                     final TransportRequest request, final TransportRequestOptions options, final Version channelVersion,
+                     final boolean compressRequest, final boolean isHandshake) throws IOException, TransportException {
+        Version version = Version.min(this.version, channelVersion);
+        OutboundMessage.Request message = new OutboundMessage.Request(threadPool.getThreadContext(), features, request, version, action,
+            requestId, isHandshake, compressRequest);
+        ActionListener<Void> listener = ActionListener.wrap(() ->
+            messageListener.onRequestSent(node, requestId, action, request, options));
+        sendMessage(channel, message, listener);
+    }
+
+    /**
+     * Sends the response to the given channel. This method should be used to send {@link TransportResponse}
+     * objects back to the caller.
+     *
+     * @see #sendErrorResponse(Version, Set, TcpChannel, long, String, Exception) for sending error responses
+     */
+    void sendResponse(final Version nodeVersion, final Set<String> features, final TcpChannel channel,
+                      final long requestId, final String action, final TransportResponse response,
+                      final boolean compress, final boolean isHandshake) throws IOException {
+        Version version = Version.min(this.version, nodeVersion);
+        OutboundMessage.Response message = new OutboundMessage.Response(threadPool.getThreadContext(), features, response, version,
+            requestId, isHandshake, compress);
+        ActionListener<Void> listener = ActionListener.wrap(() -> messageListener.onResponseSent(requestId, action, response));
+        sendMessage(channel, message, listener);
     }
 
     /**
-     * sends a message to the given channel, using the given callbacks.
+     * Sends back an error response to the caller via the given channel
      */
-    private void internalSendMessage(TcpChannel channel,  SendContext sendContext) throws IOException {
+    void sendErrorResponse(final Version nodeVersion, final Set<String> features, final TcpChannel channel, final long requestId,
+                           final String action, final Exception error) throws IOException {
+        Version version = Version.min(this.version, nodeVersion);
+        TransportAddress address = new TransportAddress(channel.getLocalAddress());
+        RemoteTransportException tx = new RemoteTransportException(nodeName, address, action, error);
+        OutboundMessage.Response message = new OutboundMessage.Response(threadPool.getThreadContext(), features, tx, version, requestId,
+            false, false);
+        ActionListener<Void> listener = ActionListener.wrap(() -> messageListener.onResponseSent(requestId, action, error));
+        sendMessage(channel, message, listener);
+    }
+
+    private void sendMessage(TcpChannel channel, OutboundMessage networkMessage, ActionListener<Void> listener) throws IOException {
+        MessageSerializer serializer = new MessageSerializer(networkMessage, bigArrays);
+        SendContext sendContext = new SendContext(channel, serializer, listener, serializer);
+        internalSend(channel, sendContext);
+    }
+
+    private void internalSend(TcpChannel channel, SendContext sendContext) throws IOException {
         channel.getChannelStats().markAccessed(threadPool.relativeTimeInMillis());
         BytesReference reference = sendContext.get();
         try {
@@ -91,6 +144,14 @@ final class OutboundHandler {
         return transmittedBytesMetric;
     }
 
+    void setMessageListener(TransportMessageListener listener) {
+        if (messageListener == TransportMessageListener.NOOP_LISTENER) {
+            messageListener = listener;
+        } else {
+            throw new IllegalStateException("Cannot set message listener twice");
+        }
+    }
+
     private static class MessageSerializer implements CheckedSupplier<BytesReference, IOException>, Releasable {
 
         private final OutboundMessage message;

+ 24 - 103
server/src/main/java/org/elasticsearch/transport/TcpTransport.java

@@ -106,19 +106,15 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
     private static final long NINETY_PER_HEAP_SIZE = (long) (JvmInfo.jvmInfo().getMem().getHeapMax().getBytes() * 0.9);
     private static final BytesReference EMPTY_BYTES_REFERENCE = new BytesArray(new byte[0]);
 
-    private final String[] features;
-
     protected final Settings settings;
     private final CircuitBreakerService circuitBreakerService;
-    private final Version version;
     protected final ThreadPool threadPool;
     protected final BigArrays bigArrays;
     protected final PageCacheRecycler pageCacheRecycler;
     protected final NetworkService networkService;
     protected final Set<ProfileSettings> profileSettings;
 
-    private static final TransportMessageListener NOOP_LISTENER = new TransportMessageListener() {};
-    private volatile TransportMessageListener messageListener = NOOP_LISTENER;
+    private volatile TransportMessageListener messageListener = TransportMessageListener.NOOP_LISTENER;
 
     private final ConcurrentMap<String, BoundTransportAddress> profileBoundAddresses = newConcurrentMap();
     private final Map<String, List<TcpServerChannel>> serverChannels = newConcurrentMap();
@@ -137,34 +133,23 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
     private final TransportKeepAlive keepAlive;
     private final InboundMessage.Reader reader;
     private final OutboundHandler outboundHandler;
-    private final String nodeName;
 
     public TcpTransport(Settings settings, Version version, ThreadPool threadPool, PageCacheRecycler pageCacheRecycler,
                         CircuitBreakerService circuitBreakerService, NamedWriteableRegistry namedWriteableRegistry,
                         NetworkService networkService) {
         this.settings = settings;
         this.profileSettings = getProfileSettings(settings);
-        this.version = version;
         this.threadPool = threadPool;
         this.bigArrays = new BigArrays(pageCacheRecycler, circuitBreakerService, CircuitBreaker.IN_FLIGHT_REQUESTS);
         this.pageCacheRecycler = pageCacheRecycler;
         this.circuitBreakerService = circuitBreakerService;
         this.networkService = networkService;
         this.transportLogger = new TransportLogger();
-        this.outboundHandler = new OutboundHandler(threadPool, bigArrays, transportLogger);
-        this.handshaker = new TransportHandshaker(version, threadPool,
-            (node, channel, requestId, v) -> sendRequestToChannel(node, channel, requestId,
-                TransportHandshaker.HANDSHAKE_ACTION_NAME, new TransportHandshaker.HandshakeRequest(version),
-                TransportRequestOptions.EMPTY, v, false, true),
-            (v, features, channel, response, requestId) -> sendResponse(v, features, channel, response, requestId,
-                TransportHandshaker.HANDSHAKE_ACTION_NAME, false, true));
-        this.keepAlive = new TransportKeepAlive(threadPool, this.outboundHandler::sendBytes);
-        this.reader = new InboundMessage.Reader(version, namedWriteableRegistry, threadPool.getThreadContext());
-        this.nodeName = Node.NODE_NAME_SETTING.get(settings);
-
+        String nodeName = Node.NODE_NAME_SETTING.get(settings);
         final Settings defaultFeatures = TransportSettings.DEFAULT_FEATURES_SETTING.get(settings);
+        String[] features;
         if (defaultFeatures == null) {
-            this.features = new String[0];
+            features = new String[0];
         } else {
             defaultFeatures.names().forEach(key -> {
                 if (Booleans.parseBoolean(defaultFeatures.get(key)) == false) {
@@ -172,8 +157,18 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
                 }
             });
             // use a sorted set to present the features in a consistent order
-            this.features = new TreeSet<>(defaultFeatures.names()).toArray(new String[defaultFeatures.names().size()]);
+            features = new TreeSet<>(defaultFeatures.names()).toArray(new String[defaultFeatures.names().size()]);
         }
+        this.outboundHandler = new OutboundHandler(nodeName, version, features, threadPool, bigArrays, transportLogger);
+
+        this.handshaker = new TransportHandshaker(version, threadPool,
+            (node, channel, requestId, v) -> outboundHandler.sendRequest(node, channel, requestId,
+                TransportHandshaker.HANDSHAKE_ACTION_NAME, new TransportHandshaker.HandshakeRequest(version),
+                TransportRequestOptions.EMPTY, v, false, true),
+            (v, features1, channel, response, requestId) -> outboundHandler.sendResponse(v, features1, channel, requestId,
+                TransportHandshaker.HANDSHAKE_ACTION_NAME, response, false, true));
+        this.keepAlive = new TransportKeepAlive(threadPool, this.outboundHandler::sendBytes);
+        this.reader = new InboundMessage.Reader(version, namedWriteableRegistry, threadPool.getThreadContext());
     }
 
     @Override
@@ -182,8 +177,9 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
 
     @Override
     public synchronized void setMessageListener(TransportMessageListener listener) {
-        if (messageListener == NOOP_LISTENER) {
+        if (messageListener == TransportMessageListener.NOOP_LISTENER) {
             messageListener = listener;
+            outboundHandler.setMessageListener(listener);
         } else {
             throw new IllegalStateException("Cannot set message listener twice");
         }
@@ -267,7 +263,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
                 throw new NodeNotConnectedException(node, "connection already closed");
             }
             TcpChannel channel = channel(options.type());
-            sendRequestToChannel(this.node, channel, requestId, action, request, options, getVersion(), compress);
+            outboundHandler.sendRequest(node, channel, requestId, action, request, options, getVersion(), compress, false);
         }
     }
 
@@ -661,81 +657,6 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
      */
     protected abstract void stopInternal();
 
-    private void sendRequestToChannel(final DiscoveryNode node, final TcpChannel channel, final long requestId, final String action,
-                                      final TransportRequest request, TransportRequestOptions options, Version channelVersion,
-                                      boolean compressRequest) throws IOException, TransportException {
-        sendRequestToChannel(node, channel, requestId, action, request, options, channelVersion, compressRequest, false);
-    }
-
-    private void sendRequestToChannel(final DiscoveryNode node, final TcpChannel channel, final long requestId, final String action,
-                                      final TransportRequest request, TransportRequestOptions options, Version channelVersion,
-                                      boolean compressRequest, boolean isHandshake) throws IOException, TransportException {
-        Version version = Version.min(this.version, channelVersion);
-        OutboundMessage.Request message = new OutboundMessage.Request(threadPool.getThreadContext(), features, request, version, action,
-            requestId, isHandshake, compressRequest);
-        ActionListener<Void> listener = ActionListener.wrap(() ->
-            messageListener.onRequestSent(node, requestId, action, request, options));
-        outboundHandler.sendMessage(channel, message, listener);
-    }
-
-    /**
-     * Sends back an error response to the caller via the given channel
-     *
-     * @param nodeVersion the caller node version
-     * @param features    the caller features
-     * @param channel     the channel to send the response to
-     * @param error       the error to return
-     * @param requestId   the request ID this response replies to
-     * @param action      the action this response replies to
-     */
-    public void sendErrorResponse(
-        final Version nodeVersion,
-        final Set<String> features,
-        final TcpChannel channel,
-        final Exception error,
-        final long requestId,
-        final String action) throws IOException {
-        Version version = Version.min(this.version, nodeVersion);
-        TransportAddress address = new TransportAddress(channel.getLocalAddress());
-        RemoteTransportException tx = new RemoteTransportException(nodeName, address, action, error);
-        OutboundMessage.Response message = new OutboundMessage.Response(threadPool.getThreadContext(), features, tx, version, requestId,
-            false, false);
-        ActionListener<Void> listener = ActionListener.wrap(() -> messageListener.onResponseSent(requestId, action, error));
-        outboundHandler.sendMessage(channel, message, listener);
-    }
-
-    /**
-     * Sends the response to the given channel. This method should be used to send {@link TransportResponse} objects back to the caller.
-     *
-     * @see #sendErrorResponse(Version, Set, TcpChannel, Exception, long, String) for sending back errors to the caller
-     */
-    public void sendResponse(
-        final Version nodeVersion,
-        final Set<String> features,
-        final TcpChannel channel,
-        final TransportResponse response,
-        final long requestId,
-        final String action,
-        final boolean compress) throws IOException {
-        sendResponse(nodeVersion, features, channel, response, requestId, action, compress, false);
-    }
-
-    private void sendResponse(
-        final Version nodeVersion,
-        final Set<String> features,
-        final TcpChannel channel,
-        final TransportResponse response,
-        final long requestId,
-        final String action,
-        boolean compress,
-        boolean isHandshake) throws IOException {
-        Version version = Version.min(this.version, nodeVersion);
-        OutboundMessage.Response message = new OutboundMessage.Response(threadPool.getThreadContext(), features, response, version,
-            requestId, isHandshake, compress);
-        ActionListener<Void> listener = ActionListener.wrap(() -> messageListener.onResponseSent(requestId, action, response));
-        outboundHandler.sendMessage(channel, message, listener);
-    }
-
     /**
      * Handles inbound message that has been decoded.
      *
@@ -913,7 +834,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
             message.getStoredContext().restore();
             threadContext.putTransient("_remote_address", remoteAddress);
             if (message.isRequest()) {
-                handleRequest(channel, (InboundMessage.RequestMessage) message, reference.length());
+                handleRequest(channel, (InboundMessage.Request) message, reference.length());
             } else {
                 final TransportResponseHandler<?> handler;
                 long requestId = message.getRequestId();
@@ -999,7 +920,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
         });
     }
 
-    protected void handleRequest(TcpChannel channel, InboundMessage.RequestMessage message, int messageLengthBytes) throws IOException {
+    protected void handleRequest(TcpChannel channel, InboundMessage.Request message, int messageLengthBytes) throws IOException {
         final Set<String> features = message.getFeatures();
         final String profileName = channel.getProfile();
         final String action = message.getActionName();
@@ -1021,8 +942,8 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
                 } else {
                     getInFlightRequestBreaker().addWithoutBreaking(messageLengthBytes);
                 }
-                transportChannel = new TcpTransportChannel(this, channel, action, requestId, version, features, profileName,
-                    messageLengthBytes, message.isCompress());
+                transportChannel = new TcpTransportChannel(outboundHandler, channel, action, requestId, version, features,
+                    circuitBreakerService, messageLengthBytes, message.isCompress());
                 final TransportRequest request = reg.newRequest(stream);
                 request.remoteAddress(new TransportAddress(channel.getRemoteAddress()));
                 // in case we throw an exception, i.e. when the limit is hit, we don't want to verify
@@ -1032,8 +953,8 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
         } catch (Exception e) {
             // the circuit breaker tripped
             if (transportChannel == null) {
-                transportChannel = new TcpTransportChannel(this, channel, action, requestId, version, features,
-                    profileName, 0, message.isCompress());
+                transportChannel = new TcpTransportChannel(outboundHandler, channel, action, requestId, version, features,
+                    circuitBreakerService, 0, message.isCompress());
             }
             try {
                 transportChannel.sendResponse(e);

+ 15 - 13
server/src/main/java/org/elasticsearch/transport/TcpTransportChannel.java

@@ -20,6 +20,8 @@
 package org.elasticsearch.transport;
 
 import org.elasticsearch.Version;
+import org.elasticsearch.common.breaker.CircuitBreaker;
+import org.elasticsearch.indices.breaker.CircuitBreakerService;
 
 import java.io.IOException;
 import java.util.Set;
@@ -28,38 +30,38 @@ import java.util.concurrent.atomic.AtomicBoolean;
 public final class TcpTransportChannel implements TransportChannel {
 
     private final AtomicBoolean released = new AtomicBoolean();
-    private final TcpTransport transport;
-    private final Version version;
-    private final Set<String> features;
+    private final OutboundHandler outboundHandler;
+    private final TcpChannel channel;
     private final String action;
     private final long requestId;
-    private final String profileName;
+    private final Version version;
+    private final Set<String> features;
+    private final CircuitBreakerService breakerService;
     private final long reservedBytes;
-    private final TcpChannel channel;
     private final boolean compressResponse;
 
-    TcpTransportChannel(TcpTransport transport, TcpChannel channel, String action, long requestId, Version version, Set<String> features,
-                        String profileName, long reservedBytes, boolean compressResponse) {
+    TcpTransportChannel(OutboundHandler outboundHandler, TcpChannel channel, String action, long requestId, Version version,
+                        Set<String> features, CircuitBreakerService breakerService, long reservedBytes, boolean compressResponse) {
         this.version = version;
         this.features = features;
         this.channel = channel;
-        this.transport = transport;
+        this.outboundHandler = outboundHandler;
         this.action = action;
         this.requestId = requestId;
-        this.profileName = profileName;
+        this.breakerService = breakerService;
         this.reservedBytes = reservedBytes;
         this.compressResponse = compressResponse;
     }
 
     @Override
     public String getProfileName() {
-        return profileName;
+        return channel.getProfile();
     }
 
     @Override
     public void sendResponse(TransportResponse response) throws IOException {
         try {
-            transport.sendResponse(version, features, channel, response, requestId, action, compressResponse);
+            outboundHandler.sendResponse(version, features, channel, requestId, action, response, compressResponse, false);
         } finally {
             release(false);
         }
@@ -68,7 +70,7 @@ public final class TcpTransportChannel implements TransportChannel {
     @Override
     public void sendResponse(Exception exception) throws IOException {
         try {
-            transport.sendErrorResponse(version, features, channel, exception, requestId, action);
+            outboundHandler.sendErrorResponse(version, features, channel, requestId, action, exception);
         } finally {
             release(true);
         }
@@ -79,7 +81,7 @@ public final class TcpTransportChannel implements TransportChannel {
     private void release(boolean isExceptionResponse) {
         if (released.compareAndSet(false, true)) {
             assert (releaseBy = new Exception()) != null; // easier to debug if it's already closed
-            transport.getInFlightRequestBreaker().addWithoutBreaking(-reservedBytes);
+            breakerService.getBreaker(CircuitBreaker.IN_FLIGHT_REQUESTS).addWithoutBreaking(-reservedBytes);
         } else if (isExceptionResponse == false) {
             // only fail if we are not sending an error - we might send the error triggered by the previous
             // sendResponse call

+ 0 - 1
server/src/main/java/org/elasticsearch/transport/Transport.java

@@ -30,7 +30,6 @@ import org.elasticsearch.common.transport.BoundTransportAddress;
 import org.elasticsearch.common.transport.TransportAddress;
 import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
 import org.elasticsearch.common.util.concurrent.ConcurrentMapLong;
-
 import java.io.Closeable;
 import java.io.IOException;
 import java.net.UnknownHostException;

+ 2 - 0
server/src/main/java/org/elasticsearch/transport/TransportMessageListener.java

@@ -22,6 +22,8 @@ import org.elasticsearch.cluster.node.DiscoveryNode;
 
 public interface TransportMessageListener {
 
+    TransportMessageListener NOOP_LISTENER = new TransportMessageListener() {};
+
     /**
      * Called once a request is received
      * @param requestId the internal request ID

+ 3 - 3
server/src/test/java/org/elasticsearch/transport/InboundMessageTests.java

@@ -63,7 +63,7 @@ public class InboundMessageTests extends ESTestCase {
 
         InboundMessage.Reader reader = new InboundMessage.Reader(version, registry, threadContext);
         BytesReference sliced = reference.slice(6, reference.length() - 6);
-        InboundMessage.RequestMessage inboundMessage = (InboundMessage.RequestMessage) reader.deserialize(sliced);
+        InboundMessage.Request inboundMessage = (InboundMessage.Request) reader.deserialize(sliced);
         // Check that deserialize does not overwrite current thread context.
         assertEquals("header_value2", threadContext.getHeader("header"));
         inboundMessage.getStoredContext().restore();
@@ -102,7 +102,7 @@ public class InboundMessageTests extends ESTestCase {
 
         InboundMessage.Reader reader = new InboundMessage.Reader(version, registry, threadContext);
         BytesReference sliced = reference.slice(6, reference.length() - 6);
-        InboundMessage.ResponseMessage inboundMessage = (InboundMessage.ResponseMessage) reader.deserialize(sliced);
+        InboundMessage.Response inboundMessage = (InboundMessage.Response) reader.deserialize(sliced);
         // Check that deserialize does not overwrite current thread context.
         assertEquals("header_value2", threadContext.getHeader("header"));
         inboundMessage.getStoredContext().restore();
@@ -138,7 +138,7 @@ public class InboundMessageTests extends ESTestCase {
 
         InboundMessage.Reader reader = new InboundMessage.Reader(version, registry, threadContext);
         BytesReference sliced = reference.slice(6, reference.length() - 6);
-        InboundMessage.ResponseMessage inboundMessage = (InboundMessage.ResponseMessage) reader.deserialize(sliced);
+        InboundMessage.Response inboundMessage = (InboundMessage.Response) reader.deserialize(sliced);
         // Check that deserialize does not overwrite current thread context.
         assertEquals("header_value2", threadContext.getHeader("header"));
         inboundMessage.getStoredContext().restore();

+ 201 - 41
server/src/test/java/org/elasticsearch/transport/OutboundHandlerTests.java

@@ -19,14 +19,16 @@
 
 package org.elasticsearch.transport;
 
+import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.Version;
 import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.common.bytes.BytesArray;
 import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
-import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.transport.TransportAddress;
 import org.elasticsearch.common.util.BigArrays;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
 import org.elasticsearch.test.ESTestCase;
@@ -38,24 +40,34 @@ import org.junit.Before;
 import java.io.IOException;
 import java.nio.charset.StandardCharsets;
 import java.util.Collections;
-import java.util.HashSet;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicLong;
 import java.util.concurrent.atomic.AtomicReference;
 
+import static org.hamcrest.Matchers.contains;
+import static org.hamcrest.Matchers.instanceOf;
+
 public class OutboundHandlerTests extends ESTestCase {
 
+    private final String feature1 = "feature1";
+    private final String feature2 = "feature2";
     private final TestThreadPool threadPool = new TestThreadPool(getClass().getName());
     private final NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList());
+    private final TransportRequestOptions options = TransportRequestOptions.EMPTY;
     private OutboundHandler handler;
-    private FakeTcpChannel fakeTcpChannel;
+    private FakeTcpChannel channel;
+    private DiscoveryNode node;
 
     @Before
     public void setUp() throws Exception {
         super.setUp();
         TransportLogger transportLogger = new TransportLogger();
-        fakeTcpChannel = new FakeTcpChannel(randomBoolean());
-        handler = new OutboundHandler(threadPool, BigArrays.NON_RECYCLING_INSTANCE, transportLogger);
+        channel = new FakeTcpChannel(randomBoolean(), buildNewFakeTransportAddress().address(), buildNewFakeTransportAddress().address());
+        TransportAddress transportAddress = buildNewFakeTransportAddress();
+        node = new DiscoveryNode("", transportAddress, Version.CURRENT);
+        String[] features = {feature1, feature2};
+        handler = new OutboundHandler("node", Version.CURRENT, features, threadPool, BigArrays.NON_RECYCLING_INSTANCE, transportLogger);
     }
 
     @After
@@ -70,10 +82,10 @@ public class OutboundHandlerTests extends ESTestCase {
         AtomicBoolean isSuccess = new AtomicBoolean(false);
         AtomicReference<Exception> exception = new AtomicReference<>();
         ActionListener<Void> listener = ActionListener.wrap((v) -> isSuccess.set(true), exception::set);
-        handler.sendBytes(fakeTcpChannel, bytesArray, listener);
+        handler.sendBytes(channel, bytesArray, listener);
 
-        BytesReference reference = fakeTcpChannel.getMessageCaptor().get();
-        ActionListener<Void> sendListener  = fakeTcpChannel.getListenerCaptor().get();
+        BytesReference reference = channel.getMessageCaptor().get();
+        ActionListener<Void> sendListener  = channel.getListenerCaptor().get();
         if (randomBoolean()) {
             sendListener.onResponse(null);
             assertTrue(isSuccess.get());
@@ -88,55 +100,118 @@ public class OutboundHandlerTests extends ESTestCase {
         assertEquals(bytesArray, reference);
     }
 
-    public void testSendMessage() throws IOException {
-        OutboundMessage message;
+    public void testSendRequest() throws IOException {
         ThreadContext threadContext = threadPool.getThreadContext();
-        Version version = Version.CURRENT;
-        String actionName = "handshake";
+        Version version = randomFrom(Version.CURRENT, Version.CURRENT.minimumCompatibilityVersion());
+        String action = "handshake";
         long requestId = randomLongBetween(0, 300);
         boolean isHandshake = randomBoolean();
         boolean compress = randomBoolean();
         String value = "message";
         threadContext.putHeader("header", "header_value");
-        Writeable writeable = new Message(value);
+        Request request = new Request(value);
+
+        AtomicReference<DiscoveryNode> nodeRef = new AtomicReference<>();
+        AtomicLong requestIdRef = new AtomicLong();
+        AtomicReference<String> actionRef = new AtomicReference<>();
+        AtomicReference<TransportRequest> requestRef = new AtomicReference<>();
+        handler.setMessageListener(new TransportMessageListener() {
+            @Override
+            public void onRequestSent(DiscoveryNode node, long requestId, String action, TransportRequest request,
+                                      TransportRequestOptions options) {
+                nodeRef.set(node);
+                requestIdRef.set(requestId);
+                actionRef.set(action);
+                requestRef.set(request);
+            }
+        });
+        handler.sendRequest(node, channel, requestId, action, request, options, version, compress, isHandshake);
 
-        boolean isRequest = randomBoolean();
-        if (isRequest) {
-            message = new OutboundMessage.Request(threadContext, new String[0], writeable, version, actionName, requestId, isHandshake,
-                compress);
+        BytesReference reference = channel.getMessageCaptor().get();
+        ActionListener<Void> sendListener  = channel.getListenerCaptor().get();
+        if (randomBoolean()) {
+            sendListener.onResponse(null);
         } else {
-            message = new OutboundMessage.Response(threadContext, new HashSet<>(), writeable, version, requestId, isHandshake, compress);
+            sendListener.onFailure(new IOException("failed"));
         }
+        assertEquals(node, nodeRef.get());
+        assertEquals(requestId, requestIdRef.get());
+        assertEquals(action, actionRef.get());
+        assertEquals(request, requestRef.get());
 
-        AtomicBoolean isSuccess = new AtomicBoolean(false);
-        AtomicReference<Exception> exception = new AtomicReference<>();
-        ActionListener<Void> listener = ActionListener.wrap((v) -> isSuccess.set(true), exception::set);
-        handler.sendMessage(fakeTcpChannel, message, listener);
+        InboundMessage.Reader reader = new InboundMessage.Reader(Version.CURRENT, namedWriteableRegistry, threadPool.getThreadContext());
+        try (InboundMessage inboundMessage = reader.deserialize(reference.slice(6, reference.length() - 6))) {
+            assertEquals(version, inboundMessage.getVersion());
+            assertEquals(requestId, inboundMessage.getRequestId());
+            assertTrue(inboundMessage.isRequest());
+            assertFalse(inboundMessage.isResponse());
+            if (isHandshake) {
+                assertTrue(inboundMessage.isHandshake());
+            } else {
+                assertFalse(inboundMessage.isHandshake());
+            }
+            if (compress) {
+                assertTrue(inboundMessage.isCompress());
+            } else {
+                assertFalse(inboundMessage.isCompress());
+            }
+            InboundMessage.Request inboundRequest = (InboundMessage.Request) inboundMessage;
+            assertThat(inboundRequest.getFeatures(), contains(feature1, feature2));
+
+            Request readMessage = new Request();
+            readMessage.readFrom(inboundMessage.getStreamInput());
+            assertEquals(value, readMessage.value);
 
-        BytesReference reference = fakeTcpChannel.getMessageCaptor().get();
-        ActionListener<Void> sendListener  = fakeTcpChannel.getListenerCaptor().get();
+            try (ThreadContext.StoredContext existing = threadContext.stashContext()) {
+                ThreadContext.StoredContext storedContext = inboundMessage.getStoredContext();
+                assertNull(threadContext.getHeader("header"));
+                storedContext.restore();
+                assertEquals("header_value", threadContext.getHeader("header"));
+            }
+        }
+    }
+
+    public void testSendResponse() throws IOException {
+        ThreadContext threadContext = threadPool.getThreadContext();
+        Version version = randomFrom(Version.CURRENT, Version.CURRENT.minimumCompatibilityVersion());
+        String action = "handshake";
+        long requestId = randomLongBetween(0, 300);
+        boolean isHandshake = randomBoolean();
+        boolean compress = randomBoolean();
+        String value = "message";
+        threadContext.putHeader("header", "header_value");
+        Response response = new Response(value);
+
+        AtomicLong requestIdRef = new AtomicLong();
+        AtomicReference<String> actionRef = new AtomicReference<>();
+        AtomicReference<TransportResponse> responseRef = new AtomicReference<>();
+        handler.setMessageListener(new TransportMessageListener() {
+            @Override
+            public void onResponseSent(long requestId, String action, TransportResponse response) {
+                requestIdRef.set(requestId);
+                actionRef.set(action);
+                responseRef.set(response);
+            }
+        });
+        handler.sendResponse(version, Collections.emptySet(), channel, requestId, action, response, compress, isHandshake);
+
+        BytesReference reference = channel.getMessageCaptor().get();
+        ActionListener<Void> sendListener  = channel.getListenerCaptor().get();
         if (randomBoolean()) {
             sendListener.onResponse(null);
-            assertTrue(isSuccess.get());
-            assertNull(exception.get());
         } else {
-            IOException e = new IOException("failed");
-            sendListener.onFailure(e);
-            assertFalse(isSuccess.get());
-            assertSame(e, exception.get());
+            sendListener.onFailure(new IOException("failed"));
         }
+        assertEquals(requestId, requestIdRef.get());
+        assertEquals(action, actionRef.get());
+        assertEquals(response, responseRef.get());
 
         InboundMessage.Reader reader = new InboundMessage.Reader(Version.CURRENT, namedWriteableRegistry, threadPool.getThreadContext());
         try (InboundMessage inboundMessage = reader.deserialize(reference.slice(6, reference.length() - 6))) {
             assertEquals(version, inboundMessage.getVersion());
             assertEquals(requestId, inboundMessage.getRequestId());
-            if (isRequest) {
-                assertTrue(inboundMessage.isRequest());
-                assertFalse(inboundMessage.isResponse());
-            } else {
-                assertTrue(inboundMessage.isResponse());
-                assertFalse(inboundMessage.isRequest());
-            }
+            assertFalse(inboundMessage.isRequest());
+            assertTrue(inboundMessage.isResponse());
             if (isHandshake) {
                 assertTrue(inboundMessage.isHandshake());
             } else {
@@ -147,7 +222,11 @@ public class OutboundHandlerTests extends ESTestCase {
             } else {
                 assertFalse(inboundMessage.isCompress());
             }
-            Message readMessage = new Message();
+
+            InboundMessage.Response inboundResponse = (InboundMessage.Response) inboundMessage;
+            assertFalse(inboundResponse.isError());
+
+            Response readMessage = new Response();
             readMessage.readFrom(inboundMessage.getStreamInput());
             assertEquals(value, readMessage.value);
 
@@ -160,14 +239,95 @@ public class OutboundHandlerTests extends ESTestCase {
         }
     }
 
-    private static final class Message extends TransportMessage {
+    public void testErrorResponse() throws IOException {
+        ThreadContext threadContext = threadPool.getThreadContext();
+        Version version = randomFrom(Version.CURRENT, Version.CURRENT.minimumCompatibilityVersion());
+        String action = "handshake";
+        long requestId = randomLongBetween(0, 300);
+        threadContext.putHeader("header", "header_value");
+        ElasticsearchException error = new ElasticsearchException("boom");
+
+        AtomicLong requestIdRef = new AtomicLong();
+        AtomicReference<String> actionRef = new AtomicReference<>();
+        AtomicReference<Exception> responseRef = new AtomicReference<>();
+        handler.setMessageListener(new TransportMessageListener() {
+            @Override
+            public void onResponseSent(long requestId, String action, Exception error) {
+                requestIdRef.set(requestId);
+                actionRef.set(action);
+                responseRef.set(error);
+            }
+        });
+        handler.sendErrorResponse(version, Collections.emptySet(), channel, requestId, action, error);
+
+        BytesReference reference = channel.getMessageCaptor().get();
+        ActionListener<Void> sendListener  = channel.getListenerCaptor().get();
+        if (randomBoolean()) {
+            sendListener.onResponse(null);
+        } else {
+            sendListener.onFailure(new IOException("failed"));
+        }
+        assertEquals(requestId, requestIdRef.get());
+        assertEquals(action, actionRef.get());
+        assertEquals(error, responseRef.get());
+
+        InboundMessage.Reader reader = new InboundMessage.Reader(Version.CURRENT, namedWriteableRegistry, threadPool.getThreadContext());
+        try (InboundMessage inboundMessage = reader.deserialize(reference.slice(6, reference.length() - 6))) {
+            assertEquals(version, inboundMessage.getVersion());
+            assertEquals(requestId, inboundMessage.getRequestId());
+            assertFalse(inboundMessage.isRequest());
+            assertTrue(inboundMessage.isResponse());
+            assertFalse(inboundMessage.isCompress());
+            assertFalse(inboundMessage.isHandshake());
+
+            InboundMessage.Response inboundResponse = (InboundMessage.Response) inboundMessage;
+            assertTrue(inboundResponse.isError());
+
+            RemoteTransportException remoteException = inboundMessage.getStreamInput().readException();
+            assertThat(remoteException.getCause(), instanceOf(ElasticsearchException.class));
+            assertEquals(remoteException.getCause().getMessage(), "boom");
+            assertEquals(action, remoteException.action());
+            assertEquals(channel.getLocalAddress(), remoteException.address().address());
+
+            try (ThreadContext.StoredContext existing = threadContext.stashContext()) {
+                ThreadContext.StoredContext storedContext = inboundMessage.getStoredContext();
+                assertNull(threadContext.getHeader("header"));
+                storedContext.restore();
+                assertEquals("header_value", threadContext.getHeader("header"));
+            }
+        }
+    }
+
+    private static final class Request extends TransportRequest {
+
+        public String value;
+
+        private Request() {
+        }
+
+        private Request(String value) {
+            this.value = value;
+        }
+
+        @Override
+        public void readFrom(StreamInput in) throws IOException {
+            value = in.readString();
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            out.writeString(value);
+        }
+    }
+
+    private static final class Response extends TransportResponse {
 
         public String value;
 
-        private Message() {
+        private Response() {
         }
 
-        private Message(String value) {
+        private Response(String value) {
             this.value = value;
         }
 

+ 2 - 2
test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java

@@ -2008,12 +2008,12 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase {
             new NetworkService(Collections.emptyList()), PageCacheRecycler.NON_RECYCLING_INSTANCE, namedWriteableRegistry,
             new NoneCircuitBreakerService()) {
             @Override
-            protected void handleRequest(TcpChannel channel, InboundMessage.RequestMessage request, int messageLengthBytes)
+            protected void handleRequest(TcpChannel channel, InboundMessage.Request request, int messageLengthBytes)
                 throws IOException {
                 // we flip the isHandshake bit back and act like the handler is not found
                 byte status = (byte) (request.status & ~(1 << 3));
                 Version version = request.getVersion();
-                InboundMessage.RequestMessage nonHandshakeRequest = new InboundMessage.RequestMessage(request.threadContext, version,
+                InboundMessage.Request nonHandshakeRequest = new InboundMessage.Request(request.threadContext, version,
                     status, request.getRequestId(), request.getActionName(), request.getFeatures(), request.getStreamInput());
                 super.handleRequest(channel, nonHandshakeRequest, messageLengthBytes);
             }

+ 4 - 0
test/framework/src/main/java/org/elasticsearch/transport/FakeTcpChannel.java

@@ -44,6 +44,10 @@ public class FakeTcpChannel implements TcpChannel {
         this(isServer, "profile", new AtomicReference<>());
     }
 
+    public FakeTcpChannel(boolean isServer, InetSocketAddress localAddress, InetSocketAddress remoteAddress) {
+        this(isServer, localAddress, remoteAddress, "profile", new AtomicReference<>());
+    }
+
     public FakeTcpChannel(boolean isServer, AtomicReference<BytesReference> messageCaptor) {
         this(isServer, "profile", messageCaptor);
     }