|  | @@ -17,6 +17,7 @@ import com.nimbusds.jose.jwk.ECKey;
 | 
	
		
			
				|  |  |  import com.nimbusds.jose.jwk.JWK;
 | 
	
		
			
				|  |  |  import com.nimbusds.jose.jwk.OctetSequenceKey;
 | 
	
		
			
				|  |  |  import com.nimbusds.jose.jwk.RSAKey;
 | 
	
		
			
				|  |  | +import com.nimbusds.jose.util.Base64URL;
 | 
	
		
			
				|  |  |  import com.nimbusds.jwt.SignedJWT;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  import org.apache.logging.log4j.LogManager;
 | 
	
	
		
			
				|  | @@ -39,6 +40,7 @@ import org.elasticsearch.xpack.core.ssl.SSLService;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  import java.util.Arrays;
 | 
	
		
			
				|  |  |  import java.util.List;
 | 
	
		
			
				|  |  | +import java.util.function.Supplier;
 | 
	
		
			
				|  |  |  import java.util.stream.Stream;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  public interface JwtSignatureValidator extends Releasable {
 | 
	
	
		
			
				|  | @@ -274,26 +276,33 @@ public interface JwtSignatureValidator extends Releasable {
 | 
	
		
			
				|  |  |                  validateSignature(signedJWT, jwksAlgs.jwks());
 | 
	
		
			
				|  |  |                  listener.onResponse(null);
 | 
	
		
			
				|  |  |              } catch (Exception primaryException) {
 | 
	
		
			
				|  |  | -                logger.debug(
 | 
	
		
			
				|  |  | -                    () -> org.elasticsearch.core.Strings.format(
 | 
	
		
			
				|  |  | -                        "Signature verification failed for JWT [%s] reloading JWKSet (was: #[%s] JWKs, #[%s] algs, sha256=[%s])",
 | 
	
		
			
				|  |  | -                        tokenPrincipal,
 | 
	
		
			
				|  |  | -                        jwksAlgs.jwks().size(),
 | 
	
		
			
				|  |  | -                        jwksAlgs.algs().size(),
 | 
	
		
			
				|  |  | -                        MessageDigests.toHexString(contentAndJwksAlgs.sha256())
 | 
	
		
			
				|  |  | -                    ),
 | 
	
		
			
				|  |  | -                    primaryException
 | 
	
		
			
				|  |  | +                String message = org.elasticsearch.core.Strings.format(
 | 
	
		
			
				|  |  | +                    "Signature verification failed for JWT token [%s] against JWK set with sha256=[%s].",
 | 
	
		
			
				|  |  | +                    tokenPrincipal,
 | 
	
		
			
				|  |  | +                    MessageDigests.toHexString(contentAndJwksAlgs.sha256())
 | 
	
		
			
				|  |  |                  );
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +                if (logger.isTraceEnabled()) {
 | 
	
		
			
				|  |  | +                    logger.trace(message, primaryException);
 | 
	
		
			
				|  |  | +                } else {
 | 
	
		
			
				|  |  | +                    logger.debug(message + " Cause: " + primaryException.getMessage());
 | 
	
		
			
				|  |  | +                }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                logger.debug("Attempting to reload JWK set with sha256=[{}]", MessageDigests.toHexString(contentAndJwksAlgs.sha256()));
 | 
	
		
			
				|  |  |                  jwkSetLoader.reload(ActionListener.wrap(ignore -> {
 | 
	
		
			
				|  |  |                      final JwkSetLoader.ContentAndJwksAlgs maybeUpdatedContentAndJwksAlgs = jwkSetLoader.getContentAndJwksAlgs();
 | 
	
		
			
				|  |  |                      if (Arrays.equals(maybeUpdatedContentAndJwksAlgs.sha256(), initialJwksVersion)) {
 | 
	
		
			
				|  |  |                          logger.debug(
 | 
	
		
			
				|  |  |                              "No change in reloaded JWK set with sha256=[{}] will not retry signature verification",
 | 
	
		
			
				|  |  | -                            MessageDigests.toHexString(jwkSetLoader.getContentAndJwksAlgs().sha256())
 | 
	
		
			
				|  |  | +                            MessageDigests.toHexString(maybeUpdatedContentAndJwksAlgs.sha256())
 | 
	
		
			
				|  |  |                          );
 | 
	
		
			
				|  |  |                          listener.onFailure(primaryException);
 | 
	
		
			
				|  |  |                          return;
 | 
	
		
			
				|  |  | +                    } else {
 | 
	
		
			
				|  |  | +                        logger.debug(
 | 
	
		
			
				|  |  | +                            "Successful reload of JWK set. Now with sha256=[{}]",
 | 
	
		
			
				|  |  | +                            MessageDigests.toHexString(maybeUpdatedContentAndJwksAlgs.sha256())
 | 
	
		
			
				|  |  | +                        );
 | 
	
		
			
				|  |  |                      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |                      // If all PKC JWKs were replaced, all PKC JWT cache entries need to be invalidated.
 | 
	
	
		
			
				|  | @@ -342,66 +351,86 @@ public interface JwtSignatureValidator extends Releasable {
 | 
	
		
			
				|  |  |       * @throws Exception Error if JWKs fail to validate the Signed JWT.
 | 
	
		
			
				|  |  |       */
 | 
	
		
			
				|  |  |      default void validateSignature(final SignedJWT jwt, final List<JWK> jwks) throws Exception {
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |          assert jwks != null : "Verify requires a non-null JWK list";
 | 
	
		
			
				|  |  |          if (jwks.isEmpty()) {
 | 
	
		
			
				|  |  | -            throw new ElasticsearchException("Verify requires a non-empty JWK list");
 | 
	
		
			
				|  |  | -        }
 | 
	
		
			
				|  |  | -        final String id = jwt.getHeader().getKeyID();
 | 
	
		
			
				|  |  | -        final JWSAlgorithm alg = jwt.getHeader().getAlgorithm();
 | 
	
		
			
				|  |  | -        logger.trace("JWKs [{}], JWT KID [{}], and JWT Algorithm [{}] before filters.", jwks.size(), id, alg.getName());
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        // If JWT has optional kid header, and realm JWKs have optional kid attribute, any mismatches JWT.kid vs JWK.kid can be ignored.
 | 
	
		
			
				|  |  | -        // Keep any JWKs if JWK optional kid attribute is missing. Keep all JWKs if JWT optional kid header is missing.
 | 
	
		
			
				|  |  | -        final List<JWK> jwksKid = jwks.stream().filter(j -> ((id == null) || (j.getKeyID() == null) || (id.equals(j.getKeyID())))).toList();
 | 
	
		
			
				|  |  | -        logger.trace("JWKs [{}] after KID [{}](|null) filter.", jwksKid.size(), id);
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        // JWT has mandatory alg header. If realm JWKs have optional alg attribute, any mismatches JWT.alg vs JWK.alg can be ignored.
 | 
	
		
			
				|  |  | -        // Keep any JWKs if JWK optional alg attribute is missing.
 | 
	
		
			
				|  |  | -        final List<JWK> jwksAlg = jwksKid.stream().filter(j -> (j.getAlgorithm() == null) || (alg.equals(j.getAlgorithm()))).toList();
 | 
	
		
			
				|  |  | -        logger.trace("JWKs [{}] after Algorithm [{}](|null) filter.", jwksAlg.size(), alg.getName());
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        // PKC Example: Realm has five PKC JWKs RSA-2048, RSA-3072, EC-P256, EC-P384, and EC-P512. JWT alg allows ignoring some.
 | 
	
		
			
				|  |  | -        // - If JWT alg is RS256, only RSA-2048 and RSA-3072 are valid for a JWT RS256 signature. Ignore three EC JWKs.
 | 
	
		
			
				|  |  | -        // - If JWT alg is ES512, only EC-P512 is valid for a JWT ES512 signature. Ignore four JWKs (two RSA, two EC).
 | 
	
		
			
				|  |  | -        // - If JWT alg is ES384, only EC-P384 is valid for a JWT ES384 signature. Ignore four JWKs (two RSA, two EC).
 | 
	
		
			
				|  |  | -        // - If JWT alg is ES256, only EC-P256 is valid for a JWT ES256 signature. Ignore four JWKs (two RSA, two EC).
 | 
	
		
			
				|  |  | -        //
 | 
	
		
			
				|  |  | -        // HMAC Example: Realm has six HMAC JWKs of bit lengths 256, 320, 384, 400, 512, and 1000. JWT alg allows ignoring some.
 | 
	
		
			
				|  |  | -        // - If JWT alg is HS256, all are valid for a JWT HS256 signature. Don't ignore any HMAC JWKs.
 | 
	
		
			
				|  |  | -        // - If JWT alg is HS384, only 384, 400, 512, and 1000 are valid for a JWT HS384 signature. Ignore two HMAC JWKs.
 | 
	
		
			
				|  |  | -        // - If JWT alg is HS512, only 512 and 1000 are valid for a JWT HS512 signature. Ignore four HMAC JWKs.
 | 
	
		
			
				|  |  | -        final List<JWK> jwksStrength = jwksAlg.stream().filter(j -> JwkValidateUtil.isMatch(j, alg.getName())).toList();
 | 
	
		
			
				|  |  | -        logger.debug("JWKs [{}] after Algorithm [{}] match filter.", jwksStrength.size(), alg);
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        // No JWKs passed the kid, alg, and strength checks, so nothing left to use in verifying the JWT signature
 | 
	
		
			
				|  |  | -        if (jwksStrength.isEmpty()) {
 | 
	
		
			
				|  |  | -            throw new ElasticsearchException("Verify failed because all " + jwks.size() + " provided JWKs were filtered.");
 | 
	
		
			
				|  |  | +            throw new ElasticsearchException("Signature verification was not attempted since there are not any JWKs available.");
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -        for (final JWK jwk : jwksStrength) {
 | 
	
		
			
				|  |  | -            if (jwt.verify(createJwsVerifier(jwk))) {
 | 
	
		
			
				|  |  | -                logger.trace(
 | 
	
		
			
				|  |  | -                    "JWT signature validation succeeded with JWK kty=[{}], jwtAlg=[{}], jwtKid=[{}], use=[{}], ops=[{}]",
 | 
	
		
			
				|  |  | -                    jwk.getKeyType(),
 | 
	
		
			
				|  |  | -                    jwk.getAlgorithm(),
 | 
	
		
			
				|  |  | -                    jwk.getKeyID(),
 | 
	
		
			
				|  |  | -                    jwk.getKeyUse(),
 | 
	
		
			
				|  |  | -                    jwk.getKeyOperations()
 | 
	
		
			
				|  |  | -                );
 | 
	
		
			
				|  |  | -                return;
 | 
	
		
			
				|  |  | -            } else {
 | 
	
		
			
				|  |  | -                logger.trace(
 | 
	
		
			
				|  |  | -                    "JWT signature validation failed with JWK kty=[{}], jwtAlg=[{}], jwtKid=[{}], use=[{}], ops={}",
 | 
	
		
			
				|  |  | -                    jwk.getKeyType(),
 | 
	
		
			
				|  |  | -                    jwk.getAlgorithm(),
 | 
	
		
			
				|  |  | -                    jwk.getKeyID(),
 | 
	
		
			
				|  |  | -                    jwk.getKeyUse(),
 | 
	
		
			
				|  |  | -                    jwk.getKeyOperations() == null ? "[null]" : jwk.getKeyOperations()
 | 
	
		
			
				|  |  | +        try (JwtUtil.TraceBuffer tracer = new JwtUtil.TraceBuffer(logger)) {
 | 
	
		
			
				|  |  | +            final String id = jwt.getHeader().getKeyID();
 | 
	
		
			
				|  |  | +            final JWSAlgorithm alg = jwt.getHeader().getAlgorithm();
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            tracer.append("Filtering [{}] possible JWKs to verifying signature for JWT [{}].", jwks.size(), getSafePrintableJWT(jwt));
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            // If JWT has optional kid header, and realm JWKs have optional kid attribute, any mismatches JWT.kid vs JWK.kid can be ignored.
 | 
	
		
			
				|  |  | +            // Keep any JWKs if JWK optional kid attribute is missing. Keep all JWKs if JWT optional kid header is missing.
 | 
	
		
			
				|  |  | +            final List<JWK> jwksKid = jwks.stream()
 | 
	
		
			
				|  |  | +                .filter(j -> ((id == null) || (j.getKeyID() == null) || (id.equals(j.getKeyID()))))
 | 
	
		
			
				|  |  | +                .toList();
 | 
	
		
			
				|  |  | +            tracer.append("[{}] JWKs remain after filtering for KID [{}].", jwksKid.size(), id);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            // JWT has mandatory alg header. If realm JWKs have optional alg attribute, any mismatches JWT.alg vs JWK.alg can be ignored.
 | 
	
		
			
				|  |  | +            // Keep any JWKs if JWK optional alg attribute is missing.
 | 
	
		
			
				|  |  | +            final List<JWK> jwksAlg = jwksKid.stream().filter(j -> (j.getAlgorithm() == null) || (alg.equals(j.getAlgorithm()))).toList();
 | 
	
		
			
				|  |  | +            tracer.append("[{}] algorithms remain after filtering for algorithm name [{}].", jwksAlg.size(), alg.getName());
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            // PKC Example: Realm has five PKC JWKs RSA-2048, RSA-3072, EC-P256, EC-P384, and EC-P512. JWT alg allows ignoring some.
 | 
	
		
			
				|  |  | +            // - If JWT alg is RS256, only RSA-2048 and RSA-3072 are valid for a JWT RS256 signature. Ignore three EC JWKs.
 | 
	
		
			
				|  |  | +            // - If JWT alg is ES512, only EC-P512 is valid for a JWT ES512 signature. Ignore four JWKs (two RSA, two EC).
 | 
	
		
			
				|  |  | +            // - If JWT alg is ES384, only EC-P384 is valid for a JWT ES384 signature. Ignore four JWKs (two RSA, two EC).
 | 
	
		
			
				|  |  | +            // - If JWT alg is ES256, only EC-P256 is valid for a JWT ES256 signature. Ignore four JWKs (two RSA, two EC).
 | 
	
		
			
				|  |  | +            //
 | 
	
		
			
				|  |  | +            // HMAC Example: Realm has six HMAC JWKs of bit lengths 256, 320, 384, 400, 512, and 1000. JWT alg allows ignoring some.
 | 
	
		
			
				|  |  | +            // - If JWT alg is HS256, all are valid for a JWT HS256 signature. Don't ignore any HMAC JWKs.
 | 
	
		
			
				|  |  | +            // - If JWT alg is HS384, only 384, 400, 512, and 1000 are valid for a JWT HS384 signature. Ignore two HMAC JWKs.
 | 
	
		
			
				|  |  | +            // - If JWT alg is HS512, only 512 and 1000 are valid for a JWT HS512 signature. Ignore four HMAC JWKs.
 | 
	
		
			
				|  |  | +            final List<JWK> jwksConfigured = jwksAlg.stream().filter(j -> JwkValidateUtil.isMatch(j, alg.getName(), tracer)).toList();
 | 
	
		
			
				|  |  | +            tracer.append("[{}] JWKs remain after filtering for configured algorithms.", jwksConfigured.size());
 | 
	
		
			
				|  |  | +            tracer.flush();
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            // No JWKs passed the kid, alg, and strength checks, so nothing left to use in verifying the JWT signature
 | 
	
		
			
				|  |  | +            if (jwksConfigured.isEmpty()) {
 | 
	
		
			
				|  |  | +                throw new ElasticsearchException(
 | 
	
		
			
				|  |  | +                    "Signature verification was not attempted since there are not any JWKs "
 | 
	
		
			
				|  |  | +                        + "available after filtering for incompatible keys."
 | 
	
		
			
				|  |  |                  );
 | 
	
		
			
				|  |  |              }
 | 
	
		
			
				|  |  | -        }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -        throw new ElasticsearchException("Verify failed using " + jwksStrength.size() + " of " + jwks.size() + " provided JWKs.");
 | 
	
		
			
				|  |  | +            int attempt = 0;
 | 
	
		
			
				|  |  | +            int maxAttempts = jwksConfigured.size();
 | 
	
		
			
				|  |  | +            tracer.append("Attempting to verify signature for JWT [{}] against [{}] possible JWKs.", getSafePrintableJWT(jwt), maxAttempts);
 | 
	
		
			
				|  |  | +            for (final JWK jwk : jwksConfigured) {
 | 
	
		
			
				|  |  | +                attempt++;
 | 
	
		
			
				|  |  | +                if (jwt.verify(createJwsVerifier(jwk))) {
 | 
	
		
			
				|  |  | +                    tracer.append(
 | 
	
		
			
				|  |  | +                        "Attempt [{}/{}] -> JWT signature verification succeeded with jwk/kid=[{}], jwk/alg=[{}], jwk/kty=[{}], "
 | 
	
		
			
				|  |  | +                            + "jwk/use=[{}], jwk/key_ops=[{}]",
 | 
	
		
			
				|  |  | +                        attempt,
 | 
	
		
			
				|  |  | +                        maxAttempts,
 | 
	
		
			
				|  |  | +                        jwk.getKeyID(),
 | 
	
		
			
				|  |  | +                        jwk.getAlgorithm(),
 | 
	
		
			
				|  |  | +                        jwk.getKeyType(),
 | 
	
		
			
				|  |  | +                        jwk.getKeyUse(),
 | 
	
		
			
				|  |  | +                        jwk.getKeyOperations()
 | 
	
		
			
				|  |  | +                    );
 | 
	
		
			
				|  |  | +                    return;
 | 
	
		
			
				|  |  | +                } else {
 | 
	
		
			
				|  |  | +                    tracer.append(
 | 
	
		
			
				|  |  | +                        "Attempt [{}/{}] -> JWT signature verification failed with jwk/kid=[{}], jwk/alg=[{}], jwk/kty=[{}], jwk/use=[{}], "
 | 
	
		
			
				|  |  | +                            + "jwk/key_ops=[{}]",
 | 
	
		
			
				|  |  | +                        attempt,
 | 
	
		
			
				|  |  | +                        maxAttempts,
 | 
	
		
			
				|  |  | +                        jwk.getKeyID(),
 | 
	
		
			
				|  |  | +                        jwk.getAlgorithm(),
 | 
	
		
			
				|  |  | +                        jwk.getKeyType(),
 | 
	
		
			
				|  |  | +                        jwk.getKeyUse(),
 | 
	
		
			
				|  |  | +                        jwk.getKeyOperations()
 | 
	
		
			
				|  |  | +                    );
 | 
	
		
			
				|  |  | +                }
 | 
	
		
			
				|  |  | +            }
 | 
	
		
			
				|  |  | +            throw new ElasticsearchException("JWT [" + getSafePrintableJWT(jwt).get() + "] signature verification failed.");
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      default JWSVerifier createJwsVerifier(final JWK jwk) throws JOSEException {
 | 
	
	
		
			
				|  | @@ -428,4 +457,16 @@ public interface JwtSignatureValidator extends Releasable {
 | 
	
		
			
				|  |  |      interface PkcJwkSetReloadNotifier {
 | 
	
		
			
				|  |  |          void reloaded();
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /**
 | 
	
		
			
				|  |  | +     * @param jwt The signed JWT
 | 
	
		
			
				|  |  | +     * @return A print safe supplier to describe a JWT that redacts the signature. While the signature is not generally sensitive,
 | 
	
		
			
				|  |  | +     * we don't want to leak the entire JWT to the log to avoid a possible replay.
 | 
	
		
			
				|  |  | +     */
 | 
	
		
			
				|  |  | +    private Supplier<String> getSafePrintableJWT(SignedJWT jwt) {
 | 
	
		
			
				|  |  | +        Base64URL[] parts = jwt.getParsedParts();
 | 
	
		
			
				|  |  | +        assert parts.length == 3;
 | 
	
		
			
				|  |  | +        return () -> parts[0].toString() + "." + parts[1].toString() + ".<redacted>";
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  }
 |