Browse Source

Extract message serialization from `TcpTransport` (#37034)

This commit introduces a NetworkMessage class. This class has two
subclasses - InboundMessage and OutboundMessage. These messages can
be serialized and deserialized independent of the transport. This allows
more granular testing. Additionally, the serialization mechanism is now
a simple Supplier. This builds the framework to eventually move the
serialization of transport messages to the network thread. This is the
one serialization component that is not currently performed on the
network thread (transport deserialization and http serialization and
deserialization are all on the network thread).
Tim Brooks 6 years ago
parent
commit
21838d73b5
18 changed files with 1104 additions and 400 deletions
  1. 3 0
      libs/nio/src/main/java/org/elasticsearch/nio/BytesWriteHandler.java
  2. 3 7
      modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4TransportIT.java
  3. 3 7
      plugins/transport-nio/src/test/java/org/elasticsearch/transport/nio/NioTransportIT.java
  4. 4 7
      server/src/main/java/org/elasticsearch/transport/CompressibleBytesOutputStream.java
  5. 168 0
      server/src/main/java/org/elasticsearch/transport/InboundMessage.java
  6. 76 0
      server/src/main/java/org/elasticsearch/transport/NetworkMessage.java
  7. 167 0
      server/src/main/java/org/elasticsearch/transport/OutboundHandler.java
  8. 155 0
      server/src/main/java/org/elasticsearch/transport/OutboundMessage.java
  9. 71 328
      server/src/main/java/org/elasticsearch/transport/TcpTransport.java
  10. 4 0
      server/src/main/java/org/elasticsearch/transport/TransportLogger.java
  11. 5 2
      server/src/test/java/org/elasticsearch/discovery/DiscoveryModuleTests.java
  12. 2 2
      server/src/test/java/org/elasticsearch/discovery/zen/FileBasedUnicastHostsProviderTests.java
  13. 7 2
      server/src/test/java/org/elasticsearch/transport/CompressibleBytesOutputStreamTests.java
  14. 228 0
      server/src/test/java/org/elasticsearch/transport/InboundMessageTests.java
  15. 184 0
      server/src/test/java/org/elasticsearch/transport/OutboundHandlerTests.java
  16. 5 35
      server/src/test/java/org/elasticsearch/transport/TcpTransportTests.java
  17. 7 4
      test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java
  18. 12 6
      test/framework/src/main/java/org/elasticsearch/transport/FakeTcpChannel.java

+ 3 - 0
libs/nio/src/main/java/org/elasticsearch/nio/BytesWriteHandler.java

@@ -34,14 +34,17 @@ public abstract class BytesWriteHandler implements ReadWriteHandler {
         return new FlushReadyWrite(context, (ByteBuffer[]) message, listener);
     }
 
+    @Override
     public List<FlushOperation> writeToBytes(WriteOperation writeOperation) {
         assert writeOperation instanceof FlushReadyWrite : "Write operation must be flush ready";
         return Collections.singletonList((FlushReadyWrite) writeOperation);
     }
 
+    @Override
     public List<FlushOperation> pollFlushOperations() {
         return EMPTY_LIST;
     }
 
+    @Override
     public void close() {}
 }

+ 3 - 7
modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4TransportIT.java

@@ -36,12 +36,12 @@ import org.elasticsearch.plugins.Plugin;
 import org.elasticsearch.test.ESIntegTestCase.ClusterScope;
 import org.elasticsearch.test.ESIntegTestCase.Scope;
 import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.transport.InboundMessage;
 import org.elasticsearch.transport.TcpChannel;
 import org.elasticsearch.transport.Transport;
 import org.elasticsearch.transport.TransportSettings;
 
 import java.io.IOException;
-import java.net.InetSocketAddress;
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Collections;
@@ -111,13 +111,9 @@ public class Netty4TransportIT extends ESNetty4IntegTestCase {
         }
 
         @Override
-        protected String handleRequest(TcpChannel channel, String profileName,
-                                       StreamInput stream, long requestId, int messageLengthBytes, Version version,
-                                       InetSocketAddress remoteAddress, byte status) throws IOException {
-            String action = super.handleRequest(channel, profileName, stream, requestId, messageLengthBytes, version,
-                    remoteAddress, status);
+        protected void handleRequest(TcpChannel channel, InboundMessage.RequestMessage request, int messageLengthBytes) throws IOException {
+            super.handleRequest(channel, request, messageLengthBytes);
             channelProfileName = TransportSettings.DEFAULT_PROFILE;
-            return action;
         }
 
         @Override

+ 3 - 7
plugins/transport-nio/src/test/java/org/elasticsearch/transport/nio/NioTransportIT.java

@@ -38,12 +38,12 @@ import org.elasticsearch.plugins.Plugin;
 import org.elasticsearch.test.ESIntegTestCase.ClusterScope;
 import org.elasticsearch.test.ESIntegTestCase.Scope;
 import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.transport.InboundMessage;
 import org.elasticsearch.transport.TcpChannel;
 import org.elasticsearch.transport.Transport;
 import org.elasticsearch.transport.TransportSettings;
 
 import java.io.IOException;
-import java.net.InetSocketAddress;
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Collections;
@@ -113,13 +113,9 @@ public class NioTransportIT extends NioIntegTestCase {
         }
 
         @Override
-        protected String handleRequest(TcpChannel channel, String profileName,
-                                       StreamInput stream, long requestId, int messageLengthBytes, Version version,
-                                       InetSocketAddress remoteAddress, byte status) throws IOException {
-            String action = super.handleRequest(channel, profileName, stream, requestId, messageLengthBytes, version,
-                    remoteAddress, status);
+        protected void handleRequest(TcpChannel channel, InboundMessage.RequestMessage request, int messageLengthBytes) throws IOException {
+            super.handleRequest(channel, request, messageLengthBytes);
             channelProfileName = TransportSettings.DEFAULT_PROFILE;
-            return action;
         }
 
         @Override

+ 4 - 7
server/src/main/java/org/elasticsearch/transport/CompressibleBytesOutputStream.java

@@ -39,8 +39,8 @@ import java.util.zip.DeflaterOutputStream;
  * written to this stream. If compression is enabled, the proper EOS bytes will be written at that point.
  * The underlying {@link BytesReference} will be returned.
  *
- * {@link CompressibleBytesOutputStream#close()} should be called when the bytes are no longer needed and
- * can be safely released.
+ * {@link CompressibleBytesOutputStream#close()} will NOT close the underlying stream. The byte stream passed
+ * in the constructor must be closed individually.
  */
 final class CompressibleBytesOutputStream extends StreamOutput {
 
@@ -92,12 +92,9 @@ final class CompressibleBytesOutputStream extends StreamOutput {
 
     @Override
     public void close() throws IOException {
-        if (stream == bytesStreamOutput) {
-            assert shouldCompress == false : "If the streams are the same we should not be compressing";
-            IOUtils.close(stream);
-        } else {
+        if (stream != bytesStreamOutput) {
             assert shouldCompress : "If the streams are different we should be compressing";
-            IOUtils.close(stream, bytesStreamOutput);
+            IOUtils.close(stream);
         }
     }
 

+ 168 - 0
server/src/main/java/org/elasticsearch/transport/InboundMessage.java

@@ -0,0 +1,168 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.transport;
+
+import org.elasticsearch.Version;
+import org.elasticsearch.common.bytes.BytesReference;
+import org.elasticsearch.common.compress.Compressor;
+import org.elasticsearch.common.compress.CompressorFactory;
+import org.elasticsearch.common.compress.NotCompressedException;
+import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput;
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.util.concurrent.ThreadContext;
+import org.elasticsearch.core.internal.io.IOUtils;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Set;
+import java.util.TreeSet;
+
+public abstract class InboundMessage extends NetworkMessage implements Closeable {
+
+    private final StreamInput streamInput;
+
+    InboundMessage(ThreadContext threadContext, Version version, byte status, long requestId, StreamInput streamInput) {
+        super(threadContext, version, status, requestId);
+        this.streamInput = streamInput;
+    }
+
+    StreamInput getStreamInput() {
+        return streamInput;
+    }
+
+    static class Reader {
+
+        private final Version version;
+        private final NamedWriteableRegistry namedWriteableRegistry;
+        private final ThreadContext threadContext;
+
+        Reader(Version version, NamedWriteableRegistry namedWriteableRegistry, ThreadContext threadContext) {
+            this.version = version;
+            this.namedWriteableRegistry = namedWriteableRegistry;
+            this.threadContext = threadContext;
+        }
+
+        InboundMessage deserialize(BytesReference reference) throws IOException {
+            int messageLengthBytes = reference.length();
+            final int totalMessageSize = messageLengthBytes + TcpHeader.MARKER_BYTES_SIZE + TcpHeader.MESSAGE_LENGTH_SIZE;
+            // we have additional bytes to read, outside of the header
+            boolean hasMessageBytesToRead = (totalMessageSize - TcpHeader.HEADER_SIZE) > 0;
+            StreamInput streamInput = reference.streamInput();
+            boolean success = false;
+            try (ThreadContext.StoredContext existing = threadContext.stashContext()) {
+                long requestId = streamInput.readLong();
+                byte status = streamInput.readByte();
+                Version remoteVersion = Version.fromId(streamInput.readInt());
+                final boolean isHandshake = TransportStatus.isHandshake(status);
+                ensureVersionCompatibility(remoteVersion, version, isHandshake);
+                if (TransportStatus.isCompress(status) && hasMessageBytesToRead && streamInput.available() > 0) {
+                    Compressor compressor;
+                    try {
+                        final int bytesConsumed = TcpHeader.REQUEST_ID_SIZE + TcpHeader.STATUS_SIZE + TcpHeader.VERSION_ID_SIZE;
+                        compressor = CompressorFactory.compressor(reference.slice(bytesConsumed, reference.length() - bytesConsumed));
+                    } catch (NotCompressedException ex) {
+                        int maxToRead = Math.min(reference.length(), 10);
+                        StringBuilder sb = new StringBuilder("stream marked as compressed, but no compressor found, first [")
+                            .append(maxToRead).append("] content bytes out of [").append(reference.length())
+                            .append("] readable bytes with message size [").append(messageLengthBytes).append("] ").append("] are [");
+                        for (int i = 0; i < maxToRead; i++) {
+                            sb.append(reference.get(i)).append(",");
+                        }
+                        sb.append("]");
+                        throw new IllegalStateException(sb.toString());
+                    }
+                    streamInput = compressor.streamInput(streamInput);
+                }
+                streamInput = new NamedWriteableAwareStreamInput(streamInput, namedWriteableRegistry);
+                streamInput.setVersion(remoteVersion);
+
+                threadContext.readHeaders(streamInput);
+
+                InboundMessage message;
+                if (TransportStatus.isRequest(status)) {
+                    final Set<String> features;
+                    if (remoteVersion.onOrAfter(Version.V_6_3_0)) {
+                        features = Collections.unmodifiableSet(new TreeSet<>(Arrays.asList(streamInput.readStringArray())));
+                    } else {
+                        features = Collections.emptySet();
+                    }
+                    final String action = streamInput.readString();
+                    message = new RequestMessage(threadContext, remoteVersion, status, requestId, action, features, streamInput);
+                } else {
+                    message = new ResponseMessage(threadContext, remoteVersion, status, requestId, streamInput);
+                }
+                success = true;
+                return message;
+            } finally {
+                if (success == false) {
+                    IOUtils.closeWhileHandlingException(streamInput);
+                }
+            }
+        }
+    }
+
+    @Override
+    public void close() throws IOException {
+        streamInput.close();
+    }
+
+    private static void ensureVersionCompatibility(Version version, Version currentVersion, boolean isHandshake) {
+        // for handshakes we are compatible with N-2 since otherwise we can't figure out our initial version
+        // since we are compatible with N-1 and N+1 so we always send our minCompatVersion as the initial version in the
+        // handshake. This looks odd but it's required to establish the connection correctly we check for real compatibility
+        // once the connection is established
+        final Version compatibilityVersion = isHandshake ? currentVersion.minimumCompatibilityVersion() : currentVersion;
+        if (version.isCompatible(compatibilityVersion) == false) {
+            final Version minCompatibilityVersion = isHandshake ? compatibilityVersion : compatibilityVersion.minimumCompatibilityVersion();
+            String msg = "Received " + (isHandshake ? "handshake " : "") + "message from unsupported version: [";
+            throw new IllegalStateException(msg + version + "] minimal compatible version is: [" + minCompatibilityVersion + "]");
+        }
+    }
+
+    public static class RequestMessage extends InboundMessage {
+
+        private final String actionName;
+        private final Set<String> features;
+
+        RequestMessage(ThreadContext threadContext, Version version, byte status, long requestId, String actionName, Set<String> features,
+                       StreamInput streamInput) {
+            super(threadContext, version, status, requestId, streamInput);
+            this.actionName = actionName;
+            this.features = features;
+        }
+
+        String getActionName() {
+            return actionName;
+        }
+
+        Set<String> getFeatures() {
+            return features;
+        }
+    }
+
+    public static class ResponseMessage extends InboundMessage {
+
+        ResponseMessage(ThreadContext threadContext, Version version, byte status, long requestId, StreamInput streamInput) {
+            super(threadContext, version, status, requestId, streamInput);
+        }
+    }
+}

+ 76 - 0
server/src/main/java/org/elasticsearch/transport/NetworkMessage.java

@@ -0,0 +1,76 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.transport;
+
+import org.elasticsearch.Version;
+import org.elasticsearch.common.util.concurrent.ThreadContext;
+
+/**
+ * Represents a transport message sent over the network. Subclasses implement serialization and
+ * deserialization.
+ */
+public abstract class NetworkMessage {
+
+    protected final Version version;
+    protected final ThreadContext threadContext;
+    protected final ThreadContext.StoredContext storedContext;
+    protected final long requestId;
+    protected final byte status;
+
+    NetworkMessage(ThreadContext threadContext, Version version, byte status, long requestId) {
+        this.threadContext = threadContext;
+        storedContext = threadContext.stashContext();
+        storedContext.restore();
+        this.version = version;
+        this.requestId = requestId;
+        this.status = status;
+    }
+
+    public Version getVersion() {
+        return version;
+    }
+
+    public long getRequestId() {
+        return requestId;
+    }
+
+    boolean isCompress() {
+        return TransportStatus.isCompress(status);
+    }
+
+    ThreadContext.StoredContext getStoredContext() {
+        return storedContext;
+    }
+
+    boolean isResponse() {
+        return TransportStatus.isRequest(status) == false;
+    }
+
+    boolean isRequest() {
+        return TransportStatus.isRequest(status);
+    }
+
+    boolean isHandshake() {
+        return TransportStatus.isHandshake(status);
+    }
+
+    boolean isError() {
+        return TransportStatus.isError(status);
+    }
+}

+ 167 - 0
server/src/main/java/org/elasticsearch/transport/OutboundHandler.java

@@ -0,0 +1,167 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.transport;
+
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.apache.logging.log4j.message.ParameterizedMessage;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.NotifyOnceListener;
+import org.elasticsearch.common.CheckedSupplier;
+import org.elasticsearch.common.bytes.BytesReference;
+import org.elasticsearch.common.io.stream.ReleasableBytesStreamOutput;
+import org.elasticsearch.common.lease.Releasable;
+import org.elasticsearch.common.lease.Releasables;
+import org.elasticsearch.common.metrics.MeanMetric;
+import org.elasticsearch.common.network.CloseableChannel;
+import org.elasticsearch.common.util.BigArrays;
+import org.elasticsearch.core.internal.io.IOUtils;
+import org.elasticsearch.threadpool.ThreadPool;
+
+import java.io.IOException;
+
+final class OutboundHandler {
+
+    private static final Logger logger = LogManager.getLogger(OutboundHandler.class);
+
+    private final MeanMetric transmittedBytesMetric = new MeanMetric();
+    private final ThreadPool threadPool;
+    private final BigArrays bigArrays;
+    private final TransportLogger transportLogger;
+
+    OutboundHandler(ThreadPool threadPool, BigArrays bigArrays, TransportLogger transportLogger) {
+        this.threadPool = threadPool;
+        this.bigArrays = bigArrays;
+        this.transportLogger = transportLogger;
+    }
+
+    void sendBytes(TcpChannel channel, BytesReference bytes, ActionListener<Void> listener) {
+        channel.getChannelStats().markAccessed(threadPool.relativeTimeInMillis());
+        SendContext sendContext = new SendContext(channel, () -> bytes, listener);
+        try {
+            internalSendMessage(channel, sendContext);
+        } catch (IOException e) {
+            // This should not happen as the bytes are already serialized
+            throw new AssertionError(e);
+        }
+    }
+
+    void sendMessage(TcpChannel channel, OutboundMessage networkMessage, ActionListener<Void> listener) throws IOException {
+        channel.getChannelStats().markAccessed(threadPool.relativeTimeInMillis());
+        MessageSerializer serializer = new MessageSerializer(networkMessage, bigArrays);
+        SendContext sendContext = new SendContext(channel, serializer, listener, serializer);
+        internalSendMessage(channel, sendContext);
+    }
+
+    /**
+     * sends a message to the given channel, using the given callbacks.
+     */
+    private void internalSendMessage(TcpChannel channel,  SendContext sendContext) throws IOException {
+        channel.getChannelStats().markAccessed(threadPool.relativeTimeInMillis());
+        BytesReference reference = sendContext.get();
+        try {
+            channel.sendMessage(reference, sendContext);
+        } catch (RuntimeException ex) {
+            sendContext.onFailure(ex);
+            CloseableChannel.closeChannel(channel);
+            throw ex;
+        }
+
+    }
+
+    MeanMetric getTransmittedBytes() {
+        return transmittedBytesMetric;
+    }
+
+    private static class MessageSerializer implements CheckedSupplier<BytesReference, IOException>, Releasable {
+
+        private final OutboundMessage message;
+        private final BigArrays bigArrays;
+        private volatile ReleasableBytesStreamOutput bytesStreamOutput;
+
+        private MessageSerializer(OutboundMessage message, BigArrays bigArrays) {
+            this.message = message;
+            this.bigArrays = bigArrays;
+        }
+
+        @Override
+        public BytesReference get() throws IOException {
+            bytesStreamOutput = new ReleasableBytesStreamOutput(bigArrays);
+            return message.serialize(bytesStreamOutput);
+        }
+
+        @Override
+        public void close() {
+            IOUtils.closeWhileHandlingException(bytesStreamOutput);
+        }
+    }
+
+    private class SendContext extends NotifyOnceListener<Void> implements CheckedSupplier<BytesReference, IOException> {
+
+        private final TcpChannel channel;
+        private final CheckedSupplier<BytesReference, IOException> messageSupplier;
+        private final ActionListener<Void> listener;
+        private final Releasable optionalReleasable;
+        private long messageSize = -1;
+
+        private SendContext(TcpChannel channel, CheckedSupplier<BytesReference, IOException> messageSupplier,
+                            ActionListener<Void> listener) {
+            this(channel, messageSupplier, listener, null);
+        }
+
+        private SendContext(TcpChannel channel, CheckedSupplier<BytesReference, IOException> messageSupplier,
+                            ActionListener<Void> listener, Releasable optionalReleasable) {
+            this.channel = channel;
+            this.messageSupplier = messageSupplier;
+            this.listener = listener;
+            this.optionalReleasable = optionalReleasable;
+        }
+
+        public BytesReference get() throws IOException {
+            BytesReference message;
+            try {
+                message = messageSupplier.get();
+                messageSize = message.length();
+                transportLogger.logOutboundMessage(channel, message);
+                return message;
+            } catch (Exception e) {
+                onFailure(e);
+                throw e;
+            }
+        }
+
+        @Override
+        protected void innerOnResponse(Void v) {
+            assert messageSize != -1 : "If onResponse is being called, the message should have been serialized";
+            transmittedBytesMetric.inc(messageSize);
+            closeAndCallback(() -> listener.onResponse(v));
+        }
+
+        @Override
+        protected void innerOnFailure(Exception e) {
+            logger.warn(() -> new ParameterizedMessage("send message failed [channel: {}]", channel), e);
+            closeAndCallback(() -> listener.onFailure(e));
+        }
+
+        private void closeAndCallback(Runnable runnable) {
+            Releasables.close(optionalReleasable, runnable::run);
+        }
+    }
+}

+ 155 - 0
server/src/main/java/org/elasticsearch/transport/OutboundMessage.java

@@ -0,0 +1,155 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.transport;
+
+import org.elasticsearch.Version;
+import org.elasticsearch.common.bytes.BytesArray;
+import org.elasticsearch.common.bytes.BytesReference;
+import org.elasticsearch.common.bytes.CompositeBytesReference;
+import org.elasticsearch.common.io.stream.BytesStreamOutput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.util.concurrent.ThreadContext;
+
+import java.io.IOException;
+import java.util.Set;
+
+abstract class OutboundMessage extends NetworkMessage implements Writeable {
+
+    private final Writeable message;
+
+    OutboundMessage(ThreadContext threadContext, Version version, byte status, long requestId, Writeable message) {
+        super(threadContext, version, status, requestId);
+        this.message = message;
+    }
+
+    BytesReference serialize(BytesStreamOutput bytesStream) throws IOException {
+        storedContext.restore();
+        bytesStream.setVersion(version);
+        bytesStream.skip(TcpHeader.HEADER_SIZE);
+
+        // The compressible bytes stream will not close the underlying bytes stream
+        BytesReference reference;
+        try (CompressibleBytesOutputStream stream = new CompressibleBytesOutputStream(bytesStream, TransportStatus.isCompress(status))) {
+            stream.setVersion(version);
+            threadContext.writeTo(stream);
+            writeTo(stream);
+            reference = writeMessage(stream);
+        }
+        bytesStream.seek(0);
+        TcpHeader.writeHeader(bytesStream, requestId, status, version, reference.length() - TcpHeader.HEADER_SIZE);
+        return reference;
+    }
+
+    private BytesReference writeMessage(CompressibleBytesOutputStream stream) throws IOException {
+        final BytesReference zeroCopyBuffer;
+        if (message instanceof BytesTransportRequest) {
+            BytesTransportRequest bRequest = (BytesTransportRequest) message;
+            bRequest.writeThin(stream);
+            zeroCopyBuffer = bRequest.bytes;
+        } else if (message instanceof RemoteTransportException) {
+            stream.writeException((RemoteTransportException) message);
+            zeroCopyBuffer = BytesArray.EMPTY;
+        } else {
+            message.writeTo(stream);
+            zeroCopyBuffer = BytesArray.EMPTY;
+        }
+        // we have to call materializeBytes() here before accessing the bytes. A CompressibleBytesOutputStream
+        // might be implementing compression. And materializeBytes() ensures that some marker bytes (EOS marker)
+        // are written. Otherwise we barf on the decompressing end when we read past EOF on purpose in the
+        // #validateRequest method. this might be a problem in deflate after all but it's important to write
+        // the marker bytes.
+        final BytesReference message = stream.materializeBytes();
+        if (zeroCopyBuffer.length() == 0) {
+            return message;
+        } else {
+            return new CompositeBytesReference(message, zeroCopyBuffer);
+        }
+    }
+
+    static class Request extends OutboundMessage {
+
+        private final String[] features;
+        private final String action;
+
+        Request(ThreadContext threadContext, String[] features, Writeable message, Version version, String action, long requestId,
+                boolean isHandshake, boolean compress) {
+            super(threadContext, version, setStatus(compress, isHandshake, message), requestId, message);
+            this.features = features;
+            this.action = action;
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            if (version.onOrAfter(Version.V_6_3_0)) {
+                out.writeStringArray(features);
+            }
+            out.writeString(action);
+        }
+
+        private static byte setStatus(boolean compress, boolean isHandshake, Writeable message) {
+            byte status = 0;
+            status = TransportStatus.setRequest(status);
+            if (compress && OutboundMessage.canCompress(message)) {
+                status = TransportStatus.setCompress(status);
+            }
+            if (isHandshake) {
+                status = TransportStatus.setHandshake(status);
+            }
+
+            return status;
+        }
+    }
+
+    static class Response extends OutboundMessage {
+
+        private final Set<String> features;
+
+        Response(ThreadContext threadContext, Set<String> features, Writeable message, Version version, long requestId,
+                 boolean isHandshake, boolean compress) {
+            super(threadContext, version, setStatus(compress, isHandshake, message), requestId, message);
+            this.features = features;
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            out.setFeatures(features);
+        }
+
+        private static byte setStatus(boolean compress, boolean isHandshake, Writeable message) {
+            byte status = 0;
+            status = TransportStatus.setResponse(status);
+            if (message instanceof RemoteTransportException) {
+                status = TransportStatus.setError(status);
+            }
+            if (compress) {
+                status = TransportStatus.setCompress(status);
+            }
+            if (isHandshake) {
+                status = TransportStatus.setHandshake(status);
+            }
+
+            return status;
+        }
+    }
+
+    private static boolean canCompress(Writeable message) {
+        return message instanceof BytesTransportRequest == false;
+    }
+}

+ 71 - 328
server/src/main/java/org/elasticsearch/transport/TcpTransport.java

@@ -26,24 +26,16 @@ import org.apache.logging.log4j.message.ParameterizedMessage;
 import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.Version;
 import org.elasticsearch.action.ActionListener;
-import org.elasticsearch.action.NotifyOnceListener;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.common.Booleans;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.breaker.CircuitBreaker;
 import org.elasticsearch.common.bytes.BytesArray;
 import org.elasticsearch.common.bytes.BytesReference;
-import org.elasticsearch.common.bytes.CompositeBytesReference;
 import org.elasticsearch.common.collect.MapBuilder;
 import org.elasticsearch.common.component.AbstractLifecycleComponent;
 import org.elasticsearch.common.component.Lifecycle;
-import org.elasticsearch.common.compress.Compressor;
-import org.elasticsearch.common.compress.CompressorFactory;
-import org.elasticsearch.common.compress.NotCompressedException;
-import org.elasticsearch.common.io.stream.BytesStreamOutput;
-import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput;
 import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
-import org.elasticsearch.common.io.stream.ReleasableBytesStreamOutput;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.lease.Releasable;
 import org.elasticsearch.common.metrics.MeanMetric;
@@ -64,17 +56,14 @@ import org.elasticsearch.common.util.concurrent.AbstractRunnable;
 import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
 import org.elasticsearch.common.util.concurrent.CountDown;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
-import org.elasticsearch.core.internal.io.IOUtils;
 import org.elasticsearch.indices.breaker.CircuitBreakerService;
 import org.elasticsearch.monitor.jvm.JvmInfo;
 import org.elasticsearch.node.Node;
 import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.threadpool.ThreadPool;
 
-import java.io.Closeable;
 import java.io.IOException;
 import java.io.StreamCorruptedException;
-import java.io.UncheckedIOException;
 import java.net.BindException;
 import java.net.InetAddress;
 import java.net.InetSocketAddress;
@@ -136,8 +125,6 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
     private final Map<String, List<TcpServerChannel>> serverChannels = newConcurrentMap();
     private final Set<TcpChannel> acceptedChannels = ConcurrentCollections.newConcurrentSet();
 
-    private final NamedWriteableRegistry namedWriteableRegistry;
-
     // this lock is here to make sure we close this transport and disconnect all the client nodes
     // connections while no connect operations is going on
     private final ReadWriteLock closeLock = new ReentrantReadWriteLock();
@@ -145,15 +132,16 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
     private final String transportName;
 
     private final MeanMetric readBytesMetric = new MeanMetric();
-    private final MeanMetric transmittedBytesMetric = new MeanMetric();
     private volatile Map<String, RequestHandlerRegistry<? extends TransportRequest>> requestHandlers = Collections.emptyMap();
     private final ResponseHandlers responseHandlers = new ResponseHandlers();
     private final TransportLogger transportLogger;
     private final TransportHandshaker handshaker;
     private final TransportKeepAlive keepAlive;
+    private final InboundMessage.Reader reader;
+    private final OutboundHandler outboundHandler;
     private final String nodeName;
 
-    public TcpTransport(String transportName, Settings settings,  Version version, ThreadPool threadPool,
+    public TcpTransport(String transportName, Settings settings, Version version, ThreadPool threadPool,
                         PageCacheRecycler pageCacheRecycler, CircuitBreakerService circuitBreakerService,
                         NamedWriteableRegistry namedWriteableRegistry, NetworkService networkService) {
         this.settings = settings;
@@ -163,17 +151,18 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
         this.bigArrays = new BigArrays(pageCacheRecycler, circuitBreakerService, CircuitBreaker.IN_FLIGHT_REQUESTS);
         this.pageCacheRecycler = pageCacheRecycler;
         this.circuitBreakerService = circuitBreakerService;
-        this.namedWriteableRegistry = namedWriteableRegistry;
         this.networkService = networkService;
         this.transportName = transportName;
         this.transportLogger = new TransportLogger();
+        this.outboundHandler = new OutboundHandler(threadPool, bigArrays, transportLogger);
         this.handshaker = new TransportHandshaker(version, threadPool,
             (node, channel, requestId, v) -> sendRequestToChannel(node, channel, requestId,
                 TransportHandshaker.HANDSHAKE_ACTION_NAME, new TransportHandshaker.HandshakeRequest(version),
-                TransportRequestOptions.EMPTY, v, false, TransportStatus.setHandshake((byte) 0)),
+                TransportRequestOptions.EMPTY, v, false, true),
             (v, features, channel, response, requestId) -> sendResponse(v, features, channel, response, requestId,
-                TransportHandshaker.HANDSHAKE_ACTION_NAME, false, TransportStatus.setHandshake((byte) 0)));
-        this.keepAlive = new TransportKeepAlive(threadPool, this::internalSendMessage);
+                TransportHandshaker.HANDSHAKE_ACTION_NAME, false, true));
+        this.keepAlive = new TransportKeepAlive(threadPool, this.outboundHandler::sendBytes);
+        this.reader = new InboundMessage.Reader(version, namedWriteableRegistry, threadPool.getThreadContext());
         this.nodeName = Node.NODE_NAME_SETTING.get(settings);
 
         final Settings defaultFeatures = TransportSettings.DEFAULT_FEATURES_SETTING.get(settings);
@@ -280,7 +269,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
                 throw new NodeNotConnectedException(node, "connection already closed");
             }
             TcpChannel channel = channel(options.type());
-            sendRequestToChannel(this.node, channel, requestId, action, request, options, getVersion(), compress, (byte) 0);
+            sendRequestToChannel(this.node, channel, requestId, action, request, options, getVersion(), compress);
         }
     }
 
@@ -573,7 +562,8 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
                 for (Map.Entry<String, List<TcpServerChannel>> entry : serverChannels.entrySet()) {
                     String profile = entry.getKey();
                     List<TcpServerChannel> channels = entry.getValue();
-                    ActionListener<Void> closeFailLogger = ActionListener.wrap(c -> {},
+                    ActionListener<Void> closeFailLogger = ActionListener.wrap(c -> {
+                        },
                         e -> logger.warn(() -> new ParameterizedMessage("Error closing serverChannel for profile [{}]", profile), e));
                     channels.forEach(c -> c.addCloseListener(closeFailLogger));
                     CloseableChannel.closeChannels(channels, true);
@@ -628,26 +618,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
             // in case we are able to return data, serialize the exception content and sent it back to the client
             if (channel.isOpen()) {
                 BytesArray message = new BytesArray(e.getMessage().getBytes(StandardCharsets.UTF_8));
-                ActionListener<Void> listener = new ActionListener<Void>() {
-                    @Override
-                    public void onResponse(Void aVoid) {
-                        CloseableChannel.closeChannel(channel);
-                    }
-
-                    @Override
-                    public void onFailure(Exception e) {
-                        logger.debug("failed to send message to httpOnTransport channel", e);
-                        CloseableChannel.closeChannel(channel);
-                    }
-                };
-                // We do not call internalSendMessage because we are not sending a message that is an
-                // elasticsearch binary message. We are just serializing an exception here. Not formatting it
-                // as an elasticsearch transport message.
-                try {
-                    channel.sendMessage(message, new SendListener(channel, message.length(), listener));
-                } catch (Exception ex) {
-                    listener.onFailure(ex);
-                }
+                outboundHandler.sendBytes(channel, message, ActionListener.wrap(() -> CloseableChannel.closeChannel(channel)));
             }
         } else {
             logger.warn(() -> new ParameterizedMessage("exception caught on transport layer [{}], closing connection", channel), e);
@@ -691,65 +662,21 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
      */
     protected abstract void stopInternal();
 
-    private boolean canCompress(TransportRequest request) {
-        return request instanceof BytesTransportRequest == false;
-    }
-
     private void sendRequestToChannel(final DiscoveryNode node, final TcpChannel channel, final long requestId, final String action,
                                       final TransportRequest request, TransportRequestOptions options, Version channelVersion,
-                                      boolean compressRequest, byte status) throws IOException, TransportException {
-
-        // only compress if asked and the request is not bytes. Otherwise only
-        // the header part is compressed, and the "body" can't be extracted as compressed
-        final boolean compressMessage = compressRequest && canCompress(request);
-
-        status = TransportStatus.setRequest(status);
-        ReleasableBytesStreamOutput bStream = new ReleasableBytesStreamOutput(bigArrays);
-        final CompressibleBytesOutputStream stream = new CompressibleBytesOutputStream(bStream, compressMessage);
-        boolean addedReleaseListener = false;
-        try {
-            if (compressMessage) {
-                status = TransportStatus.setCompress(status);
-            }
-
-            // we pick the smallest of the 2, to support both backward and forward compatibility
-            // note, this is the only place we need to do this, since from here on, we use the serialized version
-            // as the version to use also when the node receiving this request will send the response with
-            Version version = Version.min(this.version, channelVersion);
-
-            stream.setVersion(version);
-            threadPool.getThreadContext().writeTo(stream);
-            if (version.onOrAfter(Version.V_6_3_0)) {
-                stream.writeStringArray(features);
-            }
-            stream.writeString(action);
-            BytesReference message = buildMessage(requestId, status, node.getVersion(), request, stream);
-            final TransportRequestOptions finalOptions = options;
-            // this might be called in a different thread
-            ReleaseListener releaseListener = new ReleaseListener(stream,
-                () -> messageListener.onRequestSent(node, requestId, action, request, finalOptions));
-            internalSendMessage(channel, message, releaseListener);
-            addedReleaseListener = true;
-        } finally {
-            if (!addedReleaseListener) {
-                IOUtils.close(stream);
-            }
-        }
+                                      boolean compressRequest) throws IOException, TransportException {
+        sendRequestToChannel(node, channel, requestId, action, request, options, channelVersion, compressRequest, false);
     }
 
-    /**
-     * sends a message to the given channel, using the given callbacks.
-     */
-    private void internalSendMessage(TcpChannel channel, BytesReference message, ActionListener<Void> listener) {
-        channel.getChannelStats().markAccessed(threadPool.relativeTimeInMillis());
-        transportLogger.logOutboundMessage(channel, message);
-        try {
-            channel.sendMessage(message, new SendListener(channel, message.length(), listener));
-        } catch (Exception ex) {
-            // call listener to ensure that any resources are released
-            listener.onFailure(ex);
-            onException(channel, ex);
-        }
+    private void sendRequestToChannel(final DiscoveryNode node, final TcpChannel channel, final long requestId, final String action,
+                                      final TransportRequest request, TransportRequestOptions options, Version channelVersion,
+                                      boolean compressRequest, boolean isHandshake) throws IOException, TransportException {
+        Version version = Version.min(this.version, channelVersion);
+        OutboundMessage.Request message = new OutboundMessage.Request(threadPool.getThreadContext(), features, request, version, action,
+            requestId, isHandshake, compressRequest);
+        ActionListener<Void> listener = ActionListener.wrap(() ->
+            messageListener.onRequestSent(node, requestId, action, request, options));
+        outboundHandler.sendMessage(channel, message, listener);
     }
 
     /**
@@ -769,23 +696,13 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
         final Exception error,
         final long requestId,
         final String action) throws IOException {
-        try (BytesStreamOutput stream = new BytesStreamOutput()) {
-            stream.setVersion(nodeVersion);
-            stream.setFeatures(features);
-            RemoteTransportException tx = new RemoteTransportException(
-                nodeName, new TransportAddress(channel.getLocalAddress()), action, error);
-            threadPool.getThreadContext().writeTo(stream);
-            stream.writeException(tx);
-            byte status = 0;
-            status = TransportStatus.setResponse(status);
-            status = TransportStatus.setError(status);
-            final BytesReference bytes = stream.bytes();
-            final BytesReference header = buildHeader(requestId, status, nodeVersion, bytes.length());
-            CompositeBytesReference message = new CompositeBytesReference(header, bytes);
-            ReleaseListener releaseListener = new ReleaseListener(null,
-                () -> messageListener.onResponseSent(requestId, action, error));
-            internalSendMessage(channel, message, releaseListener);
-        }
+        Version version = Version.min(this.version, nodeVersion);
+        TransportAddress address = new TransportAddress(channel.getLocalAddress());
+        RemoteTransportException tx = new RemoteTransportException(nodeName, address, action, error);
+        OutboundMessage.Response message = new OutboundMessage.Response(threadPool.getThreadContext(), features, tx, version, requestId,
+            false, false);
+        ActionListener<Void> listener = ActionListener.wrap(() -> messageListener.onResponseSent(requestId, action, error));
+        outboundHandler.sendMessage(channel, message, listener);
     }
 
     /**
@@ -801,7 +718,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
         final long requestId,
         final String action,
         final boolean compress) throws IOException {
-        sendResponse(nodeVersion, features, channel, response, requestId, action, compress, (byte) 0);
+        sendResponse(nodeVersion, features, channel, response, requestId, action, compress, false);
     }
 
     private void sendResponse(
@@ -812,82 +729,18 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
         final long requestId,
         final String action,
         boolean compress,
-        byte status) throws IOException {
-
-        status = TransportStatus.setResponse(status);
-        ReleasableBytesStreamOutput bStream = new ReleasableBytesStreamOutput(bigArrays);
-        CompressibleBytesOutputStream stream = new CompressibleBytesOutputStream(bStream, compress);
-        boolean addedReleaseListener = false;
-        try {
-            if (compress) {
-                status = TransportStatus.setCompress(status);
-            }
-            threadPool.getThreadContext().writeTo(stream);
-            stream.setVersion(nodeVersion);
-            stream.setFeatures(features);
-            BytesReference message = buildMessage(requestId, status, nodeVersion, response, stream);
-
-            // this might be called in a different thread
-            ReleaseListener releaseListener = new ReleaseListener(stream,
-                () -> messageListener.onResponseSent(requestId, action, response));
-            internalSendMessage(channel, message, releaseListener);
-            addedReleaseListener = true;
-        } finally {
-            if (!addedReleaseListener) {
-                IOUtils.close(stream);
-            }
-        }
-    }
-
-    /**
-     * Writes the Tcp message header into a bytes reference.
-     *
-     * @param requestId       the request ID
-     * @param status          the request status
-     * @param protocolVersion the protocol version used to serialize the data in the message
-     * @param length          the payload length in bytes
-     * @see TcpHeader
-     */
-    private BytesReference buildHeader(long requestId, byte status, Version protocolVersion, int length) throws IOException {
-        try (BytesStreamOutput headerOutput = new BytesStreamOutput(TcpHeader.HEADER_SIZE)) {
-            headerOutput.setVersion(protocolVersion);
-            TcpHeader.writeHeader(headerOutput, requestId, status, protocolVersion, length);
-            final BytesReference bytes = headerOutput.bytes();
-            assert bytes.length() == TcpHeader.HEADER_SIZE : "header size mismatch expected: " + TcpHeader.HEADER_SIZE + " but was: "
-                + bytes.length();
-            return bytes;
-        }
-    }
-
-    /**
-     * Serializes the given message into a bytes representation
-     */
-    private BytesReference buildMessage(long requestId, byte status, Version nodeVersion, TransportMessage message,
-                                        CompressibleBytesOutputStream stream) throws IOException {
-        final BytesReference zeroCopyBuffer;
-        if (message instanceof BytesTransportRequest) { // what a shitty optimization - we should use a direct send method instead
-            BytesTransportRequest bRequest = (BytesTransportRequest) message;
-            assert nodeVersion.equals(bRequest.version());
-            bRequest.writeThin(stream);
-            zeroCopyBuffer = bRequest.bytes;
-        } else {
-            message.writeTo(stream);
-            zeroCopyBuffer = BytesArray.EMPTY;
-        }
-        // we have to call materializeBytes() here before accessing the bytes. A CompressibleBytesOutputStream
-        // might be implementing compression. And materializeBytes() ensures that some marker bytes (EOS marker)
-        // are written. Otherwise we barf on the decompressing end when we read past EOF on purpose in the
-        // #validateRequest method. this might be a problem in deflate after all but it's important to write
-        // the marker bytes.
-        final BytesReference messageBody = stream.materializeBytes();
-        final BytesReference header = buildHeader(requestId, status, stream.getVersion(), messageBody.length() + zeroCopyBuffer.length());
-        return new CompositeBytesReference(header, messageBody, zeroCopyBuffer);
+        boolean isHandshake) throws IOException {
+        Version version = Version.min(this.version, nodeVersion);
+        OutboundMessage.Response message = new OutboundMessage.Response(threadPool.getThreadContext(), features, response, version,
+            requestId, isHandshake, compress);
+        ActionListener<Void> listener = ActionListener.wrap(() -> messageListener.onResponseSent(requestId, action, response));
+        outboundHandler.sendMessage(channel, message, listener);
     }
 
     /**
      * Handles inbound message that has been decoded.
      *
-     * @param channel the channel the message if fomr
+     * @param channel the channel the message is from
      * @param message the message
      */
     public void inboundMessage(TcpChannel channel, BytesReference message) {
@@ -1055,53 +908,26 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
      * This method handles the message receive part for both request and responses
      */
     public final void messageReceived(BytesReference reference, TcpChannel channel) throws IOException {
-        String profileName = channel.getProfile();
+        readBytesMetric.inc(reference.length() + TcpHeader.MARKER_BYTES_SIZE + TcpHeader.MESSAGE_LENGTH_SIZE);
         InetSocketAddress remoteAddress = channel.getRemoteAddress();
-        int messageLengthBytes = reference.length();
-        final int totalMessageSize = messageLengthBytes + TcpHeader.MARKER_BYTES_SIZE + TcpHeader.MESSAGE_LENGTH_SIZE;
-        readBytesMetric.inc(totalMessageSize);
-        // we have additional bytes to read, outside of the header
-        boolean hasMessageBytesToRead = (totalMessageSize - TcpHeader.HEADER_SIZE) > 0;
-        StreamInput streamIn = reference.streamInput();
-        boolean success = false;
-        try (ThreadContext.StoredContext tCtx = threadPool.getThreadContext().stashContext()) {
-            long requestId = streamIn.readLong();
-            byte status = streamIn.readByte();
-            Version version = Version.fromId(streamIn.readInt());
-            if (TransportStatus.isCompress(status) && hasMessageBytesToRead && streamIn.available() > 0) {
-                Compressor compressor;
-                try {
-                    final int bytesConsumed = TcpHeader.REQUEST_ID_SIZE + TcpHeader.STATUS_SIZE + TcpHeader.VERSION_ID_SIZE;
-                    compressor = CompressorFactory.compressor(reference.slice(bytesConsumed, reference.length() - bytesConsumed));
-                } catch (NotCompressedException ex) {
-                    int maxToRead = Math.min(reference.length(), 10);
-                    StringBuilder sb = new StringBuilder("stream marked as compressed, but no compressor found, first [").append(maxToRead)
-                        .append("] content bytes out of [").append(reference.length())
-                        .append("] readable bytes with message size [").append(messageLengthBytes).append("] ").append("] are [");
-                    for (int i = 0; i < maxToRead; i++) {
-                        sb.append(reference.get(i)).append(",");
-                    }
-                    sb.append("]");
-                    throw new IllegalStateException(sb.toString());
-                }
-                streamIn = compressor.streamInput(streamIn);
-            }
-            final boolean isHandshake = TransportStatus.isHandshake(status);
-            ensureVersionCompatibility(version, this.version, isHandshake);
-            streamIn = new NamedWriteableAwareStreamInput(streamIn, namedWriteableRegistry);
-            streamIn.setVersion(version);
-            threadPool.getThreadContext().readHeaders(streamIn);
-            threadPool.getThreadContext().putTransient("_remote_address", remoteAddress);
-            if (TransportStatus.isRequest(status)) {
-                handleRequest(channel, profileName, streamIn, requestId, messageLengthBytes, version, remoteAddress, status);
+
+        ThreadContext threadContext = threadPool.getThreadContext();
+        try (ThreadContext.StoredContext existing = threadContext.stashContext();
+             InboundMessage message = reader.deserialize(reference)) {
+            // Place the context with the headers from the message
+            message.getStoredContext().restore();
+            threadContext.putTransient("_remote_address", remoteAddress);
+            if (message.isRequest()) {
+                handleRequest(channel, (InboundMessage.RequestMessage) message, reference.length());
             } else {
                 final TransportResponseHandler<?> handler;
-                if (isHandshake) {
+                long requestId = message.getRequestId();
+                if (message.isHandshake()) {
                     handler = handshaker.removeHandlerForHandshake(requestId);
                 } else {
                     TransportResponseHandler<? extends TransportResponse> theHandler =
                         responseHandlers.onResponseReceived(requestId, messageListener);
-                    if (theHandler == null && TransportStatus.isError(status)) {
+                    if (theHandler == null && message.isError()) {
                         handler = handshaker.removeHandlerForHandshake(requestId);
                     } else {
                         handler = theHandler;
@@ -1109,40 +935,20 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
                 }
                 // ignore if its null, the service logs it
                 if (handler != null) {
-                    if (TransportStatus.isError(status)) {
-                        handlerResponseError(streamIn, handler);
+                    if (message.isError()) {
+                        handlerResponseError(message.getStreamInput(), handler);
                     } else {
-                        handleResponse(remoteAddress, streamIn, handler);
+                        handleResponse(remoteAddress, message.getStreamInput(), handler);
                     }
                     // Check the entire message has been read
-                    final int nextByte = streamIn.read();
+                    final int nextByte = message.getStreamInput().read();
                     // calling read() is useful to make sure the message is fully read, even if there is an EOS marker
                     if (nextByte != -1) {
                         throw new IllegalStateException("Message not fully read (response) for requestId [" + requestId + "], handler ["
-                            + handler + "], error [" + TransportStatus.isError(status) + "]; resetting");
+                            + handler + "], error [" + message.isError() + "]; resetting");
                     }
                 }
             }
-            success = true;
-        } finally {
-            if (success) {
-                IOUtils.close(streamIn);
-            } else {
-                IOUtils.closeWhileHandlingException(streamIn);
-            }
-        }
-    }
-
-    static void ensureVersionCompatibility(Version version, Version currentVersion, boolean isHandshake) {
-        // for handshakes we are compatible with N-2 since otherwise we can't figure out our initial version
-        // since we are compatible with N-1 and N+1 so we always send our minCompatVersion as the initial version in the
-        // handshake. This looks odd but it's required to establish the connection correctly we check for real compatibility
-        // once the connection is established
-        final Version compatibilityVersion = isHandshake ? currentVersion.minimumCompatibilityVersion() : currentVersion;
-        if (version.isCompatible(compatibilityVersion) == false) {
-            final Version minCompatibilityVersion = isHandshake ? compatibilityVersion : compatibilityVersion.minimumCompatibilityVersion();
-            String msg = "Received " + (isHandshake ? "handshake " : "") + "message from unsupported version: [";
-            throw new IllegalStateException(msg + version + "] minimal compatible version is: [" + minCompatibilityVersion + "]");
         }
     }
 
@@ -1198,20 +1004,17 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
         });
     }
 
-    protected String handleRequest(TcpChannel channel, String profileName, final StreamInput stream, long requestId,
-                                   int messageLengthBytes, Version version, InetSocketAddress remoteAddress, byte status)
-        throws IOException {
-        final Set<String> features;
-        if (version.onOrAfter(Version.V_6_3_0)) {
-            features = Collections.unmodifiableSet(new TreeSet<>(Arrays.asList(stream.readStringArray())));
-        } else {
-            features = Collections.emptySet();
-        }
-        final String action = stream.readString();
+    protected void handleRequest(TcpChannel channel, InboundMessage.RequestMessage message, int messageLengthBytes) throws IOException {
+        final Set<String> features = message.getFeatures();
+        final String profileName = channel.getProfile();
+        final String action = message.getActionName();
+        final long requestId = message.getRequestId();
+        final StreamInput stream = message.getStreamInput();
+        final Version version = message.getVersion();
         messageListener.onRequestReceived(requestId, action);
         TransportChannel transportChannel = null;
         try {
-            if (TransportStatus.isHandshake(status)) {
+            if (message.isHandshake()) {
                 handshaker.handleHandshake(version, features, channel, requestId, stream);
             } else {
                 final RequestHandlerRegistry reg = getRequestHandler(action);
@@ -1224,9 +1027,9 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
                     getInFlightRequestBreaker().addWithoutBreaking(messageLengthBytes);
                 }
                 transportChannel = new TcpTransportChannel(this, channel, transportName, action, requestId, version, features, profileName,
-                    messageLengthBytes, TransportStatus.isCompress(status));
+                    messageLengthBytes, message.isCompress());
                 final TransportRequest request = reg.newRequest(stream);
-                request.remoteAddress(new TransportAddress(remoteAddress));
+                request.remoteAddress(new TransportAddress(channel.getRemoteAddress()));
                 // in case we throw an exception, i.e. when the limit is hit, we don't want to verify
                 validateRequest(stream, requestId, action);
                 threadPool.executor(reg.getExecutor()).execute(new RequestHandler(reg, request, transportChannel));
@@ -1235,7 +1038,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
             // the circuit breaker tripped
             if (transportChannel == null) {
                 transportChannel = new TcpTransportChannel(this, channel, transportName, action, requestId, version, features,
-                    profileName, 0, TransportStatus.isCompress(status));
+                    profileName, 0, message.isCompress());
             }
             try {
                 transportChannel.sendResponse(e);
@@ -1244,7 +1047,6 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
                 logger.warn(() -> new ParameterizedMessage("Failed to send error message back to client for action [{}]", action), inner);
             }
         }
-        return action;
     }
 
     // This template method is needed to inject custom error checking logic in tests.
@@ -1321,70 +1123,11 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
         }
     }
 
-    /**
-     * This listener increments the transmitted bytes metric on success.
-     */
-    private class SendListener extends NotifyOnceListener<Void> {
-
-        private final TcpChannel channel;
-        private final long messageSize;
-        private final ActionListener<Void> delegateListener;
-
-        private SendListener(TcpChannel channel, long messageSize, ActionListener<Void> delegateListener) {
-            this.channel = channel;
-            this.messageSize = messageSize;
-            this.delegateListener = delegateListener;
-        }
-
-        @Override
-        protected void innerOnResponse(Void v) {
-            transmittedBytesMetric.inc(messageSize);
-            delegateListener.onResponse(v);
-        }
-
-        @Override
-        protected void innerOnFailure(Exception e) {
-            logger.warn(() -> new ParameterizedMessage("send message failed [channel: {}]", channel), e);
-            delegateListener.onFailure(e);
-        }
-    }
-
-    private class ReleaseListener implements ActionListener<Void> {
-
-        private final Closeable optionalCloseable;
-        private final Runnable transportAdaptorCallback;
-
-        private ReleaseListener(Closeable optionalCloseable, Runnable transportAdaptorCallback) {
-            this.optionalCloseable = optionalCloseable;
-            this.transportAdaptorCallback = transportAdaptorCallback;
-        }
-
-        @Override
-        public void onResponse(Void aVoid) {
-            closeAndCallback(null);
-        }
-
-        @Override
-        public void onFailure(Exception e) {
-            closeAndCallback(e);
-        }
-
-        private void closeAndCallback(final Exception e) {
-            try {
-                IOUtils.close(optionalCloseable, transportAdaptorCallback::run);
-            } catch (final IOException inner) {
-                if (e != null) {
-                    inner.addSuppressed(e);
-                }
-                throw new UncheckedIOException(inner);
-            }
-        }
-    }
-
     @Override
     public final TransportStats getStats() {
-        return new TransportStats(acceptedChannels.size(), readBytesMetric.count(), readBytesMetric.sum(), transmittedBytesMetric.count(),
-            transmittedBytesMetric.sum());
+        MeanMetric transmittedBytes = outboundHandler.getTransmittedBytes();
+        return new TransportStats(acceptedChannels.size(), readBytesMetric.count(), readBytesMetric.sum(), transmittedBytes.count(),
+            transmittedBytes.sum());
     }
 
     /**
@@ -1559,7 +1302,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
         public void onTimeout() {
             if (countDown.fastForward()) {
                 CloseableChannel.closeChannels(channels, false);
-                listener.onFailure(new ConnectTransportException(node, "connect_timeout[" + connectionProfile.getConnectTimeout()  + "]"));
+                listener.onFailure(new ConnectTransportException(node, "connect_timeout[" + connectionProfile.getConnectTimeout() + "]"));
             }
         }
     }

+ 4 - 0
server/src/main/java/org/elasticsearch/transport/TransportLogger.java

@@ -51,6 +51,10 @@ public final class TransportLogger {
     void logOutboundMessage(TcpChannel channel, BytesReference message) {
         if (logger.isTraceEnabled()) {
             try {
+                if (message.get(0) != 'E') {
+                    // This is not an Elasticsearch transport message.
+                    return;
+                }
                 BytesReference withoutHeader = message.slice(HEADER_SIZE, message.length() - HEADER_SIZE);
                 String logMessage = format(channel, withoutHeader, "WRITE");
                 logger.trace(logMessage);

+ 5 - 2
server/src/test/java/org/elasticsearch/discovery/DiscoveryModuleTests.java

@@ -29,6 +29,7 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.network.NetworkService;
 import org.elasticsearch.common.settings.ClusterSettings;
 import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.util.concurrent.ThreadContext;
 import org.elasticsearch.core.internal.io.IOUtils;
 import org.elasticsearch.discovery.zen.UnicastHostsProvider;
 import org.elasticsearch.discovery.zen.ZenDiscovery;
@@ -53,6 +54,7 @@ import java.util.function.BiConsumer;
 import java.util.function.Supplier;
 
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
 
 public class DiscoveryModuleTests extends ESTestCase {
 
@@ -87,11 +89,12 @@ public class DiscoveryModuleTests extends ESTestCase {
 
     @Before
     public void setupDummyServices() {
-        transportService = MockTransportService.createNewService(Settings.EMPTY, Version.CURRENT, null, null);
+        threadPool = mock(ThreadPool.class);
+        when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
+        transportService = MockTransportService.createNewService(Settings.EMPTY, Version.CURRENT, threadPool, null);
         masterService = mock(MasterService.class);
         namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList());
         clusterApplier = mock(ClusterApplier.class);
-        threadPool = mock(ThreadPool.class);
         clusterSettings = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
         gatewayMetaState = mock(GatewayMetaState.class);
     }

+ 2 - 2
server/src/test/java/org/elasticsearch/discovery/zen/FileBasedUnicastHostsProviderTests.java

@@ -62,6 +62,7 @@ public class FileBasedUnicastHostsProviderTests extends ESTestCase {
         super.setUp();
         threadPool = new TestThreadPool(FileBasedUnicastHostsProviderTests.class.getName());
         executorService = Executors.newSingleThreadExecutor();
+        createTransportSvc();
     }
 
     @After
@@ -77,8 +78,7 @@ public class FileBasedUnicastHostsProviderTests extends ESTestCase {
         }
     }
 
-    @Before
-    public void createTransportSvc() {
+    private void createTransportSvc() {
         final MockNioTransport transport = new MockNioTransport(Settings.EMPTY, Version.CURRENT, threadPool,
             new NetworkService(Collections.emptyList()),
             PageCacheRecycler.NON_RECYCLING_INSTANCE,

+ 7 - 2
server/src/test/java/org/elasticsearch/transport/CompressibleBytesOutputStreamTests.java

@@ -39,6 +39,8 @@ public class CompressibleBytesOutputStreamTests extends ESTestCase {
         stream.write(expectedBytes);
 
         BytesReference bytesRef = stream.materializeBytes();
+        // Closing compression stream does not close underlying stream
+        stream.close();
 
         assertFalse(CompressorFactory.COMPRESSOR.isCompressed(bytesRef));
 
@@ -48,7 +50,8 @@ public class CompressibleBytesOutputStreamTests extends ESTestCase {
 
         assertEquals(-1, streamInput.read());
         assertArrayEquals(expectedBytes, actualBytes);
-        stream.close();
+
+        bStream.close();
 
         // The bytes should be zeroed out on close
         for (byte b : bytesRef.toBytesRef().bytes) {
@@ -64,6 +67,7 @@ public class CompressibleBytesOutputStreamTests extends ESTestCase {
         stream.write(expectedBytes);
 
         BytesReference bytesRef = stream.materializeBytes();
+        stream.close();
 
         assertTrue(CompressorFactory.COMPRESSOR.isCompressed(bytesRef));
 
@@ -73,7 +77,8 @@ public class CompressibleBytesOutputStreamTests extends ESTestCase {
 
         assertEquals(-1, streamInput.read());
         assertArrayEquals(expectedBytes, actualBytes);
-        stream.close();
+
+        bStream.close();
 
         // The bytes should be zeroed out on close
         for (byte b : bytesRef.toBytesRef().bytes) {

+ 228 - 0
server/src/test/java/org/elasticsearch/transport/InboundMessageTests.java

@@ -0,0 +1,228 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.transport;
+
+import org.elasticsearch.Version;
+import org.elasticsearch.common.bytes.BytesReference;
+import org.elasticsearch.common.io.stream.BytesStreamOutput;
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.util.concurrent.ThreadContext;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.test.VersionUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
+
+public class InboundMessageTests extends ESTestCase {
+
+    private final ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
+    private final NamedWriteableRegistry registry = new NamedWriteableRegistry(Collections.emptyList());
+
+    public void testReadRequest() throws IOException {
+        String[] features = {"feature1", "feature2"};
+        String value = randomAlphaOfLength(10);
+        Message message = new Message(value);
+        String action = randomAlphaOfLength(10);
+        long requestId = randomLong();
+        boolean isHandshake = randomBoolean();
+        boolean compress = randomBoolean();
+        threadContext.putHeader("header", "header_value");
+        Version version = randomFrom(Version.CURRENT, Version.CURRENT.minimumCompatibilityVersion());
+        OutboundMessage.Request request = new OutboundMessage.Request(threadContext, features, message, version, action, requestId,
+            isHandshake, compress);
+        BytesReference reference;
+        try (BytesStreamOutput streamOutput = new BytesStreamOutput()) {
+            reference = request.serialize(streamOutput);
+        }
+        // Check that the thread context is not deleted.
+        assertEquals("header_value", threadContext.getHeader("header"));
+
+        threadContext.stashContext();
+        threadContext.putHeader("header", "header_value2");
+
+        InboundMessage.Reader reader = new InboundMessage.Reader(version, registry, threadContext);
+        BytesReference sliced = reference.slice(6, reference.length() - 6);
+        InboundMessage.RequestMessage inboundMessage = (InboundMessage.RequestMessage) reader.deserialize(sliced);
+        // Check that deserialize does not overwrite current thread context.
+        assertEquals("header_value2", threadContext.getHeader("header"));
+        inboundMessage.getStoredContext().restore();
+        assertEquals("header_value", threadContext.getHeader("header"));
+        assertEquals(isHandshake, inboundMessage.isHandshake());
+        assertEquals(compress, inboundMessage.isCompress());
+        assertEquals(version, inboundMessage.getVersion());
+        assertEquals(action, inboundMessage.getActionName());
+        assertEquals(new HashSet<>(Arrays.asList(features)), inboundMessage.getFeatures());
+        assertTrue(inboundMessage.isRequest());
+        assertFalse(inboundMessage.isResponse());
+        assertFalse(inboundMessage.isError());
+        assertEquals(value, new Message(inboundMessage.getStreamInput()).value);
+    }
+
+    public void testReadResponse() throws IOException {
+        HashSet<String> features = new HashSet<>(Arrays.asList("feature1", "feature2"));
+        String value = randomAlphaOfLength(10);
+        Message message = new Message(value);
+        long requestId = randomLong();
+        boolean isHandshake = randomBoolean();
+        boolean compress = randomBoolean();
+        threadContext.putHeader("header", "header_value");
+        Version version = randomFrom(Version.CURRENT, Version.CURRENT.minimumCompatibilityVersion());
+        OutboundMessage.Response request = new OutboundMessage.Response(threadContext, features, message, version, requestId, isHandshake,
+            compress);
+        BytesReference reference;
+        try (BytesStreamOutput streamOutput = new BytesStreamOutput()) {
+            reference = request.serialize(streamOutput);
+        }
+        // Check that the thread context is not deleted.
+        assertEquals("header_value", threadContext.getHeader("header"));
+
+        threadContext.stashContext();
+        threadContext.putHeader("header", "header_value2");
+
+        InboundMessage.Reader reader = new InboundMessage.Reader(version, registry, threadContext);
+        BytesReference sliced = reference.slice(6, reference.length() - 6);
+        InboundMessage.ResponseMessage inboundMessage = (InboundMessage.ResponseMessage) reader.deserialize(sliced);
+        // Check that deserialize does not overwrite current thread context.
+        assertEquals("header_value2", threadContext.getHeader("header"));
+        inboundMessage.getStoredContext().restore();
+        assertEquals("header_value", threadContext.getHeader("header"));
+        assertEquals(isHandshake, inboundMessage.isHandshake());
+        assertEquals(compress, inboundMessage.isCompress());
+        assertEquals(version, inboundMessage.getVersion());
+        assertTrue(inboundMessage.isResponse());
+        assertFalse(inboundMessage.isRequest());
+        assertFalse(inboundMessage.isError());
+        assertEquals(value, new Message(inboundMessage.getStreamInput()).value);
+    }
+
+    public void testReadErrorResponse() throws IOException {
+        HashSet<String> features = new HashSet<>(Arrays.asList("feature1", "feature2"));
+        RemoteTransportException exception = new RemoteTransportException("error", new IOException());
+        long requestId = randomLong();
+        boolean isHandshake = randomBoolean();
+        boolean compress = randomBoolean();
+        threadContext.putHeader("header", "header_value");
+        Version version = randomFrom(Version.CURRENT, Version.CURRENT.minimumCompatibilityVersion());
+        OutboundMessage.Response request = new OutboundMessage.Response(threadContext, features, exception, version, requestId,
+            isHandshake, compress);
+        BytesReference reference;
+        try (BytesStreamOutput streamOutput = new BytesStreamOutput()) {
+            reference = request.serialize(streamOutput);
+        }
+        // Check that the thread context is not deleted.
+        assertEquals("header_value", threadContext.getHeader("header"));
+
+        threadContext.stashContext();
+        threadContext.putHeader("header", "header_value2");
+
+        InboundMessage.Reader reader = new InboundMessage.Reader(version, registry, threadContext);
+        BytesReference sliced = reference.slice(6, reference.length() - 6);
+        InboundMessage.ResponseMessage inboundMessage = (InboundMessage.ResponseMessage) reader.deserialize(sliced);
+        // Check that deserialize does not overwrite current thread context.
+        assertEquals("header_value2", threadContext.getHeader("header"));
+        inboundMessage.getStoredContext().restore();
+        assertEquals("header_value", threadContext.getHeader("header"));
+        assertEquals(isHandshake, inboundMessage.isHandshake());
+        assertEquals(compress, inboundMessage.isCompress());
+        assertEquals(version, inboundMessage.getVersion());
+        assertTrue(inboundMessage.isResponse());
+        assertFalse(inboundMessage.isRequest());
+        assertTrue(inboundMessage.isError());
+        assertEquals("[error]", inboundMessage.getStreamInput().readException().getMessage());
+    }
+
+    public void testEnsureVersionCompatibility() throws IOException {
+        testVersionIncompatibility(VersionUtils.randomVersionBetween(random(), Version.CURRENT.minimumCompatibilityVersion(),
+            Version.CURRENT), Version.CURRENT, randomBoolean());
+
+        final Version version = Version.fromString("7.0.0");
+        testVersionIncompatibility(Version.fromString("6.0.0"), version, true);
+        IllegalStateException ise = expectThrows(IllegalStateException.class, () ->
+            testVersionIncompatibility(Version.fromString("6.0.0"), version, false));
+        assertEquals("Received message from unsupported version: [6.0.0] minimal compatible version is: ["
+            + version.minimumCompatibilityVersion() + "]", ise.getMessage());
+
+        // For handshake we are compatible with N-2
+        testVersionIncompatibility(Version.fromString("5.6.0"), version, true);
+        ise = expectThrows(IllegalStateException.class, () ->
+            testVersionIncompatibility(Version.fromString("5.6.0"), version, false));
+        assertEquals("Received message from unsupported version: [5.6.0] minimal compatible version is: ["
+            + version.minimumCompatibilityVersion() + "]", ise.getMessage());
+
+        ise = expectThrows(IllegalStateException.class, () ->
+            testVersionIncompatibility(Version.fromString("2.3.0"), version, true));
+        assertEquals("Received handshake message from unsupported version: [2.3.0] minimal compatible version is: ["
+            + version.minimumCompatibilityVersion() + "]", ise.getMessage());
+
+        ise = expectThrows(IllegalStateException.class, () ->
+            testVersionIncompatibility(Version.fromString("2.3.0"), version, false));
+        assertEquals("Received message from unsupported version: [2.3.0] minimal compatible version is: ["
+            + version.minimumCompatibilityVersion() + "]", ise.getMessage());
+    }
+
+    private void testVersionIncompatibility(Version version, Version currentVersion, boolean isHandshake) throws IOException {
+        String[] features = {};
+        String value = randomAlphaOfLength(10);
+        Message message = new Message(value);
+        String action = randomAlphaOfLength(10);
+        long requestId = randomLong();
+        boolean compress = randomBoolean();
+        OutboundMessage.Request request = new OutboundMessage.Request(threadContext, features, message, version, action, requestId,
+            isHandshake, compress);
+        BytesReference reference;
+        try (BytesStreamOutput streamOutput = new BytesStreamOutput()) {
+            reference = request.serialize(streamOutput);
+        }
+
+        BytesReference sliced = reference.slice(6, reference.length() - 6);
+        InboundMessage.Reader reader = new InboundMessage.Reader(currentVersion, registry, threadContext);
+        reader.deserialize(sliced);
+    }
+
+    private static final class Message extends TransportMessage {
+
+        public String value;
+
+        private Message() {
+        }
+
+        private Message(StreamInput in) throws IOException {
+            readFrom(in);
+        }
+
+        private Message(String value) {
+            this.value = value;
+        }
+
+        @Override
+        public void readFrom(StreamInput in) throws IOException {
+            value = in.readString();
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            out.writeString(value);
+        }
+    }
+}

+ 184 - 0
server/src/test/java/org/elasticsearch/transport/OutboundHandlerTests.java

@@ -0,0 +1,184 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.transport;
+
+import org.elasticsearch.Version;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.common.bytes.BytesArray;
+import org.elasticsearch.common.bytes.BytesReference;
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.util.BigArrays;
+import org.elasticsearch.common.util.concurrent.ThreadContext;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.threadpool.TestThreadPool;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.junit.After;
+import org.junit.Before;
+
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
+
+public class OutboundHandlerTests extends ESTestCase {
+
+    private final TestThreadPool threadPool = new TestThreadPool(getClass().getName());
+    private final NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList());
+    private OutboundHandler handler;
+    private FakeTcpChannel fakeTcpChannel;
+
+    @Before
+    public void setUp() throws Exception {
+        super.setUp();
+        TransportLogger transportLogger = new TransportLogger();
+        fakeTcpChannel = new FakeTcpChannel(randomBoolean());
+        handler = new OutboundHandler(threadPool, BigArrays.NON_RECYCLING_INSTANCE, transportLogger);
+    }
+
+    @After
+    public void tearDown() throws Exception {
+        ThreadPool.terminate(threadPool, 10, TimeUnit.SECONDS);
+        super.tearDown();
+    }
+
+    public void testSendRawBytes() {
+        BytesArray bytesArray = new BytesArray("message".getBytes(StandardCharsets.UTF_8));
+
+        AtomicBoolean isSuccess = new AtomicBoolean(false);
+        AtomicReference<Exception> exception = new AtomicReference<>();
+        ActionListener<Void> listener = ActionListener.wrap((v) -> isSuccess.set(true), exception::set);
+        handler.sendBytes(fakeTcpChannel, bytesArray, listener);
+
+        BytesReference reference = fakeTcpChannel.getMessageCaptor().get();
+        ActionListener<Void> sendListener  = fakeTcpChannel.getListenerCaptor().get();
+        if (randomBoolean()) {
+            sendListener.onResponse(null);
+            assertTrue(isSuccess.get());
+            assertNull(exception.get());
+        } else {
+            IOException e = new IOException("failed");
+            sendListener.onFailure(e);
+            assertFalse(isSuccess.get());
+            assertSame(e, exception.get());
+        }
+
+        assertEquals(bytesArray, reference);
+    }
+
+    public void testSendMessage() throws IOException {
+        OutboundMessage message;
+        ThreadContext threadContext = threadPool.getThreadContext();
+        Version version = Version.CURRENT;
+        String actionName = "handshake";
+        long requestId = randomLongBetween(0, 300);
+        boolean isHandshake = randomBoolean();
+        boolean compress = randomBoolean();
+        String value = "message";
+        threadContext.putHeader("header", "header_value");
+        Writeable writeable = new Message(value);
+
+        boolean isRequest = randomBoolean();
+        if (isRequest) {
+            message = new OutboundMessage.Request(threadContext, new String[0], writeable, version, actionName, requestId, isHandshake,
+                compress);
+        } else {
+            message = new OutboundMessage.Response(threadContext, new HashSet<>(), writeable, version, requestId, isHandshake, compress);
+        }
+
+        AtomicBoolean isSuccess = new AtomicBoolean(false);
+        AtomicReference<Exception> exception = new AtomicReference<>();
+        ActionListener<Void> listener = ActionListener.wrap((v) -> isSuccess.set(true), exception::set);
+        handler.sendMessage(fakeTcpChannel, message, listener);
+
+        BytesReference reference = fakeTcpChannel.getMessageCaptor().get();
+        ActionListener<Void> sendListener  = fakeTcpChannel.getListenerCaptor().get();
+        if (randomBoolean()) {
+            sendListener.onResponse(null);
+            assertTrue(isSuccess.get());
+            assertNull(exception.get());
+        } else {
+            IOException e = new IOException("failed");
+            sendListener.onFailure(e);
+            assertFalse(isSuccess.get());
+            assertSame(e, exception.get());
+        }
+
+        InboundMessage.Reader reader = new InboundMessage.Reader(Version.CURRENT, namedWriteableRegistry, threadPool.getThreadContext());
+        try (InboundMessage inboundMessage = reader.deserialize(reference.slice(6, reference.length() - 6))) {
+            assertEquals(version, inboundMessage.getVersion());
+            assertEquals(requestId, inboundMessage.getRequestId());
+            if (isRequest) {
+                assertTrue(inboundMessage.isRequest());
+                assertFalse(inboundMessage.isResponse());
+            } else {
+                assertTrue(inboundMessage.isResponse());
+                assertFalse(inboundMessage.isRequest());
+            }
+            if (isHandshake) {
+                assertTrue(inboundMessage.isHandshake());
+            } else {
+                assertFalse(inboundMessage.isHandshake());
+            }
+            if (compress) {
+                assertTrue(inboundMessage.isCompress());
+            } else {
+                assertFalse(inboundMessage.isCompress());
+            }
+            Message readMessage = new Message();
+            readMessage.readFrom(inboundMessage.getStreamInput());
+            assertEquals(value, readMessage.value);
+
+            try (ThreadContext.StoredContext existing = threadContext.stashContext()) {
+                ThreadContext.StoredContext storedContext = inboundMessage.getStoredContext();
+                assertNull(threadContext.getHeader("header"));
+                storedContext.restore();
+                assertEquals("header_value", threadContext.getHeader("header"));
+            }
+        }
+    }
+
+    private static final class Message extends TransportMessage {
+
+        public String value;
+
+        private Message() {
+        }
+
+        private Message(String value) {
+            this.value = value;
+        }
+
+        @Override
+        public void readFrom(StreamInput in) throws IOException {
+            value = in.readString();
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            out.writeString(value);
+        }
+    }
+}

+ 5 - 35
server/src/test/java/org/elasticsearch/transport/TcpTransportTests.java

@@ -37,7 +37,6 @@ import org.elasticsearch.common.util.PageCacheRecycler;
 import org.elasticsearch.indices.breaker.NoneCircuitBreakerService;
 import org.elasticsearch.tasks.TaskManager;
 import org.elasticsearch.test.ESTestCase;
-import org.elasticsearch.test.VersionUtils;
 import org.elasticsearch.threadpool.TestThreadPool;
 import org.elasticsearch.threadpool.ThreadPool;
 
@@ -158,41 +157,12 @@ public class TcpTransportTests extends ESTestCase {
         assertEquals(102, addresses[2].getPort());
     }
 
-    public void testEnsureVersionCompatibility() {
-        TcpTransport.ensureVersionCompatibility(VersionUtils.randomVersionBetween(random(), Version.CURRENT.minimumCompatibilityVersion(),
-            Version.CURRENT), Version.CURRENT, randomBoolean());
-
-        final Version version = Version.fromString("7.0.0");
-        TcpTransport.ensureVersionCompatibility(Version.fromString("6.0.0"), version, true);
-        IllegalStateException ise = expectThrows(IllegalStateException.class, () ->
-            TcpTransport.ensureVersionCompatibility(Version.fromString("6.0.0"), version, false));
-        assertEquals("Received message from unsupported version: [6.0.0] minimal compatible version is: ["
-            + version.minimumCompatibilityVersion() + "]", ise.getMessage());
-
-        // For handshake we are compatible with N-2
-        TcpTransport.ensureVersionCompatibility(Version.fromString("5.6.0"), version, true);
-        ise = expectThrows(IllegalStateException.class, () ->
-            TcpTransport.ensureVersionCompatibility(Version.fromString("5.6.0"), version, false));
-        assertEquals("Received message from unsupported version: [5.6.0] minimal compatible version is: ["
-                + version.minimumCompatibilityVersion() + "]", ise.getMessage());
-
-        ise = expectThrows(IllegalStateException.class, () ->
-            TcpTransport.ensureVersionCompatibility(Version.fromString("2.3.0"), version, true));
-        assertEquals("Received handshake message from unsupported version: [2.3.0] minimal compatible version is: ["
-            + version.minimumCompatibilityVersion() + "]", ise.getMessage());
-
-        ise = expectThrows(IllegalStateException.class, () ->
-            TcpTransport.ensureVersionCompatibility(Version.fromString("2.3.0"), version, false));
-        assertEquals("Received message from unsupported version: [2.3.0] minimal compatible version is: ["
-            + version.minimumCompatibilityVersion() + "]", ise.getMessage());
-    }
-
     @SuppressForbidden(reason = "Allow accessing localhost")
     public void testCompressRequestAndResponse() throws IOException {
         final boolean compressed = randomBoolean();
         Req request = new Req(randomRealisticUnicodeOfLengthBetween(10, 100));
         ThreadPool threadPool = new TestThreadPool(TcpTransportTests.class.getName());
-        AtomicReference<BytesReference> requestCaptor = new AtomicReference<>();
+        AtomicReference<BytesReference> messageCaptor = new AtomicReference<>();
         try {
             TcpTransport transport = new TcpTransport("test", Settings.EMPTY, Version.CURRENT, threadPool,
                 PageCacheRecycler.NON_RECYCLING_INSTANCE, new NoneCircuitBreakerService(), null, null) {
@@ -204,7 +174,7 @@ public class TcpTransportTests extends ESTestCase {
 
                 @Override
                 protected FakeTcpChannel initiateChannel(DiscoveryNode node) throws IOException {
-                    return new FakeTcpChannel(false, requestCaptor);
+                    return new FakeTcpChannel(false);
                 }
 
                 @Override
@@ -219,7 +189,7 @@ public class TcpTransportTests extends ESTestCase {
                     int numConnections = profile.getNumConnections();
                     ArrayList<TcpChannel> fakeChannels = new ArrayList<>(numConnections);
                     for (int i = 0; i < numConnections; ++i) {
-                        fakeChannels.add(new FakeTcpChannel(false, requestCaptor));
+                        fakeChannels.add(new FakeTcpChannel(false, messageCaptor));
                     }
                     listener.onResponse(new NodeChannels(node, fakeChannels, profile, Version.CURRENT));
                     return () -> CloseableChannel.closeChannels(fakeChannels, false);
@@ -241,12 +211,12 @@ public class TcpTransportTests extends ESTestCase {
                 (request1, channel, task) -> channel.sendResponse(TransportResponse.Empty.INSTANCE), ThreadPool.Names.SAME,
                 true, true));
 
-            BytesReference reference = requestCaptor.get();
+            BytesReference reference = messageCaptor.get();
             assertNotNull(reference);
 
             AtomicReference<BytesReference> responseCaptor = new AtomicReference<>();
             InetSocketAddress address = new InetSocketAddress(InetAddress.getLocalHost(), 0);
-            FakeTcpChannel responseChannel = new FakeTcpChannel(true, address, address, responseCaptor);
+            FakeTcpChannel responseChannel = new FakeTcpChannel(true, address, address, "profile", responseCaptor);
             transport.messageReceived(reference.slice(6, reference.length() - 6), responseChannel);
 
 

+ 7 - 4
test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java

@@ -2038,11 +2038,14 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase {
             new NetworkService(Collections.emptyList()), PageCacheRecycler.NON_RECYCLING_INSTANCE, namedWriteableRegistry,
             new NoneCircuitBreakerService()) {
             @Override
-            protected String handleRequest(TcpChannel mockChannel, String profileName, StreamInput stream, long requestId,
-                                           int messageLengthBytes, Version version, InetSocketAddress remoteAddress, byte status)
+            protected void handleRequest(TcpChannel channel, InboundMessage.RequestMessage request, int messageLengthBytes)
                 throws IOException {
-                return super.handleRequest(mockChannel, profileName, stream, requestId, messageLengthBytes, version, remoteAddress,
-                    (byte) (status & ~(1 << 3))); // we flip the isHandshake bit back and act like the handler is not found
+                // we flip the isHandshake bit back and act like the handler is not found
+                byte status = (byte) (request.status & ~(1 << 3));
+                Version version = request.getVersion();
+                InboundMessage.RequestMessage nonHandshakeRequest = new InboundMessage.RequestMessage(request.threadContext, version,
+                    status, request.getRequestId(), request.getActionName(), request.getFeatures(), request.getStreamInput());
+                super.handleRequest(channel, nonHandshakeRequest, messageLengthBytes);
             }
         };
 

+ 12 - 6
test/framework/src/main/java/org/elasticsearch/transport/FakeTcpChannel.java

@@ -31,9 +31,10 @@ public class FakeTcpChannel implements TcpChannel {
     private final InetSocketAddress localAddress;
     private final InetSocketAddress remoteAddress;
     private final String profile;
-    private final AtomicReference<BytesReference> messageCaptor;
     private final ChannelStats stats = new ChannelStats();
     private final CompletableContext<Void> closeContext = new CompletableContext<>();
+    private final AtomicReference<BytesReference> messageCaptor;
+    private final AtomicReference<ActionListener<Void>> listenerCaptor;
 
     public FakeTcpChannel() {
         this(false, "profile", new AtomicReference<>());
@@ -47,11 +48,6 @@ public class FakeTcpChannel implements TcpChannel {
         this(isServer, "profile", messageCaptor);
     }
 
-    public FakeTcpChannel(boolean isServer, InetSocketAddress localAddress, InetSocketAddress remoteAddress,
-                          AtomicReference<BytesReference> messageCaptor) {
-        this(isServer, localAddress, remoteAddress,"profile", messageCaptor);
-    }
-
 
     public FakeTcpChannel(boolean isServer, String profile, AtomicReference<BytesReference> messageCaptor) {
         this(isServer, null, null, profile, messageCaptor);
@@ -64,6 +60,7 @@ public class FakeTcpChannel implements TcpChannel {
         this.remoteAddress = remoteAddress;
         this.profile = profile;
         this.messageCaptor = messageCaptor;
+        this.listenerCaptor = new AtomicReference<>();
     }
 
     @Override
@@ -89,6 +86,7 @@ public class FakeTcpChannel implements TcpChannel {
     @Override
     public void sendMessage(BytesReference reference, ActionListener<Void> listener) {
         messageCaptor.set(reference);
+        listenerCaptor.set(listener);
     }
 
     @Override
@@ -115,4 +113,12 @@ public class FakeTcpChannel implements TcpChannel {
     public ChannelStats getChannelStats() {
         return stats;
     }
+
+    public AtomicReference<BytesReference> getMessageCaptor() {
+        return messageCaptor;
+    }
+
+    public AtomicReference<ActionListener<Void>> getListenerCaptor() {
+        return listenerCaptor;
+    }
 }