Browse Source

Remove key rotation code from `TokenService` (#87744)

The key rotation code in the TokenService was never used outside of
tests. I've made a conservative pass with this PR that removes key
rotation related code but maintains BWC.

I haven't touched things like the keyCache field and the code related
to the actual use of the encryption key. AFAICT that code is still
relevant from a BWC perspective, where upgrades from old versions are
concerned. Making BWC cuts is tracked separately (#87726).

Relates #87729
Nikolaj Volgushev 3 years ago
parent
commit
2d782f28db

+ 0 - 28
x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/xpack/security/authc/TokenAuthIntegTests.java

@@ -13,7 +13,6 @@ import org.elasticsearch.action.search.SearchRequest;
 import org.elasticsearch.action.search.SearchResponse;
 import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.action.support.WriteRequest;
-import org.elasticsearch.action.support.master.AcknowledgedResponse;
 import org.elasticsearch.action.update.UpdateRequest;
 import org.elasticsearch.action.update.UpdateResponse;
 import org.elasticsearch.client.Request;
@@ -117,33 +116,6 @@ public class TokenAuthIntegTests extends SecurityIntegTestCase {
         assertNotNull(userTokenFuture.actionGet());
     }
 
-    public void testTokenServiceCanRotateKeys() throws Exception {
-        OAuth2Token response = createToken(TEST_USER_NAME, SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING);
-        String masterName = internalCluster().getMasterName();
-        TokenService masterTokenService = internalCluster().getInstance(TokenService.class, masterName);
-        String activeKeyHash = masterTokenService.getActiveKeyHash();
-        for (TokenService tokenService : internalCluster().getInstances(TokenService.class)) {
-            PlainActionFuture<UserToken> userTokenFuture = new PlainActionFuture<>();
-            tokenService.decodeToken(response.accessToken(), userTokenFuture);
-            assertNotNull(userTokenFuture.actionGet());
-            assertEquals(activeKeyHash, tokenService.getActiveKeyHash());
-        }
-        client().admin().cluster().prepareHealth().execute().get();
-        PlainActionFuture<AcknowledgedResponse> rotateActionFuture = new PlainActionFuture<>();
-        logger.info("rotate on master: {}", masterName);
-        masterTokenService.rotateKeysOnMaster(rotateActionFuture);
-        assertTrue(rotateActionFuture.actionGet().isAcknowledged());
-        assertNotEquals(activeKeyHash, masterTokenService.getActiveKeyHash());
-
-        for (TokenService tokenService : internalCluster().getInstances(TokenService.class)) {
-            PlainActionFuture<UserToken> userTokenFuture = new PlainActionFuture<>();
-            tokenService.decodeToken(response.accessToken(), userTokenFuture);
-            assertNotNull(userTokenFuture.actionGet());
-            assertNotEquals(activeKeyHash, tokenService.getActiveKeyHash());
-        }
-        assertEquals(TEST_USER_NAME, response.principal());
-    }
-
     public void testExpiredTokensDeletedAfterExpiration() throws Exception {
         final RestHighLevelClient restClient = new TestRestHighLevelClient();
         OAuth2Token response = createToken(TEST_USER_NAME, SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING);

+ 0 - 117
x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/TokenService.java

@@ -33,16 +33,12 @@ import org.elasticsearch.action.search.SearchResponse;
 import org.elasticsearch.action.support.ContextPreservingActionListener;
 import org.elasticsearch.action.support.TransportActions;
 import org.elasticsearch.action.support.WriteRequest.RefreshPolicy;
-import org.elasticsearch.action.support.master.AcknowledgedRequest;
-import org.elasticsearch.action.support.master.AcknowledgedResponse;
 import org.elasticsearch.action.update.UpdateRequest;
 import org.elasticsearch.action.update.UpdateRequestBuilder;
 import org.elasticsearch.action.update.UpdateResponse;
 import org.elasticsearch.client.internal.Client;
-import org.elasticsearch.cluster.AckedClusterStateUpdateTask;
 import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.cluster.ClusterStateUpdateTask;
-import org.elasticsearch.cluster.ack.AckedRequest;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.Priority;
 import org.elasticsearch.common.Strings;
@@ -60,7 +56,6 @@ import org.elasticsearch.common.settings.SecureString;
 import org.elasticsearch.common.settings.Setting;
 import org.elasticsearch.common.settings.Setting.Property;
 import org.elasticsearch.common.settings.Settings;
-import org.elasticsearch.common.util.CollectionUtils;
 import org.elasticsearch.common.util.Maps;
 import org.elasticsearch.common.util.concurrent.AbstractRunnable;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
@@ -122,7 +117,6 @@ import java.util.Arrays;
 import java.util.Base64;
 import java.util.Collection;
 import java.util.Collections;
-import java.util.Comparator;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.Iterator;
@@ -2321,67 +2315,6 @@ public final class TokenService {
         }
     }
 
-    /**
-     * Creates a new key unless present that is newer than the current active key and returns the corresponding metadata. Note:
-     * this method doesn't modify the metadata used in this token service. See {@link #refreshMetadata(TokenMetadata)}
-     */
-    @SuppressWarnings("unchecked")
-    synchronized TokenMetadata generateSpareKey() {
-        KeyAndCache maxKey = keyCache.cache.values().stream().max(Comparator.comparingLong(v -> v.keyAndTimestamp.getTimestamp())).get();
-        KeyAndCache currentKey = keyCache.activeKeyCache;
-        if (currentKey == maxKey) {
-            long timestamp = createdTimeStamps.incrementAndGet();
-            while (true) {
-                byte[] saltArr = new byte[SALT_BYTES];
-                secureRandom.nextBytes(saltArr);
-                SecureString tokenKey = generateTokenKey();
-                KeyAndCache keyAndCache = new KeyAndCache(new KeyAndTimestamp(tokenKey, timestamp), new BytesKey(saltArr));
-                if (keyCache.cache.containsKey(keyAndCache.getKeyHash())) {
-                    continue; // collision -- generate a new key
-                }
-                return newTokenMetadata(keyCache.currentTokenKeyHash, CollectionUtils.appendToCopy(keyCache.cache.values(), keyAndCache));
-            }
-        }
-        return newTokenMetadata(keyCache.currentTokenKeyHash, keyCache.cache.values());
-    }
-
-    /**
-     * Rotate the current active key to the spare key created in the previous {@link #generateSpareKey()} call.
-     */
-    synchronized TokenMetadata rotateToSpareKey() {
-        KeyAndCache maxKey = keyCache.cache.values().stream().max(Comparator.comparingLong(v -> v.keyAndTimestamp.getTimestamp())).get();
-        if (maxKey == keyCache.activeKeyCache) {
-            throw new IllegalStateException("call generateSpareKey first");
-        }
-        return newTokenMetadata(maxKey.getKeyHash(), keyCache.cache.values());
-    }
-
-    /**
-     * Prunes the keys and keeps up to the latest N keys around
-     *
-     * @param numKeysToKeep the number of keys to keep.
-     */
-    synchronized TokenMetadata pruneKeys(int numKeysToKeep) {
-        if (keyCache.cache.size() <= numKeysToKeep) {
-            return getTokenMetadata(); // nothing to do
-        }
-        Map<BytesKey, KeyAndCache> map = Maps.newMapWithExpectedSize(keyCache.cache.size() + 1);
-        KeyAndCache currentKey = keyCache.get(keyCache.currentTokenKeyHash);
-        ArrayList<KeyAndCache> entries = new ArrayList<>(keyCache.cache.values());
-        Collections.sort(entries, (left, right) -> Long.compare(right.keyAndTimestamp.getTimestamp(), left.keyAndTimestamp.getTimestamp()));
-        for (KeyAndCache value : entries) {
-            if (map.size() < numKeysToKeep || value.keyAndTimestamp.getTimestamp() >= currentKey.keyAndTimestamp.getTimestamp()) {
-                logger.debug("keeping key {} ", value.getKeyHash());
-                map.put(value.getKeyHash(), value);
-            } else {
-                logger.debug("prune key {} ", value.getKeyHash());
-            }
-        }
-        assert map.isEmpty() == false;
-        assert map.containsKey(keyCache.currentTokenKeyHash);
-        return newTokenMetadata(keyCache.currentTokenKeyHash, map.values());
-    }
-
     /**
      * Returns the current in-use metdata of this {@link TokenService}
      */
@@ -2441,61 +2374,11 @@ public final class TokenService {
         }
     }
 
-    synchronized String getActiveKeyHash() {
-        return new BytesRef(Base64.getUrlEncoder().withoutPadding().encode(this.keyCache.currentTokenKeyHash.bytes)).utf8ToString();
-    }
-
     @SuppressForbidden(reason = "legacy usage of unbatched task") // TODO add support for batching here
     private void submitUnbatchedTask(@SuppressWarnings("SameParameterValue") String source, ClusterStateUpdateTask task) {
         clusterService.submitUnbatchedStateUpdateTask(source, task);
     }
 
-    void rotateKeysOnMaster(ActionListener<AcknowledgedResponse> listener) {
-        logger.info("rotate keys on master");
-        TokenMetadata tokenMetadata = generateSpareKey();
-        submitUnbatchedTask(
-            "publish next key to prepare key rotation",
-            new TokenMetadataPublishAction(tokenMetadata, ActionListener.wrap((res) -> {
-                if (res.isAcknowledged()) {
-                    TokenMetadata metadata = rotateToSpareKey();
-                    submitUnbatchedTask("publish next key to prepare key rotation", new TokenMetadataPublishAction(metadata, listener));
-                } else {
-                    listener.onFailure(new IllegalStateException("not acked"));
-                }
-            }, listener::onFailure))
-        );
-    }
-
-    private static final class TokenMetadataPublishAction extends AckedClusterStateUpdateTask {
-
-        private final TokenMetadata tokenMetadata;
-
-        protected TokenMetadataPublishAction(TokenMetadata tokenMetadata, ActionListener<AcknowledgedResponse> listener) {
-            super(new AckedRequest() {
-                @Override
-                public TimeValue ackTimeout() {
-                    return AcknowledgedRequest.DEFAULT_ACK_TIMEOUT;
-                }
-
-                @Override
-                public TimeValue masterNodeTimeout() {
-                    return AcknowledgedRequest.DEFAULT_MASTER_NODE_TIMEOUT;
-                }
-            }, listener);
-            this.tokenMetadata = tokenMetadata;
-        }
-
-        @Override
-        public ClusterState execute(ClusterState currentState) throws Exception {
-            XPackPlugin.checkReadyForXPackCustomMetadata(currentState);
-
-            if (tokenMetadata.equals(currentState.custom(TokenMetadata.TYPE))) {
-                return currentState;
-            }
-            return ClusterState.builder(currentState).putCustom(TokenMetadata.TYPE, tokenMetadata).build();
-        }
-    }
-
     private void initialize(ClusterService clusterService) {
         clusterService.addListener(event -> {
             ClusterState state = event.state();

+ 0 - 204
x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenServiceTests.java

@@ -75,7 +75,6 @@ import org.elasticsearch.xpack.core.security.authc.Authentication;
 import org.elasticsearch.xpack.core.security.authc.Authentication.RealmRef;
 import org.elasticsearch.xpack.core.security.authc.AuthenticationTestHelper;
 import org.elasticsearch.xpack.core.security.authc.AuthenticationTests;
-import org.elasticsearch.xpack.core.security.authc.TokenMetadata;
 import org.elasticsearch.xpack.core.security.authc.support.TokensInvalidationResult;
 import org.elasticsearch.xpack.core.security.user.User;
 import org.elasticsearch.xpack.core.watcher.watch.ClockMock;
@@ -299,209 +298,6 @@ public class TokenServiceTests extends ESTestCase {
         }
     }
 
-    public void testRotateKey() throws Exception {
-        TokenService tokenService = createTokenService(tokenServiceEnabledSettings, systemUTC());
-        // This test only makes sense in mixed clusters with pre v7.2.0 nodes where the Key is actually used
-        if (null == oldNode) {
-            oldNode = addAnotherDataNodeWithVersion(this.clusterService, randomFrom(Version.V_7_0_0, Version.V_7_1_0));
-        }
-        Authentication authentication = AuthenticationTestHelper.builder()
-            .user(new User("joe", "admin"))
-            .realmRef(new RealmRef("native_realm", "native", "node1"))
-            .build(false);
-        PlainActionFuture<TokenService.CreateTokenResult> tokenFuture = new PlainActionFuture<>();
-        final String userTokenId = UUIDs.randomBase64UUID();
-        final String refreshToken = UUIDs.randomBase64UUID();
-        tokenService.createOAuth2Tokens(userTokenId, refreshToken, authentication, authentication, Collections.emptyMap(), tokenFuture);
-        final String accessToken = tokenFuture.get().getAccessToken();
-        assertNotNull(accessToken);
-        mockGetTokenFromId(tokenService, userTokenId, authentication, false);
-
-        ThreadContext requestContext = new ThreadContext(Settings.EMPTY);
-        storeTokenHeader(requestContext, accessToken);
-
-        try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) {
-            PlainActionFuture<UserToken> future = new PlainActionFuture<>();
-            final SecureString bearerToken = Authenticator.extractBearerTokenFromHeader(requestContext);
-            tokenService.tryAuthenticateToken(bearerToken, future);
-            UserToken serialized = future.get();
-            assertAuthentication(authentication, serialized.getAuthentication());
-        }
-        rotateKeys(tokenService);
-
-        try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) {
-            PlainActionFuture<UserToken> future = new PlainActionFuture<>();
-            final SecureString bearerToken = Authenticator.extractBearerTokenFromHeader(requestContext);
-            tokenService.tryAuthenticateToken(bearerToken, future);
-            UserToken serialized = future.get();
-            assertAuthentication(authentication, serialized.getAuthentication());
-        }
-
-        PlainActionFuture<TokenService.CreateTokenResult> newTokenFuture = new PlainActionFuture<>();
-        final String newUserTokenId = UUIDs.randomBase64UUID();
-        final String newRefreshToken = UUIDs.randomBase64UUID();
-        tokenService.createOAuth2Tokens(
-            newUserTokenId,
-            newRefreshToken,
-            authentication,
-            authentication,
-            Collections.emptyMap(),
-            newTokenFuture
-        );
-        final String newAccessToken = newTokenFuture.get().getAccessToken();
-        assertNotNull(newAccessToken);
-        assertNotEquals(newAccessToken, accessToken);
-
-        requestContext = new ThreadContext(Settings.EMPTY);
-        storeTokenHeader(requestContext, newAccessToken);
-        mockGetTokenFromId(tokenService, newUserTokenId, authentication, false);
-
-        try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) {
-            PlainActionFuture<UserToken> future = new PlainActionFuture<>();
-            final SecureString bearerToken = Authenticator.extractBearerTokenFromHeader(requestContext);
-            tokenService.tryAuthenticateToken(bearerToken, future);
-            UserToken serialized = future.get();
-            assertAuthentication(authentication, serialized.getAuthentication());
-        }
-    }
-
-    private void rotateKeys(TokenService tokenService) {
-        TokenMetadata tokenMetadata = tokenService.generateSpareKey();
-        tokenService.refreshMetadata(tokenMetadata);
-        tokenMetadata = tokenService.rotateToSpareKey();
-        tokenService.refreshMetadata(tokenMetadata);
-    }
-
-    public void testKeyExchange() throws Exception {
-        TokenService tokenService = createTokenService(tokenServiceEnabledSettings, systemUTC());
-        // This test only makes sense in mixed clusters with pre v7.2.0 nodes where the Key is actually used
-        if (null == oldNode) {
-            oldNode = addAnotherDataNodeWithVersion(this.clusterService, randomFrom(Version.V_7_0_0, Version.V_7_1_0));
-        }
-        int numRotations = randomIntBetween(1, 5);
-        for (int i = 0; i < numRotations; i++) {
-            rotateKeys(tokenService);
-        }
-        TokenService otherTokenService = createTokenService(tokenServiceEnabledSettings, systemUTC());
-        otherTokenService.refreshMetadata(tokenService.getTokenMetadata());
-        Authentication authentication = AuthenticationTestHelper.builder()
-            .user(new User("joe", "admin"))
-            .realmRef(new RealmRef("native_realm", "native", "node1"))
-            .build(false);
-        PlainActionFuture<TokenService.CreateTokenResult> tokenFuture = new PlainActionFuture<>();
-        final String userTokenId = UUIDs.randomBase64UUID();
-        final String refreshToken = UUIDs.randomBase64UUID();
-        tokenService.createOAuth2Tokens(userTokenId, refreshToken, authentication, authentication, Collections.emptyMap(), tokenFuture);
-        final String accessToken = tokenFuture.get().getAccessToken();
-        assertNotNull(accessToken);
-        mockGetTokenFromId(tokenService, userTokenId, authentication, false);
-
-        ThreadContext requestContext = new ThreadContext(Settings.EMPTY);
-        storeTokenHeader(requestContext, accessToken);
-        try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) {
-            PlainActionFuture<UserToken> future = new PlainActionFuture<>();
-            final SecureString bearerToken = Authenticator.extractBearerTokenFromHeader(requestContext);
-            otherTokenService.tryAuthenticateToken(bearerToken, future);
-            UserToken serialized = future.get();
-            assertAuthentication(serialized.getAuthentication(), authentication);
-        }
-
-        rotateKeys(tokenService);
-
-        otherTokenService.refreshMetadata(tokenService.getTokenMetadata());
-
-        try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) {
-            PlainActionFuture<UserToken> future = new PlainActionFuture<>();
-            final SecureString bearerToken = Authenticator.extractBearerTokenFromHeader(requestContext);
-            otherTokenService.tryAuthenticateToken(bearerToken, future);
-            UserToken serialized = future.get();
-            assertAuthentication(serialized.getAuthentication(), authentication);
-        }
-    }
-
-    public void testPruneKeys() throws Exception {
-        TokenService tokenService = createTokenService(tokenServiceEnabledSettings, systemUTC());
-        // This test only makes sense in mixed clusters with pre v7.2.0 nodes where the Key is actually used
-        if (null == oldNode) {
-            oldNode = addAnotherDataNodeWithVersion(this.clusterService, randomFrom(Version.V_7_0_0, Version.V_7_1_0));
-        }
-        Authentication authentication = AuthenticationTestHelper.builder()
-            .user(new User("joe", "admin"))
-            .realmRef(new RealmRef("native_realm", "native", "node1"))
-            .build(false);
-        PlainActionFuture<TokenService.CreateTokenResult> tokenFuture = new PlainActionFuture<>();
-        final String userTokenId = UUIDs.randomBase64UUID();
-        final String refreshToken = UUIDs.randomBase64UUID();
-        tokenService.createOAuth2Tokens(userTokenId, refreshToken, authentication, authentication, Collections.emptyMap(), tokenFuture);
-        final String accessToken = tokenFuture.get().getAccessToken();
-        assertNotNull(accessToken);
-        mockGetTokenFromId(tokenService, userTokenId, authentication, false);
-
-        ThreadContext requestContext = new ThreadContext(Settings.EMPTY);
-        storeTokenHeader(requestContext, accessToken);
-
-        try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) {
-            PlainActionFuture<UserToken> future = new PlainActionFuture<>();
-            final SecureString bearerToken = Authenticator.extractBearerTokenFromHeader(requestContext);
-            tokenService.tryAuthenticateToken(bearerToken, future);
-            UserToken serialized = future.get();
-            assertAuthentication(authentication, serialized.getAuthentication());
-        }
-        TokenMetadata metadata = tokenService.pruneKeys(randomIntBetween(0, 100));
-        tokenService.refreshMetadata(metadata);
-
-        int numIterations = scaledRandomIntBetween(1, 5);
-        for (int i = 0; i < numIterations; i++) {
-            rotateKeys(tokenService);
-        }
-
-        try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) {
-            PlainActionFuture<UserToken> future = new PlainActionFuture<>();
-            final SecureString bearerToken = Authenticator.extractBearerTokenFromHeader(requestContext);
-            tokenService.tryAuthenticateToken(bearerToken, future);
-            UserToken serialized = future.get();
-            assertAuthentication(authentication, serialized.getAuthentication());
-        }
-
-        PlainActionFuture<TokenService.CreateTokenResult> newTokenFuture = new PlainActionFuture<>();
-        final String newUserTokenId = UUIDs.randomBase64UUID();
-        final String newRefreshToken = UUIDs.randomBase64UUID();
-        tokenService.createOAuth2Tokens(
-            newUserTokenId,
-            newRefreshToken,
-            authentication,
-            authentication,
-            Collections.emptyMap(),
-            newTokenFuture
-        );
-        final String newAccessToken = newTokenFuture.get().getAccessToken();
-        assertNotNull(newAccessToken);
-        assertNotEquals(newAccessToken, accessToken);
-
-        metadata = tokenService.pruneKeys(1);
-        tokenService.refreshMetadata(metadata);
-
-        try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) {
-            PlainActionFuture<UserToken> future = new PlainActionFuture<>();
-            final SecureString bearerToken = Authenticator.extractBearerTokenFromHeader(requestContext);
-            tokenService.tryAuthenticateToken(bearerToken, future);
-            UserToken serialized = future.get();
-            assertNull(serialized);
-        }
-
-        requestContext = new ThreadContext(Settings.EMPTY);
-        storeTokenHeader(requestContext, newAccessToken);
-        mockGetTokenFromId(tokenService, newUserTokenId, authentication, false);
-        try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) {
-            PlainActionFuture<UserToken> future = new PlainActionFuture<>();
-            final SecureString bearerToken = Authenticator.extractBearerTokenFromHeader(requestContext);
-            tokenService.tryAuthenticateToken(bearerToken, future);
-            UserToken serialized = future.get();
-            assertAuthentication(authentication, serialized.getAuthentication());
-        }
-
-    }
-
     public void testPassphraseWorks() throws Exception {
         TokenService tokenService = createTokenService(tokenServiceEnabledSettings, systemUTC());
         // This test only makes sense in mixed clusters with pre v7.1.0 nodes where the Key is actually used