Преглед на файлове

JWT realm - add support for required claims (#92314)

This PR adds a new required_claims group setting that can be used to
specify additional mandatory claim checks for either ID tokens or access
tokens. A required claim must have either string or string list value.
Yang Wang преди 2 години
родител
ревизия
9d605b576f

+ 5 - 0
docs/changelog/92314.yaml

@@ -0,0 +1,5 @@
+pr: 92314
+summary: JWT realm - add support for required claims
+area: Authentication
+type: enhancement
+issues: []

+ 22 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/jwt/JwtRealmSettings.java

@@ -8,6 +8,7 @@ package org.elasticsearch.xpack.core.security.authc.jwt;
 
 import org.elasticsearch.common.settings.SecureString;
 import org.elasticsearch.common.settings.Setting;
+import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.core.Strings;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.xpack.core.security.authc.RealmSettings;
@@ -160,6 +161,7 @@ public class JwtRealmSettings {
                 ALLOWED_SUBJECTS,
                 FALLBACK_SUB_CLAIM,
                 FALLBACK_AUD_CLAIM,
+                REQUIRED_CLAIMS,
                 CLAIMS_PRINCIPAL.getClaim(),
                 CLAIMS_PRINCIPAL.getPattern(),
                 CLAIMS_GROUPS.getClaim(),
@@ -302,6 +304,26 @@ public class JwtRealmSettings {
         }, Setting.Property.NodeScope)
     );
 
+    public static final Setting.AffixSetting<Settings> REQUIRED_CLAIMS = Setting.affixKeySetting(
+        RealmSettings.realmSettingPrefix(TYPE),
+        "required_claims",
+        key -> Setting.groupSetting(key + ".", settings -> {
+            final List<String> invalidRequiredClaims = List.of("iss", "sub", "aud", "exp", "nbf", "iat");
+            for (String name : settings.names()) {
+                final String fullName = key + "." + name;
+                if (invalidRequiredClaims.contains(name)) {
+                    throw new IllegalArgumentException(
+                        Strings.format("required claim [%s] cannot be one of [%s]", fullName, String.join(",", invalidRequiredClaims))
+                    );
+                }
+                final List<String> values = settings.getAsList(name);
+                if (values.isEmpty()) {
+                    throw new IllegalArgumentException(Strings.format("required claim [%s] cannot be empty", fullName));
+                }
+            }
+        }, Setting.Property.NodeScope)
+    );
+
     // Note: ClaimSetting is a wrapper for two individual settings: getClaim(), getPattern()
     public static final ClaimSetting CLAIMS_PRINCIPAL = new ClaimSetting(TYPE, "principal");
     public static final ClaimSetting CLAIMS_GROUPS = new ClaimSetting(TYPE, "groups");

+ 4 - 1
x-pack/plugin/security/qa/jwt-realm/build.gradle

@@ -62,6 +62,8 @@ testClusters.matching { it.name == 'javaRestTest' }.configureEach {
   setting 'xpack.security.authc.realms.jwt.jwt1.claims.dn', 'dn'
   setting 'xpack.security.authc.realms.jwt.jwt1.claims.name', 'name'
   setting 'xpack.security.authc.realms.jwt.jwt1.claims.mail', 'mail'
+  setting 'xpack.security.authc.realms.jwt.jwt1.required_claims.token_use', 'id'
+  setting 'xpack.security.authc.realms.jwt.jwt1.required_claims.version', '2.0'
   setting 'xpack.security.authc.realms.jwt.jwt1.client_authentication.type', 'NONE'
   // Use default value (RS256) for signature algorithm
   setting 'xpack.security.authc.realms.jwt.jwt1.pkc_jwkset_path', 'rsa.jwkset'
@@ -84,12 +86,13 @@ testClusters.matching { it.name == 'javaRestTest' }.configureEach {
     setting 'xpack.security.authc.realms.jwt.jwt2.claims.principal', 'sub'
   }
   setting 'xpack.security.authc.realms.jwt.jwt2.claim_patterns.principal', '^(.*)@[^.]*[.]example[.]com$'
+  setting 'xpack.security.authc.realms.jwt.jwt2.required_claims.token_use', 'access'
   setting 'xpack.security.authc.realms.jwt.jwt2.authorization_realms', 'lookup_native'
   setting 'xpack.security.authc.realms.jwt.jwt2.client_authentication.type', 'shared_secret'
   keystore 'xpack.security.authc.realms.jwt.jwt2.client_authentication.shared_secret', 'test-secret'
   keystore 'xpack.security.authc.realms.jwt.jwt2.hmac_key', 'test-HMAC/secret passphrase-value'
 
-  // Place PKI realm after JWT realm to verify realm chain fall-throug
+  // Place PKI realm after JWT realm to verify realm chain fall-through
   setting 'xpack.security.authc.realms.pki.pki_realm.order', '4'
 
   setting 'xpack.security.authc.realms.jwt.jwt3.order', '5'

+ 28 - 2
x-pack/plugin/security/qa/jwt-realm/src/javaRestTest/java/org/elasticsearch/xpack/security/authc/jwt/JwtRestIT.java

@@ -338,6 +338,28 @@ public class JwtRestIT extends ESRestTestCase {
 
     }
 
+    public void testFailureOnRequiredClaims() throws JOSEException, IOException {
+        final String principal = System.getProperty("jwt2.service_subject");
+        final String username = getUsernameFromPrincipal(principal);
+        final List<String> roles = randomRoles();
+        createUser(username, roles, Map.of());
+        try {
+            final String audience = "es0" + randomIntBetween(1, 3);
+            final Map<String, Object> data = new HashMap<>(Map.of("iss", "my-issuer", "aud", audience, "email", principal));
+            // The required claim is either missing or mismatching
+            if (randomBoolean()) {
+                data.put("token_use", randomValueOtherThan("access", () -> randomAlphaOfLengthBetween(3, 10)));
+            }
+            final JWTClaimsSet claimsSet = buildJwt(data, Instant.now(), false);
+            final SignedJWT jwt = signHmacJwt(claimsSet, "test-HMAC/secret passphrase-value");
+            final TestSecurityClient client = getSecurityClient(jwt, VALID_SHARED_SECRET);
+            final ResponseException exception = expectThrows(ResponseException.class, client::authenticate);
+            assertThat(exception.getResponse(), hasStatusCode(RestStatus.UNAUTHORIZED));
+        } finally {
+            deleteUser(username);
+        }
+    }
+
     public void testAuthenticationFailureIfDelegatedAuthorizationFails() throws Exception {
         final String principal = System.getProperty("jwt2.service_subject");
         final String username = getUsernameFromPrincipal(principal);
@@ -486,7 +508,9 @@ public class JwtRestIT extends ESRestTestCase {
                 Map.entry("dn", dn),
                 Map.entry("name", name),
                 Map.entry("mail", mail),
-                Map.entry("roles", groups) // Realm realm config has `claim.groups: "roles"`
+                Map.entry("roles", groups), // Realm realm config has `claim.groups: "roles"`
+                Map.entry("token_use", "id"),
+                Map.entry("version", "2.0")
             ),
             issueTime
         );
@@ -505,7 +529,9 @@ public class JwtRestIT extends ESRestTestCase {
     private JWTClaimsSet buildJwtForRealm2(String principal, Instant issueTime) {
         // The "jwt2" realm, supports 3 audiences (es01/02/03)
         final String audience = "es0" + randomIntBetween(1, 3);
-        final Map<String, Object> data = new HashMap<>(Map.of("iss", "my-issuer", "aud", audience, "email", principal));
+        final Map<String, Object> data = new HashMap<>(
+            Map.of("iss", "my-issuer", "aud", audience, "email", principal, "token_use", "access")
+        );
         // scope (fallback audience) is ignored since aud exists
         if (randomBoolean()) {
             data.put("scope", randomAlphaOfLength(20));

+ 20 - 3
x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticator.java

@@ -15,6 +15,7 @@ import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.common.settings.SecureString;
+import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.core.Releasable;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.xpack.core.security.authc.RealmConfig;
@@ -23,6 +24,7 @@ import org.elasticsearch.xpack.core.ssl.SSLService;
 
 import java.text.ParseException;
 import java.time.Clock;
+import java.util.ArrayList;
 import java.util.List;
 import java.util.Map;
 
@@ -47,16 +49,19 @@ public class JwtAuthenticator implements Releasable {
     ) {
         this.realmConfig = realmConfig;
         this.tokenType = realmConfig.getSetting(JwtRealmSettings.TOKEN_TYPE);
+        final List<JwtFieldValidator> jwtFieldValidators = new ArrayList<>();
         if (tokenType == JwtRealmSettings.TokenType.ID_TOKEN) {
             this.fallbackClaimNames = Map.of();
-            this.jwtFieldValidators = configureFieldValidatorsForIdToken(realmConfig);
+            jwtFieldValidators.addAll(configureFieldValidatorsForIdToken(realmConfig));
         } else {
             this.fallbackClaimNames = Map.ofEntries(
                 Map.entry("sub", realmConfig.getSetting(JwtRealmSettings.FALLBACK_SUB_CLAIM)),
                 Map.entry("aud", realmConfig.getSetting(JwtRealmSettings.FALLBACK_AUD_CLAIM))
             );
-            this.jwtFieldValidators = configureFieldValidatorsForAccessToken(realmConfig, fallbackClaimNames);
+            jwtFieldValidators.addAll(configureFieldValidatorsForAccessToken(realmConfig, fallbackClaimNames));
         }
+        jwtFieldValidators.addAll(getRequireClaimsValidators());
+        this.jwtFieldValidators = List.copyOf(jwtFieldValidators);
         this.jwtSignatureValidator = new JwtSignatureValidator.DelegatingJwtSignatureValidator(realmConfig, sslService, reloadNotifier);
     }
 
@@ -104,12 +109,17 @@ public class JwtAuthenticator implements Releasable {
         }
 
         try {
-            jwtSignatureValidator.validate(tokenPrincipal, signedJWT, listener.map(ignored -> jwtClaimsSet));
+            validateSignature(tokenPrincipal, signedJWT, listener.map(ignored -> jwtClaimsSet));
         } catch (Exception e) {
             listener.onFailure(e);
         }
     }
 
+    // Package private for testing
+    void validateSignature(String tokenPrincipal, SignedJWT signedJWT, ActionListener<Void> listener) {
+        jwtSignatureValidator.validate(tokenPrincipal, signedJWT, listener);
+    }
+
     @Override
     public void close() {
         jwtSignatureValidator.close();
@@ -172,6 +182,13 @@ public class JwtAuthenticator implements Releasable {
             new JwtDateClaimValidator(clock, "iat", allowedClockSkew, JwtDateClaimValidator.Relationship.BEFORE_NOW, false),
             new JwtDateClaimValidator(clock, "exp", allowedClockSkew, JwtDateClaimValidator.Relationship.AFTER_NOW, false)
         );
+    }
 
+    private List<JwtStringClaimValidator> getRequireClaimsValidators() {
+        final Settings requiredClaims = realmConfig.getSetting(JwtRealmSettings.REQUIRED_CLAIMS);
+        return requiredClaims.names().stream().map(name -> {
+            final List<String> allowedValues = requiredClaims.getAsList(name);
+            return new JwtStringClaimValidator(name, allowedValues, false);
+        }).toList();
     }
 }

+ 3 - 17
x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticatorAccessTokenTypeTests.java

@@ -9,7 +9,6 @@ package org.elasticsearch.xpack.security.authc.jwt;
 
 import org.elasticsearch.xpack.core.security.authc.RealmSettings;
 import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings;
-import org.junit.Before;
 
 import java.text.ParseException;
 
@@ -17,23 +16,13 @@ import static org.hamcrest.Matchers.containsString;
 
 public class JwtAuthenticatorAccessTokenTypeTests extends JwtAuthenticatorTests {
 
-    private String fallbackSub;
-    private String fallbackAud;
-
-    @Before
-    public void beforeTest() {
-        doBeforeTest();
-        fallbackSub = randomBoolean() ? "_" + randomAlphaOfLength(5) : null;
-        fallbackAud = randomBoolean() ? "_" + randomAlphaOfLength(8) : null;
-    }
-
     @Override
     protected JwtRealmSettings.TokenType getTokenType() {
         return JwtRealmSettings.TokenType.ACCESS_TOKEN;
     }
 
     public void testSubjectIsRequired() throws ParseException {
-        final IllegalArgumentException e = doTestSubjectIsRequired(buildJwtAuthenticator(fallbackSub, fallbackAud));
+        final IllegalArgumentException e = doTestSubjectIsRequired(buildJwtAuthenticator());
         if (fallbackSub != null) {
             assertThat(e.getMessage(), containsString("missing required string claim [" + fallbackSub + " (fallback of sub)]"));
         }
@@ -41,10 +30,7 @@ public class JwtAuthenticatorAccessTokenTypeTests extends JwtAuthenticatorTests
 
     public void testAccessTokenTypeMandatesAllowedSubjects() {
         allowedSubject = null;
-        final IllegalArgumentException e = expectThrows(
-            IllegalArgumentException.class,
-            () -> buildJwtAuthenticator(fallbackSub, fallbackAud)
-        );
+        final IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> buildJwtAuthenticator());
 
         assertThat(
             e.getMessage(),
@@ -53,6 +39,6 @@ public class JwtAuthenticatorAccessTokenTypeTests extends JwtAuthenticatorTests
     }
 
     public void testInvalidIssuerIsCheckedBeforeAlgorithm() throws ParseException {
-        doTestInvalidIssuerIsCheckedBeforeAlgorithm(buildJwtAuthenticator(fallbackSub, fallbackAud));
+        doTestInvalidIssuerIsCheckedBeforeAlgorithm(buildJwtAuthenticator());
     }
 }

+ 2 - 13
x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticatorIdTokenTypeTests.java

@@ -8,7 +8,6 @@
 package org.elasticsearch.xpack.security.authc.jwt;
 
 import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings;
-import org.junit.Before;
 
 import java.text.ParseException;
 
@@ -16,27 +15,17 @@ import static org.hamcrest.Matchers.containsString;
 
 public class JwtAuthenticatorIdTokenTypeTests extends JwtAuthenticatorTests {
 
-    private String fallbackSub;
-    private String fallbackAud;
-
-    @Before
-    public void beforeTest() {
-        doBeforeTest();
-        fallbackSub = null;
-        fallbackAud = null;
-    }
-
     @Override
     protected JwtRealmSettings.TokenType getTokenType() {
         return JwtRealmSettings.TokenType.ID_TOKEN;
     }
 
     public void testSubjectIsRequired() throws ParseException {
-        final IllegalArgumentException e = doTestSubjectIsRequired(buildJwtAuthenticator(fallbackSub, fallbackAud));
+        final IllegalArgumentException e = doTestSubjectIsRequired(buildJwtAuthenticator());
         assertThat(e.getMessage(), containsString("missing required string claim [sub]"));
     }
 
     public void testInvalidIssuerIsCheckedBeforeAlgorithm() throws ParseException {
-        doTestInvalidIssuerIsCheckedBeforeAlgorithm(buildJwtAuthenticator(fallbackSub, fallbackAud));
+        doTestInvalidIssuerIsCheckedBeforeAlgorithm(buildJwtAuthenticator());
     }
 }

+ 165 - 4
x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticatorTests.java

@@ -12,23 +12,36 @@ import com.nimbusds.jose.util.Base64URL;
 import com.nimbusds.jwt.JWTClaimsSet;
 import com.nimbusds.jwt.SignedJWT;
 
+import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.common.settings.MockSecureSettings;
 import org.elasticsearch.common.settings.SecureString;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
 import org.elasticsearch.core.Nullable;
+import org.elasticsearch.core.Tuple;
 import org.elasticsearch.env.TestEnvironment;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xpack.core.security.authc.RealmConfig;
 import org.elasticsearch.xpack.core.security.authc.RealmSettings;
 import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings;
+import org.junit.Before;
 
 import java.text.ParseException;
+import java.time.Instant;
+import java.time.temporal.ChronoUnit;
+import java.util.List;
 import java.util.Map;
+import java.util.stream.Collectors;
 
+import static org.elasticsearch.test.ActionListenerUtils.anyActionListener;
 import static org.elasticsearch.test.TestMatchers.throwableWithMessage;
+import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.equalTo;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.when;
 
 public abstract class JwtAuthenticatorTests extends ESTestCase {
@@ -39,19 +52,149 @@ public abstract class JwtAuthenticatorTests extends ESTestCase {
     @Nullable
     protected String allowedSubject;
     protected String allowedAudience;
+    protected String fallbackSub;
+    protected String fallbackAud;
+    protected Tuple<String, List<String>> requiredClaim;
 
     protected abstract JwtRealmSettings.TokenType getTokenType();
 
-    protected void doBeforeTest() {
+    @Before
+    public void beforeTest() {
         realmName = randomAlphaOfLengthBetween(3, 8);
         allowedIssuer = randomAlphaOfLength(6);
         allowedAlgorithm = randomFrom(JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS_HMAC);
         if (getTokenType() == JwtRealmSettings.TokenType.ID_TOKEN) {
             allowedSubject = randomBoolean() ? randomAlphaOfLength(8) : null;
+            fallbackSub = null;
+            fallbackAud = null;
         } else {
             allowedSubject = randomAlphaOfLength(8);
+            fallbackSub = randomBoolean() ? "_" + randomAlphaOfLength(5) : null;
+            fallbackAud = randomBoolean() ? "_" + randomAlphaOfLength(8) : null;
         }
         allowedAudience = randomAlphaOfLength(10);
+        requiredClaim = Tuple.tuple(randomAlphaOfLength(8), randomList(1, 3, () -> randomAlphaOfLengthBetween(8, 18)));
+    }
+
+    public void testRequiredClaims() throws ParseException {
+        final Instant now = Instant.now();
+        final String requireClaimValue = randomFrom(requiredClaim.v2());
+        final JWTClaimsSet claimsSet = JWTClaimsSet.parse(
+            Map.of(
+                "iss",
+                allowedIssuer,
+                "sub",
+                allowedSubject == null ? randomAlphaOfLengthBetween(10, 18) : allowedSubject,
+                "aud",
+                allowedAudience,
+                requiredClaim.v1(),
+                randomBoolean() ? requireClaimValue : List.of(requireClaimValue, "some-other-value"),
+                "iat",
+                now.minus(1, ChronoUnit.DAYS).getEpochSecond(),
+                "exp",
+                now.plus(1, ChronoUnit.DAYS).getEpochSecond()
+            )
+        );
+        final SignedJWT signedJWT = new SignedJWT(
+            JWSHeader.parse(Map.of("alg", allowedAlgorithm)).toBase64URL(),
+            claimsSet.toPayload().toBase64URL(),
+            Base64URL.encode("signature")
+        );
+
+        final JwtAuthenticationToken jwtAuthenticationToken = mock(JwtAuthenticationToken.class);
+        when(jwtAuthenticationToken.getEndUserSignedJwt()).thenReturn(new SecureString(signedJWT.serialize().toCharArray()));
+
+        final PlainActionFuture<JWTClaimsSet> future = new PlainActionFuture<>();
+        final JwtAuthenticator jwtAuthenticator = buildJwtAuthenticator();
+        jwtAuthenticator.authenticate(jwtAuthenticationToken, future);
+        assertThat(future.actionGet(), equalTo(claimsSet));
+    }
+
+    public void testMismatchedRequiredClaims() throws ParseException {
+        final Instant now = Instant.now();
+        final String mismatchRequiredClaimValue = randomValueOtherThanMany(
+            requiredClaim.v2()::contains,
+            () -> randomAlphaOfLengthBetween(3, 18)
+        );
+        final JWTClaimsSet claimsSet = JWTClaimsSet.parse(
+            Map.of(
+                "iss",
+                allowedIssuer,
+                "sub",
+                allowedSubject == null ? randomAlphaOfLengthBetween(10, 18) : allowedSubject,
+                "aud",
+                allowedAudience,
+                requiredClaim.v1(),
+                mismatchRequiredClaimValue,
+                "iat",
+                now.minus(1, ChronoUnit.DAYS).getEpochSecond(),
+                "exp",
+                now.plus(1, ChronoUnit.DAYS).getEpochSecond()
+            )
+        );
+        final SignedJWT signedJWT = new SignedJWT(
+            JWSHeader.parse(Map.of("alg", allowedAlgorithm)).toBase64URL(),
+            claimsSet.toPayload().toBase64URL(),
+            Base64URL.encode("signature")
+        );
+
+        final JwtAuthenticationToken jwtAuthenticationToken = mock(JwtAuthenticationToken.class);
+        when(jwtAuthenticationToken.getEndUserSignedJwt()).thenReturn(new SecureString(signedJWT.serialize().toCharArray()));
+
+        final PlainActionFuture<JWTClaimsSet> future = new PlainActionFuture<>();
+        final JwtAuthenticator jwtAuthenticator = buildJwtAuthenticator();
+        jwtAuthenticator.authenticate(jwtAuthenticationToken, future);
+        final IllegalArgumentException e = expectThrows(IllegalArgumentException.class, future::actionGet);
+        assertThat(
+            e.getMessage(),
+            containsString(
+                "string claim ["
+                    + requiredClaim.v1()
+                    + "] has value ["
+                    + mismatchRequiredClaimValue
+                    + "] which does not match allowed claim values ["
+                    + requiredClaim.v2().stream().collect(Collectors.joining(","))
+                    + "]"
+            )
+        );
+    }
+
+    public void testMissingRequiredClaims() throws ParseException {
+        final Instant now = Instant.now();
+        final JWTClaimsSet claimsSet = JWTClaimsSet.parse(
+            Map.of(
+                "iss",
+                allowedIssuer,
+                "sub",
+                allowedSubject == null ? randomAlphaOfLengthBetween(10, 18) : allowedSubject,
+                "aud",
+                allowedAudience,
+                "iat",
+                now.minus(1, ChronoUnit.DAYS).getEpochSecond(),
+                "exp",
+                now.plus(1, ChronoUnit.DAYS).getEpochSecond()
+            )
+        );
+        final SignedJWT signedJWT = new SignedJWT(
+            JWSHeader.parse(Map.of("alg", allowedAlgorithm)).toBase64URL(),
+            claimsSet.toPayload().toBase64URL(),
+            Base64URL.encode("signature")
+        );
+
+        final JwtAuthenticationToken jwtAuthenticationToken = mock(JwtAuthenticationToken.class);
+        when(jwtAuthenticationToken.getEndUserSignedJwt()).thenReturn(new SecureString(signedJWT.serialize().toCharArray()));
+
+        // Required claim is mandatory when configured
+        final PlainActionFuture<JWTClaimsSet> future1 = new PlainActionFuture<>();
+        buildJwtAuthenticator().authenticate(jwtAuthenticationToken, future1);
+        final IllegalArgumentException e = expectThrows(IllegalArgumentException.class, future1::actionGet);
+        assertThat(e.getMessage(), containsString("missing required string claim [" + requiredClaim.v1() + "]"));
+
+        // Remove required claim from settings, the JWT now works
+        requiredClaim = null;
+        final PlainActionFuture<JWTClaimsSet> future2 = new PlainActionFuture<>();
+        buildJwtAuthenticator().authenticate(jwtAuthenticationToken, future2);
+        assertThat(future2.actionGet(), equalTo(claimsSet));
     }
 
     protected IllegalArgumentException doTestSubjectIsRequired(JwtAuthenticator jwtAuthenticator) throws ParseException {
@@ -92,14 +235,14 @@ public abstract class JwtAuthenticatorTests extends ESTestCase {
         );
     }
 
-    protected JwtAuthenticator buildJwtAuthenticator(String fallbackSub, String fallbackAud) {
+    protected JwtAuthenticator buildJwtAuthenticator() {
         final RealmConfig.RealmIdentifier realmIdentifier = new RealmConfig.RealmIdentifier(JwtRealmSettings.TYPE, realmName);
         final MockSecureSettings secureSettings = new MockSecureSettings();
         secureSettings.setString(RealmSettings.getFullSettingKey(realmName, JwtRealmSettings.HMAC_KEY), randomAlphaOfLength(40));
         final Settings.Builder builder = Settings.builder()
             .put(RealmSettings.getFullSettingKey(realmName, JwtRealmSettings.ALLOWED_SIGNATURE_ALGORITHMS), allowedAlgorithm)
             .put(RealmSettings.getFullSettingKey(realmName, JwtRealmSettings.ALLOWED_ISSUER), allowedIssuer)
-            .put(RealmSettings.getFullSettingKey(realmName, JwtRealmSettings.ALLOWED_AUDIENCES), randomAlphaOfLength(7))
+            .put(RealmSettings.getFullSettingKey(realmName, JwtRealmSettings.ALLOWED_AUDIENCES), allowedAudience)
             .put(RealmSettings.getFullSettingKey(realmIdentifier, RealmSettings.ORDER_SETTING), randomIntBetween(0, 99))
             .put("path.home", randomAlphaOfLength(10))
             .setSecureSettings(secureSettings);
@@ -123,6 +266,16 @@ public abstract class JwtAuthenticatorTests extends ESTestCase {
             builder.put(RealmSettings.getFullSettingKey(realmName, JwtRealmSettings.FALLBACK_AUD_CLAIM), fallbackAud);
         }
 
+        if (requiredClaim != null) {
+            final String requiredClaimsKey = RealmSettings.getFullSettingKey(realmName, JwtRealmSettings.REQUIRED_CLAIMS) + requiredClaim
+                .v1();
+            if (requiredClaim.v2().size() == 1 && randomBoolean()) {
+                builder.put(requiredClaimsKey, requiredClaim.v2().get(0));
+            } else {
+                builder.putList(requiredClaimsKey, requiredClaim.v2());
+            }
+        }
+
         final Settings settings = builder.build();
 
         final RealmConfig realmConfig = new RealmConfig(
@@ -132,6 +285,14 @@ public abstract class JwtAuthenticatorTests extends ESTestCase {
             new ThreadContext(settings)
         );
 
-        return new JwtAuthenticator(realmConfig, null, () -> {});
+        final JwtAuthenticator jwtAuthenticator = spy(new JwtAuthenticator(realmConfig, null, () -> {}));
+        // Short circuit signature validation to be always successful since this test class does not test it
+        doAnswer(invocation -> {
+            final ActionListener<Void> listener = invocation.getArgument(2);
+            listener.onResponse(null);
+            return null;
+        }).when(jwtAuthenticator).validateSignature(any(), any(), anyActionListener());
+
+        return jwtAuthenticator;
     }
 }

+ 53 - 0
x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmSettingsTests.java

@@ -23,7 +23,9 @@ import java.util.List;
 import java.util.Locale;
 
 import static org.elasticsearch.common.Strings.capitalize;
+import static org.hamcrest.Matchers.containsInAnyOrder;
 import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.emptyIterable;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.is;
 
@@ -535,4 +537,55 @@ public class JwtRealmSettingsTests extends JwtTestCase {
         );
     }
 
+    public void testRequiredClaims() {
+        final String realmName = randomAlphaOfLengthBetween(3, 8);
+
+        // Required claims are optional
+        final RealmConfig realmConfig1 = buildRealmConfig(JwtRealmSettings.TYPE, realmName, Settings.EMPTY, randomInt());
+        assertThat(realmConfig1.getSetting(JwtRealmSettings.REQUIRED_CLAIMS).names(), emptyIterable());
+
+        // Multiple required claims with different value types
+        final String prefix = RealmSettings.getFullSettingKey(realmName, JwtRealmSettings.REQUIRED_CLAIMS);
+        final Settings settings = Settings.builder()
+            .put(prefix + "extra_1", "foo")
+            .put(prefix + "extra_2", "hello,world")
+            .put(prefix + "extra_3", 42)
+            .build();
+        final RealmConfig realmConfig2 = buildRealmConfig(JwtRealmSettings.TYPE, realmName, settings, randomInt());
+        final Settings requireClaimsSettings = realmConfig2.getSetting(JwtRealmSettings.REQUIRED_CLAIMS);
+        assertThat(requireClaimsSettings.names(), containsInAnyOrder("extra_1", "extra_2", "extra_3"));
+        assertThat(requireClaimsSettings.getAsList("extra_1"), equalTo(List.of("foo")));
+        assertThat(requireClaimsSettings.getAsList("extra_2"), equalTo(List.of("hello", "world")));
+        assertThat(requireClaimsSettings.getAsList("extra_3"), equalTo(List.of("42")));
+    }
+
+    public void testInvalidRequiredClaims() {
+        final String realmName = randomAlphaOfLengthBetween(3, 8);
+        final String invalidRequiredClaim = randomFrom("iss", "sub", "aud", "exp", "nbf", "iat");
+        final String fullSettingKey = RealmSettings.getFullSettingKey(realmName, JwtRealmSettings.REQUIRED_CLAIMS) + invalidRequiredClaim;
+        final Settings settings = Settings.builder().put(fullSettingKey, randomAlphaOfLength(8)).build();
+
+        final RealmConfig realmConfig = buildRealmConfig(JwtRealmSettings.TYPE, realmName, settings, randomInt());
+        final IllegalArgumentException e = expectThrows(
+            IllegalArgumentException.class,
+            () -> realmConfig.getSetting(JwtRealmSettings.REQUIRED_CLAIMS)
+        );
+
+        assertThat(e.getMessage(), containsString("required claim [" + fullSettingKey + "] cannot be one of [iss,sub,aud,exp,nbf,iat]"));
+    }
+
+    public void testRequiredClaimsCannotBeEmpty() {
+        final String realmName = randomAlphaOfLengthBetween(3, 8);
+        final String invalidRequiredClaim = randomAlphaOfLengthBetween(4, 8);
+        final String fullSettingKey = RealmSettings.getFullSettingKey(realmName, JwtRealmSettings.REQUIRED_CLAIMS) + invalidRequiredClaim;
+        final Settings settings = Settings.builder().put(fullSettingKey, "").build();
+
+        final RealmConfig realmConfig = buildRealmConfig(JwtRealmSettings.TYPE, realmName, settings, randomInt());
+        final IllegalArgumentException e = expectThrows(
+            IllegalArgumentException.class,
+            () -> realmConfig.getSetting(JwtRealmSettings.REQUIRED_CLAIMS)
+        );
+
+        assertThat(e.getMessage(), containsString("required claim [" + fullSettingKey + "] cannot be empty"));
+    }
 }