1
0
Эх сурвалжийг харах

If signature validation fails, reload JWKs and retry if new JWKs are found (#88023)

Co-authored-by: Niels Dewulf
Justin Cranford 3 жил өмнө
parent
commit
89e54be954
13 өөрчлөгдсөн 950 нэмэгдсэн , 559 устгасан
  1. 5 0
      docs/changelog/88023.yaml
  2. 375 171
      x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealm.java
  3. 35 26
      x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtUtil.java
  4. 30 52
      x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtValidateUtil.java
  5. 17 35
      x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwkValidateUtilTests.java
  6. 12 10
      x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticationTokenTests.java
  7. 44 36
      x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtIssuer.java
  8. 11 2
      x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtIssuerHttpsServer.java
  9. 170 9
      x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmAuthenticateTests.java
  10. 21 20
      x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmGenerateTests.java
  11. 146 108
      x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmTestCase.java
  12. 76 84
      x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtTestCase.java
  13. 8 6
      x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtValidateUtilTests.java

+ 5 - 0
docs/changelog/88023.yaml

@@ -0,0 +1,5 @@
+pr: 88023
+summary: "If signature validation fails, reload JWKs and retry if new JWKs are found"
+area: Authentication
+type: enhancement
+issues: []

+ 375 - 171
x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealm.java

@@ -6,6 +6,7 @@
  */
 package org.elasticsearch.xpack.security.authc.jwt;
 
+import com.nimbusds.jose.JWSHeader;
 import com.nimbusds.jose.jwk.JWK;
 import com.nimbusds.jose.jwk.OctetSequenceKey;
 import com.nimbusds.jwt.JWTClaimsSet;
@@ -15,12 +16,15 @@ import org.apache.http.impl.nio.client.CloseableHttpAsyncClient;
 import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.bytes.BytesArray;
 import org.elasticsearch.common.cache.Cache;
 import org.elasticsearch.common.cache.CacheBuilder;
 import org.elasticsearch.common.hash.MessageDigests;
 import org.elasticsearch.common.settings.SecureString;
 import org.elasticsearch.common.settings.SettingsException;
+import org.elasticsearch.common.util.concurrent.ListenableFuture;
 import org.elasticsearch.common.util.concurrent.ReleasableLock;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
 import org.elasticsearch.core.Releasable;
@@ -37,18 +41,19 @@ import org.elasticsearch.xpack.core.security.authc.support.UserRoleMapper;
 import org.elasticsearch.xpack.core.security.support.CacheIteratorHelper;
 import org.elasticsearch.xpack.core.security.user.User;
 import org.elasticsearch.xpack.core.ssl.SSLService;
-import org.elasticsearch.xpack.security.authc.BytesKey;
 import org.elasticsearch.xpack.security.authc.support.ClaimParser;
 import org.elasticsearch.xpack.security.authc.support.DelegatedAuthorizationSupport;
 
 import java.io.IOException;
 import java.net.URI;
 import java.nio.charset.StandardCharsets;
-import java.security.MessageDigest;
+import java.util.Arrays;
 import java.util.Collections;
 import java.util.Date;
 import java.util.List;
 import java.util.Map;
+import java.util.Objects;
+import java.util.concurrent.atomic.AtomicReference;
 
 import static java.lang.String.join;
 import static org.elasticsearch.core.Strings.format;
@@ -60,11 +65,34 @@ import static org.elasticsearch.core.Strings.format;
 public class JwtRealm extends Realm implements CachingRealm, Releasable {
     private static final Logger LOGGER = LogManager.getLogger(JwtRealm.class);
 
-    record ExpiringUser(User user, Date exp) {}
+    // Cached authenticated users, and adjusted JWT expiration date (=exp+skew) for checking if the JWT expired before the cache entry
+    record ExpiringUser(User user, Date exp) {
+        ExpiringUser {
+            Objects.requireNonNull(user, "User must not be null");
+            Objects.requireNonNull(exp, "Expiration date must not be null");
+        }
+    }
 
+    // Original PKC/HMAC JWKSet or HMAC JWK content (for comparison during refresh), and filtered JWKs and Algs
+    record ContentAndJwksAlgs(byte[] sha256, JwksAlgs jwksAlgs) {
+        ContentAndJwksAlgs {
+            Objects.requireNonNull(jwksAlgs, "Filters JWKs and Algs must not be null");
+        }
+
+        boolean isEmpty() {
+            return ((this.sha256 == null) || this.sha256.length == 0) && this.jwksAlgs.isEmpty();
+        }
+    }
+
+    // Filtered JWKs and Algs
     record JwksAlgs(List<JWK> jwks, List<String> algs) {
+        JwksAlgs {
+            Objects.requireNonNull(jwks, "JWKs must not be null");
+            Objects.requireNonNull(algs, "Algs must not be null");
+        }
+
         boolean isEmpty() {
-            return jwks.isEmpty() && algs.isEmpty();
+            return this.jwks.isEmpty() && this.algs.isEmpty();
         }
     }
 
@@ -78,9 +106,11 @@ public class JwtRealm extends Realm implements CachingRealm, Releasable {
     final String allowedIssuer;
     final List<String> allowedAudiences;
     final String jwkSetPath;
-    final CloseableHttpAsyncClient httpClient;
-    final JwtRealm.JwksAlgs jwksAlgsHmac;
-    final JwtRealm.JwksAlgs jwksAlgsPkc;
+    final boolean isConfiguredJwkSetPkc;
+    final boolean isConfiguredJwkSetHmac;
+    final boolean isConfiguredJwkOidcHmac;
+    private final CloseableHttpAsyncClient httpClient;
+    final JwkSetLoader jwkSetLoader;
     final TimeValue allowedClockSkew;
     final Boolean populateUserMetadata;
     final ClaimParser claimParserPrincipal;
@@ -90,9 +120,14 @@ public class JwtRealm extends Realm implements CachingRealm, Releasable {
     final ClaimParser claimParserName;
     final JwtRealmSettings.ClientAuthenticationType clientAuthenticationType;
     final SecureString clientAuthenticationSharedSecret;
-    final Cache<BytesKey, ExpiringUser> jwtCache;
-    final CacheIteratorHelper<BytesKey, ExpiringUser> jwtCacheHelper;
+    final Cache<BytesArray, ExpiringUser> jwtCache;
+    final CacheIteratorHelper<BytesArray, ExpiringUser> jwtCacheHelper;
+    final List<String> allowedJwksAlgsPkc;
+    final List<String> allowedJwksAlgsHmac;
     DelegatedAuthorizationSupport delegatedAuthorizationSupport = null;
+    ContentAndJwksAlgs contentAndJwksAlgsPkc;
+    ContentAndJwksAlgs contentAndJwksAlgsHmac;
+    final URI jwkSetPathUri;
 
     JwtRealm(
         final RealmConfig realmConfig,
@@ -127,9 +162,17 @@ public class JwtRealm extends Realm implements CachingRealm, Releasable {
             this.clientAuthenticationSharedSecret
         );
 
-        if (config.hasSetting(JwtRealmSettings.HMAC_KEY) == false
-            && config.hasSetting(JwtRealmSettings.HMAC_JWKSET) == false
-            && config.hasSetting(JwtRealmSettings.PKC_JWKSET_PATH) == false) {
+        // Split configured signature algorithms by PKC and HMAC. Useful during validation, error logging, and JWK vs Alg filtering.
+        final List<String> algs = super.config.getSetting(JwtRealmSettings.ALLOWED_SIGNATURE_ALGORITHMS);
+        this.allowedJwksAlgsHmac = algs.stream().filter(JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS_HMAC::contains).toList();
+        this.allowedJwksAlgsPkc = algs.stream().filter(JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS_PKC::contains).toList();
+
+        // PKC JWKSet can be URL, file, or not set; only initialize HTTP client if PKC JWKSet is a URL.
+        this.jwkSetPath = super.config.getSetting(JwtRealmSettings.PKC_JWKSET_PATH);
+        this.isConfiguredJwkSetPkc = Strings.hasText(this.jwkSetPath);
+        this.isConfiguredJwkSetHmac = Strings.hasText(super.config.getSetting(JwtRealmSettings.HMAC_JWKSET));
+        this.isConfiguredJwkOidcHmac = Strings.hasText(super.config.getSetting(JwtRealmSettings.HMAC_KEY));
+        if (this.isConfiguredJwkSetPkc == false && this.isConfiguredJwkSetHmac == false && this.isConfiguredJwkOidcHmac == false) {
             throw new SettingsException(
                 "At least one of ["
                     + RealmSettings.getFullSettingKey(realmConfig, JwtRealmSettings.HMAC_KEY)
@@ -141,44 +184,39 @@ public class JwtRealm extends Realm implements CachingRealm, Releasable {
             );
         }
 
-        // PKC JWKSet can be URL, file, or not set; only initialize HTTP client if PKC JWKSet is a URL.
-        this.jwkSetPath = super.config.getSetting(JwtRealmSettings.PKC_JWKSET_PATH);
-        if (Strings.hasText(this.jwkSetPath)) {
-            final URI jwkSetPathPkcUri = JwtUtil.parseHttpsUri(this.jwkSetPath);
-            if (jwkSetPathPkcUri == null) {
-                this.httpClient = null; // local file means no HTTP client
+        if (this.isConfiguredJwkSetPkc) {
+            final URI jwkSetPathUri = JwtUtil.parseHttpsUri(jwkSetPath);
+            if (jwkSetPathUri == null) {
+                this.jwkSetPathUri = null; // local file path
+                this.httpClient = null;
             } else {
-                this.httpClient = JwtUtil.createHttpClient(super.config, sslService);
+                this.jwkSetPathUri = jwkSetPathUri; // HTTPS URL
+                this.httpClient = JwtUtil.createHttpClient(this.config, sslService);
             }
+            this.jwkSetLoader = new JwkSetLoader(); // PKC JWKSet loader for HTTPS URL or local file path
         } else {
-            this.httpClient = null; // no setting means no HTTP client
+            this.jwkSetPathUri = null; // not configured
+            this.httpClient = null;
+            this.jwkSetLoader = null;
         }
 
-        // If HTTPS client was created in JWT realm, any exception after that point requires closing it to avoid a thread pool leak
+        // Any exception during loading requires closing JwkSetLoader's HTTP client to avoid a thread pool leak
         try {
-            this.jwksAlgsHmac = this.parseJwksAlgsHmac();
-            this.jwksAlgsPkc = this.parseJwksAlgsPkc();
+            this.contentAndJwksAlgsHmac = this.parseJwksAlgsHmac();
+            this.contentAndJwksAlgsPkc = this.parseJwksAlgsPkc();
             this.verifyAnyAvailableJwkAndAlgPair();
         } catch (Throwable t) {
+            // ASSUME: Tests or startup only. Catch and rethrow Throwable here, in case some code throws an uncaught RuntimeException.
             this.close();
             throw t;
         }
     }
 
-    private Cache<BytesKey, ExpiringUser> buildJwtCache() {
-        final TimeValue jwtCacheTtl = super.config.getSetting(JwtRealmSettings.JWT_CACHE_TTL);
-        final int jwtCacheSize = super.config.getSetting(JwtRealmSettings.JWT_CACHE_SIZE);
-        if ((jwtCacheTtl.getNanos() > 0) && (jwtCacheSize > 0)) {
-            return CacheBuilder.<BytesKey, ExpiringUser>builder().setExpireAfterWrite(jwtCacheTtl).setMaximumWeight(jwtCacheSize).build();
-        }
-        return null;
-    }
-
-    // must call parseAlgsAndJwksHmac() before parseAlgsAndJwksPkc()
-    private JwtRealm.JwksAlgs parseJwksAlgsHmac() {
+    private ContentAndJwksAlgs parseJwksAlgsHmac() {
         final JwtRealm.JwksAlgs jwksAlgsHmac;
         final SecureString hmacJwkSetContents = super.config.getSetting(JwtRealmSettings.HMAC_JWKSET);
         final SecureString hmacKeyContents = super.config.getSetting(JwtRealmSettings.HMAC_KEY);
+        byte[] hmacStringContentsSha256 = null;
         if (Strings.hasText(hmacJwkSetContents) && Strings.hasText(hmacKeyContents)) {
             // HMAC Key vs HMAC JWKSet settings must be mutually exclusive
             throw new SettingsException(
@@ -195,68 +233,52 @@ public class JwtRealm extends Realm implements CachingRealm, Releasable {
             // At this point, one-and-only-one of the HMAC Key or HMAC JWKSet settings are set
             List<JWK> jwksHmac;
             if (Strings.hasText(hmacJwkSetContents)) {
+                hmacStringContentsSha256 = JwtUtil.sha256(hmacJwkSetContents.toString());
                 jwksHmac = JwkValidateUtil.loadJwksFromJwkSetString(
                     RealmSettings.getFullSettingKey(super.config, JwtRealmSettings.HMAC_JWKSET),
                     hmacJwkSetContents.toString()
                 );
             } else {
                 final OctetSequenceKey hmacKey = JwkValidateUtil.loadHmacJwkFromJwkString(
-                    RealmSettings.getFullSettingKey(super.config, JwtRealmSettings.HMAC_JWKSET),
+                    RealmSettings.getFullSettingKey(super.config, JwtRealmSettings.HMAC_KEY),
                     hmacKeyContents
                 );
+                assert hmacKey != null : "Null HMAC key should not happen here";
                 jwksHmac = List.of(hmacKey);
+                hmacStringContentsSha256 = JwtUtil.sha256(hmacKeyContents.toString());
             }
+
             // Filter JWK(s) vs signature algorithms. Only keep JWKs with a matching alg. Only keep algs with a matching JWK.
-            final List<String> algs = super.config.getSetting(JwtRealmSettings.ALLOWED_SIGNATURE_ALGORITHMS);
-            final List<String> algsHmac = algs.stream().filter(JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS_HMAC::contains).toList();
-            jwksAlgsHmac = JwkValidateUtil.filterJwksAndAlgorithms(jwksHmac, algsHmac);
+            jwksAlgsHmac = JwkValidateUtil.filterJwksAndAlgorithms(jwksHmac, this.allowedJwksAlgsHmac);
         }
         LOGGER.info("Usable HMAC: JWKs [{}]. Algorithms [{}].", jwksAlgsHmac.jwks.size(), String.join(",", jwksAlgsHmac.algs()));
-        return jwksAlgsHmac;
+        return new ContentAndJwksAlgs(hmacStringContentsSha256, jwksAlgsHmac);
     }
 
-    private JwtRealm.JwksAlgs parseJwksAlgsPkc() {
-        final JwtRealm.JwksAlgs jwksAlgsPkc;
-        if (Strings.hasText(this.jwkSetPath) == false) {
-            jwksAlgsPkc = new JwtRealm.JwksAlgs(Collections.emptyList(), Collections.emptyList());
+    private ContentAndJwksAlgs parseJwksAlgsPkc() {
+        if (this.isConfiguredJwkSetPkc == false) {
+            return new ContentAndJwksAlgs(null, new JwksAlgs(Collections.emptyList(), Collections.emptyList()));
         } else {
-            // PKC JWKSet get contents from local file or remote HTTPS URL
-            final byte[] jwkSetContentBytesPkc;
-            if (this.httpClient == null) {
-                jwkSetContentBytesPkc = JwtUtil.readFileContents(
-                    RealmSettings.getFullSettingKey(super.config, JwtRealmSettings.PKC_JWKSET_PATH),
-                    this.jwkSetPath,
-                    super.config.env()
-                );
-            } else {
-                final URI jwkSetPathPkcUri = JwtUtil.parseHttpsUri(this.jwkSetPath);
-                jwkSetContentBytesPkc = JwtUtil.readUriContents(
-                    RealmSettings.getFullSettingKey(super.config, JwtRealmSettings.PKC_JWKSET_PATH),
-                    jwkSetPathPkcUri,
-                    this.httpClient
-                );
-            }
-            final String jwkSetContentsPkc = new String(jwkSetContentBytesPkc, StandardCharsets.UTF_8);
-
-            // PKC JWKSet parse contents
-            final List<JWK> jwksPkc = JwkValidateUtil.loadJwksFromJwkSetString(
-                RealmSettings.getFullSettingKey(super.config, JwtRealmSettings.PKC_JWKSET_PATH),
-                jwkSetContentsPkc
-            );
+            // ASSUME: Blocking read operations are OK during startup
+            final PlainActionFuture<ContentAndJwksAlgs> future = new PlainActionFuture<>();
+            this.jwkSetLoader.load(future);
+            return future.actionGet();
+        }
+    }
 
-            // PKC JWKSet filter contents
-            final List<String> algs = super.config.getSetting(JwtRealmSettings.ALLOWED_SIGNATURE_ALGORITHMS);
-            final List<String> algsPkc = algs.stream().filter(JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS_PKC::contains).toList();
-            jwksAlgsPkc = JwkValidateUtil.filterJwksAndAlgorithms(jwksPkc, algsPkc);
+    private Cache<BytesArray, ExpiringUser> buildJwtCache() {
+        final TimeValue jwtCacheTtl = super.config.getSetting(JwtRealmSettings.JWT_CACHE_TTL);
+        final int jwtCacheSize = super.config.getSetting(JwtRealmSettings.JWT_CACHE_SIZE);
+        if ((jwtCacheTtl.getNanos() > 0) && (jwtCacheSize > 0)) {
+            return CacheBuilder.<BytesArray, ExpiringUser>builder().setExpireAfterWrite(jwtCacheTtl).setMaximumWeight(jwtCacheSize).build();
         }
-        LOGGER.info("Usable PKC: JWKs [{}]. Algorithms [{}].", jwksAlgsPkc.jwks().size(), String.join(",", jwksAlgsPkc.algs()));
-        return jwksAlgsPkc;
+        return null;
     }
 
     private void verifyAnyAvailableJwkAndAlgPair() {
-        assert this.jwksAlgsHmac != null : "HMAC not initialized";
-        assert this.jwksAlgsPkc != null : "PKC not initialized";
-        if (this.jwksAlgsHmac.isEmpty() && this.jwksAlgsPkc.isEmpty()) {
+        assert this.contentAndJwksAlgsHmac != null : "HMAC not initialized";
+        assert this.contentAndJwksAlgsPkc != null : "PKC not initialized";
+        if (this.contentAndJwksAlgsHmac.jwksAlgs.isEmpty() && this.contentAndJwksAlgsPkc.jwksAlgs.isEmpty()) {
             final String msg = "No available JWK and algorithm for HMAC or PKC. Realm authentication expected to fail until this is fixed.";
             throw new SettingsException(msg);
         }
@@ -289,13 +311,31 @@ public class JwtRealm extends Realm implements CachingRealm, Releasable {
      */
     @Override
     public void close() {
-        if (this.jwtCache != null) {
+        this.invalidateJwtCache();
+        this.closeHttpClient();
+    }
+
+    /**
+     * Clean up JWT cache (if enabled).
+     */
+    private void invalidateJwtCache() {
+        if ((this.jwtCache != null) && (this.jwtCacheHelper != null)) {
             try {
-                this.jwtCache.invalidateAll();
+                LOGGER.trace("Invalidating JWT cache for realm [{}]", super.name());
+                try (ReleasableLock ignored = this.jwtCacheHelper.acquireUpdateLock()) {
+                    this.jwtCache.invalidateAll();
+                }
+                LOGGER.debug("Invalidated JWT cache for realm [{}]", super.name());
             } catch (Exception e) {
                 LOGGER.warn("Exception invalidating JWT cache for realm [" + super.name() + "]", e);
             }
         }
+    }
+
+    /**
+     * Clean up HTTPS client cache (if enabled).
+     */
+    private void closeHttpClient() {
         if (this.httpClient != null) {
             try {
                 this.httpClient.close();
@@ -314,21 +354,17 @@ public class JwtRealm extends Realm implements CachingRealm, Releasable {
     @Override
     public void expire(final String username) {
         this.ensureInitialized();
-        LOGGER.trace("Expiring JWT cache entries for realm [" + super.name() + "] principal=[" + username + "]");
         if (this.jwtCacheHelper != null) {
+            LOGGER.trace("Expiring JWT cache entries for realm [{}] principal=[{}]", super.name(), username);
             this.jwtCacheHelper.removeValuesIf(expiringUser -> expiringUser.user.principal().equals(username));
+            LOGGER.trace("Expired JWT cache entries for realm [{}] principal=[{}]", super.name(), username);
         }
     }
 
     @Override
     public void expireAll() {
         this.ensureInitialized();
-        if ((this.jwtCache != null) && (this.jwtCacheHelper != null)) {
-            LOGGER.trace("Invalidating JWT cache for realm [" + super.name() + "]");
-            try (ReleasableLock ignored = this.jwtCacheHelper.acquireUpdateLock()) {
-                this.jwtCache.invalidateAll();
-            }
-        }
+        this.invalidateJwtCache();
     }
 
     @Override
@@ -365,7 +401,7 @@ public class JwtRealm extends Realm implements CachingRealm, Releasable {
 
             // JWT cache
             final SecureString serializedJwt = jwtAuthenticationToken.getEndUserSignedJwt();
-            final BytesKey jwtCacheKey = (this.jwtCache == null) ? null : computeBytesKey(serializedJwt);
+            final BytesArray jwtCacheKey = (this.jwtCache == null) ? null : new BytesArray(JwtUtil.sha256(serializedJwt));
             if (jwtCacheKey != null) {
                 final ExpiringUser expiringUser = this.jwtCache.get(jwtCacheKey);
                 if (expiringUser == null) {
@@ -417,99 +453,201 @@ public class JwtRealm extends Realm implements CachingRealm, Releasable {
             }
 
             // Validate JWT: Extract JWT and claims set, and validate JWT.
-            final SignedJWT jwt;
-            final JWTClaimsSet claimsSet;
-            try {
-                jwt = SignedJWT.parse(serializedJwt.toString());
-                final String jwtAlg = jwt.getHeader().getAlgorithm().getName();
-                final boolean isJwtAlgHmac = JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS_HMAC.contains(jwtAlg);
-                final JwtRealm.JwksAlgs jwksAndAlgs = isJwtAlgHmac ? this.jwksAlgsHmac : this.jwksAlgsPkc;
-                JwtValidateUtil.validate(
-                    jwt,
-                    this.allowedIssuer,
-                    this.allowedAudiences,
-                    this.allowedClockSkew.seconds(),
-                    jwksAndAlgs.algs,
-                    jwksAndAlgs.jwks
+            validateJwt(
+                serializedJwt,
+                tokenPrincipal,
+                ActionListener.wrap(claimsSet -> processValidatedJwt(tokenPrincipal, jwtCacheKey, claimsSet, listener), ex -> {
+                    final String msg = "Realm [" + super.name() + "] JWT validation failed for token=[" + tokenPrincipal + "].";
+                    LOGGER.debug(msg, ex);
+                    listener.onResponse(AuthenticationResult.unsuccessful(msg, ex));
+                })
+            );
+        } else {
+            final String className = (authenticationToken == null) ? "null" : authenticationToken.getClass().getCanonicalName();
+            final String msg = "Realm [" + super.name() + "] does not support AuthenticationToken [" + className + "].";
+            LOGGER.trace(msg);
+            listener.onResponse(AuthenticationResult.unsuccessful(msg, null));
+        }
+    }
+
+    private void validateJwt(SecureString serializedJwt, String tokenPrincipal, ActionListener<JWTClaimsSet> listener) {
+        final SignedJWT jwt;
+        final JWSHeader header;
+        final JWTClaimsSet claimsSet;
+        final String alg;
+        try {
+            jwt = SignedJWT.parse(serializedJwt.toString());
+            header = jwt.getHeader();
+            alg = header.getAlgorithm().getName();
+            claimsSet = jwt.getJWTClaimsSet();
+            final Date now = new Date();
+            if (LOGGER.isDebugEnabled()) {
+                LOGGER.debug(
+                    "Realm [{}] JWT parse succeeded for token=[{}]."
+                        + "Validating JWT, now [{}], alg [{}], issuer [{}], audiences [{}], kty [{}],"
+                        + " auth_time [{}], iat [{}], nbf [{}], exp [{}], kid [{}], jti [{}]",
+                    super.name(),
+                    tokenPrincipal,
+                    now,
+                    alg,
+                    claimsSet.getIssuer(),
+                    claimsSet.getAudience(),
+                    header.getType(),
+                    claimsSet.getDateClaim("auth_time"),
+                    claimsSet.getIssueTime(),
+                    claimsSet.getNotBeforeTime(),
+                    claimsSet.getExpirationTime(),
+                    header.getKeyID(),
+                    claimsSet.getJWTID()
                 );
-                claimsSet = jwt.getJWTClaimsSet();
-                LOGGER.trace("Realm [{}] JWT validation succeeded for token=[{}].", super.name(), tokenPrincipal);
-            } catch (Exception e) {
-                final String msg = "Realm [" + super.name() + "] JWT validation failed for token=[" + tokenPrincipal + "].";
-                final AuthenticationResult<User> failure = AuthenticationResult.unsuccessful(msg, e);
-                LOGGER.debug(msg, e);
-                listener.onResponse(failure);
-                return;
             }
+            // Validate all else before signature, because these checks are more helpful diagnostics than rejected signatures.
+            final boolean isJwtAlgHmac = JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS_HMAC.contains(alg);
+            JwtValidateUtil.validateType(jwt);
+            JwtValidateUtil.validateIssuer(jwt, allowedIssuer);
+            JwtValidateUtil.validateAudiences(jwt, allowedAudiences);
+            JwtValidateUtil.validateSignatureAlgorithm(jwt, isJwtAlgHmac ? this.allowedJwksAlgsHmac : this.allowedJwksAlgsPkc);
+            JwtValidateUtil.validateAuthTime(jwt, now, this.allowedClockSkew.seconds());
+            JwtValidateUtil.validateIssuedAtTime(jwt, now, this.allowedClockSkew.seconds());
+            JwtValidateUtil.validateNotBeforeTime(jwt, now, this.allowedClockSkew.seconds());
+            JwtValidateUtil.validateExpiredTime(jwt, now, this.allowedClockSkew.seconds());
+
+            // At this point, client authc and JWT kty+alg+iss+aud+time filters passed. Do sig last, in case JWK reload is expensive.
+            validateSignature(jwt, isJwtAlgHmac, tokenPrincipal, listener.map(ignored -> claimsSet));
+
+        } catch (Exception e) {
+            listener.onFailure(e);
+        }
+    }
 
-            // At this point, JWT is validated. Parse the JWT claims using realm settings.
-
-            final String principal = this.claimParserPrincipal.getClaimValue(claimsSet);
-            if (Strings.hasText(principal) == false) {
-                final String msg = "Realm ["
-                    + super.name()
-                    + "] no principal for token=["
-                    + tokenPrincipal
-                    + "] parser=["
-                    + this.claimParserPrincipal
-                    + "] claims=["
-                    + claimsSet
-                    + "].";
-                LOGGER.debug(msg);
-                listener.onResponse(AuthenticationResult.unsuccessful(msg, null));
-                return;
-            }
+    private void processValidatedJwt(
+        String tokenPrincipal,
+        BytesArray jwtCacheKey,
+        JWTClaimsSet claimsSet,
+        ActionListener<AuthenticationResult<User>> listener
+    ) {
+        // At this point, JWT is validated. Parse the JWT claims using realm settings.
+        final String principal = this.claimParserPrincipal.getClaimValue(claimsSet);
+        if (Strings.hasText(principal) == false) {
+            final String msg = "Realm ["
+                + super.name()
+                + "] no principal for token=["
+                + tokenPrincipal
+                + "] parser=["
+                + this.claimParserPrincipal
+                + "] claims=["
+                + claimsSet
+                + "].";
+            LOGGER.debug(msg);
+            listener.onResponse(AuthenticationResult.unsuccessful(msg, null));
+            return;
+        }
 
-            // Roles listener: Log roles from delegated authz lookup or role mapping, and cache User if JWT cache is enabled.
-            final ActionListener<AuthenticationResult<User>> logAndCacheListener = ActionListener.wrap(result -> {
-                if (result.isAuthenticated()) {
-                    final User user = result.getValue();
-                    LOGGER.debug(
-                        () -> format("Realm [%s] roles [%s] for principal=[%s].", super.name(), join(",", user.roles()), principal)
-                    );
-                    if ((this.jwtCache != null) && (this.jwtCacheHelper != null)) {
-                        try (ReleasableLock ignored = this.jwtCacheHelper.acquireUpdateLock()) {
-                            final long expWallClockMillis = claimsSet.getExpirationTime().getTime() + this.allowedClockSkew.getMillis();
-                            this.jwtCache.put(jwtCacheKey, new ExpiringUser(result.getValue(), new Date(expWallClockMillis)));
-                        }
+        // Roles listener: Log roles from delegated authz lookup or role mapping, and cache User if JWT cache is enabled.
+        final ActionListener<AuthenticationResult<User>> logAndCacheListener = ActionListener.wrap(result -> {
+            if (result.isAuthenticated()) {
+                final User user = result.getValue();
+                LOGGER.debug(() -> format("Realm [%s] roles [%s] for principal=[%s].", super.name(), join(",", user.roles()), principal));
+                if ((this.jwtCache != null) && (this.jwtCacheHelper != null)) {
+                    try (ReleasableLock ignored = this.jwtCacheHelper.acquireUpdateLock()) {
+                        final long expWallClockMillis = claimsSet.getExpirationTime().getTime() + this.allowedClockSkew.getMillis();
+                        this.jwtCache.put(jwtCacheKey, new ExpiringUser(result.getValue(), new Date(expWallClockMillis)));
                     }
                 }
-                listener.onResponse(result);
-            }, listener::onFailure);
-
-            // Delegated role lookup or Role mapping: Use the above listener to log roles and cache User.
-            if (this.delegatedAuthorizationSupport.hasDelegation()) {
-                this.delegatedAuthorizationSupport.resolve(principal, logAndCacheListener);
-                return;
             }
+            listener.onResponse(result);
+        }, listener::onFailure);
 
-            // User metadata: If enabled, extract metadata from JWT claims set. Use it in UserRoleMapper.UserData and User constructors.
-            final Map<String, Object> userMetadata;
-            try {
-                userMetadata = this.populateUserMetadata ? JwtUtil.toUserMetadata(jwt) : Map.of();
-            } catch (Exception e) {
-                final String msg = "Realm [" + super.name() + "] parse metadata failed for principal=[" + principal + "].";
-                final AuthenticationResult<User> unsuccessful = AuthenticationResult.unsuccessful(msg, e);
-                LOGGER.debug(msg, e);
-                listener.onResponse(unsuccessful);
+        // Delegated role lookup or Role mapping: Use the above listener to log roles and cache User.
+        if (this.delegatedAuthorizationSupport.hasDelegation()) {
+            this.delegatedAuthorizationSupport.resolve(principal, logAndCacheListener);
+            return;
+        }
+
+        // User metadata: If enabled, extract metadata from JWT claims set. Use it in UserRoleMapper.UserData and User constructors.
+        final Map<String, Object> userMetadata;
+        try {
+            userMetadata = this.populateUserMetadata ? JwtUtil.toUserMetadata(claimsSet) : Map.of();
+        } catch (Exception e) {
+            final String msg = "Realm [" + super.name() + "] parse metadata failed for principal=[" + principal + "].";
+            LOGGER.debug(msg, e);
+            listener.onResponse(AuthenticationResult.unsuccessful(msg, e));
+            return;
+        }
+
+        // Role resolution: Handle role mapping in JWT Realm.
+        final List<String> groups = this.claimParserGroups.getClaimValues(claimsSet);
+        final String dn = this.claimParserDn.getClaimValue(claimsSet);
+        final String mail = this.claimParserMail.getClaimValue(claimsSet);
+        final String name = this.claimParserName.getClaimValue(claimsSet);
+        final UserRoleMapper.UserData userData = new UserRoleMapper.UserData(principal, dn, groups, userMetadata, super.config);
+        this.userRoleMapper.resolveRoles(userData, ActionListener.wrap(rolesSet -> {
+            final User user = new User(principal, rolesSet.toArray(Strings.EMPTY_ARRAY), name, mail, userData.getMetadata(), true);
+            logAndCacheListener.onResponse(AuthenticationResult.success(user));
+        }, logAndCacheListener::onFailure));
+    }
+
+    private void validateSignature(
+        final SignedJWT jwt,
+        final boolean isJwtAlgHmac,
+        final String tokenPrincipal,
+        final ActionListener<Void> listener
+    ) throws Exception {
+        try {
+            JwtValidateUtil.validateSignature(
+                jwt,
+                isJwtAlgHmac ? this.contentAndJwksAlgsHmac.jwksAlgs.jwks : this.contentAndJwksAlgsPkc.jwksAlgs.jwks
+            );
+            listener.onResponse(null);
+        } catch (Exception primaryException) {
+            if (isJwtAlgHmac || this.jwkSetLoader == null) {
+                listener.onFailure(primaryException); // HMAC reload not supported at this time
                 return;
             }
 
-            // Role resolution: Handle role mapping in JWT Realm.
-            final List<String> groups = this.claimParserGroups.getClaimValues(claimsSet);
-            final String dn = this.claimParserDn.getClaimValue(claimsSet);
-            final String mail = this.claimParserMail.getClaimValue(claimsSet);
-            final String name = this.claimParserName.getClaimValue(claimsSet);
-            final UserRoleMapper.UserData userData = new UserRoleMapper.UserData(principal, dn, groups, userMetadata, super.config);
-            this.userRoleMapper.resolveRoles(userData, ActionListener.wrap(rolesSet -> {
-                final User user = new User(principal, rolesSet.toArray(Strings.EMPTY_ARRAY), name, mail, userData.getMetadata(), true);
-                logAndCacheListener.onResponse(AuthenticationResult.success(user));
-            }, logAndCacheListener::onFailure));
-        } else {
-            final String className = (authenticationToken == null) ? "null" : authenticationToken.getClass().getCanonicalName();
-            final String msg = "Realm [" + super.name() + "] does not support AuthenticationToken [" + className + "].";
-            LOGGER.trace(msg);
-            listener.onResponse(AuthenticationResult.unsuccessful(msg, null));
+            LOGGER.debug(
+                () -> org.elasticsearch.core.Strings.format(
+                    "Signature verification failed for [%s] reloading JWKSet (was: #[%s] JWKs, #[%s] algs, sha256=[%s])",
+                    tokenPrincipal,
+                    this.contentAndJwksAlgsPkc.jwksAlgs.jwks().size(),
+                    this.contentAndJwksAlgsPkc.jwksAlgs.algs().size(),
+                    MessageDigests.toHexString(this.contentAndJwksAlgsPkc.sha256())
+                ),
+                primaryException
+            );
+
+            this.jwkSetLoader.load(ActionListener.wrap(newContentAndJwksAlgs -> {
+                if (Arrays.equals(this.contentAndJwksAlgsPkc.sha256, newContentAndJwksAlgs.sha256)) {
+                    // No change in JWKSet
+                    logger.debug("Reloaded same PKC JWKs, can't retry verify JWT token=[{}]", tokenPrincipal);
+                    listener.onFailure(primaryException);
+                    return;
+                }
+                this.contentAndJwksAlgsPkc = newContentAndJwksAlgs;
+                // If all PKC JWKs were replaced, all PKC JWT cache entries need to be invalidated.
+                // Enhancement idea: Use separate caches for PKC vs HMAC JWKs, so only PKC entries get invalidated.
+                // Enhancement idea: When some JWKs are retained (ex: rotation), only invalidate for removed JWKs.
+                this.invalidateJwtCache();
+
+                if (this.contentAndJwksAlgsPkc.jwksAlgs.isEmpty()) {
+                    logger.debug("Reloaded empty PKC JWKs, verification of JWT token will fail [{}]", tokenPrincipal);
+                    // fall through and let try/catch below handle empty JWKs failure log and response
+                }
+
+                try {
+                    JwtValidateUtil.validateSignature(jwt, this.contentAndJwksAlgsPkc.jwksAlgs.jwks);
+                    listener.onResponse(null);
+                } catch (Exception secondaryException) {
+                    logger.debug(
+                        "Verification of JWT token for [{}] failed - original failure=[{}], failure after reload=[{}]",
+                        tokenPrincipal,
+                        primaryException.getMessage(),
+                        secondaryException.getMessage()
+                    );
+                    secondaryException.addSuppressed(primaryException);
+                    listener.onFailure(secondaryException);
+                }
+            }, listener::onFailure));
         }
     }
 
@@ -522,9 +660,75 @@ public class JwtRealm extends Realm implements CachingRealm, Releasable {
         }, listener::onFailure));
     }
 
-    static BytesKey computeBytesKey(final CharSequence charSequence) {
-        final MessageDigest messageDigest = MessageDigests.sha256();
-        messageDigest.update(charSequence.toString().getBytes(StandardCharsets.UTF_8));
-        return new BytesKey(messageDigest.digest());
+    private class JwkSetLoader {
+        private final AtomicReference<ListenableFuture<ContentAndJwksAlgs>> reloadFutureRef = new AtomicReference<>();
+
+        void load(final ActionListener<ContentAndJwksAlgs> listener) {
+            final ListenableFuture<ContentAndJwksAlgs> future = this.getFuture();
+            future.addListener(listener);
+        }
+
+        private ListenableFuture<ContentAndJwksAlgs> getFuture() {
+            for (;;) {
+                final ListenableFuture<ContentAndJwksAlgs> existingFuture = this.reloadFutureRef.get();
+                if (existingFuture != null) {
+                    return existingFuture;
+                }
+
+                final ListenableFuture<ContentAndJwksAlgs> newFuture = new ListenableFuture<>();
+                if (this.reloadFutureRef.compareAndSet(null, newFuture)) {
+                    loadInternal(ActionListener.runAfter(newFuture, () -> this.reloadFutureRef.compareAndSet(newFuture, null)));
+                    return newFuture;
+                }
+                // else, Another thread set the future-ref before us, just try it all again
+            }
+        }
+
+        private void loadInternal(final ActionListener<ContentAndJwksAlgs> listener) {
+            // PKC JWKSet get contents from local file or remote HTTPS URL
+            if (JwtRealm.this.httpClient == null) {
+                LOGGER.trace("Loading PKC JWKs from path [{}]", JwtRealm.this.jwkSetPath);
+                listener.onResponse(
+                    this.parseContent(
+                        JwtUtil.readFileContents(
+                            RealmSettings.getFullSettingKey(JwtRealm.this.config, JwtRealmSettings.PKC_JWKSET_PATH),
+                            JwtRealm.this.jwkSetPath,
+                            JwtRealm.this.config.env()
+                        )
+                    )
+                );
+            } else {
+                LOGGER.trace("Loading PKC JWKs from https URI [{}]", JwtRealm.this.jwkSetPathUri);
+                JwtUtil.readUriContents(
+                    RealmSettings.getFullSettingKey(JwtRealm.this.config, JwtRealmSettings.PKC_JWKSET_PATH),
+                    JwtRealm.this.jwkSetPathUri,
+                    JwtRealm.this.httpClient,
+                    listener.map(bytes -> {
+                        LOGGER.trace("Loaded bytes [{}] from [{}]", bytes.length, JwtRealm.this.jwkSetPathUri);
+                        return this.parseContent(bytes);
+                    })
+                );
+            }
+        }
+
+        private ContentAndJwksAlgs parseContent(final byte[] jwkSetContentBytesPkc) {
+            final String jwkSetContentsPkc = new String(jwkSetContentBytesPkc, StandardCharsets.UTF_8);
+            final byte[] jwkSetContentsPkcSha256 = JwtUtil.sha256(jwkSetContentsPkc);
+
+            // PKC JWKSet parse contents
+            final List<JWK> jwksPkc = JwkValidateUtil.loadJwksFromJwkSetString(
+                RealmSettings.getFullSettingKey(JwtRealm.this.config, JwtRealmSettings.PKC_JWKSET_PATH),
+                jwkSetContentsPkc
+            );
+            // Filter JWK(s) vs signature algorithms. Only keep JWKs with a matching alg. Only keep algs with a matching JWK.
+            final JwksAlgs jwksAlgsPkc = JwkValidateUtil.filterJwksAndAlgorithms(jwksPkc, JwtRealm.this.allowedJwksAlgsPkc);
+            LOGGER.info(
+                "Usable PKC: JWKs=[{}] algorithms=[{}] sha256=[{}]",
+                jwksAlgsPkc.jwks().size(),
+                String.join(",", jwksAlgsPkc.algs()),
+                MessageDigests.toHexString(jwkSetContentsPkcSha256)
+            );
+            return new ContentAndJwksAlgs(jwkSetContentsPkcSha256, jwksAlgsPkc);
+        }
     }
 }

+ 35 - 26
x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtUtil.java

@@ -11,7 +11,6 @@ import com.nimbusds.jose.jwk.JWK;
 import com.nimbusds.jose.jwk.JWKSet;
 import com.nimbusds.jose.util.JSONObjectUtils;
 import com.nimbusds.jwt.JWTClaimsSet;
-import com.nimbusds.jwt.SignedJWT;
 
 import org.apache.http.HttpEntity;
 import org.apache.http.HttpResponse;
@@ -33,8 +32,9 @@ import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 import org.elasticsearch.ElasticsearchSecurityException;
 import org.elasticsearch.SpecialPermission;
-import org.elasticsearch.action.support.PlainActionFuture;
+import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.hash.MessageDigests;
 import org.elasticsearch.common.settings.SecureString;
 import org.elasticsearch.common.settings.SettingsException;
 import org.elasticsearch.common.ssl.SslConfiguration;
@@ -51,6 +51,7 @@ import java.nio.charset.StandardCharsets;
 import java.nio.file.Files;
 import java.nio.file.Path;
 import java.security.AccessController;
+import java.security.MessageDigest;
 import java.security.PrivilegedAction;
 import java.security.PrivilegedActionException;
 import java.security.PrivilegedExceptionAction;
@@ -185,16 +186,25 @@ public class JwtUtil {
         return null;
     }
 
-    public static byte[] readUriContents(
+    public static void readUriContents(
         final String jwkSetConfigKeyPkc,
         final URI jwkSetPathPkcUri,
-        final CloseableHttpAsyncClient httpClient
-    ) throws SettingsException {
-        try {
-            return JwtUtil.readBytes(httpClient, jwkSetPathPkcUri);
-        } catch (Exception e) {
-            throw new SettingsException("Can't get contents for setting [" + jwkSetConfigKeyPkc + "] value [" + jwkSetPathPkcUri + "].", e);
-        }
+        final CloseableHttpAsyncClient httpClient,
+        final ActionListener<byte[]> listener
+    ) {
+        JwtUtil.readBytes(
+            httpClient,
+            jwkSetPathPkcUri,
+            ActionListener.wrap(
+                listener::onResponse,
+                ex -> listener.onFailure(
+                    new SettingsException(
+                        "Can't get contents for setting [" + jwkSetConfigKeyPkc + "] value [" + jwkSetPathPkcUri + "].",
+                        ex
+                    )
+                )
+            )
+        );
     }
 
     public static byte[] readFileContents(final String jwkSetConfigKeyPkc, final String jwkSetPathPkc, final Environment environment)
@@ -211,7 +221,7 @@ public class JwtUtil {
     }
 
     public static String serializeJwkSet(final JWKSet jwkSet, final boolean publicKeysOnly) {
-        if ((jwkSet == null) || (jwkSet.getKeys().isEmpty())) {
+        if (jwkSet == null) {
             return null;
         }
         return JSONObjectUtils.toJSONString(jwkSet.toJSONObject(publicKeysOnly));
@@ -262,13 +272,11 @@ public class JwtUtil {
     }
 
     /**
-     * Use the HTTP Client to get URL content bytes up to N max bytes.
+     * Use the HTTP Client to get URL content bytes.
      * @param httpClient Configured HTTP/HTTPS client.
      * @param uri URI to download.
-     * @return Byte array of the URI contents up to N max bytes.
      */
-    public static byte[] readBytes(final CloseableHttpAsyncClient httpClient, final URI uri) {
-        final PlainActionFuture<byte[]> plainActionFuture = PlainActionFuture.newFuture();
+    public static void readBytes(final CloseableHttpAsyncClient httpClient, final URI uri, ActionListener<byte[]> listener) {
         AccessController.doPrivileged((PrivilegedAction<Void>) () -> {
             httpClient.execute(new HttpGet(uri), new FutureCallback<>() {
                 @Override
@@ -278,12 +286,12 @@ public class JwtUtil {
                     if (statusCode == 200) {
                         final HttpEntity entity = result.getEntity();
                         try (InputStream inputStream = entity.getContent()) {
-                            plainActionFuture.onResponse(inputStream.readAllBytes());
+                            listener.onResponse(inputStream.readAllBytes());
                         } catch (Exception e) {
-                            plainActionFuture.onFailure(e);
+                            listener.onFailure(e);
                         }
                     } else {
-                        plainActionFuture.onFailure(
+                        listener.onFailure(
                             new ElasticsearchSecurityException(
                                 "Get [" + uri + "] failed, status [" + statusCode + "], reason [" + statusLine.getReasonPhrase() + "]."
                             )
@@ -293,17 +301,16 @@ public class JwtUtil {
 
                 @Override
                 public void failed(Exception e) {
-                    plainActionFuture.onFailure(new ElasticsearchSecurityException("Get [" + uri + "] failed.", e));
+                    listener.onFailure(new ElasticsearchSecurityException("Get [" + uri + "] failed.", e));
                 }
 
                 @Override
                 public void cancelled() {
-                    plainActionFuture.onFailure(new ElasticsearchSecurityException("Get [" + uri + "] was cancelled."));
+                    listener.onFailure(new ElasticsearchSecurityException("Get [" + uri + "] was cancelled."));
                 }
             });
             return null;
         });
-        return plainActionFuture.actionGet();
     }
 
     public static Path resolvePath(final Environment environment, final String jwkSetPath) {
@@ -335,14 +342,10 @@ public class JwtUtil {
      *   JWSHeader: Header are not support.
      *   JWTClaimsSet: Claims are supported. Claim keys are prefixed by "jwt_claim_".
      *   Base64URL: Signature is not supported.
-     * @param jwt SignedJWT object.
      * @return Map of formatted and filtered values to be used as user metadata.
-     * @throws Exception Parse error.
      */
-    //
     // Values will be filtered by type using isAllowedTypeForClaim().
-    public static Map<String, Object> toUserMetadata(final SignedJWT jwt) throws Exception {
-        final JWTClaimsSet claimsSet = jwt.getJWTClaimsSet();
+    public static Map<String, Object> toUserMetadata(JWTClaimsSet claimsSet) {
         return claimsSet.getClaims()
             .entrySet()
             .stream()
@@ -366,4 +369,10 @@ public class JwtUtil {
             || (value instanceof Collection
                 && ((Collection<?>) value).stream().allMatch(e -> e instanceof String || e instanceof Boolean || e instanceof Number)));
     }
+
+    public static byte[] sha256(final CharSequence charSequence) {
+        final MessageDigest messageDigest = MessageDigests.sha256();
+        messageDigest.update(charSequence.toString().getBytes(StandardCharsets.UTF_8));
+        return messageDigest.digest();
+    }
 }

+ 30 - 52
x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtValidateUtil.java

@@ -32,6 +32,7 @@ import com.nimbusds.jwt.SignedJWT;
 
 import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
+import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.common.settings.SecureString;
 
 import java.util.Date;
@@ -48,55 +49,6 @@ public class JwtValidateUtil {
         null
     );
 
-    /**
-     * Validate a SignedJWT. Use iss/aud/alg filters for those claims, JWKSet for signature, and skew seconds for time claims.
-     * @param jwt Signed JWT to be validated.
-     * @param allowedIssuer Filter for the "iss" claim.
-     * @param allowedAudiences Filter for the "aud" claim.
-     * @param allowedClockSkewSeconds Skew tolerance for the "auth_time", "iat", "nbf", and "exp" claims.
-     * @param allowedSignatureAlgorithms Filter for the "aud" header.
-     * @param jwks JWKs of HMAC secret keys or RSA/EC public keys.
-     * @throws Exception Error for the first validation to fail.
-     */
-    public static void validate(
-        final SignedJWT jwt,
-        final String allowedIssuer,
-        final List<String> allowedAudiences,
-        final long allowedClockSkewSeconds,
-        final List<String> allowedSignatureAlgorithms,
-        final List<JWK> jwks
-    ) throws Exception {
-        final Date now = new Date();
-
-        if (LOGGER.isDebugEnabled()) {
-            LOGGER.debug(
-                "Validating JWT, now [{}], alg [{}], issuer [{}], audiences [{}], typ [{}],"
-                    + " auth_time [{}], iat [{}], nbf [{}], exp [{}], kid [{}], jti [{}]",
-                now,
-                jwt.getHeader().getAlgorithm(),
-                jwt.getJWTClaimsSet().getIssuer(),
-                jwt.getJWTClaimsSet().getAudience(),
-                jwt.getHeader().getType(),
-                jwt.getJWTClaimsSet().getDateClaim("auth_time"),
-                jwt.getJWTClaimsSet().getIssueTime(),
-                jwt.getJWTClaimsSet().getNotBeforeTime(),
-                jwt.getJWTClaimsSet().getExpirationTime(),
-                jwt.getHeader().getKeyID(),
-                jwt.getJWTClaimsSet().getJWTID()
-            );
-        }
-        // validate claims before signature, because log messages about rejected claims can be more helpful than rejected signatures
-        JwtValidateUtil.validateType(jwt);
-        JwtValidateUtil.validateIssuer(jwt, allowedIssuer);
-        JwtValidateUtil.validateAudiences(jwt, allowedAudiences);
-        JwtValidateUtil.validateSignatureAlgorithm(jwt, allowedSignatureAlgorithms);
-        JwtValidateUtil.validateAuthTime(jwt, now, allowedClockSkewSeconds);
-        JwtValidateUtil.validateIssuedAtTime(jwt, now, allowedClockSkewSeconds);
-        JwtValidateUtil.validateNotBeforeTime(jwt, now, allowedClockSkewSeconds);
-        JwtValidateUtil.validateExpiredTime(jwt, now, allowedClockSkewSeconds);
-        JwtValidateUtil.validateSignature(jwt, jwks);
-    }
-
     public static void validateType(final SignedJWT jwt) throws Exception {
         final JOSEObjectType jwtHeaderType = jwt.getHeader().getType();
         try {
@@ -277,7 +229,10 @@ public class JwtValidateUtil {
      * @throws Exception Error if JWKs fail to validate the Signed JWT.
      */
     public static void validateSignature(final SignedJWT jwt, final List<JWK> jwks) throws Exception {
-        assert jwks != null && jwks.isEmpty() == false : "Caller must provide a non-empty JWK list";
+        assert jwks != null : "Verify requires a non-null JWK list";
+        if (jwks.isEmpty()) {
+            throw new ElasticsearchException("Verify requires a non-empty JWK list");
+        }
         final String id = jwt.getHeader().getKeyID();
         final JWSAlgorithm alg = jwt.getHeader().getAlgorithm();
         LOGGER.trace("JWKs [{}], JWT KID [{}], and JWT Algorithm [{}] before filters.", jwks.size(), id, alg.getName());
@@ -305,12 +260,35 @@ public class JwtValidateUtil {
         final List<JWK> jwksStrength = jwksAlg.stream().filter(j -> JwkValidateUtil.isMatch(j, alg.getName())).toList();
         LOGGER.debug("JWKs [{}] after Algorithm [{}] match filter.", jwksStrength.size(), alg);
 
+        // No JWKs passed the kid, alg, and strength checks, so nothing left to use in verifying the JWT signature
+        if (jwksStrength.isEmpty()) {
+            throw new ElasticsearchException("Verify failed because all " + jwks.size() + " provided JWKs were filtered.");
+        }
+
         for (final JWK jwk : jwksStrength) {
             if (jwt.verify(JwtValidateUtil.createJwsVerifier(jwk))) {
-                return; // VERIFY SUCCEEDED
+                LOGGER.trace(
+                    "JWT signature validation succeeded with JWK kty=[{}], jwtAlg=[{}], jwtKid=[{}], use=[{}], ops=[{}]",
+                    jwk.getKeyType(),
+                    jwk.getAlgorithm(),
+                    jwk.getKeyID(),
+                    jwk.getKeyUse(),
+                    jwk.getKeyOperations()
+                );
+                return;
+            } else {
+                LOGGER.trace(
+                    "JWT signature validation failed with JWK kty=[{}], jwtAlg=[{}], jwtKid=[{}], use=[{}], ops={}",
+                    jwk.getKeyType(),
+                    jwk.getAlgorithm(),
+                    jwk.getKeyID(),
+                    jwk.getKeyUse(),
+                    jwk.getKeyOperations() == null ? "[null]" : jwk.getKeyOperations()
+                );
             }
         }
-        throw new Exception("Verify failed using " + jwksStrength.size() + " of " + jwks.size() + " provided JWKs.");
+
+        throw new ElasticsearchException("Verify failed using " + jwksStrength.size() + " of " + jwks.size() + " provided JWKs.");
     }
 
     public static JWSVerifier createJwsVerifier(final JWK jwk) throws JOSEException {

+ 17 - 35
x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwkValidateUtilTests.java

@@ -10,13 +10,14 @@ import com.nimbusds.jose.JOSEException;
 import com.nimbusds.jose.JWSAlgorithm;
 import com.nimbusds.jose.jwk.JWK;
 import com.nimbusds.jose.jwk.OctetSequenceKey;
-import com.nimbusds.jose.util.Base64URL;
 
 import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings;
 
 import java.nio.charset.StandardCharsets;
+import java.util.Arrays;
+import java.util.Collection;
 import java.util.List;
 
 import static org.hamcrest.Matchers.anyOf;
@@ -27,46 +28,27 @@ public class JwkValidateUtilTests extends JwtTestCase {
 
     private static final Logger LOGGER = LogManager.getLogger(JwkValidateUtilTests.class);
 
-    // HMAC JWKSet setting can use keys from randomJwkHmac()
-    // HMAC key setting cannot use randomJwkHmac(), it must use randomJwkHmacString()
-    public void testConvertHmacJwkToStringToJwk() throws Exception {
-        final JWSAlgorithm jwsAlgorithm = JWSAlgorithm.parse(randomFrom(JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS_HMAC));
-
-        // Use HMAC random bytes for OIDC JWKSet setting only. Demonstrate encode/decode fails if used in OIDC HMAC key setting.
-        final OctetSequenceKey hmacKeyRandomBytes = JwtTestCase.randomJwkHmac(jwsAlgorithm);
-        assertThat(this.hmacEncodeDecodeAsPasswordTestHelper(hmacKeyRandomBytes), is(false));
-
-        // Convert HMAC random bytes to UTF8 bytes. This makes it usable as an OIDC HMAC key setting.
-        final OctetSequenceKey hmacKeyString1 = JwtTestCase.conditionJwkHmacForOidc(hmacKeyRandomBytes);
-        assertThat(this.hmacEncodeDecodeAsPasswordTestHelper(hmacKeyString1), is(true));
-
-        // Generate HMAC UTF8 bytes. This is usable as an OIDC HMAC key setting.
-        final OctetSequenceKey hmacKeyString2 = JwtTestCase.randomJwkHmacOidc(jwsAlgorithm);
-        assertThat(this.hmacEncodeDecodeAsPasswordTestHelper(hmacKeyString2), is(true));
+    // Test decode bytes as UTF8 to String, encode back to UTF8, and compare to original bytes. If same, it is safe for OIDC JWK encode.
+    static boolean isJwkHmacOidcSafe(final JWK jwk) {
+        if (jwk instanceof OctetSequenceKey jwkHmac) {
+            final byte[] rawKeyBytes = jwkHmac.getKeyValue().decode();
+            return Arrays.equals(rawKeyBytes, new String(rawKeyBytes, StandardCharsets.UTF_8).getBytes(StandardCharsets.UTF_8));
+        }
+        return true;
     }
 
-    private boolean hmacEncodeDecodeAsPasswordTestHelper(final OctetSequenceKey hmacKey) {
-        final OctetSequenceKey hmacKeyNoAttributes = JwtTestCase.jwkHmacRemoveAttributes(hmacKey);
-        // Encode input key as Base64(keyBytes) and Utf8String(keyBytes)
-        final String keyBytesToBase64 = hmacKey.getKeyValue().toString();
-        final String keyBytesAsUtf8 = hmacKey.getKeyValue().decodeToString();
-
-        // Decode Base64(keyBytes) into new key and compare to original. This always works.
-        final OctetSequenceKey decodeFromBase64 = new OctetSequenceKey.Builder(new Base64URL(keyBytesToBase64)).build();
-        LOGGER.info("Base64 enc/dec test:\ngen: [" + hmacKey + "]\nenc: [" + keyBytesToBase64 + "]\ndec: [" + decodeFromBase64 + "]\n");
-        if (decodeFromBase64.equals(hmacKeyNoAttributes) == false) {
-            return false;
+    static boolean areJwkHmacOidcSafe(final Collection<JWK> jwks) {
+        for (final JWK jwk : jwks) {
+            if (JwkValidateUtilTests.isJwkHmacOidcSafe(jwk) == false) {
+                return false;
+            }
         }
-
-        // Decode Utf8String(keyBytes) into new key and compare to original. Only works for randomJwkHmacString, fails for randomJwkHmac.
-        final OctetSequenceKey decodeFromUtf8 = new OctetSequenceKey.Builder(keyBytesAsUtf8.getBytes(StandardCharsets.UTF_8)).build();
-        LOGGER.info("UTF8 enc/dec test:\ngen: [" + hmacKey + "]\nenc: [" + keyBytesAsUtf8 + "]\ndec: [" + decodeFromUtf8 + "]\n");
-        return decodeFromUtf8.equals(hmacKeyNoAttributes);
+        return true;
     }
 
     public void testComputeBitLengthRsa() throws Exception {
         for (final String signatureAlgorithmRsa : JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS_RSA) {
-            final JWK jwk = JwtTestCase.randomJwk(signatureAlgorithmRsa);
+            final JWK jwk = JwtTestCase.randomJwkRsa(JWSAlgorithm.parse(signatureAlgorithmRsa));
             final int minLength = JwkValidateUtil.computeBitLengthRsa(jwk.toRSAKey().toPublicKey());
             assertThat(minLength, is(anyOf(equalTo(2048), equalTo(3072))));
         }
@@ -86,7 +68,7 @@ public class JwkValidateUtilTests extends JwtTestCase {
 
     private void filterJwksAndAlgorithmsTestHelper(final List<String> candidateAlgs) throws JOSEException {
         final List<String> algsRandom = randomOfMinUnique(2, candidateAlgs); // duplicates allowed
-        final List<JwtIssuer.AlgJwkPair> algJwkPairsAll = JwtTestCase.randomJwks(algsRandom);
+        final List<JwtIssuer.AlgJwkPair> algJwkPairsAll = JwtTestCase.randomJwks(algsRandom, randomBoolean());
         final List<JWK> jwks = algJwkPairsAll.stream().map(JwtIssuer.AlgJwkPair::jwk).toList();
         final List<String> algsAll = algJwkPairsAll.stream().map(JwtIssuer.AlgJwkPair::alg).toList();
         final List<JWK> jwksAll = algJwkPairsAll.stream().map(JwtIssuer.AlgJwkPair::jwk).toList();

+ 12 - 10
x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticationTokenTests.java

@@ -6,6 +6,7 @@
  */
 package org.elasticsearch.xpack.security.authc.jwt;
 
+import com.nimbusds.jose.JOSEObjectType;
 import com.nimbusds.jose.jwk.JWK;
 import com.nimbusds.jwt.SignedJWT;
 
@@ -27,7 +28,7 @@ public class JwtAuthenticationTokenTests extends JwtTestCase {
 
     public void testJwtAuthenticationTokenParse() throws Exception {
         final String signatureAlgorithm = randomFrom(JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS);
-        final JWK jwk = JwtTestCase.randomJwk(signatureAlgorithm);
+        final JWK jwk = JwtTestCase.randomJwk(signatureAlgorithm, randomBoolean());
 
         final SecureString jwt = JwtTestCase.randomBespokeJwt(jwk, signatureAlgorithm); // bespoke JWT, not tied to any JWT realm
         final SecureString clientSharedSecret = randomBoolean() ? null : new SecureString(randomAlphaOfLengthBetween(10, 20).toCharArray());
@@ -65,24 +66,25 @@ public class JwtAuthenticationTokenTests extends JwtTestCase {
         final String principalClaimValue = randomAlphaOfLengthBetween(8, 32);
 
         final String signatureAlgorithm = randomFrom(JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS);
-        final JWK jwk = JwtTestCase.randomJwk(signatureAlgorithm);
+        final JWK jwk = JwtTestCase.randomJwk(signatureAlgorithm, randomBoolean());
 
         final Instant now = Instant.now().truncatedTo(ChronoUnit.SECONDS);
         final SignedJWT unsignedJwt = JwtTestCase.buildUnsignedJwt(
-            null, // type
-            signatureAlgorithm,
+            randomBoolean() ? null : JOSEObjectType.JWT.toString(), // kty
+            randomBoolean() ? null : jwk.getKeyID(), // kid
+            signatureAlgorithm, // alg
             null, // jwtID
-            issuer,
-            List.of(audience),
+            issuer, // iss
+            List.of(audience), // aud
             null, // sub claim value
             principalClaimName, // principal claim name
             principalClaimValue, // principal claim value
             null, // groups claim
             List.of(), // groups
-            Date.from(now.minusSeconds(randomLongBetween(10, 20))), // auth_time
-            Date.from(now), // iat
-            Date.from(now.minusSeconds(randomLongBetween(5, 10))), // nbf
-            Date.from(now.plusSeconds(randomLongBetween(3600, 7200))), // exp
+            Date.from(now.minusSeconds(60 * randomLongBetween(10, 20))), // auth_time
+            Date.from(now.minusSeconds(randomBoolean() ? 0 : 60 * randomLongBetween(5, 10))), // iat
+            Date.from(now), // nbf
+            Date.from(now.plusSeconds(60 * randomLongBetween(3600, 7200))), // exp
             null, // nonce
             Map.of() // other claims
         );

+ 44 - 36
x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtIssuer.java

@@ -7,22 +7,23 @@
 
 package org.elasticsearch.xpack.security.authc.jwt;
 
+import com.nimbusds.jose.JOSEException;
 import com.nimbusds.jose.jwk.JWK;
 import com.nimbusds.jose.jwk.JWKSet;
 
 import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
-import org.elasticsearch.common.Strings;
+import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings;
 import org.elasticsearch.xpack.core.security.user.User;
 
 import java.io.Closeable;
 import java.io.IOException;
 import java.nio.charset.StandardCharsets;
-import java.util.ArrayList;
+import java.util.Collections;
 import java.util.List;
 import java.util.Map;
-import java.util.Set;
-import java.util.stream.Collectors;
+
+import static org.elasticsearch.test.ESTestCase.randomBoolean;
 
 /**
  * Test class with settings for a JWT issuer to sign JWTs for users.
@@ -38,59 +39,66 @@ public class JwtIssuer implements Closeable {
     final List<String> audiencesClaimValue; // claim name is hard-coded to `aud` for OIDC ID Token compatibility
     final String principalClaimName; // claim name is configurable, EX: Users (sub, oid, email, dn, uid), Clients (azp, appid, client_id)
     final Map<String, User> principals; // principals with roles, for sending encoded JWTs into JWT realms for authc/authz verification
-    final List<AlgJwkPair> algAndJwksPkc;
-    final List<AlgJwkPair> algAndJwksHmac;
-    final AlgJwkPair algAndJwkHmacOidc;
+    final JwtIssuerHttpsServer httpsServer;
+
+    List<String> algorithmsAll;
 
     // Computed values
-    final List<AlgJwkPair> algAndJwksAll;
-    final Set<String> algorithmsAll;
-    final String encodedJwkSetPkcPrivate;
-    final String encodedJwkSetPkcPublic;
-    final String encodedJwkSetHmac;
-    final String encodedKeyHmacOidc;
-    final JwtIssuerHttpsServer httpsServer;
+    List<AlgJwkPair> algAndJwksPkc;
+    List<AlgJwkPair> algAndJwksHmac;
+    AlgJwkPair algAndJwkHmacOidc;
+    List<AlgJwkPair> algAndJwksAll;
+    String encodedJwkSetPkcPublicPrivate;
+    String encodedJwkSetPkcPublic;
+    String encodedJwkSetHmac;
+    String encodedKeyHmacOidc;
 
     JwtIssuer(
         final String issuerClaimValue,
         final List<String> audiencesClaimValue,
         final String principalClaimName,
         final Map<String, User> principals,
-        final List<AlgJwkPair> algAndJwksPkc,
-        final List<AlgJwkPair> algAndJwksHmac,
-        final AlgJwkPair algAndJwkHmacOidc,
         final boolean createHttpsServer
     ) throws Exception {
         this.issuerClaimValue = issuerClaimValue;
         this.audiencesClaimValue = audiencesClaimValue;
         this.principalClaimName = principalClaimName;
         this.principals = principals;
-        this.algAndJwksPkc = algAndJwksPkc;
-        this.algAndJwksHmac = algAndJwksHmac;
-        this.algAndJwkHmacOidc = algAndJwkHmacOidc;
-
-        this.algAndJwksAll = new ArrayList<>(this.algAndJwksPkc.size() + this.algAndJwksHmac.size() + 1);
-        this.algAndJwksAll.addAll(this.algAndJwksPkc);
-        this.algAndJwksAll.addAll(this.algAndJwksHmac);
-        if (this.algAndJwkHmacOidc != null) {
-            this.algAndJwksAll.add(this.algAndJwkHmacOidc);
-        }
+        this.httpsServer = createHttpsServer ? new JwtIssuerHttpsServer(null) : null;
+    }
 
-        this.algorithmsAll = this.algAndJwksAll.stream().map(p -> p.alg).collect(Collectors.toSet());
+    // The flag areHmacJwksOidcSafe indicates if all provided HMAC JWKs are UTF8, for HMAC OIDC JWK encoding compatibility.
+    void setJwks(final List<AlgJwkPair> algAndJwks, final boolean areHmacJwksOidcSafe) throws JOSEException {
+        this.algorithmsAll = algAndJwks.stream().map(e -> e.alg).toList();
+        LOGGER.info("Setting JWKs: algorithms=[{}], areHmacJwksOidcSafe=[{}]", String.join(",", this.algorithmsAll), areHmacJwksOidcSafe);
+        this.algAndJwksAll = algAndJwks;
+        this.algAndJwksPkc = this.algAndJwksAll.stream()
+            .filter(e -> JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS_PKC.contains(e.alg))
+            .toList();
+        this.algAndJwksHmac = this.algAndJwksAll.stream()
+            .filter(e -> JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS_HMAC.contains(e.alg))
+            .toList();
+        if ((this.algAndJwksHmac.size() == 1) && (areHmacJwksOidcSafe) && (randomBoolean())) {
+            this.algAndJwkHmacOidc = this.algAndJwksHmac.get(0);
+            this.algAndJwksHmac = Collections.emptyList();
+        } else {
+            this.algAndJwkHmacOidc = null;
+        }
 
+        // Encode PKC JWKSet (key material bytes are wrapped in Base64URL, and then wraps in JSON)
         final JWKSet jwkSetPkc = new JWKSet(this.algAndJwksPkc.stream().map(p -> p.jwk).toList());
+        this.encodedJwkSetPkcPublicPrivate = JwtUtil.serializeJwkSet(jwkSetPkc, false);
+        this.encodedJwkSetPkcPublic = JwtUtil.serializeJwkSet(jwkSetPkc, true);
+
+        // Encode HMAC JWKSet (key material bytes are wrapped in Base64URL, and then wraps in JSON)
         final JWKSet jwkSetHmac = new JWKSet(this.algAndJwksHmac.stream().map(p -> p.jwk).toList());
+        this.encodedJwkSetHmac = JwtUtil.serializeJwkSet(jwkSetHmac, false);
 
-        this.encodedJwkSetPkcPrivate = jwkSetPkc.getKeys().isEmpty() ? null : JwtUtil.serializeJwkSet(jwkSetPkc, false);
-        this.encodedJwkSetPkcPublic = jwkSetPkc.getKeys().isEmpty() ? null : JwtUtil.serializeJwkSet(jwkSetPkc, true);
-        this.encodedJwkSetHmac = jwkSetHmac.getKeys().isEmpty() ? null : JwtUtil.serializeJwkSet(jwkSetHmac, false);
+        // Encode HMAC OIDC JWK (key material bytes are decoded from UTF8 to UNICODE String)
         this.encodedKeyHmacOidc = (algAndJwkHmacOidc == null) ? null : JwtUtil.serializeJwkHmacOidc(this.algAndJwkHmacOidc.jwk);
 
-        if ((Strings.hasText(this.encodedJwkSetPkcPublic) == false) || (createHttpsServer == false)) {
-            this.httpsServer = null; // no PKC JWKSet, or skip HTTPS server because caller will use local file instead
-        } else {
-            final byte[] encodedJwkSetPkcPublicBytes = this.encodedJwkSetPkcPublic.getBytes(StandardCharsets.UTF_8);
-            this.httpsServer = new JwtIssuerHttpsServer(encodedJwkSetPkcPublicBytes);
+        if (this.httpsServer != null) {
+            this.httpsServer.updateJwkSetPkcContents(this.encodedJwkSetPkcPublic.getBytes(StandardCharsets.UTF_8));
         }
     }
 

+ 11 - 2
x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtIssuerHttpsServer.java

@@ -68,6 +68,11 @@ public class JwtIssuerHttpsServer implements Closeable {
         LOGGER.debug("Started [{}]", this.url);
     }
 
+    public void updateJwkSetPkcContents(final byte[] encodedJwkSetPkcPublicBytes) {
+        this.httpsServer.removeContext(PATH);
+        this.httpsServer.createContext(PATH, new JwtIssuerHttpHandler(encodedJwkSetPkcPublicBytes));
+    }
+
     @Override
     public void close() throws IOException {
         if (this.httpsServer != null) {
@@ -92,8 +97,12 @@ public class JwtIssuerHttpsServer implements Closeable {
                 final String path = httpExchange.getRequestURI().getPath(); // EX: "/", "/valid/", "/valid/pkc_jwkset.json"
                 LOGGER.trace("Request: [{}]", path);
                 try (OutputStream os = httpExchange.getResponseBody()) {
-                    httpExchange.sendResponseHeaders(HttpURLConnection.HTTP_OK, this.encodedJwkSetPkcPublicBytes.length);
-                    os.write(this.encodedJwkSetPkcPublicBytes);
+                    if (encodedJwkSetPkcPublicBytes == null) {
+                        httpExchange.sendResponseHeaders(HttpURLConnection.HTTP_NOT_FOUND, 0);
+                    } else {
+                        httpExchange.sendResponseHeaders(HttpURLConnection.HTTP_OK, this.encodedJwkSetPkcPublicBytes.length);
+                        os.write(this.encodedJwkSetPkcPublicBytes);
+                    }
                 }
                 LOGGER.trace("Response: [{}]", path); // Confirm client didn't disconnect before flush
             } catch (Throwable t) {

+ 170 - 9
x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmAuthenticateTests.java

@@ -13,6 +13,8 @@ import com.nimbusds.jwt.JWTClaimsSet;
 import com.nimbusds.jwt.PlainJWT;
 import com.nimbusds.jwt.SignedJWT;
 
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
 import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.settings.MockSecureSettings;
@@ -26,9 +28,9 @@ import org.elasticsearch.xpack.core.security.authc.RealmSettings;
 import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings;
 import org.elasticsearch.xpack.core.security.user.User;
 
-import java.nio.charset.StandardCharsets;
 import java.time.Instant;
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.Date;
 import java.util.List;
 
@@ -39,6 +41,7 @@ import static org.hamcrest.Matchers.notNullValue;
 import static org.hamcrest.Matchers.nullValue;
 
 public class JwtRealmAuthenticateTests extends JwtRealmTestCase {
+    private static final Logger LOGGER = LogManager.getLogger(JwtRealmAuthenticateTests.class);
 
     /**
      * Test with empty roles.
@@ -90,6 +93,162 @@ public class JwtRealmAuthenticateTests extends JwtRealmTestCase {
         this.doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwt, clientSecret, jwtAuthcRange);
     }
 
+    /**
+     * Test with updated/removed/restored JWKs.
+     * @throws Exception Unexpected test failure
+     */
+    public void testJwkSetUpdates() throws Exception {
+        this.jwtIssuerAndRealms = this.generateJwtIssuerRealmPairs(
+            this.createJwtRealmsSettingsBuilder(),
+            new MinMax(1, 3), // realmsRange
+            new MinMax(0, 0), // authzRange
+            new MinMax(1, JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS.size()), // algsRange
+            new MinMax(1, 3), // audiencesRange
+            new MinMax(1, 3), // usersRange
+            new MinMax(0, 3), // rolesRange
+            new MinMax(0, 1), // jwtCacheSizeRange
+            randomBoolean() // createHttpsServer
+        );
+        final JwtIssuerAndRealm jwtIssuerAndRealm = this.randomJwtIssuerRealmPair();
+        assertThat(jwtIssuerAndRealm.realm().delegatedAuthorizationSupport.hasDelegation(), is(false));
+
+        final User user = this.randomUser(jwtIssuerAndRealm.issuer());
+        final SecureString jwtJwks1 = this.randomJwt(jwtIssuerAndRealm, user);
+        final SecureString clientSecret = jwtIssuerAndRealm.realm().clientAuthenticationSharedSecret;
+        final MinMax jwtAuthcRange = new MinMax(2, 3);
+        this.doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwtJwks1, clientSecret, jwtAuthcRange);
+
+        // Details about first JWT using the JWT issuer original JWKs
+        final String jwt1JwksAlg = SignedJWT.parse(jwtJwks1.toString()).getHeader().getAlgorithm().getName();
+        final boolean isPkcJwtJwks1 = JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS_PKC.contains(jwt1JwksAlg);
+        LOGGER.debug("JWT alg=[{}]", jwt1JwksAlg);
+
+        // Backup JWKs 1
+        final List<JwtIssuer.AlgJwkPair> jwtIssuerJwks1Backup = jwtIssuerAndRealm.issuer().algAndJwksAll;
+        final boolean jwtIssuerJwks1OidcSafe = JwkValidateUtilTests.areJwkHmacOidcSafe(
+            jwtIssuerJwks1Backup.stream().map(e -> e.jwk()).toList()
+        );
+        LOGGER.debug("JWKs 1, algs=[{}]", String.join(",", jwtIssuerAndRealm.issuer().algorithmsAll));
+
+        // Empty all JWT issuer JWKs.
+        LOGGER.debug("JWKs 1 backed up, algs=[{}]", String.join(",", jwtIssuerAndRealm.issuer().algorithmsAll));
+        jwtIssuerAndRealm.issuer().setJwks(Collections.emptyList(), jwtIssuerJwks1OidcSafe);
+        super.printJwtIssuer(jwtIssuerAndRealm.issuer());
+        super.copyIssuerJwksToRealmConfig(jwtIssuerAndRealm);
+        LOGGER.debug("JWKs 1 emptied, algs=[{}]", String.join(",", jwtIssuerAndRealm.issuer().algorithmsAll));
+
+        // Original JWT continues working, because JWT realm cached old JWKs in memory.
+        this.doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwtJwks1, clientSecret, jwtAuthcRange);
+        LOGGER.debug("JWT 1 still worked, because JWT realm has old JWKs cached in memory");
+
+        // Restore original JWKs 1 into the JWT issuer.
+        jwtIssuerAndRealm.issuer().setJwks(jwtIssuerJwks1Backup, jwtIssuerJwks1OidcSafe);
+        super.printJwtIssuer(jwtIssuerAndRealm.issuer());
+        super.copyIssuerJwksToRealmConfig(jwtIssuerAndRealm);
+        LOGGER.debug("JWKs 1 restored, algs=[{}]", String.join(",", jwtIssuerAndRealm.issuer().algorithmsAll));
+
+        // Original JWT continues working, because JWT realm cached old JWKs in memory.
+        this.doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwtJwks1, clientSecret, jwtAuthcRange);
+        LOGGER.debug("JWT 1 still worked, because JWT realm has old JWKs cached in memory");
+
+        // Generate a replacement set of JWKs 2 for the JWT issuer.
+        final List<JwtIssuer.AlgJwkPair> jwtIssuerJwks2Backup = JwtRealmTestCase.randomJwks(
+            jwtIssuerJwks1Backup.stream().map(e -> e.alg()).toList(),
+            jwtIssuerJwks1OidcSafe
+        );
+        jwtIssuerAndRealm.issuer().setJwks(jwtIssuerJwks2Backup, jwtIssuerJwks1OidcSafe);
+        super.printJwtIssuer(jwtIssuerAndRealm.issuer());
+        super.copyIssuerJwksToRealmConfig(jwtIssuerAndRealm);
+        LOGGER.debug("JWKs 2 created, algs=[{}]", String.join(",", jwtIssuerAndRealm.issuer().algorithmsAll));
+
+        // Original JWT continues working, because JWT realm still has original JWKs cached in memory.
+        // - jwtJwks1(PKC): Pass (Original PKC JWKs are still in the realm)
+        // - jwtJwks1(HMAC): Pass (Original HMAC JWKs are still in the realm)
+        this.doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwtJwks1, clientSecret, jwtAuthcRange);
+        LOGGER.debug("JWT 1 still worked, because JWT realm has old JWKs cached in memory");
+
+        // Create a JWT using the new JWKs.
+        final SecureString jwtJwks2 = this.randomJwt(jwtIssuerAndRealm, user);
+        final String jwtJwks2Alg = SignedJWT.parse(jwtJwks2.toString()).getHeader().getAlgorithm().getName();
+        final boolean isPkcJwtJwks2 = JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS_PKC.contains(jwtJwks2Alg);
+        LOGGER.debug("Created JWT 2: oidcSafe=[{}], algs=[{}, {}]", jwtIssuerJwks1OidcSafe, jwt1JwksAlg, jwtJwks2Alg);
+
+        // Try new JWT.
+        // - jwtJwks2(PKC): PKC reload triggered and loaded new JWKs, so PASS
+        // - jwtJwks2(HMAC): HMAC reload triggered but it is a no-op, so FAIL
+        if (isPkcJwtJwks2) {
+            this.doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwtJwks2, clientSecret, jwtAuthcRange);
+            LOGGER.debug("PKC JWT 2 worked with JWKs 2");
+        } else {
+            this.verifyAuthenticateFailureHelper(jwtIssuerAndRealm, jwtJwks2, clientSecret);
+            LOGGER.debug("HMAC JWT 2 failed with JWKs 1");
+        }
+
+        // Try old JWT.
+        // - jwtJwks2(PKC): PKC reload triggered and loaded new JWKs, jwtJwks1(PKC): PKC reload triggered and loaded new JWKs, so FAIL
+        // - jwtJwks2(PKC): PKC reload triggered and loaded new JWKs, jwtJwks1(HMAC): HMAC reload not triggered, so PASS
+        // - jwtJwks2(HMAC): HMAC reload triggered but it is a no-op, jwtJwks1(PKC): PKC reload not triggered, so PASS
+        // - jwtJwks2(HMAC): HMAC reload triggered but it is a no-op, jwtJwks1(HMAC): HMAC reload not triggered, so PASS
+        if (isPkcJwtJwks1 == false || isPkcJwtJwks2 == false) {
+            this.doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwtJwks1, clientSecret, jwtAuthcRange);
+        } else {
+            this.verifyAuthenticateFailureHelper(jwtIssuerAndRealm, jwtJwks1, clientSecret);
+        }
+
+        // Empty all JWT issuer JWKs.
+        jwtIssuerAndRealm.issuer().setJwks(Collections.emptyList(), jwtIssuerJwks1OidcSafe);
+        super.printJwtIssuer(jwtIssuerAndRealm.issuer());
+        super.copyIssuerJwksToRealmConfig(jwtIssuerAndRealm);
+
+        // New JWT continues working because JWT realm will end up with PKC JWKs 2 and HMAC JWKs 1 in memory
+        if (isPkcJwtJwks2) {
+            this.doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwtJwks2, clientSecret, jwtAuthcRange);
+        } else {
+            this.verifyAuthenticateFailureHelper(jwtIssuerAndRealm, jwtJwks2, clientSecret);
+        }
+
+        // Trigger JWT realm to reload JWKs and go into a degraded state
+        // - jwtJwks1(HMAC): HMAC reload not triggered, so PASS
+        // - jwtJwks1(PKC): PKC reload triggered and loaded new JWKs, so FAIL
+        if (isPkcJwtJwks1 == false || isPkcJwtJwks2 == false) {
+            this.doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwtJwks1, clientSecret, jwtAuthcRange);
+        } else {
+            this.verifyAuthenticateFailureHelper(jwtIssuerAndRealm, jwtJwks1, clientSecret);
+        }
+
+        // Try new JWT and verify degraded state caused by empty PKC JWKs
+        // - jwtJwks1(PKC) + jwtJwks2(PKC): If second JWT is PKC, and first JWT is PKC, degraded state can be tested.
+        // - jwtJwks1(HMAC) + jwtJwks2(PKC): If second JWT is PKC, but first JWT is HMAC, HMAC JWT 1 above didn't trigger PKC reload.
+        // - jwtJwks1(PKC) + jwtJwks2(HMAC): If second JWT is HMAC, it always fails because HMAC reload not supported.
+        // - jwtJwks1(HMAC) + jwtJwks2(HMAC): If second JWT is HMAC, it always fails because HMAC reload not supported.
+        if (isPkcJwtJwks1 == false && isPkcJwtJwks2) {
+            this.doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwtJwks2, clientSecret, jwtAuthcRange);
+        } else {
+            this.verifyAuthenticateFailureHelper(jwtIssuerAndRealm, jwtJwks2, clientSecret);
+        }
+
+        // Restore JWKs 2 to the realm
+        jwtIssuerAndRealm.issuer().setJwks(jwtIssuerJwks2Backup, jwtIssuerJwks1OidcSafe);
+        super.copyIssuerJwksToRealmConfig(jwtIssuerAndRealm);
+        super.printJwtIssuer(jwtIssuerAndRealm.issuer());
+
+        // Trigger JWT realm to reload JWKs and go into a recovered state
+        // - jwtJwks2(PKC): Pass (Triggers PKC reload, gets newer PKC JWKs), jwtJwks1(PKC): Fail (Triggers PKC reload, gets new PKC JWKs)
+        // - jwtJwks2(PKC): Pass (Triggers PKC reload, gets newer PKC JWKs), jwtJwks1(HMAC): Pass (HMAC reload was a no-op)
+        // - jwtJwks2(HMAC): Fail (Triggers HMAC reload, but it is a no-op), jwtJwks1(PKC): Fail (Triggers PKC reload, gets new PKC JWKs)
+        // - jwtJwks2(HMAC): Fail (Triggers HMAC reload, but it is a no-op), jwtJwks1(HMAC): Pass (HMAC reload was a no-op)
+        if (isPkcJwtJwks2) {
+            this.doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwtJwks2, clientSecret, jwtAuthcRange);
+        } else {
+            this.verifyAuthenticateFailureHelper(jwtIssuerAndRealm, jwtJwks2, clientSecret);
+        }
+        if (isPkcJwtJwks1 == false || isPkcJwtJwks2 == false) {
+            this.doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwtJwks1, clientSecret, jwtAuthcRange);
+        } else {
+            this.verifyAuthenticateFailureHelper(jwtIssuerAndRealm, jwtJwks1, clientSecret);
+        }
+    }
+
     /**
      * Test with authz realms.
      * @throws Exception Unexpected test failure
@@ -253,16 +412,16 @@ public class JwtRealmAuthenticateTests extends JwtRealmTestCase {
         {   // Verify rejection of a tampered header (flip HMAC=>RSA or RSA/EC=>HMAC)
             final String mixupAlg; // Check if there are any algorithms available in the realm for attempting a flip test
             if (JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS_HMAC.contains(validHeader.getAlgorithm().getName())) {
-                if (jwtIssuerAndRealm.realm().jwksAlgsPkc.algs().isEmpty()) {
+                if (jwtIssuerAndRealm.realm().contentAndJwksAlgsPkc.jwksAlgs().algs().isEmpty()) {
                     mixupAlg = null; // cannot flip HMAC to PKC (no PKC algs available)
                 } else {
-                    mixupAlg = randomFrom(jwtIssuerAndRealm.realm().jwksAlgsPkc.algs()); // flip HMAC to PKC
+                    mixupAlg = randomFrom(jwtIssuerAndRealm.realm().contentAndJwksAlgsPkc.jwksAlgs().algs()); // flip HMAC to PKC
                 }
             } else {
-                if (jwtIssuerAndRealm.realm().jwksAlgsHmac.algs().isEmpty()) {
+                if (jwtIssuerAndRealm.realm().contentAndJwksAlgsHmac.jwksAlgs().algs().isEmpty()) {
                     mixupAlg = null; // cannot flip PKC to HMAC (no HMAC algs available)
                 } else {
-                    mixupAlg = randomFrom(jwtIssuerAndRealm.realm().jwksAlgsHmac.algs()); // flip HMAC to PKC
+                    mixupAlg = randomFrom(jwtIssuerAndRealm.realm().contentAndJwksAlgsHmac.jwksAlgs().algs()); // flip HMAC to PKC
                 }
             }
             // This check can only be executed if there is a flip algorithm available in the realm
@@ -328,6 +487,7 @@ public class JwtRealmAuthenticateTests extends JwtRealmTestCase {
         final int realmsCount = 2;
         final List<Realm> allRealms = new ArrayList<>(realmsCount); // two identical realms for same issuer, except different client secret
         final JwtIssuer jwtIssuer = this.createJwtIssuer(0, principalClaimName, 12, 1, 1, 1, false);
+        super.printJwtIssuer(jwtIssuer);
         this.jwtIssuerAndRealms = new ArrayList<>(realmsCount);
         for (int i = 0; i < realmsCount; i++) {
             final String realmName = "realm_" + jwtIssuer.issuerClaimValue + "_" + i;
@@ -346,21 +506,21 @@ public class JwtRealmAuthenticateTests extends JwtRealmTestCase {
                     RealmSettings.getFullSettingKey(realmName, JwtRealmSettings.CLIENT_AUTHENTICATION_TYPE),
                     JwtRealmSettings.ClientAuthenticationType.SHARED_SECRET.value()
                 );
-            if (Strings.hasText(jwtIssuer.encodedJwkSetPkcPublic)) {
+            if (jwtIssuer.encodedJwkSetPkcPublic.isEmpty() == false) {
                 authcSettings.put(
                     RealmSettings.getFullSettingKey(realmName, JwtRealmSettings.PKC_JWKSET_PATH),
-                    super.saveToTempFile("jwkset.", ".json", jwtIssuer.encodedJwkSetPkcPublic.getBytes(StandardCharsets.UTF_8))
+                    super.saveToTempFile("jwkset.", ".json", jwtIssuer.encodedJwkSetPkcPublic)
                 );
             }
             // JWT authc realm secure settings
             final MockSecureSettings secureSettings = new MockSecureSettings();
-            if (Strings.hasText(jwtIssuer.encodedJwkSetHmac)) {
+            if (jwtIssuer.algAndJwksHmac.isEmpty() == false) {
                 secureSettings.setString(
                     RealmSettings.getFullSettingKey(realmName, JwtRealmSettings.HMAC_JWKSET),
                     jwtIssuer.encodedJwkSetHmac
                 );
             }
-            if (Strings.hasText(jwtIssuer.encodedKeyHmacOidc)) {
+            if (jwtIssuer.encodedKeyHmacOidc != null) {
                 secureSettings.setString(
                     RealmSettings.getFullSettingKey(realmName, JwtRealmSettings.HMAC_KEY),
                     jwtIssuer.encodedKeyHmacOidc
@@ -376,6 +536,7 @@ public class JwtRealmAuthenticateTests extends JwtRealmTestCase {
             jwtRealm.initialize(allRealms, super.licenseState);
             final JwtIssuerAndRealm jwtIssuerAndRealm = new JwtIssuerAndRealm(jwtIssuer, jwtRealm, jwtRealmSettingsBuilder);
             this.jwtIssuerAndRealms.add(jwtIssuerAndRealm); // add them so the test will clean them up
+            super.printJwtRealm(jwtRealm);
         }
 
         // pick 2nd realm and use its secret, verify 2nd realm does authc, which implies 1st realm rejects the secret

+ 21 - 20
x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmGenerateTests.java

@@ -7,6 +7,7 @@
 package org.elasticsearch.xpack.security.authc.jwt;
 
 import com.nimbusds.jose.JOSEObjectType;
+import com.nimbusds.jose.JWSAlgorithm;
 import com.nimbusds.jose.jwk.JWK;
 import com.nimbusds.jose.jwk.OctetSequenceKey;
 import com.nimbusds.jose.jwk.RSAKey;
@@ -77,11 +78,9 @@ public class JwtRealmGenerateTests extends JwtRealmTestCase {
             List.of("aud8"), // aud
             principalClaimName, // sub
             Collections.singletonMap("security_test_user", new User("security_test_user", "security_test_role")), // users
-            Collections.emptyList(), // algJwkPairsPkc
-            Collections.emptyList(), // algJwkPairsHmac
-            algJwkPairHmac, // algJwkPairHmac
             false // createHttpsServer
         );
+        jwtIssuer.setJwks(List.of(algJwkPairHmac), true);
 
         // Create realm settings
         final String realmName = "jwt8";
@@ -121,11 +120,13 @@ public class JwtRealmGenerateTests extends JwtRealmTestCase {
         final JwtRealmSettingsBuilder jwtRealmSettingsBuilder = new JwtRealmSettingsBuilder(realmName, configBuilder);
         final JwtIssuerAndRealm jwtIssuerAndRealm = new JwtIssuerAndRealm(jwtIssuer, jwtRealm, jwtRealmSettingsBuilder);
         super.jwtIssuerAndRealms = Collections.singletonList(jwtIssuerAndRealm); // super.shutdown() closes issuer+realm if necessary
+        super.printJwtRealmAndIssuer(jwtIssuerAndRealm);
 
         // Create JWT
         final User user = this.randomUser(jwtIssuerAndRealm.issuer());
         final SignedJWT unsignedJwt = JwtTestCase.buildUnsignedJwt(
-            randomBoolean() ? null : JOSEObjectType.JWT.toString(),
+            randomBoolean() ? null : JOSEObjectType.JWT.toString(), // kty
+            randomBoolean() ? null : algJwkPairHmac.jwk().getKeyID(), // kid
             algJwkPairHmac.alg(), // alg
             null, // jwtID
             jwtIssuerAndRealm.realm().allowedIssuer, // iss
@@ -156,7 +157,7 @@ public class JwtRealmGenerateTests extends JwtRealmTestCase {
      */
     public void testCreateJwtIntegrationTestRealm1() throws Exception {
         // Create RSA key for algorithm RS256
-        final JWK jwk = new RSAKey.Builder((RSAKey) JwtTestCase.randomJwk("RS256")).keyID("test-rsa-key").build();
+        final JWK jwk = new RSAKey.Builder(JwtTestCase.randomJwkRsa(JWSAlgorithm.RS256)).keyID("test-rsa-key").build();
         final JwtIssuer.AlgJwkPair algJwkPairPkc = new JwtIssuer.AlgJwkPair("RS256", jwk);
 
         final String principalClaimName = "sub";
@@ -171,11 +172,9 @@ public class JwtRealmGenerateTests extends JwtRealmTestCase {
             List.of("https://audience.example.com/"), // aud claim value
             principalClaimName, // principal claim name
             Collections.singletonMap("user1", new User("user1", "role1")), // users
-            List.of(algJwkPairPkc), // algJwkPairsPkc
-            Collections.emptyList(), // algJwkPairsHmac
-            null, // algJwkPairHmac
             false // createHttpsServer
         );
+        jwtIssuer.setJwks(List.of(algJwkPairPkc), false);
 
         // Create realm settings (no secure settings)
         final String realmName = "jwt1";
@@ -201,7 +200,7 @@ public class JwtRealmGenerateTests extends JwtRealmTestCase {
             )
             .put(
                 RealmSettings.getFullSettingKey(realmName, JwtRealmSettings.PKC_JWKSET_PATH),
-                super.saveToTempFile("jwkset.", ".json", jwtIssuer.encodedJwkSetPkcPublic.getBytes(StandardCharsets.UTF_8))
+                super.saveToTempFile("jwkset.", ".json", jwtIssuer.encodedJwkSetPkcPublic)
             );
 
         // Create realm
@@ -213,11 +212,13 @@ public class JwtRealmGenerateTests extends JwtRealmTestCase {
         final JwtRealmSettingsBuilder jwtRealmSettingsBuilder = new JwtRealmSettingsBuilder(realmName, configBuilder);
         final JwtIssuerAndRealm jwtIssuerAndRealm = new JwtIssuerAndRealm(jwtIssuer, jwtRealm, jwtRealmSettingsBuilder);
         super.jwtIssuerAndRealms = Collections.singletonList(jwtIssuerAndRealm); // super.shutdown() closes issuer+realm if necessary
+        super.printJwtRealmAndIssuer(jwtIssuerAndRealm);
 
         // Create JWT
         final User user = this.randomUser(jwtIssuerAndRealm.issuer());
         final SignedJWT unsignedJwt = JwtTestCase.buildUnsignedJwt(
-            randomBoolean() ? null : JOSEObjectType.JWT.toString(),
+            randomBoolean() ? null : JOSEObjectType.JWT.toString(), // kty
+            randomBoolean() ? null : algJwkPairPkc.jwk().getKeyID(), // kid
             algJwkPairPkc.alg(), // alg
             null, // jwtID
             jwtIssuerAndRealm.realm().allowedIssuer, // iss
@@ -266,11 +267,9 @@ public class JwtRealmGenerateTests extends JwtRealmTestCase {
             List.of("es01", "es02", "es03"), // aud claim value
             principalClaimName, // principal claim name
             Collections.singletonMap("user2", new User("user2", "role2")), // users
-            Collections.emptyList(), // algJwkPairsPkc
-            Collections.emptyList(), // algJwkPairsHmac
-            algJwkPairHmac, // algJwkPairHmac
             false // createHttpsServer
         );
+        jwtIssuer.setJwks(List.of(algJwkPairHmac), true);
 
         // Create realm settings
         final String realmName = "jwt2";
@@ -319,11 +318,13 @@ public class JwtRealmGenerateTests extends JwtRealmTestCase {
         final JwtRealmSettingsBuilder jwtRealmSettingsBuilder = new JwtRealmSettingsBuilder(realmName, configBuilder);
         final JwtIssuerAndRealm jwtIssuerAndRealm = new JwtIssuerAndRealm(jwtIssuer, jwtRealm, jwtRealmSettingsBuilder);
         super.jwtIssuerAndRealms = Collections.singletonList(jwtIssuerAndRealm); // super.shutdown() closes issuer+realm if necessary
+        super.printJwtRealmAndIssuer(jwtIssuerAndRealm);
 
         // Create JWT
         final User user = this.randomUser(jwtIssuerAndRealm.issuer());
         final SignedJWT unsignedJwt = JwtTestCase.buildUnsignedJwt(
-            JOSEObjectType.JWT.toString(),
+            randomBoolean() ? null : JOSEObjectType.JWT.toString(), // kty
+            randomBoolean() ? null : algJwkPairHmac.jwk().getKeyID(), // kid
             algJwkPairHmac.alg(), // alg
             null, // jwtID
             jwtIssuerAndRealm.realm().allowedIssuer, // iss
@@ -358,7 +359,6 @@ public class JwtRealmGenerateTests extends JwtRealmTestCase {
             new JwtIssuer.AlgJwkPair("HS384", new OctetSequenceKey.Builder(randomByteArrayOfLength(48)).keyID("test-hmac-384").build()),
             new JwtIssuer.AlgJwkPair("HS512", new OctetSequenceKey.Builder(randomByteArrayOfLength(64)).keyID("test-hmac-512").build())
         );
-        var selectedHmac = randomFrom(hmacKeys);
 
         final String principalClaimName = "sub";
         final Settings.Builder jwtRealmsServiceSettings = Settings.builder()
@@ -372,11 +372,9 @@ public class JwtRealmGenerateTests extends JwtRealmTestCase {
             List.of("jwt3-audience"), // aud claim value
             principalClaimName, // principal claim name
             Collections.singletonMap("user3", new User("user3", "role3")), // users
-            Collections.emptyList(), // algJwkPairsPkc
-            hmacKeys, // algJwkPairsHmac
-            null, // algJwkPairHmac
             false // createHttpsServer
         );
+        jwtIssuer.setJwks(hmacKeys, false);
 
         // Create realm settings
         final String realmName = "jwt3";
@@ -416,11 +414,14 @@ public class JwtRealmGenerateTests extends JwtRealmTestCase {
         final JwtRealmSettingsBuilder jwtRealmSettingsBuilder = new JwtRealmSettingsBuilder(realmName, configBuilder);
         final JwtIssuerAndRealm jwtIssuerAndRealm = new JwtIssuerAndRealm(jwtIssuer, jwtRealm, jwtRealmSettingsBuilder);
         super.jwtIssuerAndRealms = Collections.singletonList(jwtIssuerAndRealm); // super.shutdown() closes issuer+realm if necessary
+        super.printJwtRealmAndIssuer(jwtIssuerAndRealm);
 
         // Create JWT
+        final JwtIssuer.AlgJwkPair selectedHmac = randomFrom(hmacKeys);
         final User user = this.randomUser(jwtIssuerAndRealm.issuer());
         final SignedJWT unsignedJwt = JwtTestCase.buildUnsignedJwt(
-            JOSEObjectType.JWT.toString(),
+            randomBoolean() ? null : JOSEObjectType.JWT.toString(), // kty
+            randomBoolean() ? null : selectedHmac.jwk().getKeyID(), // kid
             selectedHmac.alg(), // alg
             null, // jwtID
             jwtIssuerAndRealm.realm().allowedIssuer, // iss
@@ -483,7 +484,7 @@ public class JwtRealmGenerateTests extends JwtRealmTestCase {
         sb.append("Audiences: ").append(String.join(",", jwtIssuer.audiencesClaimValue)).append('\n');
         sb.append("Algorithms: ").append(String.join(",", jwtIssuer.algorithmsAll)).append("\n");
         if (jwtIssuer.algAndJwksPkc.isEmpty() == false) {
-            sb.append("PKC JWKSet (Private): ").append(jwtIssuer.encodedJwkSetPkcPrivate).append("\n");
+            sb.append("PKC JWKSet (Private): ").append(jwtIssuer.encodedJwkSetPkcPublicPrivate).append("\n");
             sb.append("PKC JWKSet (Public): ").append(jwtIssuer.encodedJwkSetPkcPublic).append("\n");
         }
         if (jwtIssuer.algAndJwksHmac.isEmpty() == false) {

+ 146 - 108
x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmTestCase.java

@@ -7,18 +7,18 @@
 package org.elasticsearch.xpack.security.authc.jwt;
 
 import com.nimbusds.jose.JOSEObjectType;
-import com.nimbusds.jose.jwk.OctetSequenceKey;
+import com.nimbusds.jose.jwk.JWK;
 import com.nimbusds.jwt.SignedJWT;
 import com.nimbusds.openid.connect.sdk.Nonce;
 
 import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 import org.elasticsearch.action.support.PlainActionFuture;
-import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.settings.MockSecureSettings;
 import org.elasticsearch.common.settings.SecureString;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
+import org.elasticsearch.core.PathUtils;
 import org.elasticsearch.env.TestEnvironment;
 import org.elasticsearch.license.MockLicenseState;
 import org.elasticsearch.threadpool.TestThreadPool;
@@ -42,7 +42,9 @@ import org.elasticsearch.xpack.security.authc.support.MockLookupRealm;
 import org.junit.After;
 import org.junit.Before;
 
-import java.nio.charset.StandardCharsets;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.text.ParseException;
 import java.time.Instant;
 import java.time.temporal.ChronoUnit;
 import java.util.ArrayList;
@@ -199,45 +201,25 @@ public abstract class JwtRealmTestCase extends JwtTestCase {
         final boolean createHttpsServer
     ) throws Exception {
         final String issuer = "iss" + (i + 1) + "_" + randomIntBetween(0, 9999);
-
-        // Allow algorithm repeats, to cover testing of multiple JWKs for same algorithm
-        final List<String> algs = randomOfMinMaxNonUnique(algsCount, algsCount, JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS);
-        final List<String> algsPkc = algs.stream().filter(JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS_PKC::contains).toList();
-        final List<String> algsHmac = algs.stream().filter(JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS_HMAC::contains).toList();
-        final List<JwtIssuer.AlgJwkPair> algJwkPairsPkc = JwtTestCase.randomJwks(algsPkc);
-        // Key setting vs JWKSet setting are mutually exclusive, do not populate both
-        final List<JwtIssuer.AlgJwkPair> algJwkPairsHmac = new ArrayList<>(JwtTestCase.randomJwks(algsHmac)); // allow remove/add below
-        final JwtIssuer.AlgJwkPair algJwkPairHmacOidc;
-        if ((algJwkPairsHmac.size() == 0) || (randomBoolean())) {
-            algJwkPairHmacOidc = null; // list(0||1||N) => Key=null and JWKSet(N)
-        } else {
-            // Change one of the HMAC random bytes keys to an OIDC UTF8 key. Put it in either the Key setting or JWKSet setting.
-            final JwtIssuer.AlgJwkPair algJwkPairRandomBytes = algJwkPairsHmac.get(0);
-            final OctetSequenceKey jwkHmacRandomBytes = JwtTestCase.conditionJwkHmacForOidc((OctetSequenceKey) algJwkPairRandomBytes.jwk());
-            final JwtIssuer.AlgJwkPair algJwkPairUtf8Bytes = new JwtIssuer.AlgJwkPair(algJwkPairRandomBytes.alg(), jwkHmacRandomBytes);
-            if ((algJwkPairsHmac.size() == 1) && (randomBoolean())) {
-                algJwkPairHmacOidc = algJwkPairUtf8Bytes; // list(1) => Key=OIDC and JWKSet(0)
-                algJwkPairsHmac.remove(0);
-            } else {
-                algJwkPairHmacOidc = null; // list(N) => Key=null and JWKSet(OIDC+N-1)
-                algJwkPairsHmac.set(0, algJwkPairUtf8Bytes);
-            }
-        }
-
         final List<String> audiences = IntStream.range(0, audiencesCount).mapToObj(j -> issuer + "_aud" + (j + 1)).toList();
         final Map<String, User> users = JwtTestCase.generateTestUsersWithRoles(userCount, roleCount);
+        // Allow algorithm repeats, to cover testing of multiple JWKs for same algorithm
+        final JwtIssuer jwtIssuer = new JwtIssuer(issuer, audiences, principalClaimName, users, createHttpsServer);
+        final List<String> algorithms = randomOfMinMaxNonUnique(algsCount, algsCount, JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS);
+        final boolean areHmacJwksOidcSafe = randomBoolean();
+        final List<JwtIssuer.AlgJwkPair> algAndJwks = JwtRealmTestCase.randomJwks(algorithms, areHmacJwksOidcSafe);
+        jwtIssuer.setJwks(algAndJwks, areHmacJwksOidcSafe);
+        return jwtIssuer;
+    }
 
-        // Decide if public PKC JWKSet will be hosted in a local file or an HTTPS URL. If HTTPS URL, tell issuer to set up an HTTPS server.
-        return new JwtIssuer(
-            issuer,
-            audiences,
-            principalClaimName,
-            users,
-            algJwkPairsPkc,
-            algJwkPairsHmac,
-            algJwkPairHmacOidc,
-            createHttpsServer
-        );
+    protected void copyIssuerJwksToRealmConfig(final JwtIssuerAndRealm jwtIssuerAndRealm) throws Exception {
+        if ((jwtIssuerAndRealm.realm.isConfiguredJwkSetPkc) && (jwtIssuerAndRealm.realm.jwkSetPathUri == null)) {
+            LOGGER.trace("Updating JwtRealm PKC public JWKSet local file");
+            final Path path = PathUtils.get(jwtIssuerAndRealm.realm.jwkSetPath);
+            Files.writeString(path, jwtIssuerAndRealm.issuer.encodedJwkSetPkcPublic);
+        }
+
+        // TODO If x-pack Security plug-in add supports for reloadable settings, update HMAC JWKSet and HMAC OIDC JWK in ES Keystore
     }
 
     protected JwtRealmsServiceSettingsBuilder createJwtRealmsSettingsBuilder() throws Exception {
@@ -290,10 +272,10 @@ public abstract class JwtRealmTestCase extends JwtTestCase {
                 randomBoolean() ? "-1" : randomBoolean() ? "0" : randomIntBetween(1, 5) + randomFrom("s", "m", "h")
             );
         }
-        if (Strings.hasText(jwtIssuer.encodedJwkSetPkcPublic)) {
+        if (jwtIssuer.encodedJwkSetPkcPublic.isEmpty() == false) {
             final String jwkSetPath; // file or HTTPS URL
             if (jwtIssuer.httpsServer == null) {
-                jwkSetPath = super.saveToTempFile("jwkset.", ".json", jwtIssuer.encodedJwkSetPkcPublic.getBytes(StandardCharsets.UTF_8));
+                jwkSetPath = super.saveToTempFile("jwkset.", ".json", jwtIssuer.encodedJwkSetPkcPublic);
             } else {
                 authcSettings.putList(
                     RealmSettings.getFullSettingKey(
@@ -377,13 +359,13 @@ public abstract class JwtRealmTestCase extends JwtTestCase {
 
         // JWT authc realm secure settings
         final MockSecureSettings secureSettings = new MockSecureSettings();
-        if (Strings.hasText(jwtIssuer.encodedJwkSetHmac)) {
+        if (jwtIssuer.algAndJwksHmac.isEmpty() == false) {
             secureSettings.setString(
                 RealmSettings.getFullSettingKey(authcRealmName, JwtRealmSettings.HMAC_JWKSET),
                 jwtIssuer.encodedJwkSetHmac
             );
         }
-        if (Strings.hasText(jwtIssuer.encodedKeyHmacOidc)) {
+        if (jwtIssuer.encodedKeyHmacOidc != null) {
             secureSettings.setString(
                 RealmSettings.getFullSettingKey(authcRealmName, JwtRealmSettings.HMAC_KEY),
                 jwtIssuer.encodedKeyHmacOidc
@@ -435,7 +417,7 @@ public abstract class JwtRealmTestCase extends JwtTestCase {
         return jwtRealm;
     }
 
-    protected JwtIssuerAndRealm randomJwtIssuerRealmPair() {
+    protected JwtIssuerAndRealm randomJwtIssuerRealmPair() throws ParseException {
         // Select random JWT issuer and JWT realm pair, and log the realm settings
         assertThat(this.jwtIssuerAndRealms, is(notNullValue()));
         assertThat(this.jwtIssuerAndRealms, is(not(empty())));
@@ -444,41 +426,7 @@ public abstract class JwtRealmTestCase extends JwtTestCase {
         assertThat(jwtRealm, is(notNullValue()));
         assertThat(jwtRealm.allowedIssuer, is(equalTo(jwtIssuerAndRealm.issuer.issuerClaimValue))); // assert equal, don't print both
         assertThat(jwtIssuerAndRealm.issuer.audiencesClaimValue.stream().anyMatch(jwtRealm.allowedAudiences::contains), is(true));
-        LOGGER.info(
-            "REALM["
-                + jwtRealm.name()
-                + ","
-                + jwtRealm.order()
-                + "/"
-                + this.jwtIssuerAndRealms.size()
-                + "], iss=["
-                + jwtIssuerAndRealm.issuer
-                + "], iss.aud="
-                + jwtIssuerAndRealm.issuer.audiencesClaimValue
-                + ", realm.aud="
-                + jwtRealm.allowedAudiences
-                + ", HMAC alg="
-                + jwtRealm.jwksAlgsHmac.algs()
-                + ", PKC alg="
-                + jwtRealm.jwksAlgsPkc.algs()
-                + ", client=["
-                + jwtRealm.clientAuthenticationType
-                + "], meta=["
-                + jwtRealm.populateUserMetadata
-                + "], authz=["
-                + jwtRealm.delegatedAuthorizationSupport.hasDelegation()
-                + "], jwkSetPath=["
-                + jwtRealm.jwkSetPath
-                + "], claimPrincipal=["
-                + jwtRealm.claimParserPrincipal.getClaimName()
-                + "], claimGroups=["
-                + jwtRealm.claimParserGroups.getClaimName()
-                + "], clientAuthenticationSharedSecret=["
-                + jwtRealm.clientAuthenticationSharedSecret
-                + "], authz=["
-                + jwtRealm.delegatedAuthorizationSupport
-                + "]"
-        );
+        this.printJwtRealmAndIssuer(jwtIssuerAndRealm);
         return jwtIssuerAndRealm;
     }
 
@@ -498,9 +446,8 @@ public abstract class JwtRealmTestCase extends JwtTestCase {
         final int jwtAuthcRepeats = randomIntBetween(jwtAuthcRange.min(), jwtAuthcRange.max());
         for (int authcRun = 1; authcRun <= jwtAuthcRepeats; authcRun++) {
             // Create request with headers set
-            LOGGER.info("RUN[" + authcRun + "/" + jwtAuthcRepeats + "], jwt=[" + jwt + "], secret=[" + sharedSecret + "].");
             final ThreadContext requestThreadContext = super.createThreadContext(jwt, sharedSecret);
-            LOGGER.info(requestThreadContext.getHeaders().toString()); // TODO Remove debug log
+            LOGGER.info("REQ[" + authcRun + "/" + jwtAuthcRepeats + "] HEADERS=" + requestThreadContext.getHeaders());
 
             // Loop through all authc/authz realms. Confirm a JWT authc realm recognizes and extracts the request headers.
             JwtAuthenticationToken jwtAuthenticationToken = null;
@@ -527,15 +474,16 @@ public abstract class JwtRealmTestCase extends JwtTestCase {
                 }
                 assertThat(tokenSecret, is(equalTo(sharedSecret)));
             }
-            LOGGER.info("TOKEN[" + tokenPrincipal + "]: jwt=[" + tokenJwt + "], secret=[" + tokenSecret + "].");
+            LOGGER.info("GOT TOKEN: principal=[" + tokenPrincipal + "], jwt=[" + tokenJwt + "], secret=[" + tokenSecret + "].");
 
-            // Loop through all authc/authz realms. Confirm authenticatedUser is returned with expected principal and roles.
+            // Loop through all authc/authz realms. Confirm user is returned with expected principal and roles.
             User authenticatedUser = null;
             final List<String> realmAuthenticationResults = new ArrayList<>();
             final List<String> realmUsageStats = new ArrayList<>();
             final List<Exception> realmFailureExceptions = new ArrayList<>(jwtRealmsList.size());
             try {
                 for (final JwtRealm candidateJwtRealm : jwtRealmsList) {
+                    LOGGER.info("TRY AUTHC: expected=[" + jwtRealm.name() + "], candidate[" + candidateJwtRealm.name() + "].");
                     final PlainActionFuture<AuthenticationResult<User>> authenticateFuture = PlainActionFuture.newFuture();
                     try {
                         candidateJwtRealm.authenticate(jwtAuthenticationToken, authenticateFuture);
@@ -566,21 +514,21 @@ public abstract class JwtRealmTestCase extends JwtTestCase {
                         realmFailureExceptions.add(new Exception(realmResult, authenticationResultException));
                         switch (authenticationResult.getStatus()) {
                             case SUCCESS:
-                                assertThat(candidateJwtRealm.name(), is(equalTo(jwtRealm.name())));
-                                assertThat(authenticationResult.isAuthenticated(), is(equalTo(true)));
-                                assertThat(authenticationResult.getException(), is(nullValue()));
-                                assertThat(authenticationResult.getMessage(), is(nullValue()));
-                                assertThat(authenticationResult.getMetadata(), is(anEmptyMap()));
+                                assertThat("Unexpected realm SUCCESS status", candidateJwtRealm.name(), is(equalTo(jwtRealm.name())));
+                                assertThat("Expected realm authc false", authenticationResult.isAuthenticated(), is(equalTo(true)));
+                                assertThat("Expected realm exception thrown", authenticationResult.getException(), is(nullValue()));
+                                assertThat("Expected realm message null", authenticationResult.getMessage(), is(nullValue()));
+                                assertThat("Expected realm metadata empty", authenticationResult.getMetadata(), is(anEmptyMap()));
                                 authenticatedUser = authenticationResult.getValue();
-                                assertThat(authenticatedUser, is(notNullValue()));
+                                assertThat("Expected realm user null", authenticatedUser, is(notNullValue()));
                                 break;
                             case CONTINUE:
-                                assertThat(candidateJwtRealm.name(), is(not(equalTo(jwtRealm.name()))));
-                                assertThat(authenticationResult.isAuthenticated(), is(equalTo(false)));
+                                assertThat("Expected realm CONTINUE status", candidateJwtRealm.name(), is(not(equalTo(jwtRealm.name()))));
+                                assertThat("Unexpected realm authc success", authenticationResult.isAuthenticated(), is(equalTo(false)));
                                 continue;
                             case TERMINATE:
-                                assertThat(candidateJwtRealm.name(), is(not(equalTo(jwtRealm.name()))));
-                                assertThat(authenticationResult.isAuthenticated(), is(equalTo(false)));
+                                assertThat("Expected realm TERMINATE status", candidateJwtRealm.name(), is(not(equalTo(jwtRealm.name()))));
+                                assertThat("Unexpected realm authc success", authenticationResult.isAuthenticated(), is(equalTo(false)));
                                 break;
                             default:
                                 fail("Unexpected AuthenticationResult.Status=[" + authenticationResult.getStatus() + "]");
@@ -588,7 +536,8 @@ public abstract class JwtRealmTestCase extends JwtTestCase {
                         }
                         break; // Only SUCCESS falls through to here, break out of the loop
                     } catch (Exception e) {
-                        realmFailureExceptions.add(new Exception("Caught Exception.", e));
+                        realmFailureExceptions.add(e);
+                        throw e;
                     } finally {
                         final PlainActionFuture<Map<String, Object>> usageStatsFuture = PlainActionFuture.newFuture();
                         candidateJwtRealm.usageStats(usageStatsFuture);
@@ -605,7 +554,7 @@ public abstract class JwtRealmTestCase extends JwtTestCase {
                         );
                     }
                 }
-                // Loop ended. Confirm authenticatedUser is returned with expected principal and roles.
+                // Loop ended. Confirm user is returned with expected principal and roles.
                 assertThat("Expected realm " + jwtRealm.name() + " to authenticate.", authenticatedUser, is(notNullValue()));
                 assertThat(user.principal(), equalTo(authenticatedUser.principal()));
                 assertThat(new TreeSet<>(Arrays.asList(user.roles())), equalTo(new TreeSet<>(Arrays.asList(authenticatedUser.roles()))));
@@ -617,11 +566,9 @@ public abstract class JwtRealmTestCase extends JwtTestCase {
                     assertThat(authenticatedUser.metadata(), is(anEmptyMap())); // role mapping with flag false returns empty
                 }
             } catch (Throwable t) {
-                final Exception authcFailed = new Exception("Authentication test failed.");
-                realmFailureExceptions.forEach(authcFailed::addSuppressed); // realm exceptions
-                authcFailed.addSuppressed(t); // final throwable (ex: assertThat)
-                LOGGER.error("Unexpected exception.", authcFailed);
-                throw authcFailed;
+                realmFailureExceptions.forEach(t::addSuppressed); // all previous realm exceptions
+                // LOGGER.error("Unexpected exception.", t);
+                throw t;
             } finally {
                 LOGGER.info("STATS: expected=[" + jwtRealm.name() + "]\n" + String.join("\n", realmUsageStats));
                 if (authenticatedUser != null) {
@@ -640,11 +587,29 @@ public abstract class JwtRealmTestCase extends JwtTestCase {
 
     protected SecureString randomJwt(final JwtIssuerAndRealm jwtIssuerAndRealm, User user) throws Exception {
         final JwtIssuer.AlgJwkPair algJwkPair = randomFrom(jwtIssuerAndRealm.issuer.algAndJwksAll);
-        LOGGER.info("JWK=[" + algJwkPair.jwk().getKeyType() + "/" + algJwkPair.jwk().size() + "], alg=[" + algJwkPair.alg() + "].");
+        final JWK jwk = algJwkPair.jwk();
+        LOGGER.info(
+            "ALG["
+                + algJwkPair.alg()
+                + "]. JWK: kty=["
+                + jwk.getKeyType()
+                + "], len=["
+                + jwk.size()
+                + "], alg=["
+                + jwk.getAlgorithm()
+                + "], use=["
+                + jwk.getKeyUse()
+                + "], ops=["
+                + jwk.getKeyOperations()
+                + "], kid=["
+                + jwk.getKeyID()
+                + "]."
+        );
 
         final Instant now = Instant.now().truncatedTo(ChronoUnit.SECONDS);
         final SignedJWT unsignedJwt = JwtTestCase.buildUnsignedJwt(
-            randomBoolean() ? null : JOSEObjectType.JWT.toString(),
+            randomBoolean() ? null : JOSEObjectType.JWT.toString(), // kty
+            randomBoolean() ? null : jwk.getKeyID(), // kid
             algJwkPair.alg(), // alg
             randomAlphaOfLengthBetween(10, 20), // jwtID
             jwtIssuerAndRealm.realm.allowedIssuer, // iss
@@ -654,16 +619,89 @@ public abstract class JwtRealmTestCase extends JwtTestCase {
             user.principal(), // principal claim value
             jwtIssuerAndRealm.realm.claimParserGroups.getClaimName(), // group claim name
             List.of(user.roles()), // group claim value
-            Date.from(now.minusSeconds(randomLongBetween(10, 20))), // auth_time
-            Date.from(now), // iat
-            Date.from(now.minusSeconds(randomLongBetween(5, 10))), // nbf
-            Date.from(now.plusSeconds(randomLongBetween(3600, 7200))), // exp
+            Date.from(now.minusSeconds(60 * randomLongBetween(10, 20))), // auth_time
+            Date.from(now.minusSeconds(randomBoolean() ? 0 : 60 * randomLongBetween(5, 10))), // iat
+            Date.from(now), // nbf
+            Date.from(now.plusSeconds(60 * randomLongBetween(3600, 7200))), // exp
             randomBoolean() ? null : new Nonce(32).toString(),
             randomBoolean() ? null : Map.of("other1", randomAlphaOfLength(10), "other2", randomAlphaOfLength(10))
         );
-        final SecureString signedJWT = JwtValidateUtil.signJwt(algJwkPair.jwk(), unsignedJwt);
-        assertThat(JwtValidateUtil.verifyJwt(algJwkPair.jwk(), SignedJWT.parse(signedJWT.toString())), is(equalTo(true)));
+        final SecureString signedJWT = JwtValidateUtil.signJwt(jwk, unsignedJwt);
+        assertThat(JwtValidateUtil.verifyJwt(jwk, SignedJWT.parse(signedJWT.toString())), is(equalTo(true)));
         return signedJWT;
     }
 
+    protected void printJwtRealmAndIssuer(JwtIssuerAndRealm jwtIssuerAndRealm) throws ParseException {
+        this.printJwtIssuer(jwtIssuerAndRealm.issuer());
+        this.printJwtRealm(jwtIssuerAndRealm.realm());
+    }
+
+    protected void printJwtRealm(final JwtRealm jwtRealm) {
+        LOGGER.info(
+            "REALM["
+                + jwtRealm.name()
+                + ","
+                + jwtRealm.order()
+                + "/"
+                + this.jwtIssuerAndRealms.size()
+                + "]: clientType=["
+                + jwtRealm.clientAuthenticationType
+                + "], clientSecret=["
+                + jwtRealm.clientAuthenticationSharedSecret
+                + "], iss=["
+                + jwtRealm.allowedIssuer
+                + "], aud="
+                + jwtRealm.allowedAudiences
+                + ", algsHmac="
+                + jwtRealm.allowedJwksAlgsHmac
+                + ", filteredHmac="
+                + jwtRealm.contentAndJwksAlgsHmac.jwksAlgs().algs()
+                + ", algsPkc="
+                + jwtRealm.allowedJwksAlgsPkc
+                + ", filteredPkc="
+                + jwtRealm.contentAndJwksAlgsPkc.jwksAlgs().algs()
+                + ", claimPrincipal=["
+                + jwtRealm.claimParserPrincipal.getClaimName()
+                + "], claimGroups=["
+                + jwtRealm.claimParserGroups.getClaimName()
+                + "], authz=["
+                + jwtRealm.delegatedAuthorizationSupport.hasDelegation()
+                + "], meta=["
+                + jwtRealm.populateUserMetadata
+                + "], jwkSetPath=["
+                + jwtRealm.jwkSetPath
+                + "]."
+        );
+        for (final JWK jwk : jwtRealm.contentAndJwksAlgsHmac.jwksAlgs().jwks()) {
+            LOGGER.info("REALM HMAC: jwk=[{}]", jwk);
+        }
+        for (final JWK jwk : jwtRealm.contentAndJwksAlgsPkc.jwksAlgs().jwks()) {
+            LOGGER.info("REALM PKC: jwk=[{}]", jwk);
+        }
+    }
+
+    protected void printJwtIssuer(final JwtIssuer jwtIssuer) {
+        LOGGER.info(
+            "ISSUER: iss=["
+                + jwtIssuer.issuerClaimValue
+                + "], aud=["
+                + String.join(",", jwtIssuer.audiencesClaimValue)
+                + "], principal=["
+                + jwtIssuer.principalClaimName
+                + "], algorithms=["
+                + String.join(",", jwtIssuer.algorithmsAll)
+                + "], httpServer=["
+                + (jwtIssuer.httpsServer != null)
+                + "]."
+        );
+        if (jwtIssuer.algAndJwkHmacOidc != null) {
+            LOGGER.info("ISSUER HMAC OIDC: alg=[{}] jwk=[{}]", jwtIssuer.algAndJwkHmacOidc.alg(), jwtIssuer.encodedKeyHmacOidc);
+        }
+        for (final JwtIssuer.AlgJwkPair pair : jwtIssuer.algAndJwksHmac) {
+            LOGGER.info("ISSUER HMAC: alg=[{}] jwk=[{}]", pair.alg(), pair.jwk());
+        }
+        for (final JwtIssuer.AlgJwkPair pair : jwtIssuer.algAndJwksPkc) {
+            LOGGER.info("ISSUER PKC: alg=[{}] jwk=[{}]", pair.alg(), pair.jwk());
+        }
+    }
 }

+ 76 - 84
x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtTestCase.java

@@ -49,7 +49,6 @@ import org.elasticsearch.xpack.core.security.user.User;
 import org.elasticsearch.xpack.core.ssl.SSLConfigurationSettings;
 import org.junit.Before;
 import org.mockito.Mockito;
-import org.mockito.stubbing.Answer;
 
 import java.io.IOException;
 import java.net.URL;
@@ -59,17 +58,14 @@ import java.nio.file.Path;
 import java.time.Instant;
 import java.time.temporal.ChronoUnit;
 import java.util.ArrayList;
-import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.Date;
-import java.util.HashSet;
 import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.UUID;
-import java.util.concurrent.atomic.AtomicReference;
 import java.util.stream.IntStream;
 
 import static org.elasticsearch.test.ActionListenerUtils.anyActionListener;
@@ -241,16 +237,6 @@ public abstract class JwtTestCase extends ESTestCase {
         return new RealmConfig(realmIdentifier, settings, this.env, this.threadContext);
     }
 
-    protected Answer<Class<Void>> getAnswer(AtomicReference<UserRoleMapper.UserData> userData) {
-        return invocation -> {
-            userData.set((UserRoleMapper.UserData) invocation.getArguments()[0]);
-            @SuppressWarnings("unchecked")
-            ActionListener<Set<String>> listener = (ActionListener<Set<String>>) invocation.getArguments()[1];
-            listener.onResponse(new HashSet<>(Arrays.asList("kibana_user", "role1")));
-            return null;
-        };
-    }
-
     protected UserRoleMapper buildRoleMapper(final Map<String, User> registeredUsers) {
         final UserRoleMapper roleMapper = mock(UserRoleMapper.class);
         Mockito.doAnswer(invocation -> {
@@ -268,18 +254,19 @@ public abstract class JwtTestCase extends ESTestCase {
         return roleMapper;
     }
 
-    public static List<JwtIssuer.AlgJwkPair> randomJwks(final List<String> signatureAlgorithms) throws JOSEException {
+    public static List<JwtIssuer.AlgJwkPair> randomJwks(final List<String> signatureAlgorithms, final boolean requireOidcSafe)
+        throws JOSEException {
         final List<JwtIssuer.AlgJwkPair> algAndJwks = new ArrayList<>();
         for (final String signatureAlgorithm : signatureAlgorithms) {
-            algAndJwks.add(new JwtIssuer.AlgJwkPair(signatureAlgorithm, JwtTestCase.randomJwk(signatureAlgorithm)));
+            algAndJwks.add(new JwtIssuer.AlgJwkPair(signatureAlgorithm, JwtTestCase.randomJwk(signatureAlgorithm, requireOidcSafe)));
         }
         return algAndJwks;
     }
 
-    public static JWK randomJwk(final String signatureAlgorithm) throws JOSEException {
+    public static JWK randomJwk(final String signatureAlgorithm, final boolean requireOidcSafe) throws JOSEException {
         final JWSAlgorithm jwsAlgorithm = JWSAlgorithm.parse(signatureAlgorithm);
         if (JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS_HMAC.contains(signatureAlgorithm)) {
-            return JwtTestCase.randomJwkHmac(jwsAlgorithm);
+            return JwtTestCase.randomJwkHmac(jwsAlgorithm, requireOidcSafe);
         } else if (JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS_RSA.contains(signatureAlgorithm)) {
             return JwtTestCase.randomJwkRsa(jwsAlgorithm);
         } else if (JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS_EC.contains(signatureAlgorithm)) {
@@ -294,14 +281,31 @@ public abstract class JwtTestCase extends ESTestCase {
         );
     }
 
-    // Generate using random bytes
-    // - random byte => 2^8 => search space 8-bit per byte
-    public static OctetSequenceKey randomJwkHmac(final JWSAlgorithm jwsAlgorithm) throws JOSEException {
+    public static OctetSequenceKey randomJwkHmac(final JWSAlgorithm jwsAlgorithm, final boolean requireOidcSafe) throws JOSEException {
         final int minHmacLengthBytes = MACSigner.getMinRequiredSecretLength(jwsAlgorithm) / 8;
-        final int hmacLengthBits = scaledRandomIntBetween(minHmacLengthBytes, minHmacLengthBytes * 2) * 8; // Double it: Nice to have
-        final OctetSequenceKeyGenerator jwkGenerator = new OctetSequenceKeyGenerator(hmacLengthBits);
-        JwtTestCase.randomSettingsForJwkGenerator(jwkGenerator, jwsAlgorithm); // options: kid, alg, use, ops
-        return jwkGenerator.generate();
+        final int hmacLengthBytes = scaledRandomIntBetween(minHmacLengthBytes, minHmacLengthBytes * 2); // Double it: Nice to have
+        if (requireOidcSafe == false && randomBoolean()) {
+            // random byte => 2^8 search space per 1 byte => 8 bits per byte
+            final OctetSequenceKeyGenerator jwkGenerator = new OctetSequenceKeyGenerator(hmacLengthBytes * 8);
+            return JwtTestCase.randomSettingsForJwkGenerator(jwkGenerator, jwsAlgorithm).generate().toOctetSequenceKey(); // kid,alg,use,ops
+        }
+        final String passwordKey;
+        if (randomBoolean()) {
+            // Base 64 byte => 2^6 search space per 1 byte => 6 bits per byte
+            passwordKey = Base64URL.encode(randomByteArrayOfLength(hmacLengthBytes)).toString();
+        } else {
+            // UTF8 1 byte => 2^7 search space per 1 byte => 7 bits per byte
+            // UTF8 2 byte => 2^11 search space per 2 byte => 5.5 bits per byte
+            // UTF8 3 byte => 2^16 search space per 3 byte => 5.333 bits per byte
+            // UTF8 4 byte => 2^21 search space per 4 byte => 5.25 bits per byte (theoretical, UNICODE currently only allocates 1.1M of 2M)
+            passwordKey = randomAlphaOfLength(hmacLengthBytes);
+        }
+        final OctetSequenceKey.Builder hmacKeyBuilder = new OctetSequenceKey.Builder(passwordKey.getBytes(StandardCharsets.UTF_8));
+        return JwtTestCase.randomSettingsForHmacJwkBuilder(hmacKeyBuilder, jwsAlgorithm).build(); // kid,alg,use,ops
+    }
+
+    public static OctetSequenceKey randomJwkHmacOidcSafe(final JWSAlgorithm jwsAlgorithm) throws JOSEException {
+        return JwtTestCase.randomJwkHmac(jwsAlgorithm, true);
     }
 
     public static RSAKey randomJwkRsa(final JWSAlgorithm jwsAlgorithm) throws JOSEException {
@@ -318,46 +322,6 @@ public abstract class JwtTestCase extends ESTestCase {
         return jwkGenerator.generate();
     }
 
-    public static OctetSequenceKey randomJwkHmacOidc(final JWSAlgorithm jwsAlgorithm) throws JOSEException {
-        return JwtTestCase.conditionJwkHmacForOidc(JwtTestCase.randomJwkHmac(jwsAlgorithm));
-    }
-
-    /**
-     *  Input HMAC key is assumed random bytes. Generating random bytes is useful to guarantee min search space (aka strength, entropy).
-     *
-     *  OIDC HMAC key must be UTF8 bytes (aka password). Encoding random bytes as UTF8 doesn't work, and UTF8 search space is smaller.
-     *
-     *  To satisfy min search space and OIDC UTF8 encoding, Base64(randomBytes) is used as the bytes of a new HMAC OIDC key.
-     *
-     *  Search space comparisons of random bytes, base 64, and UTF-8.
-     *  - random byte => 2^8 search space per 1 byte => 8 bits per byte
-     *  - Base 64 byte => 2^6 search space per 1 byte => 6 bits per byte
-     *  - UTF8 1 byte => 2^7 search space per 1 byte => 7 bits per byte
-     *  - UTF8 2 byte => 2^11 search space per 2 byte => 5.5 bits per byte
-     *  - UTF8 3 byte => 2^16 search space per 3 byte => 5.333 bits per byte
-     *  - UTF8 4 byte => 2^21 search space per 4 byte => 5.25 bits per byte (theoretical, UNICODE currently only allocates 1.1M of 2M)
-     *
-     * @param hmacKey HMAC key with random bytes.
-     * @return HMAC key with UTF-8 bytes, making the key bytes compatible with OIDC UTF-8 string encoding.
-     */
-    public static OctetSequenceKey conditionJwkHmacForOidc(final OctetSequenceKey hmacKey) {
-        final String passwordKey;
-        if (randomBoolean()) {
-            final Base64URL hmacKeyBytesBase64 = hmacKey.getKeyValue(); // Random bytes => 8 bits/byte search space
-            passwordKey = hmacKeyBytesBase64.toString(); // Use Base64(randomBytes) as UTF8 bytes for a new password with same search space
-        } else {
-            final int numLetters = hmacKey.toByteArray().length * 8; // Random [A-Za-z] => 5.7 bits/byte
-            passwordKey = randomAlphaOfLength(numLetters); // Use length * ceil(2.3) to avoid reducing search space below 8 bits/byte
-        }
-        final OctetSequenceKey.Builder hmacKeyBuilder = new OctetSequenceKey.Builder(passwordKey.getBytes(StandardCharsets.UTF_8));
-        hmacKeyBuilder.keyID(hmacKey.getKeyID()); // Copy null attribute is OK (no-op)
-        hmacKeyBuilder.algorithm(hmacKey.getAlgorithm());
-        hmacKeyBuilder.keyUse(hmacKey.getKeyUse());
-        hmacKeyBuilder.keyOperations(hmacKey.getKeyOperations());
-        hmacKeyBuilder.keyStore(hmacKey.getKeyStore());
-        return hmacKeyBuilder.build();
-    }
-
     public static OctetSequenceKey jwkHmacRemoveAttributes(final OctetSequenceKey hmacKey) {
         final String keyBytesAsUtf8 = hmacKey.getKeyValue().decodeToString();
         return new OctetSequenceKey.Builder(keyBytesAsUtf8.getBytes(StandardCharsets.UTF_8)).build();
@@ -382,9 +346,29 @@ public abstract class JwtTestCase extends ESTestCase {
         return jwkGenerator;
     }
 
+    public static OctetSequenceKey.Builder randomSettingsForHmacJwkBuilder(
+        final OctetSequenceKey.Builder jwkGenerator,
+        final JWSAlgorithm jwsAlgorithm
+    ) {
+        if (randomBoolean()) {
+            jwkGenerator.keyID(UUID.randomUUID().toString());
+        }
+        if (randomBoolean()) {
+            jwkGenerator.algorithm(jwsAlgorithm);
+        }
+        if (randomBoolean()) {
+            jwkGenerator.keyUse(KeyUse.SIGNATURE);
+        }
+        if (randomBoolean()) {
+            jwkGenerator.keyOperations(Set.of(KeyOperation.SIGN, KeyOperation.VERIFY));
+        }
+        return jwkGenerator;
+    }
+
     public static SignedJWT buildUnsignedJwt(
         final String type,
-        final String signatureAlgorithm,
+        final String kid,
+        final String alg,
         final String jwtId,
         final String issuer,
         final List<String> audiences,
@@ -400,7 +384,10 @@ public abstract class JwtTestCase extends ESTestCase {
         final String nonce,
         final Map<String, Object> otherClaims
     ) {
-        final JWSHeader.Builder jwsHeaderBuilder = new JWSHeader.Builder(JWSAlgorithm.parse(signatureAlgorithm));
+        final JWSHeader.Builder jwsHeaderBuilder = new JWSHeader.Builder(JWSAlgorithm.parse(alg));
+        if (kid != null) {
+            jwsHeaderBuilder.keyID(kid);
+        }
         if (type != null) {
             jwsHeaderBuilder.type(new JOSEObjectType(type));
         }
@@ -447,20 +434,22 @@ public abstract class JwtTestCase extends ESTestCase {
         if (otherClaims != null) {
             for (final Map.Entry<String, Object> entry : otherClaims.entrySet()) {
                 if (Strings.hasText(entry.getKey()) == false) {
-                    throw new IllegalArgumentException("Null or blank other claim key allowed.");
+                    throw new IllegalArgumentException("Null or blank other claim key not allowed.");
                 } else if (entry.getValue() == null) {
-                    throw new IllegalArgumentException("Null other claim value allowed.");
+                    throw new IllegalArgumentException("Null other claim value not allowed.");
                 }
                 jwtClaimsSetBuilder.claim(entry.getKey(), entry.getValue());
             }
         }
         final JWTClaimsSet jwtClaimsSet = jwtClaimsSetBuilder.build();
         LOGGER.info(
-            "CLAIMS: alg=["
-                + jwtHeader.getAlgorithm().getName()
-                + "], jwtId=["
-                + jwtClaimsSet.getJWTID()
-                + "], iss=["
+            "JWT: HEADER{alg=["
+                + jwtHeader.getAlgorithm()
+                + "], kid=["
+                + jwtHeader.getKeyID()
+                + "], kty=["
+                + jwtHeader.getType()
+                + "]}. CLAIMS: {iss=["
                 + jwtClaimsSet.getIssuer()
                 + "], aud="
                 + jwtClaimsSet.getAudience()
@@ -474,19 +463,21 @@ public abstract class JwtTestCase extends ESTestCase {
                 + groupsClaimName
                 + "="
                 + jwtClaimsSet.getClaim(groupsClaimName)
-                + "], nbf=["
-                + jwtClaimsSet.getNotBeforeTime()
                 + "], auth_time=["
                 + jwtClaimsSet.getClaim("auth_time")
                 + "], iat=["
                 + jwtClaimsSet.getIssueTime()
+                + "], nbf=["
+                + jwtClaimsSet.getNotBeforeTime()
                 + "], exp=["
                 + jwtClaimsSet.getExpirationTime()
                 + "], nonce=["
                 + jwtClaimsSet.getClaim("nonce")
+                + "], jid=["
+                + jwtClaimsSet.getJWTID()
                 + "], other=["
                 + otherClaims
-                + "]"
+                + "]}."
         );
         return JwtValidateUtil.buildUnsignedJwt(jwtHeader, jwtClaimsSet);
     }
@@ -494,7 +485,8 @@ public abstract class JwtTestCase extends ESTestCase {
     public static SecureString randomBespokeJwt(final JWK jwk, final String signatureAlgorithm) throws Exception {
         final Instant now = Instant.now().truncatedTo(ChronoUnit.SECONDS);
         final SignedJWT unsignedJwt = JwtTestCase.buildUnsignedJwt(
-            randomBoolean() ? null : JOSEObjectType.JWT.toString(),
+            randomBoolean() ? null : JOSEObjectType.JWT.toString(), // kty
+            randomBoolean() ? null : jwk.getKeyID(), // kid
             signatureAlgorithm, // alg
             randomAlphaOfLengthBetween(10, 20), // jwtID
             randomFrom("https://www.example.com/", "") + "iss1" + randomIntBetween(0, 99),
@@ -504,10 +496,10 @@ public abstract class JwtTestCase extends ESTestCase {
             "principal1", // principal claim value
             randomBoolean() ? null : randomFrom("groups", "roles", "other"),
             randomFrom(List.of(""), List.of("grp1"), List.of("rol1", "rol2", "rol3"), List.of("per1")),
-            Date.from(now.minusSeconds(randomLongBetween(10, 20))), // auth_time
-            Date.from(now), // iat
-            Date.from(now.minusSeconds(randomLongBetween(5, 10))), // nbf
-            Date.from(now.plusSeconds(randomLongBetween(3600, 7200))), // exp
+            Date.from(now.minusSeconds(60 * randomLongBetween(10, 20))), // auth_time
+            Date.from(now.minusSeconds(randomBoolean() ? 0 : 60 * randomLongBetween(5, 10))), // iat
+            Date.from(now), // nbf
+            Date.from(now.plusSeconds(60 * randomLongBetween(3600, 7200))), // exp
             randomBoolean() ? null : new Nonce(32).toString(),
             randomBoolean() ? null : Map.of("other1", randomAlphaOfLength(10), "other2", randomAlphaOfLength(10))
         );
@@ -573,9 +565,9 @@ public abstract class JwtTestCase extends ESTestCase {
         return IntStream.rangeClosed(1, minToMaxInclusive).mapToObj(i -> randomFrom(collection)).toList(); // 1..N inclusive
     }
 
-    public String saveToTempFile(final String prefix, final String suffix, final byte[] content) throws IOException {
+    public String saveToTempFile(final String prefix, final String suffix, final String content) throws IOException {
         final Path path = Files.createTempFile(PathUtils.get(this.pathHome), prefix, suffix);
-        Files.write(path, content);
+        Files.writeString(path, content);
         return path.toString();
     }
 

+ 8 - 6
x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtValidateUtilTests.java

@@ -26,10 +26,9 @@ public class JwtValidateUtilTests extends JwtTestCase {
 
     private static final Logger LOGGER = LogManager.getLogger(JwtValidateUtilTests.class);
 
-    private boolean helpTestSignatureAlgorithm(final String signatureAlgorithm) throws Exception {
-        LOGGER.info("Testing signature algorithm " + signatureAlgorithm);
-        // randomSecretOrSecretKeyOrKeyPair() randomizes which JwtUtil methods to call, so it indirectly covers most JwtUtil code
-        final JWK jwk = JwtTestCase.randomJwk(signatureAlgorithm);
+    private boolean helpTestSignatureAlgorithm(final String signatureAlgorithm, final boolean requireOidcSafe) throws Exception {
+        LOGGER.trace("Testing signature algorithm " + signatureAlgorithm);
+        final JWK jwk = JwtTestCase.randomJwk(signatureAlgorithm, requireOidcSafe);
         final SecureString serializedJWTOriginal = JwtTestCase.randomBespokeJwt(jwk, signatureAlgorithm);
         final SignedJWT parsedSignedJWT = SignedJWT.parse(serializedJWTOriginal.toString());
         return JwtValidateUtil.verifyJwt(jwk, parsedSignedJWT);
@@ -38,10 +37,13 @@ public class JwtValidateUtilTests extends JwtTestCase {
     public void testJwtSignVerifyPassedForAllSupportedAlgorithms() throws Exception {
         // Pass: "ES256", "ES384", "ES512", RS256", "RS384", "RS512", "PS256", "PS384", "PS512, "HS256", "HS384", "HS512"
         for (final String signatureAlgorithm : JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS) {
-            assertThat(this.helpTestSignatureAlgorithm(signatureAlgorithm), is(true));
+            assertThat(this.helpTestSignatureAlgorithm(signatureAlgorithm, false), is(true));
         }
         // Fail: "ES256K"
-        final Exception exp1 = expectThrows(JOSEException.class, () -> this.helpTestSignatureAlgorithm(JWSAlgorithm.ES256K.getName()));
+        final Exception exp1 = expectThrows(
+            JOSEException.class,
+            () -> this.helpTestSignatureAlgorithm(JWSAlgorithm.ES256K.getName(), false)
+        );
         final String msg1 = "Unsupported signature algorithm ["
             + JWSAlgorithm.ES256K
             + "]. Supported signature algorithms are "