소스 검색

Validate build hash in handshake (#65732)

There is no guarantee of wire compatibility between nodes running
different builds of the same version, but today we do not validate
whether two communicating nodes are compatible or not. This results in
confusing failures that look like serialization bugs, and it usually
takes nontrivial effort to determine that the failure is in fact due to
the user running incompatible builds.

This commit adds the build hash to the transport service handshake and
validates that matching versions have matching build hashes.

Closes #65249
David Turner 4 년 전
부모
커밋
aba2f3eb33

+ 108 - 16
server/src/main/java/org/elasticsearch/transport/TransportService.java

@@ -22,6 +22,7 @@ package org.elasticsearch.transport;
 import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 import org.apache.logging.log4j.message.ParameterizedMessage;
+import org.elasticsearch.Build;
 import org.elasticsearch.Version;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.ActionListenerResponseHandler;
@@ -34,6 +35,7 @@ import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.lease.Releasable;
+import org.elasticsearch.common.logging.DeprecationLogger;
 import org.elasticsearch.common.logging.Loggers;
 import org.elasticsearch.common.regex.Regex;
 import org.elasticsearch.common.settings.ClusterSettings;
@@ -68,10 +70,27 @@ import java.util.function.Function;
 import java.util.function.Predicate;
 import java.util.function.Supplier;
 
-public class TransportService extends AbstractLifecycleComponent implements ReportingService<TransportInfo>, TransportMessageListener,
-    TransportConnectionListener {
+public class TransportService extends AbstractLifecycleComponent
+        implements ReportingService<TransportInfo>, TransportMessageListener, TransportConnectionListener {
+
     private static final Logger logger = LogManager.getLogger(TransportService.class);
 
+    private static final String PERMIT_HANDSHAKES_FROM_INCOMPATIBLE_BUILDS_KEY = "es.unsafely_permit_handshake_from_incompatible_builds";
+    private static final boolean PERMIT_HANDSHAKES_FROM_INCOMPATIBLE_BUILDS;
+
+    static {
+        final String value = System.getProperty(PERMIT_HANDSHAKES_FROM_INCOMPATIBLE_BUILDS_KEY);
+        if (value == null) {
+            PERMIT_HANDSHAKES_FROM_INCOMPATIBLE_BUILDS = false;
+        } else if (Boolean.parseBoolean(value)) {
+            PERMIT_HANDSHAKES_FROM_INCOMPATIBLE_BUILDS = true;
+        } else {
+            throw new IllegalArgumentException("invalid value [" + value + "] for system property ["
+                    + PERMIT_HANDSHAKES_FROM_INCOMPATIBLE_BUILDS_KEY + "]");
+        }
+    }
+
+
     public static final String DIRECT_RESPONSE_PROFILE = ".direct";
     public static final String HANDSHAKE_ACTION_NAME = "internal:transport/handshake";
 
@@ -182,7 +201,14 @@ public class TransportService extends AbstractLifecycleComponent implements Repo
             false, false,
             HandshakeRequest::new,
             (request, channel, task) -> channel.sendResponse(
-                new HandshakeResponse(localNode, clusterName, localNode.getVersion())));
+                new HandshakeResponse(localNode.getVersion(), Build.CURRENT.hash(), localNode, clusterName)));
+
+        if (PERMIT_HANDSHAKES_FROM_INCOMPATIBLE_BUILDS) {
+            logger.warn("transport handshakes from incompatible builds are unsafely permitted on this node; remove system property [" +
+                    PERMIT_HANDSHAKES_FROM_INCOMPATIBLE_BUILDS_KEY + "] to resolve this warning");
+            DeprecationLogger.getLogger(TransportService.class).deprecate("permit_handshake_from_incompatible_builds",
+                "system property [" + PERMIT_HANDSHAKES_FROM_INCOMPATIBLE_BUILDS_KEY + "] is deprecated and should be removed");
+        }
     }
 
     public RemoteClusterService getRemoteClusterService() {
@@ -440,8 +466,8 @@ public class TransportService extends AbstractLifecycleComponent implements Repo
                     public void onFailure(Exception e) {
                         listener.onFailure(e);
                     }
-                }
-                , HandshakeResponse::new, ThreadPool.Names.GENERIC
+                },
+                HandshakeResponse::new, ThreadPool.Names.GENERIC
             ));
     }
 
@@ -463,28 +489,89 @@ public class TransportService extends AbstractLifecycleComponent implements Repo
     }
 
     public static class HandshakeResponse extends TransportResponse {
+
+        private static final Version BUILD_HASH_HANDSHAKE_VERSION = Version.V_8_0_0;
+
+        private final Version version;
+
+        @Nullable // if version < BUILD_HASH_HANDSHAKE_VERSION
+        private final String buildHash;
+
         private final DiscoveryNode discoveryNode;
+
         private final ClusterName clusterName;
-        private final Version version;
 
-        public HandshakeResponse(DiscoveryNode discoveryNode, ClusterName clusterName, Version version) {
-            this.discoveryNode = discoveryNode;
-            this.version = version;
-            this.clusterName = clusterName;
+        public HandshakeResponse(Version version, String buildHash, DiscoveryNode discoveryNode, ClusterName clusterName) {
+            this.buildHash = Objects.requireNonNull(buildHash);
+            this.discoveryNode = Objects.requireNonNull(discoveryNode);
+            this.version = Objects.requireNonNull(version);
+            this.clusterName = Objects.requireNonNull(clusterName);
         }
 
         public HandshakeResponse(StreamInput in) throws IOException {
             super(in);
-            discoveryNode = in.readOptionalWriteable(DiscoveryNode::new);
-            clusterName = new ClusterName(in);
-            version = Version.readVersion(in);
+            if (in.getVersion().onOrAfter(BUILD_HASH_HANDSHAKE_VERSION)) {
+                // the first two fields need only VInts and raw (ASCII) characters, so we cross our fingers and hope that they appear
+                // on the wire as we expect them to even if this turns out to be an incompatible build
+                version = Version.readVersion(in);
+                buildHash = in.readString();
+
+                try {
+                    // If the remote node is incompatible then make an effort to identify it anyway, so we can mention it in the exception
+                    // message, but recognise that this may fail
+                    discoveryNode = new DiscoveryNode(in);
+                } catch (Exception e) {
+                    if (isIncompatibleBuild(version, buildHash)) {
+                        throw new IllegalArgumentException("unidentifiable remote node is build [" + buildHash +
+                                "] of version [" + version + "] but this node is build [" + Build.CURRENT.hash() +
+                                "] of version [" + Version.CURRENT + "] which has an incompatible wire format", e);
+                    } else {
+                        throw e;
+                    }
+                }
+
+                if (isIncompatibleBuild(version, buildHash)) {
+                    if (PERMIT_HANDSHAKES_FROM_INCOMPATIBLE_BUILDS) {
+                        logger.warn("remote node [{}] is build [{}] of version [{}] but this node is build [{}] of version [{}] " +
+                                        "which may not be compatible; remove system property [{}] to resolve this warning",
+                                discoveryNode, buildHash, version, Build.CURRENT.hash(), Version.CURRENT,
+                                PERMIT_HANDSHAKES_FROM_INCOMPATIBLE_BUILDS_KEY);
+                    } else {
+                        throw new IllegalArgumentException("remote node [" + discoveryNode + "] is build [" + buildHash +
+                                "] of version [" + version + "] but this node is build [" + Build.CURRENT.hash() +
+                                "] of version [" + Version.CURRENT + "] which has an incompatible wire format");
+                    }
+                }
+
+                clusterName = new ClusterName(in);
+            } else {
+                discoveryNode = in.readOptionalWriteable(DiscoveryNode::new);
+                clusterName = new ClusterName(in);
+                version = Version.readVersion(in);
+                buildHash = null;
+            }
         }
 
         @Override
         public void writeTo(StreamOutput out) throws IOException {
-            out.writeOptionalWriteable(discoveryNode);
-            clusterName.writeTo(out);
-            Version.writeVersion(version, out);
+            if (out.getVersion().onOrAfter(BUILD_HASH_HANDSHAKE_VERSION)) {
+                Version.writeVersion(version, out);
+                out.writeString(buildHash);
+                discoveryNode.writeTo(out);
+                clusterName.writeTo(out);
+            } else {
+                out.writeOptionalWriteable(discoveryNode);
+                clusterName.writeTo(out);
+                Version.writeVersion(version, out);
+            }
+        }
+
+        public Version getVersion() {
+            return version;
+        }
+
+        public String getBuildHash() {
+            return buildHash;
         }
 
         public DiscoveryNode getDiscoveryNode() {
@@ -494,6 +581,10 @@ public class TransportService extends AbstractLifecycleComponent implements Repo
         public ClusterName getClusterName() {
             return clusterName;
         }
+
+        private static boolean isIncompatibleBuild(Version version, String buildHash) {
+            return version == Version.CURRENT && Build.CURRENT.hash().equals(buildHash) == false;
+        }
     }
 
     public void disconnectFromNode(DiscoveryNode node) {
@@ -1293,4 +1384,5 @@ public class TransportService extends AbstractLifecycleComponent implements Repo
             }
         }
     }
+
 }

+ 2 - 1
server/src/test/java/org/elasticsearch/cluster/NodeConnectionsServiceTests.java

@@ -21,6 +21,7 @@ package org.elasticsearch.cluster;
 
 import org.apache.logging.log4j.Level;
 import org.apache.logging.log4j.LogManager;
+import org.elasticsearch.Build;
 import org.elasticsearch.ElasticsearchTimeoutException;
 import org.elasticsearch.Version;
 import org.elasticsearch.action.ActionListener;
@@ -479,7 +480,7 @@ public class NodeConnectionsServiceTests extends ESTestCase {
         @Override
         public void handshake(Transport.Connection connection, TimeValue timeout, Predicate<ClusterName> clusterNamePredicate,
                               ActionListener<HandshakeResponse> listener) {
-            listener.onResponse(new HandshakeResponse(connection.getNode(), new ClusterName(""), Version.CURRENT));
+            listener.onResponse(new HandshakeResponse(Version.CURRENT, Build.CURRENT.hash(), connection.getNode(), new ClusterName("")));
         }
 
         @Override

+ 6 - 1
server/src/test/java/org/elasticsearch/cluster/coordination/FollowersCheckerTests.java

@@ -18,6 +18,7 @@
  */
 package org.elasticsearch.cluster.coordination;
 
+import org.elasticsearch.Build;
 import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.Version;
 import org.elasticsearch.cluster.ClusterName;
@@ -227,7 +228,11 @@ public class FollowersCheckerTests extends ESTestCase {
             protected void onSendRequest(long requestId, String action, TransportRequest request, DiscoveryNode node) {
                 assertFalse(node.equals(localNode));
                 if (action.equals(HANDSHAKE_ACTION_NAME)) {
-                    handleResponse(requestId, new TransportService.HandshakeResponse(node, ClusterName.DEFAULT, Version.CURRENT));
+                    handleResponse(requestId, new TransportService.HandshakeResponse(
+                            Version.CURRENT,
+                            Build.CURRENT.hash(),
+                            node,
+                            ClusterName.DEFAULT));
                     return;
                 }
                 deterministicTaskQueue.scheduleNow(new Runnable() {

+ 11 - 2
server/src/test/java/org/elasticsearch/cluster/coordination/LeaderCheckerTests.java

@@ -19,6 +19,7 @@
 
 package org.elasticsearch.cluster.coordination;
 
+import org.elasticsearch.Build;
 import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.Version;
 import org.elasticsearch.cluster.ClusterName;
@@ -222,7 +223,11 @@ public class LeaderCheckerTests extends ESTestCase {
             @Override
             protected void onSendRequest(long requestId, String action, TransportRequest request, DiscoveryNode node) {
                 if (action.equals(HANDSHAKE_ACTION_NAME)) {
-                    handleResponse(requestId, new TransportService.HandshakeResponse(node, ClusterName.DEFAULT, Version.CURRENT));
+                    handleResponse(requestId, new TransportService.HandshakeResponse(
+                            Version.CURRENT,
+                            Build.CURRENT.hash(),
+                            node,
+                            ClusterName.DEFAULT));
                     return;
                 }
                 assertThat(action, equalTo(LEADER_CHECK_ACTION_NAME));
@@ -332,7 +337,11 @@ public class LeaderCheckerTests extends ESTestCase {
             @Override
             protected void onSendRequest(long requestId, String action, TransportRequest request, DiscoveryNode node) {
                 if (action.equals(HANDSHAKE_ACTION_NAME)) {
-                    handleResponse(requestId, new TransportService.HandshakeResponse(node, ClusterName.DEFAULT, Version.CURRENT));
+                    handleResponse(requestId, new TransportService.HandshakeResponse(
+                            Version.CURRENT,
+                            Build.CURRENT.hash(),
+                            node,
+                            ClusterName.DEFAULT));
                     return;
                 }
                 assertThat(action, equalTo(LEADER_CHECK_ACTION_NAME));

+ 7 - 2
server/src/test/java/org/elasticsearch/cluster/coordination/NodeJoinTests.java

@@ -19,6 +19,7 @@
 package org.elasticsearch.cluster.coordination;
 
 import org.apache.logging.log4j.message.ParameterizedMessage;
+import org.elasticsearch.Build;
 import org.elasticsearch.Version;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.cluster.ClusterName;
@@ -161,8 +162,12 @@ public class NodeJoinTests extends ESTestCase {
             @Override
             protected void onSendRequest(long requestId, String action, TransportRequest request, DiscoveryNode destination) {
                 if (action.equals(HANDSHAKE_ACTION_NAME)) {
-                    handleResponse(requestId, new TransportService.HandshakeResponse(destination, initialState.getClusterName(),
-                        destination.getVersion()));
+                    handleResponse(requestId, new TransportService.HandshakeResponse(
+                            destination.getVersion(),
+                            Build.CURRENT.hash(),
+                            destination,
+                            initialState.getClusterName()
+                    ));
                 } else if (action.equals(JoinHelper.VALIDATE_JOIN_ACTION_NAME)) {
                     handleResponse(requestId, new TransportResponse.Empty());
                 } else {

+ 6 - 1
server/src/test/java/org/elasticsearch/discovery/HandshakingTransportAddressConnectorTests.java

@@ -23,6 +23,7 @@ import org.apache.logging.log4j.Level;
 import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 import org.apache.lucene.util.SetOnce;
+import org.elasticsearch.Build;
 import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.Version;
 import org.elasticsearch.action.ActionListener;
@@ -98,7 +99,11 @@ public class HandshakingTransportAddressConnectorTests extends ESTestCase {
                     if (fullConnectionFailure != null && node.getAddress().equals(remoteNode.getAddress())) {
                         handleError(requestId, fullConnectionFailure);
                     } else {
-                        handleResponse(requestId, new HandshakeResponse(remoteNode, new ClusterName(remoteClusterName), Version.CURRENT));
+                        handleResponse(requestId, new HandshakeResponse(
+                                Version.CURRENT,
+                                Build.CURRENT.hash(),
+                                remoteNode,
+                                new ClusterName(remoteClusterName)));
                     }
                 }
             }

+ 6 - 1
server/src/test/java/org/elasticsearch/transport/TransportServiceDeserializationFailureTests.java

@@ -19,6 +19,7 @@
 
 package org.elasticsearch.transport;
 
+import org.elasticsearch.Build;
 import org.elasticsearch.Version;
 import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.cluster.ClusterName;
@@ -58,7 +59,11 @@ public class TransportServiceDeserializationFailureTests extends ESTestCase {
             @Override
             protected void onSendRequest(long requestId, String action, TransportRequest request, DiscoveryNode node) {
                 if (action.equals(TransportService.HANDSHAKE_ACTION_NAME)) {
-                    handleResponse(requestId, new TransportService.HandshakeResponse(otherNode, new ClusterName(""), Version.CURRENT));
+                    handleResponse(requestId, new TransportService.HandshakeResponse(
+                            Version.CURRENT,
+                            Build.CURRENT.hash(),
+                            otherNode,
+                            new ClusterName("")));
                 }
             }
         };

+ 100 - 5
server/src/test/java/org/elasticsearch/transport/TransportServiceHandshakeTests.java

@@ -19,6 +19,7 @@
 
 package org.elasticsearch.transport;
 
+import org.elasticsearch.ExceptionsHelper;
 import org.elasticsearch.Version;
 import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.cluster.node.DiscoveryNode;
@@ -47,6 +48,7 @@ import java.util.concurrent.TimeUnit;
 import static java.util.Collections.emptyMap;
 import static java.util.Collections.emptySet;
 import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.instanceOf;
 
 public class TransportServiceHandshakeTests extends ESTestCase {
 
@@ -65,8 +67,9 @@ public class TransportServiceHandshakeTests extends ESTestCase {
                 new MockNioTransport(settings, Version.CURRENT, threadPool, new NetworkService(Collections.emptyList()),
                     PageCacheRecycler.NON_RECYCLING_INSTANCE, new NamedWriteableRegistry(Collections.emptyList()),
                     new NoneCircuitBreakerService());
+        final DisruptingTransportInterceptor transportInterceptor = new DisruptingTransportInterceptor();
         TransportService transportService = new MockTransportService(settings, transport, threadPool,
-            TransportService.NOOP_TRANSPORT_INTERCEPTOR, (boundAddress) -> new DiscoveryNode(
+            transportInterceptor, (boundAddress) -> new DiscoveryNode(
             nodeNameAndId,
             nodeNameAndId,
             boundAddress.publishAddress(),
@@ -76,7 +79,7 @@ public class TransportServiceHandshakeTests extends ESTestCase {
         transportService.start();
         transportService.acceptIncomingRequests();
         transportServices.add(transportService);
-        return new NetworkHandle(transportService, transportService.getLocalNode());
+        return new NetworkHandle(transportService, transportService.getLocalNode(), transportInterceptor);
     }
 
     @After
@@ -180,13 +183,105 @@ public class TransportServiceHandshakeTests extends ESTestCase {
         assertFalse(handleA.transportService.nodeConnected(discoveryNode));
     }
 
+    public void testRejectsMismatchedBuildHash() {
+        final Settings settings = Settings.builder().put("cluster.name", "a").build();
+        final NetworkHandle handleA = startServices("TS_A", settings, Version.CURRENT);
+        final NetworkHandle handleB = startServices("TS_B", settings, Version.CURRENT);
+        final DiscoveryNode discoveryNode = new DiscoveryNode(
+                "",
+                handleB.discoveryNode.getAddress(),
+                emptyMap(),
+                emptySet(),
+                Version.CURRENT.minimumCompatibilityVersion());
+        handleA.transportInterceptor.setModifyBuildHash(true);
+        handleB.transportInterceptor.setModifyBuildHash(true);
+        TransportSerializationException ex = expectThrows(TransportSerializationException.class, () -> {
+            try (Transport.Connection connection =
+                     AbstractSimpleTransportTestCase.openConnection(handleA.transportService, discoveryNode, TestProfiles.LIGHT_PROFILE)) {
+                PlainActionFuture.get(fut -> handleA.transportService.handshake(connection, timeout, fut.map(x -> null)));
+            }
+        });
+        assertThat(
+                ExceptionsHelper.unwrap(ex, IllegalArgumentException.class).getMessage(),
+                containsString("which has an incompatible wire format"));
+        assertFalse(handleA.transportService.nodeConnected(discoveryNode));
+    }
+
+    public void testAcceptsMismatchedBuildHashFromDifferentVersion() {
+        final NetworkHandle handleA = startServices(
+                "TS_A",
+                Settings.builder().put("cluster.name", "a").build(),
+                Version.CURRENT);
+        final NetworkHandle handleB = startServices(
+                "TS_B",
+                Settings.builder().put("cluster.name", "a").build(),
+                Version.CURRENT.minimumCompatibilityVersion());
+        handleA.transportInterceptor.setModifyBuildHash(true);
+        handleB.transportInterceptor.setModifyBuildHash(true);
+        AbstractSimpleTransportTestCase.connectToNode(handleA.transportService, handleB.discoveryNode, TestProfiles.LIGHT_PROFILE);
+        assertTrue(handleA.transportService.nodeConnected(handleB.discoveryNode));
+    }
+
     private static class NetworkHandle {
-        private TransportService transportService;
-        private DiscoveryNode discoveryNode;
+        final TransportService transportService;
+        final DiscoveryNode discoveryNode;
+        final DisruptingTransportInterceptor transportInterceptor;
 
-        NetworkHandle(TransportService transportService, DiscoveryNode discoveryNode) {
+        NetworkHandle(TransportService transportService, DiscoveryNode discoveryNode, DisruptingTransportInterceptor transportInterceptor) {
             this.transportService = transportService;
             this.discoveryNode = discoveryNode;
+            this.transportInterceptor = transportInterceptor;
+        }
+    }
+
+    private static class DisruptingTransportInterceptor implements TransportInterceptor {
+
+        private boolean modifyBuildHash;
+
+        public void setModifyBuildHash(boolean modifyBuildHash) {
+            this.modifyBuildHash = modifyBuildHash;
+        }
+
+        @Override
+        public <T extends TransportRequest> TransportRequestHandler<T> interceptHandler(
+                String action, String executor, boolean forceExecution, TransportRequestHandler<T> actualHandler) {
+
+            if (TransportService.HANDSHAKE_ACTION_NAME.equals(action)) {
+                return (request, channel, task) -> actualHandler.messageReceived(request, new TransportChannel() {
+                    @Override
+                    public String getProfileName() {
+                        return channel.getProfileName();
+                    }
+
+                    @Override
+                    public String getChannelType() {
+                        return channel.getChannelType();
+                    }
+
+                    @Override
+                    public void sendResponse(TransportResponse response) throws IOException {
+                        assertThat(response, instanceOf(TransportService.HandshakeResponse.class));
+                        if (modifyBuildHash) {
+                            final TransportService.HandshakeResponse handshakeResponse = (TransportService.HandshakeResponse) response;
+                            channel.sendResponse(new TransportService.HandshakeResponse(
+                                    handshakeResponse.getVersion(),
+                                    handshakeResponse.getBuildHash() + "-modified",
+                                    handshakeResponse.getDiscoveryNode(),
+                                    handshakeResponse.getClusterName()));
+                        } else {
+                            channel.sendResponse(response);
+                        }
+                    }
+
+                    @Override
+                    public void sendResponse(Exception exception) throws IOException {
+                        channel.sendResponse(exception);
+
+                    }
+                }, task);
+            } else {
+                return actualHandler;
+            }
         }
     }