Browse Source

Add JWT cache to JWT realm. (#84842)

Justin Cranford 3 years ago
parent
commit
070dec4603

+ 15 - 0
docs/reference/settings/security-settings.asciidoc

@@ -2059,6 +2059,21 @@ If this setting is used, then the JWT realm does not perform role
 mapping and instead loads the user from the listed realms.
 See <<authorization_realms>>.
 
+`jwt.cache.ttl`::
+(<<static-cluster-setting,Static>>)
+Specifies the time-to-live for JWT cache entries.
+JWT entries will be cached for this period of time.
+JWTs can only be cached if client authentication is successful (or disabled).
+Use the standard {es} <<time-units,time units>>.
+Defaults to `20m`. Zero disables JWT cache.
+If clients use a different JWT for every request, set to 0 to disable JWT cache.
+
+`jwt.cache.size`::
+(<<static-cluster-setting,Static>>)
+Specifies the maximum number of JWT cache entries.
+Defaults to `100000`. Zero disables JWT cache.
+If clients use a different JWT for every request, set to 0 to disable JWT cache.
+
 // tag::jwt-http-connect-timeout-tag[]
 `http.connect_timeout` {ess-icon}::
 (<<static-cluster-setting,Static>>)

+ 19 - 7
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/jwt/JwtRealmSettings.java

@@ -82,13 +82,9 @@ public class JwtRealmSettings {
     private static final TimeValue DEFAULT_ALLOWED_CLOCK_SKEW = TimeValue.timeValueSeconds(60);
     private static final List<String> DEFAULT_ALLOWED_SIGNATURE_ALGORITHMS = Collections.singletonList("RS256");
     private static final boolean DEFAULT_POPULATE_USER_METADATA = true;
-    private static final String DEFAULT_JWT_VALIDATION_CACHE_HASH_ALGO = "ssha256";
-    private static final TimeValue DEFAULT_JWT_VALIDATION_CACHE_TTL = TimeValue.timeValueMinutes(20);
-    private static final int DEFAULT_JWT_VALIDATION_CACHE_MAX_USERS = 100_000;
-    private static final int MIN_JWT_VALIDATION_CACHE_MAX_USERS = 0;
-    private static final TimeValue DEFAULT_ROLES_LOOKUP_CACHE_TTL = TimeValue.timeValueMinutes(20);
-    private static final int DEFAULT_ROLES_LOOKUP_CACHE_MAX_USERS = 100_000;
-    private static final int MIN_ROLES_LOOKUP_CACHE_MAX_USERS = 0;
+    private static final TimeValue DEFAULT_JWT_CACHE_TTL = TimeValue.timeValueMinutes(20);
+    private static final int DEFAULT_JWT_CACHE_SIZE = 100_000;
+    private static final int MIN_JWT_CACHE_SIZE = 0;
     private static final TimeValue DEFAULT_HTTP_CONNECT_TIMEOUT = TimeValue.timeValueSeconds(5);
     private static final TimeValue DEFAULT_HTTP_CONNECTION_READ_TIMEOUT = TimeValue.timeValueSeconds(5);
     private static final TimeValue DEFAULT_HTTP_SOCKET_TIMEOUT = TimeValue.timeValueSeconds(5);
@@ -140,6 +136,8 @@ public class JwtRealmSettings {
         );
         // JWT Client settings
         set.addAll(List.of(CLIENT_AUTHENTICATION_TYPE));
+        // JWT Cache settings
+        set.addAll(List.of(JWT_CACHE_TTL, JWT_CACHE_SIZE));
         // Standard HTTP settings for outgoing connections to get JWT issuer jwkset_path
         set.addAll(
             List.of(
@@ -238,6 +236,20 @@ public class JwtRealmSettings {
         "client_authentication.shared_secret"
     );
 
+    // Individual Cache settings
+
+    public static final Setting.AffixSetting<TimeValue> JWT_CACHE_TTL = Setting.affixKeySetting(
+        RealmSettings.realmSettingPrefix(TYPE),
+        "jwt.cache.ttl",
+        key -> Setting.timeSetting(key, DEFAULT_JWT_CACHE_TTL, Setting.Property.NodeScope)
+    );
+
+    public static final Setting.AffixSetting<Integer> JWT_CACHE_SIZE = Setting.affixKeySetting(
+        RealmSettings.realmSettingPrefix(TYPE),
+        "jwt.cache.size",
+        key -> Setting.intSetting(key, DEFAULT_JWT_CACHE_SIZE, MIN_JWT_CACHE_SIZE, Setting.Property.NodeScope)
+    );
+
     // Individual outgoing HTTP settings
 
     public static final Setting.AffixSetting<TimeValue> HTTP_CONNECT_TIMEOUT = Setting.affixKeySetting(

+ 13 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/support/CacheIteratorHelper.java

@@ -57,4 +57,17 @@ public class CacheIteratorHelper<K, V> {
             }
         }
     }
+
+    public void removeValuesIf(Predicate<V> removeIf) {
+        // the cache cannot be modified while doing this operation per the terms of the cache iterator
+        try (ReleasableLock ignored = this.acquireForIterator()) {
+            Iterator<V> iterator = cache.values().iterator();
+            while (iterator.hasNext()) {
+                V value = iterator.next();
+                if (removeIf.test(value)) {
+                    iterator.remove();
+                }
+            }
+        }
+    }
 }

+ 128 - 35
x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealm.java

@@ -16,8 +16,12 @@ import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.cache.Cache;
+import org.elasticsearch.common.cache.CacheBuilder;
+import org.elasticsearch.common.hash.MessageDigests;
 import org.elasticsearch.common.settings.SecureString;
 import org.elasticsearch.common.settings.SettingsException;
+import org.elasticsearch.common.util.concurrent.ReleasableLock;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
 import org.elasticsearch.core.Releasable;
 import org.elasticsearch.core.TimeValue;
@@ -30,16 +34,20 @@ import org.elasticsearch.xpack.core.security.authc.RealmSettings;
 import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings;
 import org.elasticsearch.xpack.core.security.authc.support.CachingRealm;
 import org.elasticsearch.xpack.core.security.authc.support.UserRoleMapper;
+import org.elasticsearch.xpack.core.security.support.CacheIteratorHelper;
 import org.elasticsearch.xpack.core.security.user.User;
 import org.elasticsearch.xpack.core.ssl.SSLService;
+import org.elasticsearch.xpack.security.authc.BytesKey;
 import org.elasticsearch.xpack.security.authc.support.ClaimParser;
 import org.elasticsearch.xpack.security.authc.support.DelegatedAuthorizationSupport;
 
 import java.io.IOException;
 import java.net.URI;
 import java.nio.charset.StandardCharsets;
+import java.security.MessageDigest;
 import java.util.Arrays;
 import java.util.Collections;
+import java.util.Date;
 import java.util.List;
 import java.util.Map;
 
@@ -50,6 +58,8 @@ import java.util.Map;
 public class JwtRealm extends Realm implements CachingRealm, Releasable {
     private static final Logger LOGGER = LogManager.getLogger(JwtRealm.class);
 
+    record ExpiringUser(User user, Date exp) {}
+
     record JwksAlgs(List<JWK> jwks, List<String> algs) {
         boolean isEmpty() {
             return jwks.isEmpty() && algs.isEmpty();
@@ -77,6 +87,8 @@ public class JwtRealm extends Realm implements CachingRealm, Releasable {
     final ClaimParser claimParserName;
     final JwtRealmSettings.ClientAuthenticationType clientAuthenticationType;
     final SecureString clientAuthenticationSharedSecret;
+    final Cache<BytesKey, ExpiringUser> jwtCache;
+    final CacheIteratorHelper<BytesKey, ExpiringUser> jwtCacheHelper;
     DelegatedAuthorizationSupport delegatedAuthorizationSupport = null;
 
     public JwtRealm(final RealmConfig realmConfig, final SSLService sslService, final UserRoleMapper userRoleMapper)
@@ -96,6 +108,8 @@ public class JwtRealm extends Realm implements CachingRealm, Releasable {
         this.clientAuthenticationType = realmConfig.getSetting(JwtRealmSettings.CLIENT_AUTHENTICATION_TYPE);
         final SecureString sharedSecret = realmConfig.getSetting(JwtRealmSettings.CLIENT_AUTHENTICATION_SHARED_SECRET);
         this.clientAuthenticationSharedSecret = Strings.hasText(sharedSecret) ? sharedSecret : null; // convert "" to null
+        this.jwtCache = this.buildJwtCache();
+        this.jwtCacheHelper = (this.jwtCache == null) ? null : new CacheIteratorHelper<>(this.jwtCache);
 
         // Validate Client Authentication settings. Throw SettingsException there was a problem.
         JwtUtil.validateClientAuthenticationSettings(
@@ -143,6 +157,15 @@ public class JwtRealm extends Realm implements CachingRealm, Releasable {
         }
     }
 
+    private Cache<BytesKey, ExpiringUser> buildJwtCache() {
+        final TimeValue jwtCacheTtl = super.config.getSetting(JwtRealmSettings.JWT_CACHE_TTL);
+        final int jwtCacheSize = super.config.getSetting(JwtRealmSettings.JWT_CACHE_SIZE);
+        if ((jwtCacheTtl.getNanos() > 0) && (jwtCacheSize > 0)) {
+            return CacheBuilder.<BytesKey, ExpiringUser>builder().setExpireAfterWrite(jwtCacheTtl).setMaximumWeight(jwtCacheSize).build();
+        }
+        return null;
+    }
+
     // must call parseAlgsAndJwksHmac() before parseAlgsAndJwksPkc()
     private JwtRealm.JwksAlgs parseJwksAlgsHmac() {
         final JwtRealm.JwksAlgs jwksAlgsHmac;
@@ -252,8 +275,19 @@ public class JwtRealm extends Realm implements CachingRealm, Releasable {
         this.delegatedAuthorizationSupport = new DelegatedAuthorizationSupport(allRealms, super.config, xpackLicenseState);
     }
 
+    /**
+     * Clean up JWT cache (if enabled).
+     * Clean up HTTPS client cache (if enabled).
+     */
     @Override
     public void close() {
+        if (this.jwtCache != null) {
+            try {
+                this.jwtCache.invalidateAll();
+            } catch (Exception e) {
+                LOGGER.warn("Exception invalidating JWT cache for realm [" + super.name() + "]", e);
+            }
+        }
         if (this.httpClient != null) {
             try {
                 this.httpClient.close();
@@ -272,11 +306,21 @@ public class JwtRealm extends Realm implements CachingRealm, Releasable {
     @Override
     public void expire(final String username) {
         this.ensureInitialized();
+        LOGGER.trace("Expiring JWT cache entries for realm [" + super.name() + "] principal=[" + username + "]");
+        if (this.jwtCacheHelper != null) {
+            this.jwtCacheHelper.removeValuesIf(expiringUser -> expiringUser.user.principal().equals(username));
+        }
     }
 
     @Override
     public void expireAll() {
         this.ensureInitialized();
+        if ((this.jwtCache != null) && (this.jwtCacheHelper != null)) {
+            LOGGER.trace("Invalidating JWT cache for realm [" + super.name() + "]");
+            try (ReleasableLock ignored = this.jwtCacheHelper.acquireUpdateLock()) {
+                this.jwtCache.invalidateAll();
+            }
+        }
     }
 
     @Override
@@ -321,22 +365,64 @@ public class JwtRealm extends Realm implements CachingRealm, Releasable {
                 return; // FAILED (secret is missing or mismatched)
             }
 
-            // Parse JWT: Extract claims for logs and role-mapping.
+            // JWT cache
             final SecureString serializedJwt = jwtAuthenticationToken.getEndUserSignedJwt();
+            final BytesKey jwtCacheKey = (this.jwtCache == null) ? null : computeBytesKey(serializedJwt);
+            if (jwtCacheKey != null) {
+                final ExpiringUser expiringUser = this.jwtCache.get(jwtCacheKey);
+                if (expiringUser == null) {
+                    LOGGER.trace("Realm [" + super.name() + "] JWT cache miss token=[" + tokenPrincipal + "] key=[" + jwtCacheKey + "].");
+                } else {
+                    final User user = expiringUser.user;
+                    final Date exp = expiringUser.exp; // claimsSet.getExpirationTime().getTime() + this.allowedClockSkew.getMillis()
+                    final String principal = user.principal();
+                    final Date now = new Date();
+                    if (now.getTime() < exp.getTime()) {
+                        LOGGER.trace(
+                            "Realm ["
+                                + super.name()
+                                + "] JWT cache hit token=["
+                                + tokenPrincipal
+                                + "] key=["
+                                + jwtCacheKey
+                                + "] principal=["
+                                + principal
+                                + "] exp=["
+                                + exp
+                                + "] now=["
+                                + now
+                                + "]."
+                        );
+                        if (this.delegatedAuthorizationSupport.hasDelegation()) {
+                            this.delegatedAuthorizationSupport.resolve(principal, listener);
+                        } else {
+                            listener.onResponse(AuthenticationResult.success(user));
+                        }
+                        return;
+                    }
+                    LOGGER.trace(
+                        "Realm ["
+                            + super.name()
+                            + "] JWT cache exp token=["
+                            + tokenPrincipal
+                            + "] key=["
+                            + jwtCacheKey
+                            + "] principal=["
+                            + principal
+                            + "] exp=["
+                            + exp
+                            + "] now=["
+                            + now
+                            + "]."
+                    );
+                }
+            }
+
+            // Validate JWT: Extract JWT and claims set, and validate JWT.
             final SignedJWT jwt;
             final JWTClaimsSet claimsSet;
             try {
                 jwt = SignedJWT.parse(serializedJwt.toString());
-                claimsSet = jwt.getJWTClaimsSet();
-            } catch (Exception e) {
-                final String msg = "Realm [" + super.name() + "] JWT parse failed for token=[" + tokenPrincipal + "].";
-                LOGGER.debug(msg);
-                listener.onResponse(AuthenticationResult.unsuccessful(msg, e));
-                return; // FAILED (JWT parse fail or regex parse fail)
-            }
-
-            // Validate JWT
-            try {
                 final String jwtAlg = jwt.getHeader().getAlgorithm().getName();
                 final boolean isJwtAlgHmac = JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS_HMAC.contains(jwtAlg);
                 final JwtRealm.JwksAlgs jwksAndAlgs = isJwtAlgHmac ? this.jwksAlgsHmac : this.jwksAlgsPkc;
@@ -348,6 +434,7 @@ public class JwtRealm extends Realm implements CachingRealm, Releasable {
                     jwksAndAlgs.algs,
                     jwksAndAlgs.jwks
                 );
+                claimsSet = jwt.getJWTClaimsSet();
                 LOGGER.trace("Realm [" + super.name() + "] JWT validation succeeded for token=[" + tokenPrincipal + "].");
             } catch (Exception e) {
                 final String msg = "Realm [" + super.name() + "] JWT validation failed for token=[" + tokenPrincipal + "].";
@@ -375,23 +462,25 @@ public class JwtRealm extends Realm implements CachingRealm, Releasable {
                 return;
             }
 
-            // Delegated role lookup: If enabled, lookup in authz realms. Otherwise, fall through to JWT realm role mapping.
-            if (this.delegatedAuthorizationSupport.hasDelegation()) {
-                this.delegatedAuthorizationSupport.resolve(principal, ActionListener.wrap(result -> {
-                    if (result.isAuthenticated()) {
-                        // Intercept the delegated authorization listener response to log roles. Empty roles is OK.
-                        final User user = result.getValue();
-                        final String rolesString = Arrays.toString(user.roles());
-                        LOGGER.debug(
-                            "Realm [" + super.name() + "] delegated roles [" + rolesString + "] for principal=[" + principal + "]."
-                        );
+            // Roles listener: Log roles from delegated authz lookup or role mapping, and cache User if JWT cache is enabled.
+            final ActionListener<AuthenticationResult<User>> logAndCacheListener = ActionListener.wrap(result -> {
+                if (result.isAuthenticated()) {
+                    final User user = result.getValue();
+                    final String rolesString = Arrays.toString(user.roles());
+                    LOGGER.debug("Realm [" + super.name() + "] roles [" + rolesString + "] for principal=[" + principal + "].");
+                    if ((this.jwtCache != null) && (this.jwtCacheHelper != null)) {
+                        try (ReleasableLock ignored = this.jwtCacheHelper.acquireUpdateLock()) {
+                            final long expWallClockMillis = claimsSet.getExpirationTime().getTime() + this.allowedClockSkew.getMillis();
+                            this.jwtCache.put(jwtCacheKey, new ExpiringUser(result.getValue(), new Date(expWallClockMillis)));
+                        }
                     }
-                    listener.onResponse(result);
-                }, e -> {
-                    final String msg = "Realm [" + super.name() + "] delegated roles failed for principal=[" + principal + "].";
-                    LOGGER.warn(msg, e);
-                    listener.onResponse(AuthenticationResult.unsuccessful(msg, e));
-                }));
+                }
+                listener.onResponse(result);
+            }, listener::onFailure);
+
+            // Delegated role lookup or Role mapping: Use the above listener to log roles and cache User.
+            if (this.delegatedAuthorizationSupport.hasDelegation()) {
+                this.delegatedAuthorizationSupport.resolve(principal, logAndCacheListener);
                 return;
             }
 
@@ -415,13 +504,8 @@ public class JwtRealm extends Realm implements CachingRealm, Releasable {
             final UserRoleMapper.UserData userData = new UserRoleMapper.UserData(principal, dn, groups, userMetadata, super.config);
             this.userRoleMapper.resolveRoles(userData, ActionListener.wrap(rolesSet -> {
                 final User user = new User(principal, rolesSet.toArray(Strings.EMPTY_ARRAY), name, mail, userData.getMetadata(), true);
-                LOGGER.debug("Realm [" + super.name() + "] roles " + String.join(",", rolesSet) + " for principal=[" + principal + "].");
-                listener.onResponse(AuthenticationResult.success(user));
-            }, e -> {
-                final String msg = "Realm [" + super.name() + "] roles failed for principal=[" + principal + "].";
-                LOGGER.warn(msg, e);
-                listener.onResponse(AuthenticationResult.unsuccessful(msg, e));
-            }));
+                logAndCacheListener.onResponse(AuthenticationResult.success(user));
+            }, logAndCacheListener::onFailure));
         } else {
             final String className = (authenticationToken == null) ? "null" : authenticationToken.getClass().getCanonicalName();
             final String msg = "Realm [" + super.name() + "] does not support AuthenticationToken [" + className + "].";
@@ -433,6 +517,15 @@ public class JwtRealm extends Realm implements CachingRealm, Releasable {
     @Override
     public void usageStats(final ActionListener<Map<String, Object>> listener) {
         this.ensureInitialized();
-        super.usageStats(ActionListener.wrap(listener::onResponse, listener::onFailure));
+        super.usageStats(ActionListener.wrap(stats -> {
+            stats.put("jwt.cache", Collections.singletonMap("size", this.jwtCache == null ? -1 : this.jwtCache.count()));
+            listener.onResponse(stats);
+        }, listener::onFailure));
+    }
+
+    static BytesKey computeBytesKey(final CharSequence charSequence) {
+        final MessageDigest messageDigest = MessageDigests.sha256();
+        messageDigest.update(charSequence.toString().getBytes(StandardCharsets.UTF_8));
+        return new BytesKey(messageDigest.digest());
     }
 }

+ 1 - 5
x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmAuthenticateTests.java

@@ -50,7 +50,6 @@ public class JwtRealmAuthenticateTests extends JwtRealmTestCase {
             new MinMax(1, 3), // usersRange
             new MinMax(0, 0), // rolesRange
             new MinMax(0, 1), // jwtCacheSizeRange
-            new MinMax(0, 1), // userCacheSizeRange
             randomBoolean() // createHttpsServer
         );
         final JwtIssuerAndRealm jwtIssuerAndRealm = this.randomJwtIssuerRealmPair();
@@ -74,7 +73,6 @@ public class JwtRealmAuthenticateTests extends JwtRealmTestCase {
             new MinMax(1, 3), // usersRange
             new MinMax(0, 3), // rolesRange
             new MinMax(0, 1), // jwtCacheSizeRange
-            new MinMax(0, 1), // userCacheSizeRange
             randomBoolean() // createHttpsServer
         );
         final JwtIssuerAndRealm jwtIssuerAndRealm = this.randomJwtIssuerRealmPair();
@@ -100,7 +98,6 @@ public class JwtRealmAuthenticateTests extends JwtRealmTestCase {
             new MinMax(1, 3), // usersRange
             new MinMax(0, 3), // rolesRange
             new MinMax(0, 1), // jwtCacheSizeRange
-            new MinMax(0, 1), // userCacheSizeRange
             randomBoolean() // createHttpsServer
         );
         final JwtIssuerAndRealm jwtIssuerAndRealm = this.randomJwtIssuerRealmPair();
@@ -140,7 +137,7 @@ public class JwtRealmAuthenticateTests extends JwtRealmTestCase {
         final JwtIssuer jwtIssuer = this.createJwtIssuer(0, 12, 1, 1, 1, true);
         assertThat(jwtIssuer.httpsServer, is(notNullValue()));
         try {
-            final JwtRealmNameAndSettingsBuilder realmNameAndSettingsBuilder = this.createJwtRealmSettings(jwtIssuer, 0);
+            final JwtRealmNameAndSettingsBuilder realmNameAndSettingsBuilder = this.createJwtRealmSettings(jwtIssuer, 0, 0);
             final String configKey = RealmSettings.getFullSettingKey(realmNameAndSettingsBuilder.name(), JwtRealmSettings.PKC_JWKSET_PATH);
             final String configValue = jwtIssuer.httpsServer.url.replace("/valid/", "/invalid");
             realmNameAndSettingsBuilder.settingsBuilder().put(configKey, configValue);
@@ -168,7 +165,6 @@ public class JwtRealmAuthenticateTests extends JwtRealmTestCase {
             new MinMax(1, 1), // usersRange
             new MinMax(1, 1), // rolesRange
             new MinMax(0, 1), // jwtCacheSizeRange
-            new MinMax(0, 1), // userCacheSizeRange
             randomBoolean() // createHttpsServer
         );
         final JwtIssuerAndRealm jwtIssuerAndRealm = this.randomJwtIssuerRealmPair();

+ 17 - 5
x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmTestCase.java

@@ -124,7 +124,6 @@ public abstract class JwtRealmTestCase extends JwtTestCase {
         MinMax usersRange,
         MinMax rolesRange,
         MinMax jwtCacheSizeRange,
-        MinMax userCacheSizeRange,
         boolean createHttpsServer
     ) throws Exception {
         assertThat(realmsRange.min(), is(greaterThanOrEqualTo(1)));
@@ -134,7 +133,6 @@ public abstract class JwtRealmTestCase extends JwtTestCase {
         assertThat(usersRange.min(), is(greaterThanOrEqualTo(1)));
         assertThat(rolesRange.min(), is(greaterThanOrEqualTo(0)));
         assertThat(jwtCacheSizeRange.min(), is(greaterThanOrEqualTo(0)));
-        assertThat(userCacheSizeRange.min(), is(greaterThanOrEqualTo(0)));
 
         // Create JWT authc realms and mocked authz realms. Initialize each JWT realm, and test ensureInitialized() before and after.
         final int realmsCount = randomIntBetween(realmsRange.min(), realmsRange.max());
@@ -147,12 +145,15 @@ public abstract class JwtRealmTestCase extends JwtTestCase {
             final int usersCount = randomIntBetween(usersRange.min(), usersRange.max());
             final int rolesCount = randomIntBetween(rolesRange.min(), rolesRange.max());
             final int jwtCacheSize = randomIntBetween(jwtCacheSizeRange.min(), jwtCacheSizeRange.max());
-            final int usersCacheSize = randomIntBetween(userCacheSizeRange.min(), userCacheSizeRange.max());
 
             final JwtIssuer jwtIssuer = this.createJwtIssuer(i, algsCount, audiencesCount, usersCount, rolesCount, createHttpsServer);
             // If HTTPS server was created in JWT issuer, any exception after that point requires closing it to avoid a thread pool leak
             try {
-                final JwtRealmNameAndSettingsBuilder realmNameAndSettingsBuilder = this.createJwtRealmSettings(jwtIssuer, authzCount);
+                final JwtRealmNameAndSettingsBuilder realmNameAndSettingsBuilder = this.createJwtRealmSettings(
+                    jwtIssuer,
+                    authzCount,
+                    jwtCacheSize
+                );
                 final JwtRealm jwtRealm = this.createJwtRealm(allRealms, jwtIssuer, realmNameAndSettingsBuilder);
 
                 // verify exception before initialize()
@@ -212,7 +213,8 @@ public abstract class JwtRealmTestCase extends JwtTestCase {
         return new JwtIssuer(issuer, audiences, algJwkPairsPkc, algJwkPairsHmac, algJwkPairHmacOidc, users, createHttpsServer);
     }
 
-    protected JwtRealmNameAndSettingsBuilder createJwtRealmSettings(final JwtIssuer jwtIssuer, final int authzCount) throws Exception {
+    protected JwtRealmNameAndSettingsBuilder createJwtRealmSettings(final JwtIssuer jwtIssuer, final int authzCount, final int jwtCacheSize)
+        throws Exception {
         final String authcRealmName = "realm_" + jwtIssuer.issuer;
         final String[] authzRealmNames = IntStream.range(0, authzCount).mapToObj(z -> authcRealmName + "_authz" + z).toArray(String[]::new);
 
@@ -318,6 +320,16 @@ public abstract class JwtRealmTestCase extends JwtTestCase {
                 String.join(",", authzRealmNames)
             );
         }
+
+        // JWT cache (on/off controlled by jwtCacheSize)
+        if (randomBoolean()) {
+            authcSettings.put(
+                RealmSettings.getFullSettingKey(authcRealmName, JwtRealmSettings.JWT_CACHE_TTL),
+                randomIntBetween(10, 120) + randomFrom("s", "m", "h")
+            );
+        }
+        authcSettings.put(RealmSettings.getFullSettingKey(authcRealmName, JwtRealmSettings.JWT_CACHE_SIZE), jwtCacheSize);
+
         // JWT authc realm secure settings
         final MockSecureSettings secureSettings = new MockSecureSettings();
         if (Strings.hasText(jwtIssuer.encodedJwkSetHmac)) {

+ 6 - 0
x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtTestCase.java

@@ -169,6 +169,12 @@ public abstract class JwtTestCase extends ESTestCase {
                 RealmSettings.getFullSettingKey(name, DelegatedAuthorizationSettings.AUTHZ_REALMS.apply(JwtRealmSettings.TYPE)),
                 randomBoolean() ? "" : "authz1, authz2"
             )
+            // Cache settings
+            .put(
+                RealmSettings.getFullSettingKey(name, JwtRealmSettings.JWT_CACHE_TTL),
+                randomBoolean() ? "-1" : randomBoolean() ? "0" : randomIntBetween(10, 120) + randomFrom("s", "m", "h")
+            )
+            .put(RealmSettings.getFullSettingKey(name, JwtRealmSettings.JWT_CACHE_SIZE), randomIntBetween(0, 1))
             // HTTP settings for outgoing connections
             .put(
                 RealmSettings.getFullSettingKey(name, JwtRealmSettings.HTTP_CONNECT_TIMEOUT),