Browse Source

JWT realm - Initial support for access tokens (#91781)

This PR adds initial support for access tokens to JWT realm. Unlike ID
tokens, access tokens are not well defined and rather arbitrary.
Therefore the JWT realm needs to enforce a set of criteria for an access
token to be supported: * An access token must be a JWT * An access token
must be identifiable by (contain) three coordinates: issuer, subject and
audiences * The issuer corresponds to a JWT's `iss` claim. The subject
defaults to the `sub` claim but can fallback to another claim if
configured. Similarly, audiences default to the `aud` claim but can
fallback to another claim. * An access token must have both `iat` and
`exp` claims of the same semantics as defined by ID token spec

An access token meets all above requirements is processed further to run
through additional checks similar to ID tokens, such as
`allowed_issuer`, `allowed_audiences`, signature validation etc. 

A fallback claim takes effect only when the main one does _not_ exist.
When it is effective, it is honored for attribute mapping as well. For
example, if we have `fallback_claims.sub: email` and `claims.principal:
sub`, the principal attribute will be mapped from the `email` claim if
`sub` does not exist.

Since we now have dedicated support for access tokens which allow
fallback claims. The ID token support can be more strict and hence it
now always mandates the `sub` claim (per ID token spec).

NOTE: Fallback of the `aud` claim does not fully work in this PR because
`JwtAuthenticationToken` currently
[mandates](https://github.com/elastic/elasticsearch/blob/cf0b1af418cad79dfa9c8193c3be51e3c644d8d9/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticationToken.java#L74)
it. This needs to be fixed. The plan is to remove the entire
JwtRealmServices and let each JwtRealm directly extract the token. It
will be done in a separate PR to keep the size under control.
Yang Wang 2 years ago
parent
commit
6179e9ffd8
27 changed files with 1074 additions and 257 deletions
  1. 5 0
      docs/changelog/91781.yaml
  2. 101 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/jwt/JwtRealmSettings.java
  3. 17 5
      x-pack/plugin/security/qa/jwt-realm/build.gradle
  4. 50 23
      x-pack/plugin/security/qa/jwt-realm/src/javaRestTest/java/org/elasticsearch/xpack/security/authc/jwt/JwtRestIT.java
  5. 82 0
      x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/FallbackableClaim.java
  6. 9 8
      x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtAlgorithmValidator.java
  7. 39 4
      x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticator.java
  8. 6 10
      x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtDateClaimValidator.java
  9. 15 5
      x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealm.java
  10. 30 26
      x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtStringClaimValidator.java
  11. 1 4
      x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtTypeValidator.java
  12. 34 15
      x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/support/ClaimParser.java
  13. 53 0
      x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/FallbackableClaimTests.java
  14. 2 3
      x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtAlgorithmValidatorTests.java
  15. 58 0
      x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticatorAccessTokenTypeTests.java
  16. 42 0
      x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticatorIdTokenTypeTests.java
  17. 59 13
      x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticatorTests.java
  18. 8 9
      x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtDateClaimValidatorTests.java
  19. 167 0
      x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmAuthenticateAccessTokenTypeTests.java
  20. 57 57
      x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmAuthenticateTests.java
  21. 5 5
      x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmGenerateTests.java
  22. 4 0
      x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmInspector.java
  23. 107 0
      x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmSettingsTests.java
  24. 14 35
      x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmTestCase.java
  25. 105 30
      x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtStringClaimValidatorTests.java
  26. 2 3
      x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtTypeValidatorTests.java
  27. 2 2
      x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/oidc/OpenIdConnectRealmTests.java

+ 5 - 0
docs/changelog/91781.yaml

@@ -0,0 +1,5 @@
+pr: 91781
+summary: JWT realm - Initial support for access tokens
+area: Authentication
+type: enhancement
+issues: []

+ 101 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/jwt/JwtRealmSettings.java

@@ -19,7 +19,9 @@ import java.util.Collection;
 import java.util.Collections;
 import java.util.EnumSet;
 import java.util.HashSet;
+import java.util.Iterator;
 import java.util.List;
+import java.util.Map;
 import java.util.Set;
 import java.util.function.Function;
 import java.util.stream.Collectors;
@@ -155,6 +157,9 @@ public class JwtRealmSettings {
         // JWT End-user settings
         set.addAll(
             List.of(
+                ALLOWED_SUBJECTS,
+                FALLBACK_SUB_CLAIM,
+                FALLBACK_AUD_CLAIM,
                 CLAIMS_PRINCIPAL.getClaim(),
                 CLAIMS_PRINCIPAL.getPattern(),
                 CLAIMS_GROUPS.getClaim(),
@@ -245,6 +250,58 @@ public class JwtRealmSettings {
 
     // JWT end-user settings
 
+    public static final Setting.AffixSetting<List<String>> ALLOWED_SUBJECTS = Setting.affixKeySetting(
+        RealmSettings.realmSettingPrefix(TYPE),
+        "allowed_subjects",
+        key -> Setting.stringListSetting(key, values -> verifyNonNullNotEmpty(key, values, null), Setting.Property.NodeScope)
+    );
+
+    // Registered claim names from the JWT spec https://www.rfc-editor.org/rfc/rfc7519#section-4.1.
+    // Being registered means they have prescribed meanings when they present in a JWT.
+    public static final List<String> REGISTERED_CLAIM_NAMES = List.of("iss", "sub", "aud", "exp", "nbf", "iat", "jti");
+
+    public static final Setting.AffixSetting<String> FALLBACK_SUB_CLAIM = Setting.affixKeySetting(
+        RealmSettings.realmSettingPrefix(TYPE),
+        "fallback_claims.sub",
+        key -> Setting.simpleString(key, "sub", new Setting.Validator<>() {
+            @Override
+            public void validate(String value) {}
+
+            @Override
+            public void validate(String value, Map<Setting<?>, Object> settings, boolean isPresent) {
+                validateFallbackClaimSetting(FALLBACK_SUB_CLAIM, key, value, settings, isPresent);
+            }
+
+            @Override
+            public Iterator<Setting<?>> settings() {
+                final String namespace = FALLBACK_SUB_CLAIM.getNamespace(FALLBACK_SUB_CLAIM.getConcreteSetting(key));
+                final List<Setting<?>> settings = List.of(TOKEN_TYPE.getConcreteSettingForNamespace(namespace));
+                return settings.iterator();
+            }
+        }, Setting.Property.NodeScope)
+    );
+
+    public static final Setting.AffixSetting<String> FALLBACK_AUD_CLAIM = Setting.affixKeySetting(
+        RealmSettings.realmSettingPrefix(TYPE),
+        "fallback_claims.aud",
+        key -> Setting.simpleString(key, "aud", new Setting.Validator<>() {
+            @Override
+            public void validate(String value) {}
+
+            @Override
+            public void validate(String value, Map<Setting<?>, Object> settings, boolean isPresent) {
+                validateFallbackClaimSetting(FALLBACK_AUD_CLAIM, key, value, settings, isPresent);
+            }
+
+            @Override
+            public Iterator<Setting<?>> settings() {
+                final String namespace = FALLBACK_AUD_CLAIM.getNamespace(FALLBACK_AUD_CLAIM.getConcreteSetting(key));
+                final List<Setting<?>> settings = List.of(TOKEN_TYPE.getConcreteSettingForNamespace(namespace));
+                return settings.iterator();
+            }
+        }, Setting.Property.NodeScope)
+    );
+
     // Note: ClaimSetting is a wrapper for two individual settings: getClaim(), getPattern()
     public static final ClaimSetting CLAIMS_PRINCIPAL = new ClaimSetting(TYPE, "principal");
     public static final ClaimSetting CLAIMS_GROUPS = new ClaimSetting(TYPE, "groups");
@@ -357,4 +414,48 @@ public class JwtRealmSettings {
         }
     }
 
+    private static void validateFallbackClaimSetting(
+        Setting.AffixSetting<String> setting,
+        String key,
+        String value,
+        Map<Setting<?>, Object> settings,
+        boolean isPresent
+    ) {
+        if (false == isPresent) {
+            return;
+        }
+        final String namespace = setting.getNamespace(setting.getConcreteSetting(key));
+        final TokenType tokenType = (TokenType) settings.get(TOKEN_TYPE.getConcreteSettingForNamespace(namespace));
+        if (tokenType == TokenType.ID_TOKEN) {
+            throw new IllegalArgumentException(
+                Strings.format(
+                    "fallback claim setting [%s] is not allowed when JWT realm [%s] is [%s] type",
+                    key,
+                    namespace,
+                    JwtRealmSettings.TokenType.ID_TOKEN.value()
+                )
+            );
+        }
+        verifyFallbackClaimName(key, value);
+    }
+
+    private static void verifyFallbackClaimName(String key, String fallbackClaimName) {
+        final String claimName = key.substring(key.lastIndexOf(".") + 1);
+        verifyNonNullNotEmpty(key, fallbackClaimName, null);
+        if (claimName.equals(fallbackClaimName)) {
+            return;
+        }
+        // Registered claims have prescribed meanings and should not be used for something else.
+        if (REGISTERED_CLAIM_NAMES.contains(fallbackClaimName)) {
+            throw new IllegalArgumentException(
+                Strings.format(
+                    "Invalid fallback claims setting [%s]. Claim [%s] cannot fallback to a registered claim [%s]",
+                    key,
+                    claimName,
+                    fallbackClaimName
+                )
+            );
+        }
+    }
+
 }

+ 17 - 5
x-pack/plugin/security/qa/jwt-realm/build.gradle

@@ -10,7 +10,13 @@ dependencies {
   javaRestTestImplementation project(":client:rest")
 }
 
-boolean explicitIdTokenType = BuildParams.random.nextBoolean()
+def random = BuildParams.random
+boolean explicitIdTokenType = random.nextBoolean()
+def serviceSubject = 'service_' + random.nextInt(1, 9) + '@app' + random.nextInt(1, 9) + '.example.com'
+
+tasks.named("javaRestTest").configure {
+  systemProperty 'jwt2.service_subject', serviceSubject
+}
 
 testClusters.matching { it.name == 'javaRestTest' }.configureEach {
   testDistribution = 'DEFAULT'
@@ -64,13 +70,19 @@ testClusters.matching { it.name == 'javaRestTest' }.configureEach {
   setting 'xpack.security.authc.realms.native.lookup_native.order', '2'
 
   setting 'xpack.security.authc.realms.jwt.jwt2.order', '3'
-  if (explicitIdTokenType) {
-    setting 'xpack.security.authc.realms.jwt.jwt2.token_type', 'id_token'
-  }
+  setting 'xpack.security.authc.realms.jwt.jwt2.token_type', 'access_token'
+  setting 'xpack.security.authc.realms.jwt.jwt2.fallback_claims.sub', 'email'
+  setting 'xpack.security.authc.realms.jwt.jwt2.fallback_claims.aud', 'scope'
   setting 'xpack.security.authc.realms.jwt.jwt2.allowed_issuer', 'my-issuer'
+  setting 'xpack.security.authc.realms.jwt.jwt2.allowed_subjects', serviceSubject
   setting 'xpack.security.authc.realms.jwt.jwt2.allowed_audiences', 'es01,es02,es03'
   setting 'xpack.security.authc.realms.jwt.jwt2.allowed_signature_algorithms', 'HS256,HS384'
-  setting 'xpack.security.authc.realms.jwt.jwt2.claims.principal', 'email'
+  // Both email or sub works because of fallback
+  if (random.nextBoolean()) {
+    setting 'xpack.security.authc.realms.jwt.jwt2.claims.principal', 'email'
+  } else {
+    setting 'xpack.security.authc.realms.jwt.jwt2.claims.principal', 'sub'
+  }
   setting 'xpack.security.authc.realms.jwt.jwt2.claim_patterns.principal', '^(.*)@[^.]*[.]example[.]com$'
   setting 'xpack.security.authc.realms.jwt.jwt2.authorization_realms', 'lookup_native'
   setting 'xpack.security.authc.realms.jwt.jwt2.client_authentication.type', 'shared_secret'

+ 50 - 23
x-pack/plugin/security/qa/jwt-realm/src/javaRestTest/java/org/elasticsearch/xpack/security/authc/jwt/JwtRestIT.java

@@ -281,10 +281,11 @@ public class JwtRestIT extends ESRestTestCase {
      * - uses a shared-secret for client authentication
      */
     public void testAuthenticateWithHmacSignedJWTAndDelegatedAuthorization() throws Exception {
-        final String principal = randomPrincipal();
+        final String principal = System.getProperty("jwt2.service_subject");
+        final String username = getUsernameFromPrincipal(principal);
         final List<String> roles = randomRoles();
         final String randomMetadata = randomAlphaOfLengthBetween(6, 18);
-        createUser(principal, roles, Map.of("test_key", randomMetadata));
+        createUser(username, roles, Map.of("test_key", randomMetadata));
 
         try {
             final SignedJWT jwt = buildAndSignJwtForRealm2(principal);
@@ -292,7 +293,7 @@ public class JwtRestIT extends ESRestTestCase {
 
             final Map<String, Object> response = client.authenticate();
 
-            assertThat(response.get(User.Fields.USERNAME.getPreferredName()), is(principal));
+            assertThat(response.get(User.Fields.USERNAME.getPreferredName()), is(username));
             assertThat(assertMap(response, User.Fields.AUTHENTICATION_REALM), hasEntry(User.Fields.REALM_NAME.getPreferredName(), "jwt2"));
             assertThat(assertList(response, User.Fields.ROLES), Matchers.containsInAnyOrder(roles.toArray(String[]::new)));
             assertThat(assertMap(response, User.Fields.METADATA), hasEntry("test_key", randomMetadata));
@@ -304,14 +305,15 @@ public class JwtRestIT extends ESRestTestCase {
             );
             assertThat(exception.getResponse(), hasStatusCode(RestStatus.FORBIDDEN));
         } finally {
-            deleteUser(principal);
+            deleteUser(username);
         }
     }
 
     public void testFailureOnInvalidHMACSignature() throws Exception {
-        final String principal = randomPrincipal();
+        final String principal = System.getProperty("jwt2.service_subject");
+        final String username = getUsernameFromPrincipal(principal);
         final List<String> roles = randomRoles();
-        createUser(principal, roles, Map.of());
+        createUser(username, roles, Map.of());
 
         try {
             final JWTClaimsSet claimsSet = buildJwtForRealm2(principal, Instant.now());
@@ -320,7 +322,7 @@ public class JwtRestIT extends ESRestTestCase {
                 // This is the correct HMAC passphrase (from build.gradle)
                 final SignedJWT jwt = signHmacJwt(claimsSet, "test-HMAC/secret passphrase-value");
                 final TestSecurityClient client = getSecurityClient(jwt, VALID_SHARED_SECRET);
-                assertThat(client.authenticate(), hasEntry(User.Fields.USERNAME.getPreferredName(), principal));
+                assertThat(client.authenticate(), hasEntry(User.Fields.USERNAME.getPreferredName(), username));
             }
             {
                 // This is not the correct HMAC passphrase
@@ -331,13 +333,14 @@ public class JwtRestIT extends ESRestTestCase {
                 assertThat(exception.getResponse(), hasStatusCode(RestStatus.UNAUTHORIZED));
             }
         } finally {
-            deleteUser(principal);
+            deleteUser(username);
         }
 
     }
 
     public void testAuthenticationFailureIfDelegatedAuthorizationFails() throws Exception {
-        final String principal = randomPrincipal();
+        final String principal = System.getProperty("jwt2.service_subject");
+        final String username = getUsernameFromPrincipal(principal);
         final SignedJWT jwt = buildAndSignJwtForRealm2(principal);
         final TestSecurityClient client = getSecurityClient(jwt, VALID_SHARED_SECRET);
 
@@ -345,19 +348,20 @@ public class JwtRestIT extends ESRestTestCase {
         final ResponseException exception = expectThrows(ResponseException.class, client::authenticate);
         assertThat(exception.getResponse(), hasStatusCode(RestStatus.UNAUTHORIZED));
 
-        createUser(principal, List.of(), Map.of());
+        createUser(username, List.of(), Map.of());
         try {
             // Now it works
-            assertThat(client.authenticate(), hasEntry(User.Fields.USERNAME.getPreferredName(), principal));
+            assertThat(client.authenticate(), hasEntry(User.Fields.USERNAME.getPreferredName(), username));
         } finally {
-            deleteUser(principal);
+            deleteUser(username);
         }
     }
 
     public void testFailureOnInvalidClientAuthentication() throws Exception {
-        final String principal = randomPrincipal();
+        final String principal = System.getProperty("jwt2.service_subject");
+        final String username = getUsernameFromPrincipal(principal);
         final List<String> roles = randomRoles();
-        createUser(principal, roles, Map.of());
+        createUser(username, roles, Map.of());
 
         try {
             final SignedJWT jwt = buildAndSignJwtForRealm2(principal);
@@ -368,7 +372,7 @@ public class JwtRestIT extends ESRestTestCase {
             assertThat(exception.getResponse(), hasStatusCode(RestStatus.UNAUTHORIZED));
 
         } finally {
-            deleteUser(principal);
+            deleteUser(username);
         }
     }
 
@@ -499,14 +503,15 @@ public class JwtRestIT extends ESRestTestCase {
     }
 
     private JWTClaimsSet buildJwtForRealm2(String principal, Instant issueTime) {
-        final String emailAddress = principal + "@" + randomAlphaOfLengthBetween(3, 6) + ".example.com";
         // The "jwt2" realm, supports 3 audiences (es01/02/03)
         final String audience = "es0" + randomIntBetween(1, 3);
-        final JWTClaimsSet claimsSet = buildJwt(
-            Map.ofEntries(Map.entry("iss", "my-issuer"), Map.entry("aud", audience), Map.entry("email", emailAddress)),
-            issueTime,
-            false
-        );
+        final Map<String, Object> data = new HashMap<>(Map.of("iss", "my-issuer", "aud", audience, "email", principal));
+        // scope (fallback audience) is ignored since aud exists
+        if (randomBoolean()) {
+            data.put("scope", randomAlphaOfLength(20));
+        }
+
+        final JWTClaimsSet claimsSet = buildJwt(data, issueTime, false);
         return claimsSet;
     }
 
@@ -645,16 +650,38 @@ public class JwtRestIT extends ESRestTestCase {
         createUser(principal, new SecureString(randomAlphaOfLength(12).toCharArray()), roles, metadata);
     }
 
-    private void createUser(String username, SecureString password, List<String> roles, Map<String, Object> metadata) throws IOException {
+    private void createUser(String principal, SecureString password, List<String> roles, Map<String, Object> metadata) throws IOException {
+        final String username;
+        if (principal.contains("@")) {
+            username = principal.substring(0, principal.indexOf("@"));
+        } else {
+            username = principal;
+        }
         final String realName = randomAlphaOfLengthBetween(6, 18);
         final User user = new User(username, roles.toArray(String[]::new), realName, null, metadata, true);
         getAdminSecurityClient().putUser(user, password);
     }
 
-    private void deleteUser(String username) throws IOException {
+    private void deleteUser(String principal) throws IOException {
+        final String username;
+        if (principal.contains("@")) {
+            username = principal.substring(0, principal.indexOf("@"));
+        } else {
+            username = principal;
+        }
         getAdminSecurityClient().deleteUser(username);
     }
 
+    private String getUsernameFromPrincipal(String principal) {
+        final String username;
+        if (principal.contains("@")) {
+            username = principal.substring(0, principal.indexOf("@"));
+        } else {
+            username = principal;
+        }
+        return username;
+    }
+
     private String createRoleMapping(List<String> roles, String rules) throws IOException {
         Map<String, Object> mapping = new HashMap<>();
         mapping.put("enabled", true);

+ 82 - 0
x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/FallbackableClaim.java

@@ -0,0 +1,82 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.security.authc.jwt;
+
+import com.nimbusds.jwt.JWTClaimsSet;
+
+import org.elasticsearch.core.Nullable;
+
+import java.text.ParseException;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+
+import static org.elasticsearch.core.Strings.format;
+
+/**
+ * A JWT claim that can optionally fallback to another claim (if configured) for retrieving the associated value
+ * from a {@link JWTClaimsSet}. The fallback behaviour happens only when:
+ * 1. The fallback is configured (it can be null)
+ * 2. The original claim does not exist in the {@link JWTClaimsSet}
+ * In any other cases, the original claim will be used for retrieving the value.
+ */
+public class FallbackableClaim {
+    private final String name;
+    private final JWTClaimsSet claimsSet;
+    private final String actualName;
+
+    public FallbackableClaim(String name, @Nullable Map<String, String> fallbackClaimNames, JWTClaimsSet claimsSet) {
+        this.name = Objects.requireNonNull(name);
+        this.claimsSet = Objects.requireNonNull(claimsSet);
+        final String fallbackName;
+        if (fallbackClaimNames != null) {
+            fallbackName = fallbackClaimNames.getOrDefault(name, name);
+        } else {
+            fallbackName = null;
+        }
+        if (fallbackName == null) {
+            this.actualName = name;
+        } else {
+            this.actualName = claimsSet.getClaim(name) != null ? name : fallbackName;
+        }
+    }
+
+    public String getActualName() {
+        return actualName;
+    }
+
+    public String getStringClaimValue() {
+        try {
+            return claimsSet.getStringClaim(actualName);
+        } catch (ParseException e) {
+            throw new IllegalArgumentException(format("cannot parse string claim [%s] as string", this), e);
+        }
+    }
+
+    public List<String> getStringListClaimValue() {
+        final Object claimValue = claimsSet.getClaim(actualName);
+        if (claimValue instanceof String) {
+            return List.of((String) claimValue);
+        } else {
+            try {
+                return claimsSet.getStringListClaim(actualName);
+            } catch (ParseException e) {
+                throw new IllegalArgumentException(format("cannot parse string claim [%s] as string array", this), e);
+            }
+        }
+    }
+
+    @Override
+    public String toString() {
+        if (name.equals(actualName)) {
+            return name;
+        } else {
+            return format("%s (fallback of %s)", actualName, name);
+        }
+    }
+}

+ 9 - 8
x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtAlgorithmValidator.java

@@ -11,12 +11,12 @@ import com.nimbusds.jose.JWSAlgorithm;
 import com.nimbusds.jose.JWSHeader;
 import com.nimbusds.jwt.JWTClaimsSet;
 
-import org.elasticsearch.ElasticsearchSecurityException;
 import org.elasticsearch.common.Strings;
-import org.elasticsearch.rest.RestStatus;
 
 import java.util.List;
 
+import static org.elasticsearch.core.Strings.format;
+
 public class JwtAlgorithmValidator implements JwtFieldValidator {
 
     private final List<String> allowedAlgorithms;
@@ -28,15 +28,16 @@ public class JwtAlgorithmValidator implements JwtFieldValidator {
     public void validate(JWSHeader jwsHeader, JWTClaimsSet jwtClaimsSet) {
         final JWSAlgorithm algorithm = jwsHeader.getAlgorithm();
         if (algorithm == null) {
-            throw new ElasticsearchSecurityException("missing JWT algorithm header", RestStatus.BAD_REQUEST);
+            throw new IllegalArgumentException("missing JWT algorithm header");
         }
 
         if (false == allowedAlgorithms.contains(algorithm.getName())) {
-            throw new ElasticsearchSecurityException(
-                "invalid JWT algorithm [{}], allowed algorithms are [{}]",
-                RestStatus.BAD_REQUEST,
-                algorithm,
-                Strings.collectionToCommaDelimitedString(allowedAlgorithms)
+            throw new IllegalArgumentException(
+                format(
+                    "invalid JWT algorithm [%s], allowed algorithms are [%s]",
+                    algorithm,
+                    Strings.collectionToCommaDelimitedString(allowedAlgorithms)
+                )
             );
         }
     }

+ 39 - 4
x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticator.java

@@ -24,6 +24,7 @@ import org.elasticsearch.xpack.core.ssl.SSLService;
 import java.text.ParseException;
 import java.time.Clock;
 import java.util.List;
+import java.util.Map;
 
 /**
  * This class performs validations of header, claims and signatures against the incoming {@link JwtAuthenticationToken}.
@@ -37,6 +38,7 @@ public class JwtAuthenticator implements Releasable {
     private final List<JwtFieldValidator> jwtFieldValidators;
     private final JwtSignatureValidator jwtSignatureValidator;
     private final JwtRealmSettings.TokenType tokenType;
+    private final Map<String, String> fallbackClaimNames;
 
     public JwtAuthenticator(
         final RealmConfig realmConfig,
@@ -46,9 +48,14 @@ public class JwtAuthenticator implements Releasable {
         this.realmConfig = realmConfig;
         this.tokenType = realmConfig.getSetting(JwtRealmSettings.TOKEN_TYPE);
         if (tokenType == JwtRealmSettings.TokenType.ID_TOKEN) {
+            this.fallbackClaimNames = Map.of();
             this.jwtFieldValidators = configureFieldValidatorsForIdToken(realmConfig);
         } else {
-            this.jwtFieldValidators = configureFieldValidatorsForAccessToken(realmConfig);
+            this.fallbackClaimNames = Map.ofEntries(
+                Map.entry("sub", realmConfig.getSetting(JwtRealmSettings.FALLBACK_SUB_CLAIM)),
+                Map.entry("aud", realmConfig.getSetting(JwtRealmSettings.FALLBACK_AUD_CLAIM))
+            );
+            this.jwtFieldValidators = configureFieldValidatorsForAccessToken(realmConfig, fallbackClaimNames);
         }
         this.jwtSignatureValidator = new JwtSignatureValidator.DelegatingJwtSignatureValidator(realmConfig, sslService, reloadNotifier);
     }
@@ -112,6 +119,10 @@ public class JwtAuthenticator implements Releasable {
         return tokenType;
     }
 
+    public Map<String, String> getFallbackClaimNames() {
+        return fallbackClaimNames;
+    }
+
     // Package private for testing
     JwtSignatureValidator.DelegatingJwtSignatureValidator getJwtSignatureValidator() {
         assert jwtSignatureValidator instanceof JwtSignatureValidator.DelegatingJwtSignatureValidator;
@@ -122,10 +133,19 @@ public class JwtAuthenticator implements Releasable {
         assert realmConfig.getSetting(JwtRealmSettings.TOKEN_TYPE) == JwtRealmSettings.TokenType.ID_TOKEN;
         final TimeValue allowedClockSkew = realmConfig.getSetting(JwtRealmSettings.ALLOWED_CLOCK_SKEW);
         final Clock clock = Clock.systemUTC();
+
+        final JwtStringClaimValidator subjectClaimValidator;
+        if (realmConfig.hasSetting(JwtRealmSettings.ALLOWED_SUBJECTS)) {
+            subjectClaimValidator = new JwtStringClaimValidator("sub", realmConfig.getSetting(JwtRealmSettings.ALLOWED_SUBJECTS), true);
+        } else {
+            // Allow any value for the sub claim as long as there is a non-null value
+            subjectClaimValidator = JwtStringClaimValidator.ALLOW_ALL_SUBJECTS;
+        }
+
         return List.of(
             JwtTypeValidator.INSTANCE,
-            // TODO: mandate "sub" claim once access token support is in place
             new JwtStringClaimValidator("iss", List.of(realmConfig.getSetting(JwtRealmSettings.ALLOWED_ISSUER)), true),
+            subjectClaimValidator,
             new JwtStringClaimValidator("aud", realmConfig.getSetting(JwtRealmSettings.ALLOWED_AUDIENCES), false),
             new JwtAlgorithmValidator(realmConfig.getSetting(JwtRealmSettings.ALLOWED_SIGNATURE_ALGORITHMS)),
             new JwtDateClaimValidator(clock, "iat", allowedClockSkew, JwtDateClaimValidator.Relationship.BEFORE_NOW, false),
@@ -135,8 +155,23 @@ public class JwtAuthenticator implements Releasable {
         );
     }
 
-    private static List<JwtFieldValidator> configureFieldValidatorsForAccessToken(RealmConfig realmConfig) {
+    private static List<JwtFieldValidator> configureFieldValidatorsForAccessToken(
+        RealmConfig realmConfig,
+        Map<String, String> fallbackClaimLookup
+    ) {
         assert realmConfig.getSetting(JwtRealmSettings.TOKEN_TYPE) == JwtRealmSettings.TokenType.ACCESS_TOKEN;
-        throw new UnsupportedOperationException("NYI");
+        final TimeValue allowedClockSkew = realmConfig.getSetting(JwtRealmSettings.ALLOWED_CLOCK_SKEW);
+        final Clock clock = Clock.systemUTC();
+
+        return List.of(
+            JwtTypeValidator.INSTANCE,
+            new JwtStringClaimValidator("iss", List.of(realmConfig.getSetting(JwtRealmSettings.ALLOWED_ISSUER)), true),
+            new JwtStringClaimValidator("sub", fallbackClaimLookup, realmConfig.getSetting(JwtRealmSettings.ALLOWED_SUBJECTS), true),
+            new JwtStringClaimValidator("aud", fallbackClaimLookup, realmConfig.getSetting(JwtRealmSettings.ALLOWED_AUDIENCES), false),
+            new JwtAlgorithmValidator(realmConfig.getSetting(JwtRealmSettings.ALLOWED_SIGNATURE_ALGORITHMS)),
+            new JwtDateClaimValidator(clock, "iat", allowedClockSkew, JwtDateClaimValidator.Relationship.BEFORE_NOW, false),
+            new JwtDateClaimValidator(clock, "exp", allowedClockSkew, JwtDateClaimValidator.Relationship.AFTER_NOW, false)
+        );
+
     }
 }

+ 6 - 10
x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtDateClaimValidator.java

@@ -10,10 +10,8 @@ package org.elasticsearch.xpack.security.authc.jwt;
 import com.nimbusds.jose.JWSHeader;
 import com.nimbusds.jwt.JWTClaimsSet;
 
-import org.elasticsearch.ElasticsearchSecurityException;
 import org.elasticsearch.core.Strings;
 import org.elasticsearch.core.TimeValue;
-import org.elasticsearch.rest.RestStatus;
 
 import java.text.ParseException;
 import java.time.Clock;
@@ -47,14 +45,14 @@ public class JwtDateClaimValidator implements JwtFieldValidator {
         try {
             claimValue = jwtClaimsSet.getDateClaim(claimName);
         } catch (ParseException e) {
-            throw new ElasticsearchSecurityException("cannot parse date claim [" + claimName + "]", RestStatus.BAD_REQUEST, e);
+            throw new IllegalArgumentException("cannot parse date claim [" + claimName + "]", e);
         }
 
         if (claimValue == null) {
             if (allowNull) {
                 return;
             } else {
-                throw new ElasticsearchSecurityException("missing required date claim [" + claimName + "]");
+                throw new IllegalArgumentException("missing required date claim [" + claimName + "]");
             }
         }
 
@@ -64,27 +62,25 @@ public class JwtDateClaimValidator implements JwtFieldValidator {
         switch (relationship) {
             case BEFORE_NOW:
                 if (false == claimInstant.isBefore(now.plusSeconds(allowedClockSkewSeconds))) {
-                    throw new ElasticsearchSecurityException(
+                    throw new IllegalArgumentException(
                         Strings.format(
                             "date claim [%s] value [%s] must be before now [%s]",
                             claimName,
                             claimInstant.toEpochMilli(),
                             now.toEpochMilli()
-                        ),
-                        RestStatus.BAD_REQUEST
+                        )
                     );
                 }
                 break;
             case AFTER_NOW:
                 if (false == claimInstant.isAfter(now.minusSeconds(allowedClockSkewSeconds))) {
-                    throw new ElasticsearchSecurityException(
+                    throw new IllegalArgumentException(
                         Strings.format(
                             "date claim [%s] value [%s] must be after now [%s]",
                             claimName,
                             claimInstant.toEpochMilli(),
                             now.toEpochMilli()
-                        ),
-                        RestStatus.BAD_REQUEST
+                        )
                     );
                 }
                 break;

+ 15 - 5
x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealm.java

@@ -83,11 +83,7 @@ public class JwtRealm extends Realm implements CachingRealm, Releasable {
         this.userRoleMapper = userRoleMapper;
         this.userRoleMapper.refreshRealmOnChange(this);
         this.allowedClockSkew = realmConfig.getSetting(JwtRealmSettings.ALLOWED_CLOCK_SKEW);
-        this.claimParserPrincipal = ClaimParser.forSetting(logger, JwtRealmSettings.CLAIMS_PRINCIPAL, realmConfig, true);
-        this.claimParserGroups = ClaimParser.forSetting(logger, JwtRealmSettings.CLAIMS_GROUPS, realmConfig, false);
-        this.claimParserDn = ClaimParser.forSetting(logger, JwtRealmSettings.CLAIMS_DN, realmConfig, false);
-        this.claimParserMail = ClaimParser.forSetting(logger, JwtRealmSettings.CLAIMS_MAIL, realmConfig, false);
-        this.claimParserName = ClaimParser.forSetting(logger, JwtRealmSettings.CLAIMS_NAME, realmConfig, false);
+
         this.populateUserMetadata = realmConfig.getSetting(JwtRealmSettings.POPULATE_USER_METADATA);
         this.clientAuthenticationType = realmConfig.getSetting(JwtRealmSettings.CLIENT_AUTHENTICATION_TYPE);
         final SecureString sharedSecret = realmConfig.getSetting(JwtRealmSettings.CLIENT_AUTHENTICATION_SHARED_SECRET);
@@ -115,6 +111,20 @@ public class JwtRealm extends Realm implements CachingRealm, Releasable {
             this.jwtCacheHelper = null;
         }
         jwtAuthenticator = new JwtAuthenticator(realmConfig, sslService, this::expireAll);
+
+        final Map<String, String> fallbackClaimNames = jwtAuthenticator.getFallbackClaimNames();
+
+        this.claimParserPrincipal = ClaimParser.forSetting(
+            logger,
+            JwtRealmSettings.CLAIMS_PRINCIPAL,
+            fallbackClaimNames,
+            realmConfig,
+            true
+        );
+        this.claimParserGroups = ClaimParser.forSetting(logger, JwtRealmSettings.CLAIMS_GROUPS, fallbackClaimNames, realmConfig, false);
+        this.claimParserDn = ClaimParser.forSetting(logger, JwtRealmSettings.CLAIMS_DN, fallbackClaimNames, realmConfig, false);
+        this.claimParserMail = ClaimParser.forSetting(logger, JwtRealmSettings.CLAIMS_MAIL, fallbackClaimNames, realmConfig, false);
+        this.claimParserName = ClaimParser.forSetting(logger, JwtRealmSettings.CLAIMS_NAME, fallbackClaimNames, realmConfig, false);
     }
 
     /**

+ 30 - 26
x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtStringClaimValidator.java

@@ -10,12 +10,11 @@ package org.elasticsearch.xpack.security.authc.jwt;
 import com.nimbusds.jose.JWSHeader;
 import com.nimbusds.jwt.JWTClaimsSet;
 
-import org.elasticsearch.ElasticsearchSecurityException;
 import org.elasticsearch.common.Strings;
-import org.elasticsearch.rest.RestStatus;
+import org.elasticsearch.core.Nullable;
 
-import java.text.ParseException;
 import java.util.List;
+import java.util.Map;
 
 /**
  * Validates a string claim against a list of allowed values. The validation is successful
@@ -25,60 +24,65 @@ import java.util.List;
  * values.
  * Whether a claim's value can be an array of strings is customised with the {@link #singleValuedClaim}
  * field, which enforces the claim's value to be a single string if it is configured to {@code true}.
+ *
+ * NOTE the allowed values can be null which means skipping the actual value check, i.e. the validator
+ * succeeds as long as there is a (non-null) value.
  */
 public class JwtStringClaimValidator implements JwtFieldValidator {
 
+    public static JwtStringClaimValidator ALLOW_ALL_SUBJECTS = new JwtStringClaimValidator("sub", null, true);
+
     private final String claimName;
+    @Nullable
+    private final Map<String, String> fallbackClaimNames;
+    @Nullable
     private final List<String> allowedClaimValues;
     // Whether the claim should be a single string
     private final boolean singleValuedClaim;
 
     public JwtStringClaimValidator(String claimName, List<String> allowedClaimValues, boolean singleValuedClaim) {
+        this(claimName, null, allowedClaimValues, singleValuedClaim);
+    }
+
+    public JwtStringClaimValidator(
+        String claimName,
+        Map<String, String> fallbackClaimNames,
+        List<String> allowedClaimValues,
+        boolean singleValuedClaim
+    ) {
         this.claimName = claimName;
+        this.fallbackClaimNames = fallbackClaimNames;
         this.allowedClaimValues = allowedClaimValues;
         this.singleValuedClaim = singleValuedClaim;
     }
 
     @Override
     public void validate(JWSHeader jwsHeader, JWTClaimsSet jwtClaimsSet) {
-        final List<String> claimValues;
-        try {
-            claimValues = getStringClaimValues(jwtClaimsSet);
-        } catch (ParseException e) {
-            throw new ElasticsearchSecurityException("cannot parse string claim [" + claimName + "]", RestStatus.BAD_REQUEST, e);
-        }
+        final FallbackableClaim fallbackableClaim = new FallbackableClaim(claimName, fallbackClaimNames, jwtClaimsSet);
+        final List<String> claimValues = getStringClaimValues(fallbackableClaim);
         if (claimValues == null) {
-            throw new ElasticsearchSecurityException("missing required string claim [" + claimName + "]", RestStatus.BAD_REQUEST);
+            throw new IllegalArgumentException("missing required string claim [" + fallbackableClaim + "]");
         }
 
-        if (false == claimValues.stream().anyMatch(allowedClaimValues::contains)) {
-            throw new ElasticsearchSecurityException(
+        if (allowedClaimValues != null && false == claimValues.stream().anyMatch(allowedClaimValues::contains)) {
+            throw new IllegalArgumentException(
                 "string claim ["
-                    + claimName
+                    + fallbackableClaim
                     + "] has value ["
                     + Strings.collectionToCommaDelimitedString(claimValues)
                     + "] which does not match allowed claim values ["
                     + Strings.collectionToCommaDelimitedString(allowedClaimValues)
-                    + "]",
-                RestStatus.BAD_REQUEST
+                    + "]"
             );
         }
     }
 
-    private List<String> getStringClaimValues(JWTClaimsSet claimsSet) throws ParseException {
-        // TODO: fallback claims
-        final String actualClaimName = claimName;
-
+    private List<String> getStringClaimValues(FallbackableClaim fallbackableClaim) {
         if (singleValuedClaim) {
-            final String claimValue = claimsSet.getStringClaim(actualClaimName);
+            final String claimValue = fallbackableClaim.getStringClaimValue();
             return claimValue != null ? List.of(claimValue) : null;
         } else {
-            final Object claimValue = claimsSet.getClaim(actualClaimName);
-            if (claimValue instanceof String) {
-                return List.of((String) claimValue);
-            } else {
-                return claimsSet.getStringListClaim(actualClaimName);
-            }
+            return fallbackableClaim.getStringListClaimValue();
         }
     }
 }

+ 1 - 4
x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtTypeValidator.java

@@ -15,9 +15,6 @@ import com.nimbusds.jose.proc.JOSEObjectTypeVerifier;
 import com.nimbusds.jose.proc.SecurityContext;
 import com.nimbusds.jwt.JWTClaimsSet;
 
-import org.elasticsearch.ElasticsearchSecurityException;
-import org.elasticsearch.rest.RestStatus;
-
 public class JwtTypeValidator implements JwtFieldValidator {
 
     private static final JOSEObjectTypeVerifier<SecurityContext> JWT_HEADER_TYPE_VERIFIER = new DefaultJOSEObjectTypeVerifier<>(
@@ -34,7 +31,7 @@ public class JwtTypeValidator implements JwtFieldValidator {
         try {
             JWT_HEADER_TYPE_VERIFIER.verify(jwtHeaderType, null);
         } catch (BadJOSEException e) {
-            throw new ElasticsearchSecurityException("invalid jwt typ header", RestStatus.BAD_REQUEST, e);
+            throw new IllegalArgumentException("invalid jwt typ header", e);
         }
     }
 }

+ 34 - 15
x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/support/ClaimParser.java

@@ -15,9 +15,11 @@ import org.elasticsearch.common.settings.SettingsException;
 import org.elasticsearch.xpack.core.security.authc.RealmConfig;
 import org.elasticsearch.xpack.core.security.authc.RealmSettings;
 import org.elasticsearch.xpack.core.security.authc.support.ClaimSetting;
+import org.elasticsearch.xpack.security.authc.jwt.FallbackableClaim;
 
 import java.util.Collection;
 import java.util.List;
+import java.util.Map;
 import java.util.Objects;
 import java.util.function.Function;
 import java.util.regex.Matcher;
@@ -80,9 +82,9 @@ public final class ClaimParser {
     }
 
     @SuppressWarnings("unchecked")
-    private static Collection<String> parseClaimValues(JWTClaimsSet claimsSet, String claimName, String settingKey) {
+    private static Collection<String> parseClaimValues(JWTClaimsSet claimsSet, FallbackableClaim fallbackableClaim, String settingKey) {
         Collection<String> values;
-        final Object claimValueObject = claimsSet.getClaim(claimName);
+        final Object claimValueObject = claimsSet.getClaim(fallbackableClaim.getActualName());
         if (claimValueObject == null) {
             values = List.of();
         } else if (claimValueObject instanceof String) {
@@ -91,50 +93,67 @@ public final class ClaimParser {
             && ((Collection<?>) claimValueObject).stream().allMatch(c -> c instanceof String)) {
                 values = (Collection<String>) claimValueObject;
             } else {
-                throw new SettingsException("Setting [ " + settingKey + " expects a claim with String or a String Array value");
+                throw new SettingsException(
+                    "Setting [ " + settingKey + "] expects claim [" + fallbackableClaim + "] with String or a String Array value"
+                );
             }
         return values;
     }
 
     public static ClaimParser forSetting(Logger logger, ClaimSetting setting, RealmConfig realmConfig, boolean required) {
+        return forSetting(logger, setting, Map.of(), realmConfig, required);
+    }
+
+    public static ClaimParser forSetting(
+        Logger logger,
+        ClaimSetting setting,
+        Map<String, String> fallbackClaimNames,
+        RealmConfig realmConfig,
+        boolean required
+    ) {
 
         if (realmConfig.hasSetting(setting.getClaim())) {
-            String claimName = realmConfig.getSetting(setting.getClaim());
+            final String claimName = realmConfig.getSetting(setting.getClaim());
             if (realmConfig.hasSetting(setting.getPattern())) {
                 Pattern regex = Pattern.compile(realmConfig.getSetting(setting.getPattern()));
                 return new ClaimParser(setting.name(realmConfig), claimName, regex.pattern(), claims -> {
+                    final FallbackableClaim fallbackableClaim = new FallbackableClaim(claimName, fallbackClaimNames, claims);
                     Collection<String> values = parseClaimValues(
                         claims,
-                        claimName,
+                        fallbackableClaim,
                         RealmSettings.getFullSettingKey(realmConfig, setting.getClaim())
                     );
                     return values.stream().map(s -> {
                         if (s == null) {
-                            logger.debug("Claim [{}] is null", claimName);
+                            logger.debug("Claim [{}] is null", fallbackableClaim);
                             return null;
                         }
                         final Matcher matcher = regex.matcher(s);
                         if (matcher.find() == false) {
-                            logger.debug("Claim [{}] is [{}], which does not match [{}]", claimName, s, regex.pattern());
+                            logger.debug("Claim [{}] is [{}], which does not match [{}]", fallbackableClaim, s, regex.pattern());
                             return null;
                         }
                         final String value = matcher.group(1);
                         if (Strings.isNullOrEmpty(value)) {
-                            logger.debug("Claim [{}] is [{}], which does match [{}] but group(1) is empty", claimName, s, regex.pattern());
+                            logger.debug(
+                                "Claim [{}] is [{}], which does match [{}] but group(1) is empty",
+                                fallbackableClaim,
+                                s,
+                                regex.pattern()
+                            );
                             return null;
                         }
                         return value;
                     }).filter(Objects::nonNull).toList();
                 });
             } else {
-                return new ClaimParser(
-                    setting.name(realmConfig),
-                    claimName,
-                    null,
-                    claims -> parseClaimValues(claims, claimName, RealmSettings.getFullSettingKey(realmConfig, setting.getClaim())).stream()
+                return new ClaimParser(setting.name(realmConfig), claimName, null, claims -> {
+                    final FallbackableClaim fallbackableClaim = new FallbackableClaim(claimName, fallbackClaimNames, claims);
+                    return parseClaimValues(claims, fallbackableClaim, RealmSettings.getFullSettingKey(realmConfig, setting.getClaim()))
+                        .stream()
                         .filter(Objects::nonNull)
-                        .toList()
-                );
+                        .toList();
+                });
             }
         } else if (required) {
             throw new SettingsException("Setting [" + RealmSettings.getFullSettingKey(realmConfig, setting.getClaim()) + "] is required");

+ 53 - 0
x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/FallbackableClaimTests.java

@@ -0,0 +1,53 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.security.authc.jwt;
+
+import com.nimbusds.jwt.JWTClaimsSet;
+
+import org.elasticsearch.test.ESTestCase;
+
+import java.text.ParseException;
+import java.util.List;
+import java.util.Map;
+
+import static org.hamcrest.Matchers.equalTo;
+
+public class FallbackableClaimTests extends ESTestCase {
+
+    public void testNoFallback() throws ParseException {
+        final String name = randomAlphaOfLength(10);
+        final String value = randomAlphaOfLength(10);
+        final FallbackableClaim fallbackableClaim = new FallbackableClaim(name, null, JWTClaimsSet.parse(Map.of(name, value)));
+        assertThat(fallbackableClaim.getActualName(), equalTo(name));
+        assertThat(fallbackableClaim.toString(), equalTo(name));
+        assertThat(fallbackableClaim.getStringClaimValue(), equalTo(value));
+        assertThat(fallbackableClaim.getStringListClaimValue(), equalTo(List.of(value)));
+    }
+
+    public void testFallback() throws ParseException {
+        final String name = randomAlphaOfLength(10);
+        final String fallbackName = randomAlphaOfLength(12);
+        final String value = randomAlphaOfLength(10);
+
+        // fallback ignored
+        final JWTClaimsSet claimSet1 = JWTClaimsSet.parse(Map.of(name, value, fallbackName, randomAlphaOfLength(16)));
+        final FallbackableClaim fallbackableClaim1 = new FallbackableClaim(name, Map.of(name, fallbackName), claimSet1);
+        assertThat(fallbackableClaim1.getActualName(), equalTo(name));
+        assertThat(fallbackableClaim1.toString(), equalTo(name));
+        assertThat(fallbackableClaim1.getStringClaimValue(), equalTo(value));
+        assertThat(fallbackableClaim1.getStringListClaimValue(), equalTo(List.of(value)));
+
+        // fallback active
+        final JWTClaimsSet claimSet2 = JWTClaimsSet.parse(Map.of(fallbackName, value));
+        final FallbackableClaim fallbackableClaim2 = new FallbackableClaim(name, Map.of(name, fallbackName), claimSet2);
+        assertThat(fallbackableClaim2.getActualName(), equalTo(fallbackName));
+        assertThat(fallbackableClaim2.toString(), equalTo(fallbackName + " (fallback of " + name + ")"));
+        assertThat(fallbackableClaim2.getStringClaimValue(), equalTo(value));
+        assertThat(fallbackableClaim2.getStringListClaimValue(), equalTo(List.of(value)));
+    }
+}

+ 2 - 3
x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtAlgorithmValidatorTests.java

@@ -10,7 +10,6 @@ package org.elasticsearch.xpack.security.authc.jwt;
 import com.nimbusds.jose.JWSHeader;
 import com.nimbusds.jwt.JWTClaimsSet;
 
-import org.elasticsearch.ElasticsearchSecurityException;
 import org.elasticsearch.test.ESTestCase;
 
 import java.text.ParseException;
@@ -39,8 +38,8 @@ public class JwtAlgorithmValidatorTests extends ESTestCase {
         final JwtAlgorithmValidator validator = new JwtAlgorithmValidator(randomList(1, 5, () -> randomAlphaOfLength(8)));
 
         final JWSHeader jwsHeader = JWSHeader.parse(Map.of("alg", algorithm));
-        final ElasticsearchSecurityException e = expectThrows(
-            ElasticsearchSecurityException.class,
+        final IllegalArgumentException e = expectThrows(
+            IllegalArgumentException.class,
             () -> validator.validate(jwsHeader, JWTClaimsSet.parse(Map.of()))
         );
         assertThat(e.getMessage(), containsString("invalid JWT algorithm"));

+ 58 - 0
x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticatorAccessTokenTypeTests.java

@@ -0,0 +1,58 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.security.authc.jwt;
+
+import org.elasticsearch.xpack.core.security.authc.RealmSettings;
+import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings;
+import org.junit.Before;
+
+import java.text.ParseException;
+
+import static org.hamcrest.Matchers.containsString;
+
+public class JwtAuthenticatorAccessTokenTypeTests extends JwtAuthenticatorTests {
+
+    private String fallbackSub;
+    private String fallbackAud;
+
+    @Before
+    public void beforeTest() {
+        doBeforeTest();
+        fallbackSub = randomBoolean() ? "_" + randomAlphaOfLength(5) : null;
+        fallbackAud = randomBoolean() ? "_" + randomAlphaOfLength(8) : null;
+    }
+
+    @Override
+    protected JwtRealmSettings.TokenType getTokenType() {
+        return JwtRealmSettings.TokenType.ACCESS_TOKEN;
+    }
+
+    public void testSubjectIsRequired() throws ParseException {
+        final IllegalArgumentException e = doTestSubjectIsRequired(buildJwtAuthenticator(fallbackSub, fallbackAud));
+        if (fallbackSub != null) {
+            assertThat(e.getMessage(), containsString("missing required string claim [" + fallbackSub + " (fallback of sub)]"));
+        }
+    }
+
+    public void testAccessTokenTypeMandatesAllowedSubjects() {
+        allowedSubject = null;
+        final IllegalArgumentException e = expectThrows(
+            IllegalArgumentException.class,
+            () -> buildJwtAuthenticator(fallbackSub, fallbackAud)
+        );
+
+        assertThat(
+            e.getMessage(),
+            containsString("Invalid empty list for [" + RealmSettings.getFullSettingKey(realmName, JwtRealmSettings.ALLOWED_SUBJECTS) + "]")
+        );
+    }
+
+    public void testInvalidIssuerIsCheckedBeforeAlgorithm() throws ParseException {
+        doTestInvalidIssuerIsCheckedBeforeAlgorithm(buildJwtAuthenticator(fallbackSub, fallbackAud));
+    }
+}

+ 42 - 0
x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticatorIdTokenTypeTests.java

@@ -0,0 +1,42 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.security.authc.jwt;
+
+import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings;
+import org.junit.Before;
+
+import java.text.ParseException;
+
+import static org.hamcrest.Matchers.containsString;
+
+public class JwtAuthenticatorIdTokenTypeTests extends JwtAuthenticatorTests {
+
+    private String fallbackSub;
+    private String fallbackAud;
+
+    @Before
+    public void beforeTest() {
+        doBeforeTest();
+        fallbackSub = null;
+        fallbackAud = null;
+    }
+
+    @Override
+    protected JwtRealmSettings.TokenType getTokenType() {
+        return JwtRealmSettings.TokenType.ID_TOKEN;
+    }
+
+    public void testSubjectIsRequired() throws ParseException {
+        final IllegalArgumentException e = doTestSubjectIsRequired(buildJwtAuthenticator(fallbackSub, fallbackAud));
+        assertThat(e.getMessage(), containsString("missing required string claim [sub]"));
+    }
+
+    public void testInvalidIssuerIsCheckedBeforeAlgorithm() throws ParseException {
+        doTestInvalidIssuerIsCheckedBeforeAlgorithm(buildJwtAuthenticator(fallbackSub, fallbackAud));
+    }
+}

+ 59 - 13
x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticatorTests.java

@@ -12,12 +12,12 @@ import com.nimbusds.jose.util.Base64URL;
 import com.nimbusds.jwt.JWTClaimsSet;
 import com.nimbusds.jwt.SignedJWT;
 
-import org.elasticsearch.ElasticsearchSecurityException;
 import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.common.settings.MockSecureSettings;
 import org.elasticsearch.common.settings.SecureString;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
+import org.elasticsearch.core.Nullable;
 import org.elasticsearch.env.TestEnvironment;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xpack.core.security.authc.RealmConfig;
@@ -31,20 +31,50 @@ import static org.elasticsearch.test.TestMatchers.throwableWithMessage;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
 
-public class JwtAuthenticatorTests extends ESTestCase {
+public abstract class JwtAuthenticatorTests extends ESTestCase {
 
-    public void testInvalidIssuerIsCheckedBeforeAlgorithm() throws ParseException {
-        final String realmName = randomAlphaOfLengthBetween(3, 8);
-        final String allowedIssuer = randomAlphaOfLength(6);
-        final String allowedAlgorithm = randomFrom(JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS_HMAC);
-        final JwtAuthenticator jwtAuthenticator = buildJwtAuthenticator(realmName, allowedAlgorithm, allowedIssuer);
+    protected String realmName;
+    protected String allowedAlgorithm;
+    protected String allowedIssuer;
+    @Nullable
+    protected String allowedSubject;
+    protected String allowedAudience;
 
+    protected abstract JwtRealmSettings.TokenType getTokenType();
+
+    protected void doBeforeTest() {
+        realmName = randomAlphaOfLengthBetween(3, 8);
+        allowedIssuer = randomAlphaOfLength(6);
+        allowedAlgorithm = randomFrom(JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS_HMAC);
+        if (getTokenType() == JwtRealmSettings.TokenType.ID_TOKEN) {
+            allowedSubject = randomBoolean() ? randomAlphaOfLength(8) : null;
+        } else {
+            allowedSubject = randomAlphaOfLength(8);
+        }
+        allowedAudience = randomAlphaOfLength(10);
+    }
+
+    protected IllegalArgumentException doTestSubjectIsRequired(JwtAuthenticator jwtAuthenticator) throws ParseException {
+        final SignedJWT signedJWT = new SignedJWT(
+            JWSHeader.parse(Map.of("alg", allowedAlgorithm)).toBase64URL(),
+            JWTClaimsSet.parse(Map.of("iss", allowedIssuer)).toPayload().toBase64URL(),
+            Base64URL.encode("signature")
+        );
+        final JwtAuthenticationToken jwtAuthenticationToken = mock(JwtAuthenticationToken.class);
+        when(jwtAuthenticationToken.getEndUserSignedJwt()).thenReturn(new SecureString(signedJWT.serialize().toCharArray()));
+
+        final PlainActionFuture<JWTClaimsSet> future = new PlainActionFuture<>();
+        jwtAuthenticator.authenticate(jwtAuthenticationToken, future);
+        return expectThrows(IllegalArgumentException.class, future::actionGet);
+    }
+
+    protected void doTestInvalidIssuerIsCheckedBeforeAlgorithm(JwtAuthenticator jwtAuthenticator) throws ParseException {
         // A JWT token that has mismatch for both algorithm and issuer
         final String invalidAlgorithm = randomValueOtherThan(allowedAlgorithm, () -> randomAlphaOfLengthBetween(3, 8));
         final String invalidIssuer = randomValueOtherThan(allowedIssuer, () -> randomAlphaOfLengthBetween(3, 8));
         final SignedJWT signedJWT = new SignedJWT(
             JWSHeader.parse(Map.of("alg", invalidAlgorithm)).toBase64URL(),
-            JWTClaimsSet.parse(Map.of("iss", invalidIssuer)).toPayload().toBase64URL(),
+            JWTClaimsSet.parse(Map.of("iss", invalidIssuer, "sub", randomAlphaOfLengthBetween(3, 8))).toPayload().toBase64URL(),
             Base64URL.encode("signature")
         );
         final JwtAuthenticationToken jwtAuthenticationToken = mock(JwtAuthenticationToken.class);
@@ -53,8 +83,7 @@ public class JwtAuthenticatorTests extends ESTestCase {
         final PlainActionFuture<JWTClaimsSet> future = new PlainActionFuture<>();
         jwtAuthenticator.authenticate(jwtAuthenticationToken, future);
 
-        final ElasticsearchSecurityException e = expectThrows(ElasticsearchSecurityException.class, future::actionGet);
-
+        final IllegalArgumentException e = expectThrows(IllegalArgumentException.class, future::actionGet);
         assertThat(
             e,
             throwableWithMessage(
@@ -63,7 +92,7 @@ public class JwtAuthenticatorTests extends ESTestCase {
         );
     }
 
-    private JwtAuthenticator buildJwtAuthenticator(String realmName, String allowedAlgorithm, String allowedIssuer) {
+    protected JwtAuthenticator buildJwtAuthenticator(String fallbackSub, String fallbackAud) {
         final RealmConfig.RealmIdentifier realmIdentifier = new RealmConfig.RealmIdentifier(JwtRealmSettings.TYPE, realmName);
         final MockSecureSettings secureSettings = new MockSecureSettings();
         secureSettings.setString(RealmSettings.getFullSettingKey(realmName, JwtRealmSettings.HMAC_KEY), randomAlphaOfLength(40));
@@ -74,9 +103,26 @@ public class JwtAuthenticatorTests extends ESTestCase {
             .put(RealmSettings.getFullSettingKey(realmIdentifier, RealmSettings.ORDER_SETTING), randomIntBetween(0, 99))
             .put("path.home", randomAlphaOfLength(10))
             .setSecureSettings(secureSettings);
-        if (randomBoolean()) {
-            builder.put(RealmSettings.getFullSettingKey(realmName, JwtRealmSettings.TOKEN_TYPE), "id_token");
+
+        if (allowedSubject != null) {
+            builder.put(RealmSettings.getFullSettingKey(realmName, JwtRealmSettings.ALLOWED_SUBJECTS), allowedSubject);
         }
+
+        if (getTokenType() == JwtRealmSettings.TokenType.ID_TOKEN) {
+            if (randomBoolean()) {
+                builder.put(RealmSettings.getFullSettingKey(realmName, JwtRealmSettings.TOKEN_TYPE), "id_token");
+            }
+        } else {
+            builder.put(RealmSettings.getFullSettingKey(realmName, JwtRealmSettings.TOKEN_TYPE), "access_token");
+        }
+
+        if (fallbackSub != null) {
+            builder.put(RealmSettings.getFullSettingKey(realmName, JwtRealmSettings.FALLBACK_SUB_CLAIM), fallbackSub);
+        }
+        if (fallbackAud != null) {
+            builder.put(RealmSettings.getFullSettingKey(realmName, JwtRealmSettings.FALLBACK_AUD_CLAIM), fallbackAud);
+        }
+
         final Settings settings = builder.build();
 
         final RealmConfig realmConfig = new RealmConfig(

+ 8 - 9
x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtDateClaimValidatorTests.java

@@ -10,7 +10,6 @@ package org.elasticsearch.xpack.security.authc.jwt;
 import com.nimbusds.jose.JWSHeader;
 import com.nimbusds.jwt.JWTClaimsSet;
 
-import org.elasticsearch.ElasticsearchSecurityException;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.test.ESTestCase;
 import org.junit.Before;
@@ -44,8 +43,8 @@ public class JwtDateClaimValidatorTests extends ESTestCase {
         );
 
         final JWTClaimsSet jwtClaimsSet = JWTClaimsSet.parse(Map.of(claimName, randomAlphaOfLengthBetween(3, 8)));
-        final ElasticsearchSecurityException e = expectThrows(
-            ElasticsearchSecurityException.class,
+        final IllegalArgumentException e = expectThrows(
+            IllegalArgumentException.class,
             () -> validator.validate(getJwsHeader(), jwtClaimsSet)
         );
         assertThat(e.getMessage(), containsString("cannot parse date claim"));
@@ -64,8 +63,8 @@ public class JwtDateClaimValidatorTests extends ESTestCase {
         );
 
         final JWTClaimsSet jwtClaimsSet = JWTClaimsSet.parse(Map.of());
-        final ElasticsearchSecurityException e = expectThrows(
-            ElasticsearchSecurityException.class,
+        final IllegalArgumentException e = expectThrows(
+            IllegalArgumentException.class,
             () -> validator.validate(getJwsHeader(), jwtClaimsSet)
         );
         assertThat(e.getMessage(), containsString("missing required date claim"));
@@ -112,8 +111,8 @@ public class JwtDateClaimValidatorTests extends ESTestCase {
         }
 
         final Instant after = now.plusSeconds(randomLongBetween(1 + allowedSkewInSeconds, 600));
-        final ElasticsearchSecurityException e = expectThrows(
-            ElasticsearchSecurityException.class,
+        final IllegalArgumentException e = expectThrows(
+            IllegalArgumentException.class,
             () -> validator.validate(getJwsHeader(), JWTClaimsSet.parse(Map.of(claimName, after.getEpochSecond())))
         );
         assertThat(
@@ -145,8 +144,8 @@ public class JwtDateClaimValidatorTests extends ESTestCase {
         when(clock.instant()).thenReturn(now);
 
         final Instant before = now.minusSeconds(randomLongBetween(1 + allowedSkewInSeconds, 600));
-        final ElasticsearchSecurityException e = expectThrows(
-            ElasticsearchSecurityException.class,
+        final IllegalArgumentException e = expectThrows(
+            IllegalArgumentException.class,
             () -> validator.validate(getJwsHeader(), JWTClaimsSet.parse(Map.of(claimName, before.getEpochSecond())))
         );
         assertThat(

+ 167 - 0
x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmAuthenticateAccessTokenTypeTests.java

@@ -0,0 +1,167 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.security.authc.jwt;
+
+import com.nimbusds.jose.JOSEObjectType;
+import com.nimbusds.jose.jwk.JWK;
+import com.nimbusds.jwt.SignedJWT;
+import com.nimbusds.openid.connect.sdk.Nonce;
+
+import org.elasticsearch.common.settings.SecureString;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.xpack.core.security.authc.RealmSettings;
+import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings;
+import org.elasticsearch.xpack.core.security.user.User;
+
+import java.time.Instant;
+import java.time.temporal.ChronoUnit;
+import java.util.Date;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.is;
+
+public class JwtRealmAuthenticateAccessTokenTypeTests extends JwtRealmTestCase {
+
+    private String fallbackSub;
+    private String fallbackAud;
+    private SignedJWT unsignedJwt;
+
+    public void testAccessTokenTypeWorksWithNoFallback() throws Exception {
+        noFallback();
+
+        jwtIssuerAndRealms = generateJwtIssuerRealmPairs(
+            createJwtRealmsSettingsBuilder(),
+            randomIntBetween(1, 1), // realms
+            randomIntBetween(0, 1), // authz
+            randomIntBetween(1, JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS.size()), // algorithms
+            randomIntBetween(1, 3), // audiences
+            randomIntBetween(1, 3), // users
+            randomIntBetween(0, 3), // roles
+            randomIntBetween(0, 1), // jwtCacheSize
+            randomBoolean() // createHttpsServer
+        );
+        final JwtIssuerAndRealm jwtIssuerAndRealm = randomJwtIssuerRealmPair();
+        final User user = randomUser(jwtIssuerAndRealm.issuer());
+
+        final SecureString jwt = randomJwt(jwtIssuerAndRealm, user);
+        final SecureString clientSecret = JwtRealmInspector.getClientAuthenticationSharedSecret(jwtIssuerAndRealm.realm());
+        doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwt, clientSecret, randomIntBetween(1, 3));
+    }
+
+    public void testAccessTokenTypeWorksWithFallbacks() throws Exception {
+        randomFallbacks();
+
+        jwtIssuerAndRealms = generateJwtIssuerRealmPairs(
+            createJwtRealmsSettingsBuilder(),
+            randomIntBetween(1, 1), // realms
+            randomIntBetween(0, 1), // authz
+            randomIntBetween(1, JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS.size()), // algorithms
+            randomIntBetween(1, 3), // audiences
+            randomIntBetween(1, 3), // users
+            randomIntBetween(0, 3), // roles
+            randomIntBetween(0, 1), // jwtCacheSize
+            randomBoolean() // createHttpsServer
+        );
+        final JwtIssuerAndRealm jwtIssuerAndRealm = randomJwtIssuerRealmPair();
+        final User user = randomUser(jwtIssuerAndRealm.issuer());
+
+        final SecureString jwt2 = randomJwt(jwtIssuerAndRealm, user);
+        final SecureString clientSecret = JwtRealmInspector.getClientAuthenticationSharedSecret(jwtIssuerAndRealm.realm());
+        doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwt2, clientSecret, randomIntBetween(1, 3));
+    }
+
+    @Override
+    protected JwtRealmSettingsBuilder createJwtRealmSettingsBuilder(JwtIssuer jwtIssuer, int authzCount, int jwtCacheSize)
+        throws Exception {
+        final JwtRealmSettingsBuilder jwtRealmSettingsBuilder = super.createJwtRealmSettingsBuilder(jwtIssuer, authzCount, jwtCacheSize);
+        final String realmName = jwtRealmSettingsBuilder.name();
+        final Settings.Builder settingsBuilder = jwtRealmSettingsBuilder.settingsBuilder();
+        settingsBuilder.put(RealmSettings.getFullSettingKey(realmName, JwtRealmSettings.TOKEN_TYPE), "access_token")
+            .putList(
+                RealmSettings.getFullSettingKey(realmName, JwtRealmSettings.ALLOWED_SUBJECTS),
+                jwtIssuer.principals.keySet().stream().toList()
+            );
+
+        if (fallbackSub != null) {
+            settingsBuilder.put(RealmSettings.getFullSettingKey(realmName, JwtRealmSettings.FALLBACK_SUB_CLAIM), fallbackSub);
+        }
+        if (fallbackAud != null) {
+            settingsBuilder.put(RealmSettings.getFullSettingKey(realmName, JwtRealmSettings.FALLBACK_AUD_CLAIM), fallbackAud);
+        }
+
+        return jwtRealmSettingsBuilder;
+    }
+
+    @Override
+    protected SecureString randomJwt(JwtIssuerAndRealm jwtIssuerAndRealm, User user) throws Exception {
+        final JwtIssuer.AlgJwkPair algJwkPair = randomFrom(jwtIssuerAndRealm.issuer().algAndJwksAll);
+        final JWK jwk = algJwkPair.jwk();
+
+        final HashMap<String, Object> otherClaims = new HashMap<>();
+        if (randomBoolean()) {
+            otherClaims.putAll(Map.of("other1", randomAlphaOfLength(10), "other2", randomAlphaOfLength(10)));
+        }
+
+        // Randomly set the fallback claims, it can co-exist with the original one in which case it is ignored
+        String subClaimValue = user.principal();
+        if (fallbackSub != null) {
+            if (randomBoolean()) {
+                // original claim does not exist, so it's the effective fallback
+                otherClaims.put(fallbackSub, subClaimValue);
+                subClaimValue = null;
+            } else {
+                // original claim still exist, in this case, the fallback can be anything and it does not matter
+                otherClaims.put(fallbackSub, randomValueOtherThan(subClaimValue, () -> randomAlphaOfLength(15)));
+            }
+        }
+        // TODO: fallback aud
+        List<String> audClaimValue = JwtRealmInspector.getAllowedAudiences(jwtIssuerAndRealm.realm());
+
+        // A bogus auth_time but access_token type does not check it
+        if (randomBoolean()) {
+            otherClaims.put("auth_time", randomAlphaOfLengthBetween(6, 18));
+        }
+
+        final Instant now = Instant.now().truncatedTo(ChronoUnit.SECONDS);
+        unsignedJwt = JwtTestCase.buildUnsignedJwt(
+            randomBoolean() ? null : JOSEObjectType.JWT.toString(), // kty
+            randomBoolean() ? null : jwk.getKeyID(), // kid
+            algJwkPair.alg(), // alg
+            randomAlphaOfLengthBetween(10, 20), // jwtID
+            JwtRealmInspector.getAllowedIssuer(jwtIssuerAndRealm.realm()), // iss
+            audClaimValue,
+            subClaimValue,
+            JwtRealmInspector.getPrincipalClaimName(jwtIssuerAndRealm.realm()), // principal claim name
+            user.principal(), // principal claim value
+            JwtRealmInspector.getGroupsClaimName(jwtIssuerAndRealm.realm()), // group claim name
+            List.of(user.roles()), // group claim value
+            null,
+            Date.from(now.minusSeconds(randomBoolean() ? 0 : 60 * randomLongBetween(5, 10))), // iat
+            Date.from(now), // nbf
+            Date.from(now.plusSeconds(60 * randomLongBetween(3600, 7200))), // exp
+            randomBoolean() ? null : new Nonce(32).toString(),
+            otherClaims
+        );
+        final SecureString signedJWT = JwtValidateUtil.signJwt(jwk, unsignedJwt);
+        assertThat(JwtValidateUtil.verifyJwt(jwk, SignedJWT.parse(signedJWT.toString())), is(equalTo(true)));
+        return signedJWT;
+    }
+
+    private void noFallback() {
+        fallbackSub = null;
+        fallbackAud = null;
+    }
+
+    private void randomFallbacks() {
+        fallbackSub = randomBoolean() ? "_" + randomAlphaOfLength(5) : null;
+        fallbackAud = randomBoolean() || fallbackSub == null ? "_" + randomAlphaOfLength(8) : null;
+    }
+}

+ 57 - 57
x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmAuthenticateTests.java

@@ -50,21 +50,21 @@ public class JwtRealmAuthenticateTests extends JwtRealmTestCase {
     public void testJwtAuthcRealmAuthcAuthzWithEmptyRoles() throws Exception {
         this.jwtIssuerAndRealms = this.generateJwtIssuerRealmPairs(
             this.createJwtRealmsSettingsBuilder(),
-            new MinMax(1, 1), // realmsRange
-            new MinMax(0, 1), // authzRange
-            new MinMax(1, JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS.size()), // algsRange
-            new MinMax(1, 3), // audiencesRange
-            new MinMax(1, 3), // usersRange
-            new MinMax(0, 0), // rolesRange
-            new MinMax(0, 1), // jwtCacheSizeRange
+            randomIntBetween(1, 1), // realmsRange
+            randomIntBetween(0, 1), // authzRange
+            randomIntBetween(1, JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS.size()), // algsRange
+            randomIntBetween(1, 3), // audiencesRange
+            randomIntBetween(1, 3), // usersRange
+            randomIntBetween(0, 0), // rolesRange
+            randomIntBetween(0, 1), // jwtCacheSizeRange
             randomBoolean() // createHttpsServer
         );
         final JwtIssuerAndRealm jwtIssuerAndRealm = this.randomJwtIssuerRealmPair();
         final User user = this.randomUser(jwtIssuerAndRealm.issuer());
         final SecureString jwt = this.randomJwt(jwtIssuerAndRealm, user);
         final SecureString clientSecret = JwtRealmInspector.getClientAuthenticationSharedSecret(jwtIssuerAndRealm.realm());
-        final MinMax jwtAuthcRange = new MinMax(2, 3);
-        this.doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwt, clientSecret, jwtAuthcRange);
+        final int jwtAuthcCount = randomIntBetween(2, 3);
+        this.doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwt, clientSecret, jwtAuthcCount);
     }
 
     /**
@@ -74,13 +74,13 @@ public class JwtRealmAuthenticateTests extends JwtRealmTestCase {
     public void testJwtAuthcRealmAuthcAuthzWithoutAuthzRealms() throws Exception {
         this.jwtIssuerAndRealms = this.generateJwtIssuerRealmPairs(
             this.createJwtRealmsSettingsBuilder(),
-            new MinMax(1, 3), // realmsRange
-            new MinMax(0, 0), // authzRange
-            new MinMax(1, JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS.size()), // algsRange
-            new MinMax(1, 3), // audiencesRange
-            new MinMax(1, 3), // usersRange
-            new MinMax(0, 3), // rolesRange
-            new MinMax(0, 1), // jwtCacheSizeRange
+            randomIntBetween(1, 3), // realmsRange
+            randomIntBetween(0, 0), // authzRange
+            randomIntBetween(1, JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS.size()), // algsRange
+            randomIntBetween(1, 3), // audiencesRange
+            randomIntBetween(1, 3), // usersRange
+            randomIntBetween(0, 3), // rolesRange
+            randomIntBetween(0, 1), // jwtCacheSizeRange
             randomBoolean() // createHttpsServer
         );
         final JwtIssuerAndRealm jwtIssuerAndRealm = this.randomJwtIssuerRealmPair();
@@ -89,8 +89,8 @@ public class JwtRealmAuthenticateTests extends JwtRealmTestCase {
         final User user = this.randomUser(jwtIssuerAndRealm.issuer());
         final SecureString jwt = this.randomJwt(jwtIssuerAndRealm, user);
         final SecureString clientSecret = JwtRealmInspector.getClientAuthenticationSharedSecret(jwtIssuerAndRealm.realm());
-        final MinMax jwtAuthcRange = new MinMax(2, 3);
-        this.doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwt, clientSecret, jwtAuthcRange);
+        final int jwtAuthcCount = randomIntBetween(2, 3);
+        this.doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwt, clientSecret, jwtAuthcCount);
     }
 
     /**
@@ -100,13 +100,13 @@ public class JwtRealmAuthenticateTests extends JwtRealmTestCase {
     public void testJwkSetUpdates() throws Exception {
         this.jwtIssuerAndRealms = this.generateJwtIssuerRealmPairs(
             this.createJwtRealmsSettingsBuilder(),
-            new MinMax(1, 3), // realmsRange
-            new MinMax(0, 0), // authzRange
-            new MinMax(1, JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS.size()), // algsRange
-            new MinMax(1, 3), // audiencesRange
-            new MinMax(1, 3), // usersRange
-            new MinMax(0, 3), // rolesRange
-            new MinMax(0, 1), // jwtCacheSizeRange
+            randomIntBetween(1, 3), // realmsRange
+            randomIntBetween(0, 0), // authzRange
+            randomIntBetween(1, JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS.size()), // algsRange
+            randomIntBetween(1, 3), // audiencesRange
+            randomIntBetween(1, 3), // usersRange
+            randomIntBetween(0, 3), // rolesRange
+            randomIntBetween(0, 1), // jwtCacheSizeRange
             randomBoolean() // createHttpsServer
         );
         final JwtIssuerAndRealm jwtIssuerAndRealm = this.randomJwtIssuerRealmPair();
@@ -115,8 +115,8 @@ public class JwtRealmAuthenticateTests extends JwtRealmTestCase {
         final User user = this.randomUser(jwtIssuerAndRealm.issuer());
         final SecureString jwtJwks1 = this.randomJwt(jwtIssuerAndRealm, user);
         final SecureString clientSecret = JwtRealmInspector.getClientAuthenticationSharedSecret(jwtIssuerAndRealm.realm());
-        final MinMax jwtAuthcRange = new MinMax(2, 3);
-        this.doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwtJwks1, clientSecret, jwtAuthcRange);
+        final int jwtAuthcCount = randomIntBetween(2, 3);
+        this.doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwtJwks1, clientSecret, jwtAuthcCount);
 
         // Details about first JWT using the JWT issuer original JWKs
         final String jwt1JwksAlg = SignedJWT.parse(jwtJwks1.toString()).getHeader().getAlgorithm().getName();
@@ -138,7 +138,7 @@ public class JwtRealmAuthenticateTests extends JwtRealmTestCase {
         LOGGER.debug("JWKs 1 emptied, algs=[{}]", String.join(",", jwtIssuerAndRealm.issuer().algorithmsAll));
 
         // Original JWT continues working, because JWT realm cached old JWKs in memory.
-        this.doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwtJwks1, clientSecret, jwtAuthcRange);
+        this.doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwtJwks1, clientSecret, jwtAuthcCount);
         LOGGER.debug("JWT 1 still worked, because JWT realm has old JWKs cached in memory");
 
         // Restore original JWKs 1 into the JWT issuer.
@@ -148,7 +148,7 @@ public class JwtRealmAuthenticateTests extends JwtRealmTestCase {
         LOGGER.debug("JWKs 1 restored, algs=[{}]", String.join(",", jwtIssuerAndRealm.issuer().algorithmsAll));
 
         // Original JWT continues working, because JWT realm cached old JWKs in memory.
-        this.doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwtJwks1, clientSecret, jwtAuthcRange);
+        this.doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwtJwks1, clientSecret, jwtAuthcCount);
         LOGGER.debug("JWT 1 still worked, because JWT realm has old JWKs cached in memory");
 
         // Generate a replacement set of JWKs 2 for the JWT issuer.
@@ -164,7 +164,7 @@ public class JwtRealmAuthenticateTests extends JwtRealmTestCase {
         // Original JWT continues working, because JWT realm still has original JWKs cached in memory.
         // - jwtJwks1(PKC): Pass (Original PKC JWKs are still in the realm)
         // - jwtJwks1(HMAC): Pass (Original HMAC JWKs are still in the realm)
-        this.doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwtJwks1, clientSecret, jwtAuthcRange);
+        this.doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwtJwks1, clientSecret, jwtAuthcCount);
         LOGGER.debug("JWT 1 still worked, because JWT realm has old JWKs cached in memory");
 
         // Create a JWT using the new JWKs.
@@ -177,7 +177,7 @@ public class JwtRealmAuthenticateTests extends JwtRealmTestCase {
         // - jwtJwks2(PKC): PKC reload triggered and loaded new JWKs, so PASS
         // - jwtJwks2(HMAC): HMAC reload triggered but it is a no-op, so FAIL
         if (isPkcJwtJwks2) {
-            this.doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwtJwks2, clientSecret, jwtAuthcRange);
+            this.doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwtJwks2, clientSecret, jwtAuthcCount);
             LOGGER.debug("PKC JWT 2 worked with JWKs 2");
         } else {
             this.verifyAuthenticateFailureHelper(jwtIssuerAndRealm, jwtJwks2, clientSecret);
@@ -190,7 +190,7 @@ public class JwtRealmAuthenticateTests extends JwtRealmTestCase {
         // - jwtJwks2(HMAC): HMAC reload triggered but it is a no-op, jwtJwks1(PKC): PKC reload not triggered, so PASS
         // - jwtJwks2(HMAC): HMAC reload triggered but it is a no-op, jwtJwks1(HMAC): HMAC reload not triggered, so PASS
         if (isPkcJwtJwks1 == false || isPkcJwtJwks2 == false) {
-            this.doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwtJwks1, clientSecret, jwtAuthcRange);
+            this.doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwtJwks1, clientSecret, jwtAuthcCount);
         } else {
             this.verifyAuthenticateFailureHelper(jwtIssuerAndRealm, jwtJwks1, clientSecret);
         }
@@ -202,7 +202,7 @@ public class JwtRealmAuthenticateTests extends JwtRealmTestCase {
 
         // New JWT continues working because JWT realm will end up with PKC JWKs 2 and HMAC JWKs 1 in memory
         if (isPkcJwtJwks2) {
-            this.doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwtJwks2, clientSecret, jwtAuthcRange);
+            this.doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwtJwks2, clientSecret, jwtAuthcCount);
         } else {
             this.verifyAuthenticateFailureHelper(jwtIssuerAndRealm, jwtJwks2, clientSecret);
         }
@@ -211,7 +211,7 @@ public class JwtRealmAuthenticateTests extends JwtRealmTestCase {
         // - jwtJwks1(HMAC): HMAC reload not triggered, so PASS
         // - jwtJwks1(PKC): PKC reload triggered and loaded new JWKs, so FAIL
         if (isPkcJwtJwks1 == false || isPkcJwtJwks2 == false) {
-            this.doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwtJwks1, clientSecret, jwtAuthcRange);
+            this.doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwtJwks1, clientSecret, jwtAuthcCount);
         } else {
             this.verifyAuthenticateFailureHelper(jwtIssuerAndRealm, jwtJwks1, clientSecret);
         }
@@ -222,7 +222,7 @@ public class JwtRealmAuthenticateTests extends JwtRealmTestCase {
         // - jwtJwks1(PKC) + jwtJwks2(HMAC): If second JWT is HMAC, it always fails because HMAC reload not supported.
         // - jwtJwks1(HMAC) + jwtJwks2(HMAC): If second JWT is HMAC, it always fails because HMAC reload not supported.
         if (isPkcJwtJwks1 == false && isPkcJwtJwks2) {
-            this.doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwtJwks2, clientSecret, jwtAuthcRange);
+            this.doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwtJwks2, clientSecret, jwtAuthcCount);
         } else {
             this.verifyAuthenticateFailureHelper(jwtIssuerAndRealm, jwtJwks2, clientSecret);
         }
@@ -238,12 +238,12 @@ public class JwtRealmAuthenticateTests extends JwtRealmTestCase {
         // - jwtJwks2(HMAC): Fail (Triggers HMAC reload, but it is a no-op), jwtJwks1(PKC): Fail (Triggers PKC reload, gets new PKC JWKs)
         // - jwtJwks2(HMAC): Fail (Triggers HMAC reload, but it is a no-op), jwtJwks1(HMAC): Pass (HMAC reload was a no-op)
         if (isPkcJwtJwks2) {
-            this.doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwtJwks2, clientSecret, jwtAuthcRange);
+            this.doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwtJwks2, clientSecret, jwtAuthcCount);
         } else {
             this.verifyAuthenticateFailureHelper(jwtIssuerAndRealm, jwtJwks2, clientSecret);
         }
         if (isPkcJwtJwks1 == false || isPkcJwtJwks2 == false) {
-            this.doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwtJwks1, clientSecret, jwtAuthcRange);
+            this.doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwtJwks1, clientSecret, jwtAuthcCount);
         } else {
             this.verifyAuthenticateFailureHelper(jwtIssuerAndRealm, jwtJwks1, clientSecret);
         }
@@ -256,13 +256,13 @@ public class JwtRealmAuthenticateTests extends JwtRealmTestCase {
     public void testJwtAuthcRealmAuthcAuthzWithAuthzRealms() throws Exception {
         this.jwtIssuerAndRealms = this.generateJwtIssuerRealmPairs(
             this.createJwtRealmsSettingsBuilder(),
-            new MinMax(1, 3), // realmsRange
-            new MinMax(1, 3), // authzRange
-            new MinMax(1, JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS.size()), // algsRange
-            new MinMax(1, 3), // audiencesRange
-            new MinMax(1, 3), // usersRange
-            new MinMax(0, 3), // rolesRange
-            new MinMax(0, 1), // jwtCacheSizeRange
+            randomIntBetween(1, 3), // realmsRange
+            randomIntBetween(1, 3), // authzRange
+            randomIntBetween(1, JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS.size()), // algsRange
+            randomIntBetween(1, 3), // audiencesRange
+            randomIntBetween(1, 3), // usersRange
+            randomIntBetween(0, 3), // rolesRange
+            randomIntBetween(0, 1), // jwtCacheSizeRange
             randomBoolean() // createHttpsServer
         );
         final JwtIssuerAndRealm jwtIssuerAndRealm = this.randomJwtIssuerRealmPair();
@@ -271,8 +271,8 @@ public class JwtRealmAuthenticateTests extends JwtRealmTestCase {
         final User user = this.randomUser(jwtIssuerAndRealm.issuer());
         final SecureString jwt = this.randomJwt(jwtIssuerAndRealm, user);
         final SecureString clientSecret = JwtRealmInspector.getClientAuthenticationSharedSecret(jwtIssuerAndRealm.realm());
-        final MinMax jwtAuthcRange = new MinMax(2, 3);
-        this.doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwt, clientSecret, jwtAuthcRange);
+        final int jwtAuthcCount = randomIntBetween(2, 3);
+        this.doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwt, clientSecret, jwtAuthcCount);
 
         // After the above success path test, do a negative path test for an authc user that does not exist in any authz realm.
         // In other words, above the `user` was found in an authz realm, but below `otherUser` will not be found in any authz realm.
@@ -336,23 +336,23 @@ public class JwtRealmAuthenticateTests extends JwtRealmTestCase {
     public void testJwtValidationFailures() throws Exception {
         this.jwtIssuerAndRealms = this.generateJwtIssuerRealmPairs(
             this.createJwtRealmsSettingsBuilder(),
-            new MinMax(1, 1), // realmsRange
-            new MinMax(0, 0), // authzRange
-            new MinMax(1, JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS.size()), // algsRange
-            new MinMax(1, 1), // audiencesRange
-            new MinMax(1, 1), // usersRange
-            new MinMax(1, 1), // rolesRange
-            new MinMax(0, 1), // jwtCacheSizeRange
+            randomIntBetween(1, 1), // realmsRange
+            randomIntBetween(0, 0), // authzRange
+            randomIntBetween(1, JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS.size()), // algsRange
+            randomIntBetween(1, 1), // audiencesRange
+            randomIntBetween(1, 1), // usersRange
+            randomIntBetween(1, 1), // rolesRange
+            randomIntBetween(0, 1), // jwtCacheSizeRange
             randomBoolean() // createHttpsServer
         );
         final JwtIssuerAndRealm jwtIssuerAndRealm = this.randomJwtIssuerRealmPair();
         final User user = this.randomUser(jwtIssuerAndRealm.issuer());
         final SecureString jwt = this.randomJwt(jwtIssuerAndRealm, user);
         final SecureString clientSecret = JwtRealmInspector.getClientAuthenticationSharedSecret(jwtIssuerAndRealm.realm());
-        final MinMax jwtAuthcRange = new MinMax(2, 3);
+        final int jwtAuthcCount = randomIntBetween(2, 3);
 
         // Indirectly verify authentication works before performing any failure scenarios
-        this.doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwt, clientSecret, jwtAuthcRange);
+        this.doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwt, clientSecret, jwtAuthcCount);
 
         // The above confirmed JWT realm authc/authz is working.
         // Now perform negative path tests to confirm JWT validation rejects invalid JWTs for different scenarios.
@@ -544,7 +544,7 @@ public class JwtRealmAuthenticateTests extends JwtRealmTestCase {
         final User user = this.randomUser(jwtIssuerAndRealm.issuer());
         final SecureString jwt = this.randomJwt(jwtIssuerAndRealm, user);
         final SecureString clientSecret = JwtRealmInspector.getClientAuthenticationSharedSecret(jwtIssuerAndRealm.realm());
-        final MinMax jwtAuthcRange = new MinMax(2, 3);
-        this.doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwt, clientSecret, jwtAuthcRange);
+        final int jwtAuthcCount = randomIntBetween(2, 3);
+        this.doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwt, clientSecret, jwtAuthcCount);
     }
 }

+ 5 - 5
x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmGenerateTests.java

@@ -50,7 +50,7 @@ import static org.hamcrest.Matchers.is;
 public class JwtRealmGenerateTests extends JwtRealmTestCase {
     private static final Logger LOGGER = LogManager.getLogger(JwtRealmGenerateTests.class);
 
-    private static final MinMax JWT_AUTHC_RANGE_1 = new MinMax(1, 1);
+    private static final int JWT_AUTHC_REPEATS_1 = 1;
     private static final Date DATE_2000_1_1 = Date.from(ZonedDateTime.of(2000, 1, 1, 0, 0, 0, 0, ZoneOffset.UTC).toInstant());
     private static final Date DATE_2099_1_1 = Date.from(ZonedDateTime.of(2099, 1, 1, 0, 0, 0, 0, ZoneOffset.UTC).toInstant());
 
@@ -147,7 +147,7 @@ public class JwtRealmGenerateTests extends JwtRealmTestCase {
         assertThat(JwtValidateUtil.verifyJwt(algJwkPairHmac.jwk(), SignedJWT.parse(jwt.toString())), is(equalTo(true)));
 
         // Verify authc+authz, then print all artifacts
-        super.doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwt, clientSecret, JWT_AUTHC_RANGE_1);
+        super.doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwt, clientSecret, JWT_AUTHC_REPEATS_1);
         this.printArtifacts(jwtIssuer, config, clientSecret, jwt);
     }
 
@@ -239,7 +239,7 @@ public class JwtRealmGenerateTests extends JwtRealmTestCase {
         assertThat(JwtValidateUtil.verifyJwt(algJwkPairPkc.jwk(), SignedJWT.parse(jwt.toString())), is(equalTo(true)));
 
         // Verify authc+authz, then print all artifacts
-        super.doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwt, null, JWT_AUTHC_RANGE_1);
+        super.doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwt, null, JWT_AUTHC_REPEATS_1);
         this.printArtifacts(jwtIssuer, config, null, jwt);
     }
 
@@ -345,7 +345,7 @@ public class JwtRealmGenerateTests extends JwtRealmTestCase {
         assertThat(JwtValidateUtil.verifyJwt(algJwkPairHmac.jwk(), SignedJWT.parse(jwt.toString())), is(equalTo(true)));
 
         // Verify authc+authz, then print all artifacts
-        super.doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwt, clientSecret, JWT_AUTHC_RANGE_1);
+        super.doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwt, clientSecret, JWT_AUTHC_REPEATS_1);
         this.printArtifacts(jwtIssuer, config, clientSecret, jwt);
     }
 
@@ -442,7 +442,7 @@ public class JwtRealmGenerateTests extends JwtRealmTestCase {
         assertThat(JwtValidateUtil.verifyJwt(selectedHmac.jwk(), SignedJWT.parse(jwt.toString())), is(equalTo(true)));
 
         // Verify authc+authz, then print all artifacts
-        super.doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwt, clientSecret, JWT_AUTHC_RANGE_1);
+        super.doMultipleAuthcAuthzAndVerifySuccess(jwtIssuerAndRealm.realm(), user, jwt, clientSecret, JWT_AUTHC_REPEATS_1);
         this.printArtifacts(jwtIssuer, config, clientSecret, jwt);
     }
 

+ 4 - 0
x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmInspector.java

@@ -20,6 +20,10 @@ class JwtRealmInspector {
 
     private JwtRealmInspector() {}
 
+    public static JwtRealmSettings.TokenType getTokenType(JwtRealm realm) {
+        return realm.getConfig().getSetting(JwtRealmSettings.TOKEN_TYPE);
+    }
+
     public static String getJwkSetPath(JwtRealm realm) {
         return realm.getConfig().getSetting(JwtRealmSettings.PKC_JWKSET_PATH);
     }

+ 107 - 0
x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmSettingsTests.java

@@ -10,6 +10,7 @@ import org.elasticsearch.common.settings.MockSecureSettings;
 import org.elasticsearch.common.settings.SecureString;
 import org.elasticsearch.common.settings.Setting;
 import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.core.Strings;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.xpack.core.security.authc.RealmConfig;
 import org.elasticsearch.xpack.core.security.authc.RealmSettings;
@@ -428,4 +429,110 @@ public class JwtRealmSettingsTests extends JwtTestCase {
         );
         assertThat(e.getMessage(), containsString("Invalid value"));
     }
+
+    public void testFallbackClaimSettingsNotAllowedForIdTokenType() {
+        final String realmName = randomAlphaOfLengthBetween(3, 8);
+        final Settings.Builder settingsBuilder = Settings.builder();
+        if (randomBoolean()) {
+            settingsBuilder.put(
+                RealmSettings.getFullSettingKey(realmName, JwtRealmSettings.TOKEN_TYPE),
+                JwtRealmSettings.TokenType.ID_TOKEN.value()
+            );
+        }
+        settingsBuilder.put(RealmSettings.getFullSettingKey(realmName, JwtRealmSettings.FALLBACK_SUB_CLAIM), randomAlphaOfLength(8))
+            .put(RealmSettings.getFullSettingKey(realmName, JwtRealmSettings.FALLBACK_AUD_CLAIM), randomAlphaOfLength(8));
+
+        final RealmConfig realmConfig = buildRealmConfig(JwtRealmSettings.TYPE, realmName, settingsBuilder.build(), randomInt());
+
+        final IllegalArgumentException e1 = expectThrows(
+            IllegalArgumentException.class,
+            () -> realmConfig.getSetting(JwtRealmSettings.FALLBACK_SUB_CLAIM)
+        );
+        assertThat(
+            e1.getMessage(),
+            containsString(
+                "fallback claim setting ["
+                    + RealmSettings.getFullSettingKey(realmName, JwtRealmSettings.FALLBACK_SUB_CLAIM)
+                    + "] is not allowed when JWT realm ["
+                    + realmName
+                    + "] is [id_token] type"
+            )
+        );
+
+        final IllegalArgumentException e2 = expectThrows(
+            IllegalArgumentException.class,
+            () -> realmConfig.getSetting(JwtRealmSettings.FALLBACK_AUD_CLAIM)
+        );
+        assertThat(
+            e2.getMessage(),
+            containsString(
+                "fallback claim setting ["
+                    + RealmSettings.getFullSettingKey(realmName, JwtRealmSettings.FALLBACK_AUD_CLAIM)
+                    + "] is not allowed when JWT realm ["
+                    + realmName
+                    + "] is [id_token] type"
+            )
+        );
+    }
+
+    public void testFallbackSettingsForAccessTokenType() {
+        final String realmName = randomAlphaOfLengthBetween(3, 8);
+        final String fallbackSub = randomAlphaOfLength(8);
+        final String fallbackAud = randomAlphaOfLength(8);
+        final Settings settings = Settings.builder()
+            .put(RealmSettings.getFullSettingKey(realmName, JwtRealmSettings.TOKEN_TYPE), JwtRealmSettings.TokenType.ACCESS_TOKEN.value())
+            .put(RealmSettings.getFullSettingKey(realmName, JwtRealmSettings.FALLBACK_SUB_CLAIM), fallbackSub)
+            .put(RealmSettings.getFullSettingKey(realmName, JwtRealmSettings.FALLBACK_AUD_CLAIM), fallbackAud)
+            .build();
+
+        final RealmConfig realmConfig = buildRealmConfig(JwtRealmSettings.TYPE, realmName, settings, randomInt());
+        assertThat(realmConfig.getSetting(JwtRealmSettings.FALLBACK_SUB_CLAIM), equalTo(fallbackSub));
+        assertThat(realmConfig.getSetting(JwtRealmSettings.FALLBACK_AUD_CLAIM), equalTo(fallbackAud));
+    }
+
+    public void testRegisteredClaimsCannotBeUsedForFallbackSettings() {
+        final String realmName = randomAlphaOfLengthBetween(3, 8);
+        final String fallbackSub = randomValueOtherThan("sub", () -> randomFrom(JwtRealmSettings.REGISTERED_CLAIM_NAMES));
+        final String fallbackAud = randomValueOtherThan("aud", () -> randomFrom(JwtRealmSettings.REGISTERED_CLAIM_NAMES));
+        final Settings settings = Settings.builder()
+            .put(RealmSettings.getFullSettingKey(realmName, JwtRealmSettings.TOKEN_TYPE), JwtRealmSettings.TokenType.ACCESS_TOKEN.value())
+            .put(RealmSettings.getFullSettingKey(realmName, JwtRealmSettings.FALLBACK_SUB_CLAIM), fallbackSub)
+            .put(RealmSettings.getFullSettingKey(realmName, JwtRealmSettings.FALLBACK_AUD_CLAIM), fallbackAud)
+            .build();
+
+        final RealmConfig realmConfig = buildRealmConfig(JwtRealmSettings.TYPE, realmName, settings, randomInt());
+
+        final IllegalArgumentException e1 = expectThrows(
+            IllegalArgumentException.class,
+            () -> realmConfig.getSetting(JwtRealmSettings.FALLBACK_SUB_CLAIM)
+        );
+        assertThat(
+            e1.getMessage(),
+            containsString(
+                Strings.format(
+                    "Invalid fallback claims setting [%s]. Claim [%s] cannot fallback to a registered claim [%s]",
+                    RealmSettings.getFullSettingKey(realmName, JwtRealmSettings.FALLBACK_SUB_CLAIM),
+                    "sub",
+                    fallbackSub
+                )
+            )
+        );
+
+        final IllegalArgumentException e2 = expectThrows(
+            IllegalArgumentException.class,
+            () -> realmConfig.getSetting(JwtRealmSettings.FALLBACK_AUD_CLAIM)
+        );
+        assertThat(
+            e2.getMessage(),
+            containsString(
+                Strings.format(
+                    "Invalid fallback claims setting [%s]. Claim [%s] cannot fallback to a registered claim [%s]",
+                    RealmSettings.getFullSettingKey(realmName, JwtRealmSettings.FALLBACK_AUD_CLAIM),
+                    "aud",
+                    fallbackAud
+                )
+            )
+        );
+    }
+
 }

+ 14 - 35
x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmTestCase.java

@@ -59,7 +59,6 @@ import java.util.stream.IntStream;
 import static org.hamcrest.Matchers.anEmptyMap;
 import static org.hamcrest.Matchers.empty;
 import static org.hamcrest.Matchers.equalTo;
-import static org.hamcrest.Matchers.greaterThanOrEqualTo;
 import static org.hamcrest.Matchers.hasEntry;
 import static org.hamcrest.Matchers.hasKey;
 import static org.hamcrest.Matchers.is;
@@ -80,12 +79,6 @@ public abstract class JwtRealmTestCase extends JwtTestCase {
 
     record JwtIssuerAndRealm(JwtIssuer issuer, JwtRealm realm, JwtRealmSettingsBuilder realmSettingsBuilder) {}
 
-    record MinMax(int min, int max) {
-        MinMax {
-            assert min >= 0 && max >= min : "Invalid min=" + min + " max=" + max;
-        }
-    }
-
     protected ThreadPool threadPool;
     protected ResourceWatcherService resourceWatcherService;
     protected MockLicenseState licenseState;
@@ -130,35 +123,20 @@ public abstract class JwtRealmTestCase extends JwtTestCase {
 
     protected List<JwtIssuerAndRealm> generateJwtIssuerRealmPairs(
         final JwtRealmsServiceSettingsBuilder jwtRealmsServiceSettingsBuilder,
-        final MinMax realmsRange,
-        final MinMax authzRange,
-        final MinMax algsRange,
-        final MinMax audiencesRange,
-        final MinMax usersRange,
-        final MinMax rolesRange,
-        final MinMax jwtCacheSizeRange,
+        final int realmsCount,
+        final int authzCount,
+        final int algsCount,
+        final int audiencesCount,
+        final int usersCount,
+        final int rolesCount,
+        final int jwtCacheSize,
         final boolean createHttpsServer
     ) throws Exception {
-        assertThat(realmsRange.min(), is(greaterThanOrEqualTo(1)));
-        assertThat(authzRange.min(), is(greaterThanOrEqualTo(0)));
-        assertThat(algsRange.min(), is(greaterThanOrEqualTo(1)));
-        assertThat(audiencesRange.min(), is(greaterThanOrEqualTo(1)));
-        assertThat(usersRange.min(), is(greaterThanOrEqualTo(1)));
-        assertThat(rolesRange.min(), is(greaterThanOrEqualTo(0)));
-        assertThat(jwtCacheSizeRange.min(), is(greaterThanOrEqualTo(0)));
-
         // Create JWT authc realms and mocked authz realms. Initialize each JWT realm, and test ensureInitialized() before and after.
         final JwtRealmsService jwtRealmsService = this.generateJwtRealmsService(jwtRealmsServiceSettingsBuilder);
-        final int realmsCount = randomIntBetween(realmsRange.min(), realmsRange.max());
         final List<Realm> allRealms = new ArrayList<>(); // authc and authz realms
         this.jwtIssuerAndRealms = new ArrayList<>(realmsCount);
         for (int i = 0; i < realmsCount; i++) {
-            final int authzCount = randomIntBetween(authzRange.min(), authzRange.max());
-            final int algsCount = randomIntBetween(algsRange.min(), algsRange.max());
-            final int audiencesCount = randomIntBetween(audiencesRange.min(), audiencesRange.max());
-            final int usersCount = randomIntBetween(usersRange.min(), usersRange.max());
-            final int rolesCount = randomIntBetween(rolesRange.min(), rolesRange.max());
-            final int jwtCacheSize = randomIntBetween(jwtCacheSizeRange.min(), jwtCacheSizeRange.max());
 
             final JwtIssuer jwtIssuer = this.createJwtIssuer(
                 i,
@@ -444,15 +422,13 @@ public abstract class JwtRealmTestCase extends JwtTestCase {
         final User user,
         final SecureString jwt,
         final SecureString sharedSecret,
-        final MinMax jwtAuthcRange
+        final int jwtAuthcRepeats
     ) throws Exception {
-        assertThat(jwtAuthcRange.min(), is(greaterThanOrEqualTo(1)));
 
         // Select one JWT authc Issuer/Realm pair. Select one test user, to use inside the authc test loop.
         final List<JwtRealm> jwtRealmsList = this.jwtIssuerAndRealms.stream().map(p -> p.realm).toList();
 
         // Select different test JWKs from the JWT realm, and generate test JWTs for the test user. Run the JWT through the chain.
-        final int jwtAuthcRepeats = randomIntBetween(jwtAuthcRange.min(), jwtAuthcRange.max());
         for (int authcRun = 1; authcRun <= jwtAuthcRepeats; authcRun++) {
             // Create request with headers set
             final ThreadContext requestThreadContext = super.createThreadContext(jwt, sharedSecret);
@@ -571,10 +547,13 @@ public abstract class JwtRealmTestCase extends JwtTestCase {
                 if (jwtRealm.delegatedAuthorizationSupport.hasDelegation()) {
                     assertThat(user.metadata(), is(equalTo(authenticatedUser.metadata()))); // delegated authz returns user's metadata
                 } else if (JwtRealmInspector.shouldPopulateUserMetadata(jwtRealm)) {
-                    assertThat(authenticatedUser.metadata(), hasEntry("jwt_token_type", "id_token"));
+                    assertThat(authenticatedUser.metadata(), hasEntry("jwt_token_type", JwtRealmInspector.getTokenType(jwtRealm).value()));
                     assertThat(authenticatedUser.metadata(), hasKey(startsWith("jwt_claim_")));
                 } else {
-                    assertThat(authenticatedUser.metadata(), equalTo(Map.of("jwt_token_type", "id_token")));
+                    assertThat(
+                        authenticatedUser.metadata(),
+                        equalTo(Map.of("jwt_token_type", JwtRealmInspector.getTokenType(jwtRealm).value()))
+                    );
                 }
             } catch (Throwable t) {
                 realmFailureExceptions.forEach(t::addSuppressed); // all previous realm exceptions
@@ -625,7 +604,7 @@ public abstract class JwtRealmTestCase extends JwtTestCase {
             randomAlphaOfLengthBetween(10, 20), // jwtID
             JwtRealmInspector.getAllowedIssuer(jwtIssuerAndRealm.realm), // iss
             JwtRealmInspector.getAllowedAudiences(jwtIssuerAndRealm.realm), // aud
-            randomBoolean() ? null : randomBoolean() ? user.principal() : user.principal() + "_" + randomInt(9), // sub claim value
+            randomBoolean() ? user.principal() : user.principal() + "_" + randomInt(9), // sub claim value
             JwtRealmInspector.getPrincipalClaimName(jwtIssuerAndRealm.realm), // principal claim name
             user.principal(), // principal claim value
             JwtRealmInspector.getGroupsClaimName(jwtIssuerAndRealm.realm), // group claim name

+ 105 - 30
x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtStringClaimValidatorTests.java

@@ -10,7 +10,6 @@ package org.elasticsearch.xpack.security.authc.jwt;
 import com.nimbusds.jose.JWSHeader;
 import com.nimbusds.jwt.JWTClaimsSet;
 
-import org.elasticsearch.ElasticsearchSecurityException;
 import org.elasticsearch.test.ESTestCase;
 
 import java.text.ParseException;
@@ -23,12 +22,22 @@ import static org.hamcrest.core.IsInstanceOf.instanceOf;
 public class JwtStringClaimValidatorTests extends ESTestCase {
 
     public void testClaimIsNotString() throws ParseException {
-        final String claimName = randomAlphaOfLengthBetween(10, 18);
-        final JwtStringClaimValidator validator = new JwtStringClaimValidator(claimName, List.of(), randomBoolean());
+        final String claimName = randomAlphaOfLength(10);
+        final String fallbackClaimName = randomAlphaOfLength(12);
+
+        final JwtStringClaimValidator validator;
+        final JWTClaimsSet jwtClaimsSet;
+        if (randomBoolean()) {
+            validator = new JwtStringClaimValidator(claimName, List.of(), randomBoolean());
+            // fallback claim is ignored
+            jwtClaimsSet = JWTClaimsSet.parse(Map.of(claimName, List.of(42), fallbackClaimName, randomAlphaOfLength(8)));
+        } else {
+            validator = new JwtStringClaimValidator(claimName, Map.of(claimName, fallbackClaimName), List.of(), randomBoolean());
+            jwtClaimsSet = JWTClaimsSet.parse(Map.of(fallbackClaimName, List.of(42)));
+        }
 
-        final JWTClaimsSet jwtClaimsSet = JWTClaimsSet.parse(Map.of(claimName, List.of(42)));
-        final ElasticsearchSecurityException e = expectThrows(
-            ElasticsearchSecurityException.class,
+        final IllegalArgumentException e = expectThrows(
+            IllegalArgumentException.class,
             () -> validator.validate(getJwsHeader(), jwtClaimsSet)
         );
         assertThat(e.getMessage(), containsString("cannot parse string claim"));
@@ -37,11 +46,21 @@ public class JwtStringClaimValidatorTests extends ESTestCase {
 
     public void testClaimIsNotSingleValued() throws ParseException {
         final String claimName = randomAlphaOfLengthBetween(10, 18);
-        final JwtStringClaimValidator validator = new JwtStringClaimValidator(claimName, List.of(), true);
+        final String fallbackClaimName = randomAlphaOfLength(12);
+
+        final JwtStringClaimValidator validator;
+        final JWTClaimsSet jwtClaimsSet;
+        if (randomBoolean()) {
+            validator = new JwtStringClaimValidator(claimName, List.of(), true);
+            // fallback claim is ignored
+            jwtClaimsSet = JWTClaimsSet.parse(Map.of(claimName, List.of("foo", "bar"), fallbackClaimName, randomAlphaOfLength(8)));
+        } else {
+            validator = new JwtStringClaimValidator(claimName, Map.of(claimName, fallbackClaimName), List.of(), true);
+            jwtClaimsSet = JWTClaimsSet.parse(Map.of(fallbackClaimName, List.of("foo", "bar")));
+        }
 
-        final JWTClaimsSet jwtClaimsSet = JWTClaimsSet.parse(Map.of(claimName, List.of("foo", "bar")));
-        final ElasticsearchSecurityException e = expectThrows(
-            ElasticsearchSecurityException.class,
+        final IllegalArgumentException e = expectThrows(
+            IllegalArgumentException.class,
             () -> validator.validate(getJwsHeader(), jwtClaimsSet)
         );
         assertThat(e.getMessage(), containsString("cannot parse string claim"));
@@ -50,11 +69,19 @@ public class JwtStringClaimValidatorTests extends ESTestCase {
 
     public void testClaimDoesNotExist() throws ParseException {
         final String claimName = randomAlphaOfLengthBetween(10, 18);
-        final JwtStringClaimValidator validator = new JwtStringClaimValidator(claimName, List.of(), randomBoolean());
+        final String fallbackClaimName = randomAlphaOfLength(12);
+
+        final JwtStringClaimValidator validator;
+        final JWTClaimsSet jwtClaimsSet;
+        if (randomBoolean()) {
+            validator = new JwtStringClaimValidator(claimName, List.of(), randomBoolean());
+        } else {
+            validator = new JwtStringClaimValidator(claimName, Map.of(claimName, fallbackClaimName), List.of(), randomBoolean());
+        }
+        jwtClaimsSet = JWTClaimsSet.parse(Map.of());
 
-        final JWTClaimsSet jwtClaimsSet = JWTClaimsSet.parse(Map.of());
-        final ElasticsearchSecurityException e = expectThrows(
-            ElasticsearchSecurityException.class,
+        final IllegalArgumentException e = expectThrows(
+            IllegalArgumentException.class,
             () -> validator.validate(getJwsHeader(), jwtClaimsSet)
         );
         assertThat(e.getMessage(), containsString("missing required string claim"));
@@ -62,26 +89,40 @@ public class JwtStringClaimValidatorTests extends ESTestCase {
 
     public void testMatchingClaimValues() throws ParseException {
         final String claimName = randomAlphaOfLengthBetween(10, 18);
+        final String fallbackClaimName = randomAlphaOfLength(12);
         final String claimValue = randomAlphaOfLength(10);
         final boolean singleValuedClaim = randomBoolean();
-        final JwtStringClaimValidator validator = new JwtStringClaimValidator(
-            claimName,
-            List.of(claimValue, randomAlphaOfLengthBetween(11, 20)),
-            singleValuedClaim
-        );
+        final List<String> allowedClaimValues = List.of(claimValue, randomAlphaOfLengthBetween(11, 20));
+        final Object incomingClaimValue = singleValuedClaim ? claimValue : randomFrom(claimValue, List.of(claimValue, "other-stuff"));
+
+        final JwtStringClaimValidator validator;
+        final JWTClaimsSet validJwtClaimsSet;
+        final boolean noFallback = randomBoolean();
+        if (noFallback) {
+            validator = new JwtStringClaimValidator(claimName, allowedClaimValues, singleValuedClaim);
+            // fallback claim is ignored
+            validJwtClaimsSet = JWTClaimsSet.parse(Map.of(claimName, incomingClaimValue, fallbackClaimName, List.of(42)));
+        } else {
+            validator = new JwtStringClaimValidator(claimName, Map.of(claimName, fallbackClaimName), allowedClaimValues, randomBoolean());
+            validJwtClaimsSet = JWTClaimsSet.parse(Map.of(fallbackClaimName, incomingClaimValue));
+        }
 
-        final JWTClaimsSet validJwtClaimsSet = JWTClaimsSet.parse(
-            Map.of(claimName, singleValuedClaim ? claimValue : randomFrom(claimValue, List.of(claimValue, "other-stuff")))
-        );
         try {
             validator.validate(getJwsHeader(), validJwtClaimsSet);
         } catch (Exception e) {
             throw new AssertionError("validation should have passed without exception", e);
         }
 
-        final JWTClaimsSet invalidJwtClaimsSet = JWTClaimsSet.parse(Map.of(claimName, "not-" + claimValue));
-        final ElasticsearchSecurityException e = expectThrows(
-            ElasticsearchSecurityException.class,
+        final JWTClaimsSet invalidJwtClaimsSet;
+        if (noFallback) {
+            // fallback is ignored (even when it has a valid value) since the main claim exists
+            invalidJwtClaimsSet = JWTClaimsSet.parse(Map.of(claimName, "not-" + claimValue, fallbackClaimName, claimValue));
+        } else {
+            invalidJwtClaimsSet = JWTClaimsSet.parse(Map.of(fallbackClaimName, "not-" + claimValue));
+        }
+
+        final IllegalArgumentException e = expectThrows(
+            IllegalArgumentException.class,
             () -> validator.validate(getJwsHeader(), invalidJwtClaimsSet)
         );
         assertThat(e.getMessage(), containsString("does not match allowed claim values"));
@@ -89,19 +130,36 @@ public class JwtStringClaimValidatorTests extends ESTestCase {
 
     public void testDoesNotSupportWildcardOrRegex() throws ParseException {
         final String claimName = randomAlphaOfLengthBetween(10, 18);
+        final String fallbackClaimName = randomAlphaOfLength(12);
         final String claimValue = randomFrom("*", "/.*/");
-        final JwtStringClaimValidator validator = new JwtStringClaimValidator(claimName, List.of(claimValue), randomBoolean());
+
+        final JwtStringClaimValidator validator;
+        final JWTClaimsSet invalidJwtClaimsSet;
+        final boolean noFallback = randomBoolean();
+        if (noFallback) {
+            validator = new JwtStringClaimValidator(claimName, List.of(claimValue), randomBoolean());
+            // fallback is ignored (even when it has a valid value) since the main claim exists
+            invalidJwtClaimsSet = JWTClaimsSet.parse(Map.of(claimName, randomAlphaOfLengthBetween(1, 10), fallbackClaimName, claimValue));
+        } else {
+            validator = new JwtStringClaimValidator(claimName, Map.of(claimName, fallbackClaimName), List.of(claimValue), randomBoolean());
+            invalidJwtClaimsSet = JWTClaimsSet.parse(Map.of(fallbackClaimName, randomAlphaOfLengthBetween(1, 10)));
+        }
 
         // It should not match arbitrary claim value because wildcard or regex is not supported
-        final JWTClaimsSet invalidJwtClaimsSet = JWTClaimsSet.parse(Map.of(claimName, randomAlphaOfLengthBetween(1, 10)));
-        final ElasticsearchSecurityException e = expectThrows(
-            ElasticsearchSecurityException.class,
+        final IllegalArgumentException e = expectThrows(
+            IllegalArgumentException.class,
             () -> validator.validate(getJwsHeader(), invalidJwtClaimsSet)
         );
         assertThat(e.getMessage(), containsString("does not match allowed claim values"));
 
         // It should support literal matching
-        final JWTClaimsSet validJwtClaimsSet = JWTClaimsSet.parse(Map.of(claimName, claimValue));
+        final JWTClaimsSet validJwtClaimsSet;
+        if (noFallback) {
+            // fallback claim is ignored
+            validJwtClaimsSet = JWTClaimsSet.parse(Map.of(claimName, claimValue, fallbackClaimName, randomAlphaOfLength(10)));
+        } else {
+            validJwtClaimsSet = JWTClaimsSet.parse(Map.of(fallbackClaimName, claimValue));
+        }
         try {
             validator.validate(getJwsHeader(), validJwtClaimsSet);
         } catch (Exception e2) {
@@ -109,6 +167,23 @@ public class JwtStringClaimValidatorTests extends ESTestCase {
         }
     }
 
+    public void testAllowAllSubjects() {
+        try {
+            JwtStringClaimValidator.ALLOW_ALL_SUBJECTS.validate(
+                getJwsHeader(),
+                JWTClaimsSet.parse(Map.of("sub", randomAlphaOfLengthBetween(1, 10)))
+            );
+        } catch (Exception e) {
+            throw new AssertionError("validation should have passed without exception", e);
+        }
+
+        final IllegalArgumentException e = expectThrows(
+            IllegalArgumentException.class,
+            () -> JwtStringClaimValidator.ALLOW_ALL_SUBJECTS.validate(getJwsHeader(), JWTClaimsSet.parse(Map.of()))
+        );
+        assertThat(e.getMessage(), containsString("missing required string claim"));
+    }
+
     private JWSHeader getJwsHeader() throws ParseException {
         return JWSHeader.parse(Map.of("alg", randomAlphaOfLengthBetween(3, 8)));
     }

+ 2 - 3
x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtTypeValidatorTests.java

@@ -10,7 +10,6 @@ package org.elasticsearch.xpack.security.authc.jwt;
 import com.nimbusds.jose.JWSHeader;
 import com.nimbusds.jwt.JWTClaimsSet;
 
-import org.elasticsearch.ElasticsearchSecurityException;
 import org.elasticsearch.test.ESTestCase;
 
 import java.text.ParseException;
@@ -41,8 +40,8 @@ public class JwtTypeValidatorTests extends ESTestCase {
             Map.of("typ", randomAlphaOfLengthBetween(4, 8), "alg", randomAlphaOfLengthBetween(3, 8))
         );
 
-        final ElasticsearchSecurityException e = expectThrows(
-            ElasticsearchSecurityException.class,
+        final IllegalArgumentException e = expectThrows(
+            IllegalArgumentException.class,
             () -> JwtTypeValidator.INSTANCE.validate(jwsHeader, JWTClaimsSet.parse(Map.of()))
         );
         assertThat(e.getMessage(), containsString("invalid jwt typ header"));

+ 2 - 2
x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/oidc/OpenIdConnectRealmTests.java

@@ -133,8 +133,8 @@ public class OpenIdConnectRealmTests extends OpenIdConnectTestCase {
             Exception.class,
             () -> authenticateWithOidc(principal, roleMapper, false, false, REALM_NAME, claimsWithNumber)
         );
-        assertThat(e.getCause().getMessage(), containsString("expects a claim with String or a String Array value"));
-        assertThat(e2.getCause().getMessage(), containsString("expects a claim with String or a String Array value"));
+        assertThat(e.getCause().getMessage(), containsString("expects claim [groups] with String or a String Array value"));
+        assertThat(e2.getCause().getMessage(), containsString("expects claim [groups] with String or a String Array value"));
     }
 
     public void testClaimMetadataMapping() throws Exception {