瀏覽代碼

Use TransportChannel in TransportHandshaker (#54684)

Currently the TransportHandshaker has a specialized codepath for sending
a response. In other work, we are going to start having handshakes
contribute to circuit breaking (while not being breakable). This commit
moves in that direction by allowing the handshaker to responding using a
standard TcpTransportChannel similar to other requests.
Tim Brooks 5 年之前
父節點
當前提交
2c5951ae1c

+ 6 - 3
server/src/main/java/org/elasticsearch/transport/InboundHandler.java

@@ -155,7 +155,10 @@ public class InboundHandler {
         try {
             messageListener.onRequestReceived(requestId, action);
             if (header.isHandshake()) {
-                handshaker.handleHandshake(version, channel, requestId, stream);
+                // Handshakes are not currently circuit broken
+                transportChannel = new TcpTransportChannel(outboundHandler, channel, action, requestId, version,
+                    circuitBreakerService, 0, header.isCompressed(), header.isHandshake());
+                handshaker.handleHandshake(transportChannel, requestId, stream);
             } else {
                 final RequestHandlerRegistry<T> reg = getRequestHandler(action);
                 if (reg == null) {
@@ -168,7 +171,7 @@ public class InboundHandler {
                     breaker.addWithoutBreaking(messageLengthBytes);
                 }
                 transportChannel = new TcpTransportChannel(outboundHandler, channel, action, requestId, version,
-                    circuitBreakerService, messageLengthBytes, header.isCompressed());
+                    circuitBreakerService, messageLengthBytes, header.isCompressed(), header.isHandshake());
                 final T request = reg.newRequest(stream);
                 request.remoteAddress(new TransportAddress(channel.getRemoteAddress()));
                 // in case we throw an exception, i.e. when the limit is hit, we don't want to verify
@@ -184,7 +187,7 @@ public class InboundHandler {
             // the circuit breaker tripped
             if (transportChannel == null) {
                 transportChannel = new TcpTransportChannel(outboundHandler, channel, action, requestId, version,
-                    circuitBreakerService, 0, header.isCompressed());
+                    circuitBreakerService, 0, header.isCompressed(), header.isHandshake());
             }
             try {
                 transportChannel.sendResponse(e);

+ 1 - 3
server/src/main/java/org/elasticsearch/transport/TcpTransport.java

@@ -144,9 +144,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
         this.handshaker = new TransportHandshaker(version, threadPool,
             (node, channel, requestId, v) -> outboundHandler.sendRequest(node, channel, requestId,
                 TransportHandshaker.HANDSHAKE_ACTION_NAME, new TransportHandshaker.HandshakeRequest(version),
-                TransportRequestOptions.EMPTY, v, false, true),
-            (v, channel, response, requestId) -> outboundHandler.sendResponse(v, channel, requestId,
-                TransportHandshaker.HANDSHAKE_ACTION_NAME, response, false, true));
+                TransportRequestOptions.EMPTY, v, false, true));
         this.keepAlive = new TransportKeepAlive(threadPool, this.outboundHandler::sendBytes);
         this.inboundHandler = new InboundHandler(threadPool, outboundHandler, namedWriteableRegistry, circuitBreakerService, handshaker,
             keepAlive);

+ 4 - 3
server/src/main/java/org/elasticsearch/transport/TcpTransportChannel.java

@@ -37,9 +37,10 @@ public final class TcpTransportChannel implements TransportChannel {
     private final CircuitBreakerService breakerService;
     private final long reservedBytes;
     private final boolean compressResponse;
+    private final boolean isHandshake;
 
     TcpTransportChannel(OutboundHandler outboundHandler, TcpChannel channel, String action, long requestId, Version version,
-                        CircuitBreakerService breakerService, long reservedBytes, boolean compressResponse) {
+                        CircuitBreakerService breakerService, long reservedBytes, boolean compressResponse, boolean isHandshake) {
         this.version = version;
         this.channel = channel;
         this.outboundHandler = outboundHandler;
@@ -48,6 +49,7 @@ public final class TcpTransportChannel implements TransportChannel {
         this.breakerService = breakerService;
         this.reservedBytes = reservedBytes;
         this.compressResponse = compressResponse;
+        this.isHandshake = isHandshake;
     }
 
     @Override
@@ -58,7 +60,7 @@ public final class TcpTransportChannel implements TransportChannel {
     @Override
     public void sendResponse(TransportResponse response) throws IOException {
         try {
-            outboundHandler.sendResponse(version, channel, requestId, action, response, compressResponse, false);
+            outboundHandler.sendResponse(version, channel, requestId, action, response, compressResponse, isHandshake);
         } finally {
             release(false);
         }
@@ -99,6 +101,5 @@ public final class TcpTransportChannel implements TransportChannel {
     public TcpChannel getChannel() {
         return channel;
     }
-
 }
 

+ 3 - 15
server/src/main/java/org/elasticsearch/transport/TransportHandshaker.java

@@ -49,14 +49,11 @@ final class TransportHandshaker {
     private final Version version;
     private final ThreadPool threadPool;
     private final HandshakeRequestSender handshakeRequestSender;
-    private final HandshakeResponseSender handshakeResponseSender;
 
-    TransportHandshaker(Version version, ThreadPool threadPool, HandshakeRequestSender handshakeRequestSender,
-                        HandshakeResponseSender handshakeResponseSender) {
+    TransportHandshaker(Version version, ThreadPool threadPool, HandshakeRequestSender handshakeRequestSender) {
         this.version = version;
         this.threadPool = threadPool;
         this.handshakeRequestSender = handshakeRequestSender;
-        this.handshakeResponseSender = handshakeResponseSender;
     }
 
     void sendHandshake(long requestId, DiscoveryNode node, TcpChannel channel, TimeValue timeout, ActionListener<Version> listener) {
@@ -88,7 +85,7 @@ final class TransportHandshaker {
         }
     }
 
-    void handleHandshake(Version version, TcpChannel channel, long requestId, StreamInput stream) throws IOException {
+    void handleHandshake(TransportChannel channel, long requestId, StreamInput stream) throws IOException {
         // Must read the handshake request to exhaust the stream
         HandshakeRequest handshakeRequest = new HandshakeRequest(stream);
         final int nextByte = stream.read();
@@ -96,8 +93,7 @@ final class TransportHandshaker {
             throw new IllegalStateException("Handshake request not fully read for requestId [" + requestId + "], action ["
                 + TransportHandshaker.HANDSHAKE_ACTION_NAME + "], available [" + stream.available() + "]; resetting");
         }
-        HandshakeResponse response = new HandshakeResponse(this.version);
-        handshakeResponseSender.sendResponse(version, channel, response, requestId);
+        channel.sendResponse(new HandshakeResponse(this.version));
     }
 
     TransportResponseHandler<HandshakeResponse> removeHandlerForHandshake(long requestId) {
@@ -228,12 +224,4 @@ final class TransportHandshaker {
 
         void sendRequest(DiscoveryNode node, TcpChannel channel, long requestId, Version version) throws IOException;
     }
-
-    @FunctionalInterface
-    interface HandshakeResponseSender {
-
-        void sendResponse(Version version, TcpChannel channel, TransportResponse response, long requestId) throws IOException;
-
-    }
-
 }

+ 4 - 31
server/src/test/java/org/elasticsearch/action/support/broadcast/node/TransportBroadcastByNodeActionTests.java

@@ -57,7 +57,7 @@ import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.transport.CapturingTransport;
 import org.elasticsearch.threadpool.TestThreadPool;
 import org.elasticsearch.threadpool.ThreadPool;
-import org.elasticsearch.transport.TransportChannel;
+import org.elasticsearch.transport.TestTransportChannel;
 import org.elasticsearch.transport.TransportResponse;
 import org.elasticsearch.transport.TransportService;
 import org.junit.After;
@@ -366,14 +366,15 @@ public class TransportBroadcastByNodeActionTests extends ESTestCase {
         final TransportBroadcastByNodeAction.BroadcastByNodeTransportRequestHandler handler =
                 action.new BroadcastByNodeTransportRequestHandler();
 
-        TestTransportChannel channel = new TestTransportChannel();
+        final PlainActionFuture<TransportResponse> future = PlainActionFuture.newFuture();
+        TestTransportChannel channel = new TestTransportChannel(future);
 
         handler.messageReceived(action.new NodeRequest(nodeId, new Request(), new ArrayList<>(shards)), channel, null);
 
         // check the operation was executed only on the expected shards
         assertEquals(shards, action.getResults().keySet());
 
-        TransportResponse response = channel.getCapturedResponse();
+        TransportResponse response = future.actionGet();
         assertTrue(response instanceof TransportBroadcastByNodeAction.NodeResponse);
         TransportBroadcastByNodeAction.NodeResponse nodeResponse = (TransportBroadcastByNodeAction.NodeResponse) response;
 
@@ -469,32 +470,4 @@ public class TransportBroadcastByNodeActionTests extends ESTestCase {
         assertEquals("failed shards", totalFailedShards, response.getFailedShards());
         assertEquals("accumulated exceptions", totalFailedShards, response.getShardFailures().length);
     }
-
-    public class TestTransportChannel implements TransportChannel {
-        private TransportResponse capturedResponse;
-
-        public TransportResponse getCapturedResponse() {
-            return capturedResponse;
-        }
-
-        @Override
-        public String getProfileName() {
-            return "";
-        }
-
-        @Override
-        public void sendResponse(TransportResponse response) throws IOException {
-            capturedResponse = response;
-        }
-
-        @Override
-        public void sendResponse(Exception exception) throws IOException {
-        }
-
-        @Override
-        public String getChannelType() {
-            return "test";
-        }
-
-    }
 }

+ 8 - 28
server/src/test/java/org/elasticsearch/action/support/replication/TransportReplicationActionTests.java

@@ -78,6 +78,7 @@ import org.elasticsearch.test.transport.CapturingTransport;
 import org.elasticsearch.test.transport.MockTransportService;
 import org.elasticsearch.threadpool.TestThreadPool;
 import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.transport.TestTransportChannel;
 import org.elasticsearch.transport.Transport;
 import org.elasticsearch.transport.TransportChannel;
 import org.elasticsearch.transport.TransportException;
@@ -817,7 +818,7 @@ public class TransportReplicationActionTests extends ESTestCase {
         Request request = new Request(shardId);
         TransportReplicationAction.ConcreteShardRequest<Request> concreteShardRequest =
             new TransportReplicationAction.ConcreteShardRequest<>(request, routingEntry.allocationId().getId(), primaryTerm);
-        PlainActionFuture<TestResponse> listener = new PlainActionFuture<>();
+        PlainActionFuture<TransportResponse> listener = new PlainActionFuture<>();
 
 
         final IndexShard shard = mockIndexShard(shardId, clusterService);
@@ -981,7 +982,7 @@ public class TransportReplicationActionTests extends ESTestCase {
         setState(clusterService, state(index, true, ShardRoutingState.STARTED));
         final ShardRouting primary = clusterService.state().routingTable().shardRoutingTable(shardId).primaryShard();
         final long primaryTerm = clusterService.state().metadata().index(shardId.getIndexName()).primaryTerm(shardId.id());
-        PlainActionFuture<TestResponse> listener = new PlainActionFuture<>();
+        PlainActionFuture<TransportResponse> listener = new PlainActionFuture<>();
         final boolean wrongAllocationId = randomBoolean();
         final long requestTerm = wrongAllocationId && randomBoolean() ? primaryTerm : primaryTerm + randomIntBetween(1, 10);
         Request request = new Request(shardId).timeout("1ms");
@@ -1018,7 +1019,7 @@ public class TransportReplicationActionTests extends ESTestCase {
         state = ClusterState.builder(state).nodes(DiscoveryNodes.builder(state.nodes()).localNodeId(replica.currentNodeId())).build();
         setState(clusterService, state);
 
-        PlainActionFuture<TestResponse> listener = new PlainActionFuture<>();
+        PlainActionFuture<TransportResponse> listener = new PlainActionFuture<>();
         Request request = new Request(shardId).timeout("1ms");
         action.handleReplicaRequest(
             new TransportReplicationAction.ConcreteReplicaRequest<>(request, "_not_a_valid_aid_", randomNonNegativeLong(),
@@ -1062,7 +1063,7 @@ public class TransportReplicationActionTests extends ESTestCase {
                 return new ReplicaResult();
             }
         };
-        final PlainActionFuture<TestResponse> listener = new PlainActionFuture<>();
+        final PlainActionFuture<TransportResponse> listener = new PlainActionFuture<>();
         final Request request = new Request(shardId);
         final long checkpoint = randomNonNegativeLong();
         final long maxSeqNoOfUpdatesOrDeletes = randomNonNegativeLong();
@@ -1130,7 +1131,7 @@ public class TransportReplicationActionTests extends ESTestCase {
                 return new ReplicaResult();
             }
         };
-        final PlainActionFuture<TestResponse> listener = new PlainActionFuture<>();
+        final PlainActionFuture<TransportResponse> listener = new PlainActionFuture<>();
         final Request request = new Request(shardId);
         final long checkpoint = randomNonNegativeLong();
         final long maxSeqNoOfUpdates = randomNonNegativeLong();
@@ -1371,29 +1372,8 @@ public class TransportReplicationActionTests extends ESTestCase {
     /**
      * Transport channel that is needed for replica operation testing.
      */
-    public TransportChannel createTransportChannel(final PlainActionFuture<TestResponse> listener) {
-        return new TransportChannel() {
-
-            @Override
-            public String getProfileName() {
-                return "";
-            }
-
-            @Override
-            public void sendResponse(TransportResponse response) {
-                listener.onResponse(((TestResponse) response));
-            }
-
-            @Override
-            public void sendResponse(Exception exception) {
-                listener.onFailure(exception);
-            }
-
-            @Override
-            public String getChannelType() {
-                return "replica_test";
-            }
-        };
+    public TransportChannel createTransportChannel(final PlainActionFuture<TransportResponse> listener) {
+        return new TestTransportChannel(listener);
     }
 
 }

+ 17 - 41
server/src/test/java/org/elasticsearch/cluster/coordination/NodeJoinTests.java

@@ -20,6 +20,7 @@ package org.elasticsearch.cluster.coordination;
 
 import org.apache.logging.log4j.message.ParameterizedMessage;
 import org.elasticsearch.Version;
+import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.cluster.ClusterName;
 import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.cluster.ESAllocationTestCase;
@@ -44,8 +45,8 @@ import org.elasticsearch.test.transport.CapturingTransport;
 import org.elasticsearch.threadpool.TestThreadPool;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.RequestHandlerRegistry;
+import org.elasticsearch.transport.TestTransportChannel;
 import org.elasticsearch.transport.Transport;
-import org.elasticsearch.transport.TransportChannel;
 import org.elasticsearch.transport.TransportRequest;
 import org.elasticsearch.transport.TransportResponse;
 import org.elasticsearch.transport.TransportService;
@@ -229,29 +230,22 @@ public class NodeJoinTests extends ESTestCase {
         try {
             final RequestHandlerRegistry<JoinRequest> joinHandler = (RequestHandlerRegistry<JoinRequest>)
                 transport.getRequestHandler(JoinHelper.JOIN_ACTION_NAME);
-            joinHandler.processMessageReceived(joinRequest, new TransportChannel() {
-                @Override
-                public String getProfileName() {
-                    return "dummy";
-                }
-
-                @Override
-                public String getChannelType() {
-                    return "dummy";
-                }
+            final ActionListener<TransportResponse> listener = new ActionListener<>() {
 
                 @Override
-                public void sendResponse(TransportResponse response) {
+                public void onResponse(TransportResponse transportResponse) {
                     logger.debug("{} completed", future);
                     future.markAsDone();
                 }
 
                 @Override
-                public void sendResponse(Exception e) {
+                public void onFailure(Exception e) {
                     logger.error(() -> new ParameterizedMessage("unexpected error for {}", future), e);
                     future.markAsFailed(e);
                 }
-            });
+            };
+
+            joinHandler.processMessageReceived(joinRequest, new TestTransportChannel(listener));
         } catch (Exception e) {
             logger.error(() -> new ParameterizedMessage("unexpected error for {}", future), e);
             future.markAsFailed(e);
@@ -402,27 +396,17 @@ public class NodeJoinTests extends ESTestCase {
     private void handleStartJoinFrom(DiscoveryNode node, long term) throws Exception {
         final RequestHandlerRegistry<StartJoinRequest> startJoinHandler = (RequestHandlerRegistry<StartJoinRequest>)
             transport.getRequestHandler(JoinHelper.START_JOIN_ACTION_NAME);
-        startJoinHandler.processMessageReceived(new StartJoinRequest(node, term), new TransportChannel() {
+        startJoinHandler.processMessageReceived(new StartJoinRequest(node, term), new TestTransportChannel(new ActionListener<>() {
             @Override
-            public String getProfileName() {
-                return "dummy";
-            }
+            public void onResponse(TransportResponse transportResponse) {
 
-            @Override
-            public String getChannelType() {
-                return "dummy";
             }
 
             @Override
-            public void sendResponse(TransportResponse response) {
-
-            }
-
-            @Override
-            public void sendResponse(Exception exception) {
+            public void onFailure(Exception e) {
                 fail();
             }
-        });
+        }));
         deterministicTaskQueue.runAllRunnableTasks();
         assertFalse(isLocalNodeElectedMaster());
         assertThat(coordinator.getMode(), equalTo(Coordinator.Mode.CANDIDATE));
@@ -432,27 +416,19 @@ public class NodeJoinTests extends ESTestCase {
         final RequestHandlerRegistry<FollowersChecker.FollowerCheckRequest> followerCheckHandler =
             (RequestHandlerRegistry<FollowersChecker.FollowerCheckRequest>)
             transport.getRequestHandler(FollowersChecker.FOLLOWER_CHECK_ACTION_NAME);
-        followerCheckHandler.processMessageReceived(new FollowersChecker.FollowerCheckRequest(term, node), new TransportChannel() {
-            @Override
-            public String getProfileName() {
-                return "dummy";
-            }
-
-            @Override
-            public String getChannelType() {
-                return "dummy";
-            }
-
+        final TestTransportChannel channel = new TestTransportChannel(new ActionListener<>() {
             @Override
-            public void sendResponse(TransportResponse response) {
+            public void onResponse(TransportResponse transportResponse) {
 
             }
 
             @Override
-            public void sendResponse(Exception exception) {
+            public void onFailure(Exception e) {
                 fail();
             }
         });
+        followerCheckHandler.processMessageReceived(new FollowersChecker.FollowerCheckRequest(term, node), channel);
+        // Will throw exception if failed
         deterministicTaskQueue.runAllRunnableTasks();
         assertFalse(isLocalNodeElectedMaster());
         assertThat(coordinator.getMode(), equalTo(Coordinator.Mode.FOLLOWER));

+ 1 - 1
server/src/test/java/org/elasticsearch/transport/InboundHandlerTests.java

@@ -58,7 +58,7 @@ public class InboundHandlerTests extends ESTestCase {
         taskManager = new TaskManager(Settings.EMPTY, threadPool, Collections.emptySet());
         channel = new FakeTcpChannel(randomBoolean(), buildNewFakeTransportAddress().address(), buildNewFakeTransportAddress().address());
         NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList());
-        TransportHandshaker handshaker = new TransportHandshaker(version, threadPool, (n, c, r, v) -> {}, (v, c, r, r_id) -> { });
+        TransportHandshaker handshaker = new TransportHandshaker(version, threadPool, (n, c, r, v) -> {});
         TransportKeepAlive keepAlive = new TransportKeepAlive(threadPool, TcpChannel::sendMessage);
         OutboundHandler outboundHandler = new OutboundHandler("node", version, new StatsTracker(), threadPool,
             BigArrays.NON_RECYCLING_INSTANCE);

+ 9 - 21
server/src/test/java/org/elasticsearch/transport/TransportHandshakerTests.java

@@ -27,14 +27,12 @@ import org.elasticsearch.common.unit.TimeValue;
 import org.elasticsearch.tasks.TaskId;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.threadpool.TestThreadPool;
-import org.mockito.ArgumentCaptor;
 
 import java.io.IOException;
 import java.util.Collections;
 import java.util.concurrent.TimeUnit;
 
 import static org.hamcrest.Matchers.containsString;
-import static org.mockito.Matchers.eq;
 import static org.mockito.Mockito.doThrow;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.verify;
@@ -46,7 +44,6 @@ public class TransportHandshakerTests extends ESTestCase {
     private TcpChannel channel;
     private TestThreadPool threadPool;
     private TransportHandshaker.HandshakeRequestSender requestSender;
-    private TransportHandshaker.HandshakeResponseSender responseSender;
 
     @Override
     public void setUp() throws Exception {
@@ -54,11 +51,10 @@ public class TransportHandshakerTests extends ESTestCase {
         String nodeId = "node-id";
         channel = mock(TcpChannel.class);
         requestSender = mock(TransportHandshaker.HandshakeRequestSender.class);
-        responseSender = mock(TransportHandshaker.HandshakeResponseSender.class);
         node = new DiscoveryNode(nodeId, nodeId, nodeId, "host", "host_address", buildNewFakeTransportAddress(), Collections.emptyMap(),
             Collections.emptySet(), Version.CURRENT);
         threadPool = new TestThreadPool("thread-poll");
-        handshaker = new TransportHandshaker(Version.CURRENT, threadPool, requestSender, responseSender);
+        handshaker = new TransportHandshaker(Version.CURRENT, threadPool, requestSender);
     }
 
     @Override
@@ -76,20 +72,16 @@ public class TransportHandshakerTests extends ESTestCase {
 
         assertFalse(versionFuture.isDone());
 
-        TcpChannel mockChannel = mock(TcpChannel.class);
         TransportHandshaker.HandshakeRequest handshakeRequest = new TransportHandshaker.HandshakeRequest(Version.CURRENT);
         BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
         handshakeRequest.writeTo(bytesStreamOutput);
         StreamInput input = bytesStreamOutput.bytes().streamInput();
-        handshaker.handleHandshake(Version.CURRENT, mockChannel, reqId, input);
-
-
-        ArgumentCaptor<TransportResponse> responseCaptor = ArgumentCaptor.forClass(TransportResponse.class);
-        verify(responseSender).sendResponse(eq(Version.CURRENT), eq(mockChannel), responseCaptor.capture(),
-            eq(reqId));
+        final PlainActionFuture<TransportResponse> responseFuture = PlainActionFuture.newFuture();
+        final TestTransportChannel channel = new TestTransportChannel(responseFuture);
+        handshaker.handleHandshake(channel, reqId, input);
 
         TransportResponseHandler<TransportHandshaker.HandshakeResponse> handler = handshaker.removeHandlerForHandshake(reqId);
-        handler.handleResponse((TransportHandshaker.HandshakeResponse) responseCaptor.getValue());
+        handler.handleResponse((TransportHandshaker.HandshakeResponse) responseFuture.actionGet());
 
         assertTrue(versionFuture.isDone());
         assertEquals(Version.CURRENT, versionFuture.actionGet());
@@ -101,7 +93,6 @@ public class TransportHandshakerTests extends ESTestCase {
 
         verify(requestSender).sendRequest(node, channel, reqId, Version.CURRENT.minimumCompatibilityVersion());
 
-        TcpChannel mockChannel = mock(TcpChannel.class);
         TransportHandshaker.HandshakeRequest handshakeRequest = new TransportHandshaker.HandshakeRequest(Version.CURRENT);
         BytesStreamOutput currentHandshakeBytes = new BytesStreamOutput();
         handshakeRequest.writeTo(currentHandshakeBytes);
@@ -121,15 +112,12 @@ public class TransportHandshakerTests extends ESTestCase {
         // Otherwise, we need to update the test.
         assertEquals(currentHandshakeBytes.bytes().length(), lengthCheckingHandshake.bytes().length());
         assertEquals(1031, futureHandshakeStream.available());
-        handshaker.handleHandshake(Version.CURRENT, mockChannel, reqId, futureHandshakeStream);
+        final PlainActionFuture<TransportResponse> responseFuture = PlainActionFuture.newFuture();
+        final TestTransportChannel channel = new TestTransportChannel(responseFuture);
+        handshaker.handleHandshake(channel, reqId, futureHandshakeStream);
         assertEquals(0, futureHandshakeStream.available());
 
-
-        ArgumentCaptor<TransportResponse> responseCaptor = ArgumentCaptor.forClass(TransportResponse.class);
-        verify(responseSender).sendResponse(eq(Version.CURRENT), eq(mockChannel), responseCaptor.capture(),
-            eq(reqId));
-
-        TransportHandshaker.HandshakeResponse response = (TransportHandshaker.HandshakeResponse) responseCaptor.getValue();
+        TransportHandshaker.HandshakeResponse response = (TransportHandshaker.HandshakeResponse) responseFuture.actionGet();
 
         assertEquals(Version.CURRENT, response.getResponseVersion());
     }

+ 51 - 0
test/framework/src/main/java/org/elasticsearch/transport/TestTransportChannel.java

@@ -0,0 +1,51 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.transport;
+
+import org.elasticsearch.action.ActionListener;
+
+public class TestTransportChannel implements TransportChannel {
+
+    private final ActionListener<TransportResponse> listener;
+
+    public TestTransportChannel(ActionListener<TransportResponse> listener) {
+        this.listener = listener;
+    }
+
+    @Override
+    public String getProfileName() {
+        return "default";
+    }
+
+    @Override
+    public void sendResponse(TransportResponse response) {
+        listener.onResponse(response);
+    }
+
+    @Override
+    public void sendResponse(Exception exception) {
+        listener.onFailure(exception);
+    }
+
+    @Override
+    public String getChannelType() {
+        return "test";
+    }
+}