1
0
Эх сурвалжийг харах

JWT realm - Simplify token principal calculation (#92315)

Token principal is an implementation detail because it is useful only
for realm ordering cache and logging. Hence it should *not* be exposed
as a user-configurable setting. This PR removes the setting and related
support classes. The token principal is now computed by the realms which
also has the advantage of working correctly with fallback claims.
Yang Wang 2 жил өмнө
parent
commit
1daf314a5d
18 өөрчлөгдсөн 440 нэмэгдсэн , 510 устгасан
  1. 5 0
      docs/changelog/92315.yaml
  2. 0 38
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/jwt/JwtRealmsServiceSettings.java
  3. 0 2
      x-pack/plugin/security/qa/jwt-realm/build.gradle
  4. 12 9
      x-pack/plugin/security/qa/jwt-realm/src/javaRestTest/java/org/elasticsearch/xpack/security/authc/jwt/JwtRestIT.java
  5. 224 0
      x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmSingleNodeTests.java
  6. 0 2
      x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java
  7. 2 3
      x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/InternalRealms.java
  8. 42 101
      x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticationToken.java
  9. 3 22
      x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticator.java
  10. 113 45
      x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealm.java
  11. 0 88
      x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmsService.java
  12. 0 97
      x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticationTokenTests.java
  13. 10 6
      x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticatorTests.java
  14. 4 3
      x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtIssuer.java
  15. 8 3
      x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmAuthenticateAccessTokenTypeTests.java
  16. 7 21
      x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmAuthenticateTests.java
  17. 5 26
      x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmGenerateTests.java
  18. 5 44
      x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmTestCase.java

+ 5 - 0
docs/changelog/92315.yaml

@@ -0,0 +1,5 @@
+pr: 92315
+summary: JWT realm - Simplify token principal calculation
+area: Authentication
+type: enhancement
+issues: []

+ 0 - 38
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/jwt/JwtRealmsServiceSettings.java

@@ -1,38 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License
- * 2.0; you may not use this file except in compliance with the Elastic License
- * 2.0.
- */
-package org.elasticsearch.xpack.core.security.authc.jwt;
-
-import org.elasticsearch.common.settings.Setting;
-
-import java.util.Collection;
-import java.util.List;
-import java.util.function.Function;
-
-/**
- * Settings used by JwtRealmsService for common handling of JWT realm instances.
- */
-public class JwtRealmsServiceSettings {
-
-    public static final List<String> DEFAULT_PRINCIPAL_CLAIMS = List.of("sub", "oid", "client_id", "appid", "azp", "email");
-
-    public static final Setting<List<String>> PRINCIPAL_CLAIMS_SETTING = Setting.listSetting(
-        "xpack.security.authc.jwt.principal_claims",
-        DEFAULT_PRINCIPAL_CLAIMS,
-        Function.identity(),
-        Setting.Property.NodeScope
-    );
-
-    /**
-     * Get all settings shared by all JWT Realms.
-     * @return All settings shared by all JWT Realms.
-     */
-    public static Collection<Setting<?>> getSettings() {
-        return List.of(PRINCIPAL_CLAIMS_SETTING);
-    }
-
-    private JwtRealmsServiceSettings() {}
-}

+ 0 - 2
x-pack/plugin/security/qa/jwt-realm/build.gradle

@@ -46,8 +46,6 @@ testClusters.matching { it.name == 'javaRestTest' }.configureEach {
   setting 'xpack.security.http.ssl.certificate_authorities', 'ca.crt'
   setting 'xpack.security.http.ssl.client_authentication', 'optional'
 
-  setting 'xpack.security.authc.jwt.principal_claims', 'sub,oid,client_id,azp,appid,email'
-
   setting 'xpack.security.authc.realms.file.admin_file.order', '0'
 
   // These realm settings are generated by JwtRealmGenerateTests

+ 12 - 9
x-pack/plugin/security/qa/jwt-realm/src/javaRestTest/java/org/elasticsearch/xpack/security/authc/jwt/JwtRestIT.java

@@ -350,7 +350,7 @@ public class JwtRestIT extends ESRestTestCase {
             if (randomBoolean()) {
                 data.put("token_use", randomValueOtherThan("access", () -> randomAlphaOfLengthBetween(3, 10)));
             }
-            final JWTClaimsSet claimsSet = buildJwt(data, Instant.now(), false);
+            final JWTClaimsSet claimsSet = buildJwt(data, Instant.now(), false, false);
             final SignedJWT jwt = signHmacJwt(claimsSet, "test-HMAC/secret passphrase-value");
             final TestSecurityClient client = getSecurityClient(jwt, VALID_SHARED_SECRET);
             final ResponseException exception = expectThrows(ResponseException.class, client::authenticate);
@@ -529,15 +529,16 @@ public class JwtRestIT extends ESRestTestCase {
     private JWTClaimsSet buildJwtForRealm2(String principal, Instant issueTime) {
         // The "jwt2" realm, supports 3 audiences (es01/02/03)
         final String audience = "es0" + randomIntBetween(1, 3);
-        final Map<String, Object> data = new HashMap<>(
-            Map.of("iss", "my-issuer", "aud", audience, "email", principal, "token_use", "access")
-        );
-        // scope (fallback audience) is ignored since aud exists
+        final Map<String, Object> data = new HashMap<>(Map.of("iss", "my-issuer", "email", principal, "token_use", "access"));
         if (randomBoolean()) {
+            data.put("aud", audience);
+            // scope (fallback audience) is ignored since aud exists
             data.put("scope", randomAlphaOfLength(20));
+        } else {
+            data.put("scope", audience);
         }
 
-        final JWTClaimsSet claimsSet = buildJwt(data, issueTime, false);
+        final JWTClaimsSet claimsSet = buildJwt(data, issueTime, false, false);
         return claimsSet;
     }
 
@@ -597,16 +598,18 @@ public class JwtRestIT extends ESRestTestCase {
 
     // JWT construction
     private JWTClaimsSet buildJwt(Map<String, Object> claims, Instant issueTime) {
-        return buildJwt(claims, issueTime, true);
+        return buildJwt(claims, issueTime, true, true);
     }
 
-    private JWTClaimsSet buildJwt(Map<String, Object> claims, Instant issueTime, boolean includeSub) {
+    private JWTClaimsSet buildJwt(Map<String, Object> claims, Instant issueTime, boolean includeSub, boolean includeAud) {
         final JWTClaimsSet.Builder builder = new JWTClaimsSet.Builder();
         builder.issuer(randomAlphaOfLengthBetween(4, 24));
         if (includeSub) {
             builder.subject(randomAlphaOfLengthBetween(4, 24));
         }
-        builder.audience(randomList(1, 6, () -> randomAlphaOfLengthBetween(4, 12)));
+        if (includeAud) {
+            builder.audience(randomList(1, 6, () -> randomAlphaOfLengthBetween(4, 12)));
+        }
         if (randomBoolean()) {
             builder.jwtID(UUIDs.randomBase64UUID(random()));
         }

+ 224 - 0
x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmSingleNodeTests.java

@@ -0,0 +1,224 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.security.authc.jwt;
+
+import com.nimbusds.jose.JWSHeader;
+import com.nimbusds.jose.util.Base64URL;
+import com.nimbusds.jwt.JWTClaimsSet;
+import com.nimbusds.jwt.SignedJWT;
+
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.util.concurrent.ThreadContext;
+import org.elasticsearch.core.Strings;
+import org.elasticsearch.test.SecuritySettingsSource;
+import org.elasticsearch.test.SecuritySingleNodeTestCase;
+import org.elasticsearch.xpack.security.authc.Realms;
+
+import java.text.ParseException;
+import java.time.Instant;
+import java.time.temporal.ChronoUnit;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.nullValue;
+
+public class JwtRealmSingleNodeTests extends SecuritySingleNodeTestCase {
+
+    @Override
+    protected Settings nodeSettings() {
+        final Settings.Builder builder = Settings.builder()
+            .put(super.nodeSettings())
+            // 1st JWT realm
+            .put("xpack.security.authc.realms.jwt.jwt0.order", 10)
+            .put(
+                randomBoolean()
+                    ? Settings.builder().put("xpack.security.authc.realms.jwt.jwt0.token_type", "id_token").build()
+                    : Settings.EMPTY
+            )
+            .put("xpack.security.authc.realms.jwt.jwt0.allowed_issuer", "my-issuer-01")
+            .put("xpack.security.authc.realms.jwt.jwt0.allowed_audiences", "es-01")
+            .put("xpack.security.authc.realms.jwt.jwt0.claims.principal", "sub")
+            .put("xpack.security.authc.realms.jwt.jwt0.claims.groups", "groups")
+            .put("xpack.security.authc.realms.jwt.jwt0.client_authentication.type", "shared_secret")
+            .putList("xpack.security.authc.realms.jwt.jwt0.allowed_signature_algorithms", "HS256", "HS384")
+            // 2nd JWT realm
+            .put("xpack.security.authc.realms.jwt.jwt1.order", 20)
+            .put("xpack.security.authc.realms.jwt.jwt1.token_type", "access_token")
+            .put("xpack.security.authc.realms.jwt.jwt1.allowed_issuer", "my-issuer-02")
+            .put("xpack.security.authc.realms.jwt.jwt1.allowed_subjects", "user-02")
+            .put("xpack.security.authc.realms.jwt.jwt1.allowed_audiences", "es-02")
+            .put("xpack.security.authc.realms.jwt.jwt1.fallback_claims.sub", "client_id")
+            .put("xpack.security.authc.realms.jwt.jwt1.claims.principal", "appid")
+            .put("xpack.security.authc.realms.jwt.jwt1.claims.groups", "groups")
+            .put("xpack.security.authc.realms.jwt.jwt1.client_authentication.type", "shared_secret")
+            .putList("xpack.security.authc.realms.jwt.jwt1.allowed_signature_algorithms", "HS256", "HS384")
+            // 3rd JWT realm
+            .put("xpack.security.authc.realms.jwt.jwt2.order", 30)
+            .put("xpack.security.authc.realms.jwt.jwt2.token_type", "access_token")
+            .put("xpack.security.authc.realms.jwt.jwt2.allowed_issuer", "my-issuer-03")
+            .put("xpack.security.authc.realms.jwt.jwt2.allowed_subjects", "user-03")
+            .put("xpack.security.authc.realms.jwt.jwt2.allowed_audiences", "es-03")
+            .put("xpack.security.authc.realms.jwt.jwt2.fallback_claims.sub", "oid")
+            .put("xpack.security.authc.realms.jwt.jwt2.claims.principal", "email")
+            .put("xpack.security.authc.realms.jwt.jwt2.claims.groups", "groups")
+            .put("xpack.security.authc.realms.jwt.jwt2.client_authentication.type", "shared_secret")
+            .putList("xpack.security.authc.realms.jwt.jwt2.allowed_signature_algorithms", "HS256", "HS384");
+
+        SecuritySettingsSource.addSecureSettings(builder, secureSettings -> {
+            secureSettings.setString("xpack.security.authc.realms.jwt.jwt0.hmac_key", "jwt0_hmac_key");
+            secureSettings.setString("xpack.security.authc.realms.jwt.jwt0.client_authentication.shared_secret", "jwt0_shared_secret");
+            secureSettings.setString("xpack.security.authc.realms.jwt.jwt1.hmac_key", "jwt1_hmac_key");
+            secureSettings.setString("xpack.security.authc.realms.jwt.jwt1.client_authentication.shared_secret", "jwt1_shared_secret");
+            secureSettings.setString("xpack.security.authc.realms.jwt.jwt2.hmac_key", "jwt2_hmac_key");
+            secureSettings.setString("xpack.security.authc.realms.jwt.jwt2.client_authentication.shared_secret", "jwt2_shared_secret");
+        });
+
+        return builder.build();
+    }
+
+    public void testAnyJwtRealmWillExtractTheToken() throws ParseException {
+        final List<JwtRealm> jwtRealms = getJwtRealms();
+        final JwtRealm jwtRealm = randomFrom(jwtRealms);
+
+        final String sharedSecret = randomBoolean() ? randomAlphaOfLengthBetween(10, 20) : null;
+        final String iss = randomAlphaOfLengthBetween(5, 18);
+        final String aud = randomAlphaOfLengthBetween(5, 18);
+        final String sub = randomAlphaOfLengthBetween(5, 18);
+
+        // Realm 1 will extract the token because the JWT has all iss, sub, aud, principal claims.
+        // Their values do not match what realm 1 expects but that does not matter when extracting the token
+        final SignedJWT signedJWT1 = getSignedJWT(Map.of("iss", iss, "aud", aud, "sub", sub));
+        final ThreadContext threadContext1 = prepareThreadContext(signedJWT1, sharedSecret);
+        final var token1 = (JwtAuthenticationToken) jwtRealm.token(threadContext1);
+        final String principal1 = Strings.format("%s/%s/%s/%s", iss, aud, sub, sub);
+        assertJwtToken(token1, principal1, sharedSecret, signedJWT1);
+
+        // Realm 2 for extracting the token from the following JWT
+        // Because it does not have the sub claim but client_id, which is configured as fallback by realm 2
+        final String appId = randomAlphaOfLengthBetween(5, 18);
+        final SignedJWT signedJWT2 = getSignedJWT(Map.of("iss", iss, "aud", aud, "client_id", sub, "appid", appId));
+        final ThreadContext threadContext2 = prepareThreadContext(signedJWT2, sharedSecret);
+        final var token2 = (JwtAuthenticationToken) jwtRealm.token(threadContext2);
+        final String principal2 = Strings.format("%s/%s/%s/%s", iss, aud, sub, appId);
+        assertJwtToken(token2, principal2, sharedSecret, signedJWT2);
+
+        // Realm 3 will extract the token from the following JWT
+        // Because it has the oid claim which is configured as a fallback by realm 3
+        final String email = randomAlphaOfLengthBetween(5, 18) + "@example.com";
+        final SignedJWT signedJWT3 = getSignedJWT(Map.of("iss", iss, "aud", aud, "oid", sub, "email", email));
+        final ThreadContext threadContext3 = prepareThreadContext(signedJWT3, sharedSecret);
+        final var token3 = (JwtAuthenticationToken) jwtRealm.token(threadContext3);
+        final String principal3 = Strings.format("%s/%s/%s/%s", iss, aud, sub, email);
+        assertJwtToken(token3, principal3, sharedSecret, signedJWT3);
+
+        // The JWT does not match any realm's configuration, a token with generic token principal will be extracted
+        final SignedJWT signedJWT4 = getSignedJWT(Map.of("iss", iss, "aud", aud, "azp", sub, "email", email));
+        final ThreadContext threadContext4 = prepareThreadContext(signedJWT4, sharedSecret);
+        final var token4 = (JwtAuthenticationToken) jwtRealm.token(threadContext4);
+        final String principal4 = Strings.format("<unrecognized-jwt> by %s", iss);
+        assertJwtToken(token4, principal4, sharedSecret, signedJWT4);
+
+        // The JWT does not have an issuer, a token with generic token principal will be extracted
+        final SignedJWT signedJWT5 = getSignedJWT(Map.of("aud", aud, "sub", sub));
+        final ThreadContext threadContext5 = prepareThreadContext(signedJWT5, sharedSecret);
+        final var token5 = (JwtAuthenticationToken) jwtRealm.token(threadContext5);
+        final String principal5 = "<unrecognized-jwt>";
+        assertJwtToken(token5, principal5, sharedSecret, signedJWT5);
+    }
+
+    public void testJwtRealmReturnsNullTokenWhenJwtCredentialIsAbsent() {
+        final List<JwtRealm> jwtRealms = getJwtRealms();
+        final JwtRealm jwtRealm = randomFrom(jwtRealms);
+        final String sharedSecret = randomBoolean() ? randomAlphaOfLengthBetween(10, 20) : null;
+
+        // Authorization header is absent
+        final ThreadContext threadContext1 = prepareThreadContext(null, sharedSecret);
+        assertThat(jwtRealm.token(threadContext1), nullValue());
+
+        // Scheme is not Bearer
+        final ThreadContext threadContext2 = prepareThreadContext(null, sharedSecret);
+        threadContext2.putHeader("Authorization", "Basic foobar");
+        assertThat(jwtRealm.token(threadContext2), nullValue());
+    }
+
+    public void testJwtRealmThrowsErrorOnJwtParsingFailure() throws ParseException {
+        final List<JwtRealm> jwtRealms = getJwtRealms();
+        final JwtRealm jwtRealm = randomFrom(jwtRealms);
+        final String sharedSecret = randomBoolean() ? randomAlphaOfLengthBetween(10, 20) : null;
+
+        // Not a JWT
+        final ThreadContext threadContext1 = prepareThreadContext(null, sharedSecret);
+        threadContext1.putHeader("Authorization", "Bearer " + randomAlphaOfLengthBetween(40, 60));
+        final IllegalArgumentException e1 = expectThrows(IllegalArgumentException.class, () -> jwtRealm.token(threadContext1));
+        assertThat(e1.getMessage(), containsString("Failed to parse JWT bearer token"));
+
+        // Payload is not JSON
+        final SignedJWT signedJWT2 = new SignedJWT(
+            JWSHeader.parse(Map.of("alg", randomAlphaOfLengthBetween(5, 10))).toBase64URL(),
+            Base64URL.encode("payload"),
+            Base64URL.encode("signature")
+        );
+        final ThreadContext threadContext2 = prepareThreadContext(null, sharedSecret);
+        threadContext2.putHeader("Authorization", "Bearer " + signedJWT2.serialize());
+        final IllegalArgumentException e2 = expectThrows(IllegalArgumentException.class, () -> jwtRealm.token(threadContext2));
+        assertThat(e2.getMessage(), containsString("Failed to parse JWT claims set"));
+    }
+
+    private void assertJwtToken(JwtAuthenticationToken token, String tokenPrincipal, String sharedSecret, SignedJWT signedJWT)
+        throws ParseException {
+        assertThat(token.principal(), equalTo(tokenPrincipal));
+        assertThat(token.getClientAuthenticationSharedSecret(), equalTo(sharedSecret));
+        assertThat(token.getJWTClaimsSet(), equalTo(signedJWT.getJWTClaimsSet()));
+        assertThat(token.getSignedJWT().getHeader().toJSONObject(), equalTo(signedJWT.getHeader().toJSONObject()));
+        assertThat(token.getSignedJWT().getSignature(), equalTo(signedJWT.getSignature()));
+        assertThat(token.getSignedJWT().getJWTClaimsSet(), equalTo(token.getJWTClaimsSet()));
+    }
+
+    private List<JwtRealm> getJwtRealms() {
+        final Realms realms = getInstanceFromNode(Realms.class);
+        final List<JwtRealm> jwtRealms = realms.getActiveRealms()
+            .stream()
+            .filter(realm -> realm instanceof JwtRealm)
+            .map(JwtRealm.class::cast)
+            .toList();
+        return jwtRealms;
+    }
+
+    private SignedJWT getSignedJWT(Map<String, Object> m) throws ParseException {
+        final HashMap<String, Object> claimsMap = new HashMap<>(m);
+        final Instant now = Instant.now();
+        // timestamp does not matter for tokenExtraction
+        claimsMap.put("iat", now.minus(randomIntBetween(-1, 1), ChronoUnit.DAYS).getEpochSecond());
+        claimsMap.put("exp", now.plus(randomIntBetween(-1, 1), ChronoUnit.DAYS).getEpochSecond());
+
+        final JWTClaimsSet claimsSet = JWTClaimsSet.parse(claimsMap);
+        final SignedJWT signedJWT = new SignedJWT(
+            JWSHeader.parse(Map.of("alg", randomAlphaOfLengthBetween(5, 10))).toBase64URL(),
+            claimsSet.toPayload().toBase64URL(),
+            Base64URL.encode("signature")
+        );
+        return signedJWT;
+    }
+
+    private ThreadContext prepareThreadContext(SignedJWT signedJWT, String clientSecret) {
+        final ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
+        if (signedJWT != null) {
+            threadContext.putHeader("Authorization", "Bearer " + signedJWT.serialize());
+        }
+        if (clientSecret != null) {
+            threadContext.putHeader(
+                JwtRealm.HEADER_CLIENT_AUTHENTICATION,
+                JwtRealm.HEADER_SHARED_SECRET_AUTHENTICATION_SCHEME + " " + clientSecret
+            );
+        }
+        return threadContext;
+    }
+}

+ 0 - 2
x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java

@@ -160,7 +160,6 @@ import org.elasticsearch.xpack.core.security.authc.InternalRealmsSettings;
 import org.elasticsearch.xpack.core.security.authc.Realm;
 import org.elasticsearch.xpack.core.security.authc.RealmConfig;
 import org.elasticsearch.xpack.core.security.authc.RealmSettings;
-import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmsServiceSettings;
 import org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken;
 import org.elasticsearch.xpack.core.security.authz.AuthorizationEngine;
 import org.elasticsearch.xpack.core.security.authz.AuthorizationServiceField;
@@ -1094,7 +1093,6 @@ public class Security extends Plugin
         // authentication and authorization settings
         AnonymousUser.addSettings(settingsList);
         settingsList.addAll(InternalRealmsSettings.getSettings());
-        settingsList.addAll(JwtRealmsServiceSettings.getSettings());
         ReservedRealm.addSettings(settingsList);
         AuthenticationService.addSettings(settingsList);
         AuthorizationService.addSettings(settingsList);

+ 2 - 3
x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/InternalRealms.java

@@ -32,7 +32,7 @@ import org.elasticsearch.xpack.security.authc.esnative.NativeRealm;
 import org.elasticsearch.xpack.security.authc.esnative.NativeUsersStore;
 import org.elasticsearch.xpack.security.authc.esnative.ReservedRealm;
 import org.elasticsearch.xpack.security.authc.file.FileRealm;
-import org.elasticsearch.xpack.security.authc.jwt.JwtRealmsService;
+import org.elasticsearch.xpack.security.authc.jwt.JwtRealm;
 import org.elasticsearch.xpack.security.authc.kerberos.KerberosRealm;
 import org.elasticsearch.xpack.security.authc.ldap.LdapRealm;
 import org.elasticsearch.xpack.security.authc.oidc.OpenIdConnectRealm;
@@ -137,7 +137,6 @@ public final class InternalRealms {
         NativeRoleMappingStore nativeRoleMappingStore,
         SecurityIndexManager securityIndex
     ) {
-        final JwtRealmsService jwtRealmsService = new JwtRealmsService(settings); // parse shared settings needed by all JwtRealm instances
         return Map.of(
             // file realm
             FileRealmSettings.TYPE,
@@ -169,7 +168,7 @@ public final class InternalRealms {
             config -> new OpenIdConnectRealm(config, sslService, nativeRoleMappingStore, resourceWatcherService),
             // JWT realm
             JwtRealmSettings.TYPE,
-            config -> jwtRealmsService.createJwtRealm(config, sslService, nativeRoleMappingStore)
+            config -> new JwtRealm(config, sslService, nativeRoleMappingStore)
         );
     }
 

+ 42 - 101
x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticationToken.java

@@ -9,121 +9,51 @@ package org.elasticsearch.xpack.security.authc.jwt;
 import com.nimbusds.jwt.JWTClaimsSet;
 import com.nimbusds.jwt.SignedJWT;
 
-import org.apache.logging.log4j.LogManager;
-import org.apache.logging.log4j.Logger;
-import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.settings.SecureString;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.xpack.core.security.authc.AuthenticationToken;
-import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmsServiceSettings;
 
 import java.text.ParseException;
-import java.util.List;
-import java.util.Map;
-import java.util.TreeSet;
-import java.util.stream.Collectors;
+import java.util.Arrays;
+import java.util.Objects;
 
 /**
  * An {@link AuthenticationToken} to hold JWT authentication related content.
  */
 public class JwtAuthenticationToken implements AuthenticationToken {
-    private static final Logger LOGGER = LogManager.getLogger(JwtAuthenticationToken.class);
-
-    // Stored members
-    protected SecureString endUserSignedJwt; // required
-    protected SecureString clientAuthenticationSharedSecret; // optional, nullable
-    protected String principal; // Defaults to "iss/aud/sub", with an ordered "aud" list
+    private final String principal;
+    private SignedJWT signedJWT;
+    private final byte[] userCredentialsHash;
+    @Nullable
+    private final SecureString clientAuthenticationSharedSecret;
 
     /**
-     * Store a mandatory JWT and optional Shared Secret. Parse the JWT, and extract the header, claims set, and signature.
-     * Compute a token principal, for use as a realm order cache key. For OIDC ID Tokens, cache key is iss/aud/sub.
-     * For other JWTs, {@link JwtRealmsServiceSettings#PRINCIPAL_CLAIMS_SETTING} supports alternative claims for sub.
-     * Throws IllegalArgumentException if principalClaimNames is empty, JWT is missing, or if JWT parsing fails.
-     * @param principalClaimNames Ordered list of string claims to use for principalClaimValue. The first one found is used (ex: sub).
-     * @param endUserSignedJwt Base64Url-encoded JWT for End-user authentication. Required by all JWT realms.
+     * Store a mandatory JWT and optional Shared Secret.
+     * @param principal The token's principal, useful as a realm order cache key
+     * @param signedJWT The JWT parsed from the end-user credentials
+     * @param userCredentialsHash The hash of the end-user credentials is used to compute the key for user cache at the realm level.
+     *                            See also {@link JwtRealm#authenticate}.
      * @param clientAuthenticationSharedSecret URL-safe Shared Secret for Client authentication. Required by some JWT realms.
      */
     public JwtAuthenticationToken(
-        final List<String> principalClaimNames,
-        final SecureString endUserSignedJwt,
+        String principal,
+        SignedJWT signedJWT,
+        byte[] userCredentialsHash,
         @Nullable final SecureString clientAuthenticationSharedSecret
     ) {
-        if (principalClaimNames.isEmpty()) {
-            throw new IllegalArgumentException("JWT token principal claim names list must be non-empty");
-        } else if (endUserSignedJwt.isEmpty()) {
-            throw new IllegalArgumentException("JWT bearer token must be non-empty");
-        } else if ((clientAuthenticationSharedSecret != null) && (clientAuthenticationSharedSecret.isEmpty())) {
-            throw new IllegalArgumentException("Client shared secret must be non-empty");
-        }
-        this.endUserSignedJwt = endUserSignedJwt; // required
-        this.clientAuthenticationSharedSecret = clientAuthenticationSharedSecret; // optional, nullable
-
-        JWTClaimsSet jwtClaimsSet;
-        try {
-            jwtClaimsSet = SignedJWT.parse(this.endUserSignedJwt.toString()).getJWTClaimsSet();
-        } catch (ParseException e) {
-            throw new IllegalArgumentException("Failed to parse JWT bearer token", e);
-        }
-
-        // get and validate iss and aud claims
-        final String issuer = jwtClaimsSet.getIssuer();
-        final List<String> audiences = jwtClaimsSet.getAudience();
-        if (Strings.hasText(issuer) == false) {
-            throw new IllegalArgumentException("Issuer claim 'iss' is missing.");
-        } else if ((audiences == null) || (audiences.isEmpty())) {
-            throw new IllegalArgumentException("Audiences claim 'aud' is missing.");
-        }
-
-        // get and validate sub claim, or the first configured backup claim (if sub is absent)
-        final String principalClaimValue = this.resolvePrincipalClaimName(jwtClaimsSet, principalClaimNames);
-        this.principal = issuer + "/" + String.join(",", new TreeSet<>(audiences)) + "/" + principalClaimValue;
-    }
+        this.principal = Objects.requireNonNull(principal);
+        this.signedJWT = Objects.requireNonNull(signedJWT);
+        this.userCredentialsHash = Objects.requireNonNull(userCredentialsHash);
 
-    private String resolvePrincipalClaimName(final JWTClaimsSet jwtClaimsSet, final List<String> principalClaimNames) {
-        for (final String principalClaimName : principalClaimNames) {
-            final Object claimValue = jwtClaimsSet.getClaim(principalClaimName);
-            if (claimValue instanceof String principalClaimValue) {
-                // found an allowed string claim name
-                if (principalClaimValue.isEmpty()) {
-                    throw new IllegalArgumentException(
-                        "Allowed principal claim name '"
-                            + principalClaimName
-                            + "' exists but cannot be used because the value of that claim is an empty string"
-                    );
-                }
-                LOGGER.trace("Found allowed principal claim name [{}] with value [{}]", principalClaimName, principalClaimValue);
-                return principalClaimValue;
-            } else if (claimValue != null) {
-                throw new IllegalArgumentException(
-                    "Allowed principal claim name '"
-                        + principalClaimName
-                        + "' exists but cannot be used because the value of that claim must be a string, but instead it was a ["
-                        + claimValue.getClass().getSimpleName()
-                        + "]"
-                );
-            }
+        if ((clientAuthenticationSharedSecret != null) && (clientAuthenticationSharedSecret.isEmpty())) {
+            throw new IllegalArgumentException("Client shared secret must be non-empty");
         }
-
-        // at this point, none of the principalClaimNames were found
-        // throw an exception with a detailed log message about which string claims were available in the JWT
-        final String allClaimNamesWithStringValues = jwtClaimsSet.getClaims()
-            .entrySet()
-            .stream()
-            .filter(e -> e.getValue() instanceof String)
-            .map(Map.Entry::getKey)
-            .collect(Collectors.joining(","));
-        throw new IllegalArgumentException(
-            "None of these configured principal claim names were found in the JWT Claims Set ["
-                + String.join(",", principalClaimNames)
-                + "] - available claims in the JWT with potential compatible string values are ["
-                + allClaimNamesWithStringValues
-                + "]"
-        );
+        this.clientAuthenticationSharedSecret = clientAuthenticationSharedSecret;
     }
 
     @Override
     public String principal() {
-        return this.principal;
+        return principal;
     }
 
     @Override
@@ -131,23 +61,34 @@ public class JwtAuthenticationToken implements AuthenticationToken {
         return null;
     }
 
-    public SecureString getEndUserSignedJwt() {
-        return this.endUserSignedJwt;
+    public SignedJWT getSignedJWT() {
+        return signedJWT;
+    }
+
+    public JWTClaimsSet getJWTClaimsSet() {
+        try {
+            return signedJWT.getJWTClaimsSet();
+        } catch (ParseException e) {
+            assert false : "The JWT claims set should have already been successfully parsed before building the JWT authentication token";
+            throw new IllegalArgumentException(e);
+        }
+    }
+
+    public byte[] getUserCredentialsHash() {
+        return userCredentialsHash;
     }
 
     public SecureString getClientAuthenticationSharedSecret() {
-        return this.clientAuthenticationSharedSecret;
+        return clientAuthenticationSharedSecret;
     }
 
     @Override
     public void clearCredentials() {
-        this.endUserSignedJwt.close();
-        this.endUserSignedJwt = null;
-        if (this.clientAuthenticationSharedSecret != null) {
-            this.clientAuthenticationSharedSecret.close();
-            this.clientAuthenticationSharedSecret = null;
+        signedJWT = null;
+        Arrays.fill(userCredentialsHash, (byte) 0);
+        if (clientAuthenticationSharedSecret != null) {
+            clientAuthenticationSharedSecret.close();
         }
-        this.principal = null;
     }
 
     @Override

+ 3 - 22
x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticator.java

@@ -14,7 +14,6 @@ import com.nimbusds.jwt.SignedJWT;
 import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 import org.elasticsearch.action.ActionListener;
-import org.elasticsearch.common.settings.SecureString;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.core.Releasable;
 import org.elasticsearch.core.TimeValue;
@@ -22,7 +21,6 @@ import org.elasticsearch.xpack.core.security.authc.RealmConfig;
 import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings;
 import org.elasticsearch.xpack.core.ssl.SSLService;
 
-import java.text.ParseException;
 import java.time.Clock;
 import java.util.ArrayList;
 import java.util.List;
@@ -67,28 +65,11 @@ public class JwtAuthenticator implements Releasable {
 
     public void authenticate(JwtAuthenticationToken jwtAuthenticationToken, ActionListener<JWTClaimsSet> listener) {
         final String tokenPrincipal = jwtAuthenticationToken.principal();
-
         // JWT cache
-        final SecureString serializedJwt = jwtAuthenticationToken.getEndUserSignedJwt();
-        final SignedJWT signedJWT;
-        try {
-            signedJWT = SignedJWT.parse(serializedJwt.toString());
-        } catch (ParseException e) {
-            // TODO: No point to continue to another realm since parsing failed
-            listener.onFailure(e);
-            return;
-        }
-
-        final JWTClaimsSet jwtClaimsSet;
-        try {
-            jwtClaimsSet = signedJWT.getJWTClaimsSet();
-        } catch (ParseException e) {
-            // TODO: No point to continue to another realm since get claimset failed
-            listener.onFailure(e);
-            return;
-        }
-
+        final SignedJWT signedJWT = jwtAuthenticationToken.getSignedJWT();
+        final JWTClaimsSet jwtClaimsSet = jwtAuthenticationToken.getJWTClaimsSet();
         final JWSHeader jwsHeader = signedJWT.getHeader();
+
         if (logger.isDebugEnabled()) {
             logger.debug(
                 "Realm [{}] successfully parsed JWT token [{}] with header [{}] and claimSet [{}]",

+ 113 - 45
x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealm.java

@@ -7,6 +7,7 @@
 package org.elasticsearch.xpack.security.authc.jwt;
 
 import com.nimbusds.jwt.JWTClaimsSet;
+import com.nimbusds.jwt.SignedJWT;
 
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.common.Strings;
@@ -34,6 +35,8 @@ import org.elasticsearch.xpack.core.ssl.SSLService;
 import org.elasticsearch.xpack.security.authc.support.ClaimParser;
 import org.elasticsearch.xpack.security.authc.support.DelegatedAuthorizationSupport;
 
+import java.text.ParseException;
+import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.Date;
@@ -41,6 +44,8 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
+import java.util.TreeSet;
+import java.util.function.Function;
 
 import static java.lang.String.join;
 import static org.elasticsearch.core.Strings.format;
@@ -58,7 +63,6 @@ public class JwtRealm extends Realm implements CachingRealm, Releasable {
 
     private final Cache<BytesArray, ExpiringUser> jwtCache;
     private final CacheIteratorHelper<BytesArray, ExpiringUser> jwtCacheHelper;
-    private final JwtRealmsService jwtRealmsService;
     private final UserRoleMapper userRoleMapper;
     private final Boolean populateUserMetadata;
     private final ClaimParser claimParserPrincipal;
@@ -71,15 +75,11 @@ public class JwtRealm extends Realm implements CachingRealm, Releasable {
     private final JwtAuthenticator jwtAuthenticator;
     private final TimeValue allowedClockSkew;
     DelegatedAuthorizationSupport delegatedAuthorizationSupport = null;
+    private List<Function<JWTClaimsSet, String>> tokenPrincipalFunctions;
 
-    JwtRealm(
-        final RealmConfig realmConfig,
-        final JwtRealmsService jwtRealmsService,
-        final SSLService sslService,
-        final UserRoleMapper userRoleMapper
-    ) throws SettingsException {
+    public JwtRealm(final RealmConfig realmConfig, final SSLService sslService, final UserRoleMapper userRoleMapper)
+        throws SettingsException {
         super(realmConfig);
-        this.jwtRealmsService = jwtRealmsService; // common configuration settings shared by all JwtRealm instances
         this.userRoleMapper = userRoleMapper;
         this.userRoleMapper.refreshRealmOnChange(this);
         this.allowedClockSkew = realmConfig.getSetting(JwtRealmSettings.ALLOWED_CLOCK_SKEW);
@@ -140,6 +140,14 @@ public class JwtRealm extends Realm implements CachingRealm, Releasable {
         }
         // extract list of realms referenced by config.settings() value for DelegatedAuthorizationSettings.ROLES_REALMS
         delegatedAuthorizationSupport = new DelegatedAuthorizationSupport(allRealms, config, xpackLicenseState);
+
+        final List<Function<JWTClaimsSet, String>> tokenPrincipalFunctions = new ArrayList<>();
+        for (var realm : allRealms) {
+            if (realm instanceof final JwtRealm jwtRealm) {
+                tokenPrincipalFunctions.add(jwtRealm::buildTokenPrincipal);
+            }
+        }
+        this.tokenPrincipalFunctions = List.copyOf(tokenPrincipalFunctions);
     }
 
     /**
@@ -176,10 +184,66 @@ public class JwtRealm extends Realm implements CachingRealm, Releasable {
     @Override
     public AuthenticationToken token(final ThreadContext threadContext) {
         ensureInitialized();
-        // Token parsing is common code for all realms
-        // First JWT realm will parse in a way that is compatible with all JWT realms,
-        // taking into consideration each JWT realm might have a different principal claim name
-        return jwtRealmsService.token(threadContext);
+
+        final SecureString userCredentials = JwtUtil.getHeaderValue(
+            threadContext,
+            JwtRealm.HEADER_END_USER_AUTHENTICATION,
+            JwtRealm.HEADER_END_USER_AUTHENTICATION_SCHEME,
+            false
+        );
+        if (userCredentials == null) {
+            return null;
+        }
+        if (userCredentials.isEmpty()) {
+            throw new IllegalArgumentException("JWT bearer token must be non-empty");
+        }
+
+        final SecureString clientCredentials = JwtUtil.getHeaderValue(
+            threadContext,
+            JwtRealm.HEADER_CLIENT_AUTHENTICATION,
+            JwtRealm.HEADER_SHARED_SECRET_AUTHENTICATION_SCHEME,
+            true
+        );
+
+        // No point to fall through the realm chain if JWT parsing fails, so we throw error here on failure.
+        final SignedJWT signedJWT;
+        try {
+            signedJWT = SignedJWT.parse(userCredentials.toString());
+        } catch (ParseException e) {
+            throw new IllegalArgumentException("Failed to parse JWT bearer token", e);
+        }
+
+        final JWTClaimsSet jwtClaimsSet;
+        try {
+            jwtClaimsSet = signedJWT.getJWTClaimsSet();
+        } catch (ParseException e) {
+            throw new IllegalArgumentException("Failed to parse JWT claims set", e);
+        }
+
+        // If Issuer is not found, still return a JWT token since it is after still a JWT, authentication
+        // will fail later because issuer is mandated
+        final String issuer = jwtClaimsSet.getIssuer();
+        if (Strings.hasText(issuer) == false) {
+            logger.warn("Issuer claim 'iss' is missing.");
+            return new JwtAuthenticationToken("<unrecognized-jwt>", signedJWT, JwtUtil.sha256(userCredentials), clientCredentials);
+        }
+
+        // Try all known extraction functions to build the token principal
+        for (Function<JWTClaimsSet, String> func : tokenPrincipalFunctions) {
+            final String tokenPrincipalSuffix = func.apply(jwtClaimsSet);
+            if (tokenPrincipalSuffix != null) {
+                return new JwtAuthenticationToken(
+                    issuer + "/" + tokenPrincipalSuffix,
+                    signedJWT,
+                    JwtUtil.sha256(userCredentials),
+                    clientCredentials
+                );
+            }
+        }
+
+        // Token principal cannot be extracted even after trying all functions, but this is
+        // still a JWT token so that we should return as one.
+        return new JwtAuthenticationToken("<unrecognized-jwt> by " + issuer, signedJWT, JwtUtil.sha256(userCredentials), clientCredentials);
     }
 
     @Override
@@ -205,8 +269,7 @@ public class JwtRealm extends Realm implements CachingRealm, Releasable {
                 return; // FAILED (secret is missing or mismatched)
             }
 
-            final SecureString serializedJwt = jwtAuthenticationToken.getEndUserSignedJwt();
-            final BytesArray jwtCacheKey = isCacheEnabled() ? new BytesArray(JwtUtil.sha256(serializedJwt)) : null;
+            final BytesArray jwtCacheKey = isCacheEnabled() ? new BytesArray(jwtAuthenticationToken.getUserCredentialsHash()) : null;
             if (jwtCacheKey != null) {
                 final User cachedUser = tryAuthenticateWithCache(tokenPrincipal, jwtCacheKey);
                 if (cachedUser != null) {
@@ -264,40 +327,21 @@ public class JwtRealm extends Realm implements CachingRealm, Releasable {
             final Date exp = expiringUser.exp; // claimsSet.getExpirationTime().getTime() + allowedClockSkew.getMillis()
             final String principal = user.principal();
             final Date now = new Date();
-            if (now.getTime() < exp.getTime()) {
-                logger.trace(
-                    "Realm ["
-                        + name()
-                        + "] JWT cache hit token=["
-                        + tokenPrincipal
-                        + "] key=["
-                        + jwtCacheKey
-                        + "] principal=["
-                        + principal
-                        + "] exp=["
-                        + exp
-                        + "] now=["
-                        + now
-                        + "]."
-                );
+            final boolean cacheEntryNotExpired = now.getTime() < exp.getTime();
+            logger.trace(
+                "Realm [{}] JWT cache {} token=[{}] key=[{}] principal=[{}] exp=[{}] now=[{}].",
+                name(),
+                cacheEntryNotExpired ? "hit" : "exp",
+                tokenPrincipal,
+                jwtCacheKey,
+                principal,
+                exp,
+                now
+            );
+            if (cacheEntryNotExpired) {
                 return user;
             }
             // TODO: evict the entry
-            logger.trace(
-                "Realm ["
-                    + name()
-                    + "] JWT cache exp token=["
-                    + tokenPrincipal
-                    + "] key=["
-                    + jwtCacheKey
-                    + "] principal=["
-                    + principal
-                    + "] exp=["
-                    + exp
-                    + "] now=["
-                    + now
-                    + "]."
-            );
         }
         return null;
     }
@@ -410,6 +454,30 @@ public class JwtRealm extends Realm implements CachingRealm, Releasable {
         return Map.copyOf(metadata);
     }
 
+    private String buildTokenPrincipal(JWTClaimsSet jwtClaimsSet) {
+        final Map<String, String> fallbackClaimNames = jwtAuthenticator.getFallbackClaimNames();
+        final FallbackableClaim subClaim = new FallbackableClaim("sub", fallbackClaimNames, jwtClaimsSet);
+        final String subject = subClaim.getStringClaimValue();
+        if (false == Strings.hasText(subject)) {
+            logger.debug("claim [{}] is missing for building token principal for realm [{}]", subClaim, name());
+            return null;
+        }
+
+        final FallbackableClaim audClaim = new FallbackableClaim("aud", fallbackClaimNames, jwtClaimsSet);
+        final List<String> audiences = audClaim.getStringListClaimValue();
+        if (audiences == null || audiences.isEmpty()) {
+            logger.debug("claim [{}] is missing for building token principal for realm [{}]", audClaim, name());
+            return null;
+        }
+
+        final String userPrincipal = claimParserPrincipal.getClaimValue(jwtClaimsSet);
+        if (false == Strings.hasText(userPrincipal)) {
+            logger.debug("No user principal can be extracted with [{}] for realm [{}]", claimParserPrincipal, name());
+            return null;
+        }
+        return String.join(",", new TreeSet<>(audiences)) + "/" + subject + "/" + userPrincipal;
+    }
+
     /**
      * JWTClaimsSet values are only allowed to be String, Boolean, Number, or Collection.
      * Collections are only allowed to contain String, Boolean, or Number.

+ 0 - 88
x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmsService.java

@@ -1,88 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License
- * 2.0; you may not use this file except in compliance with the Elastic License
- * 2.0.
- */
-package org.elasticsearch.xpack.security.authc.jwt;
-
-import org.elasticsearch.common.settings.SecureString;
-import org.elasticsearch.common.settings.Settings;
-import org.elasticsearch.common.util.concurrent.ThreadContext;
-import org.elasticsearch.xpack.core.security.authc.AuthenticationToken;
-import org.elasticsearch.xpack.core.security.authc.RealmConfig;
-import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmsServiceSettings;
-import org.elasticsearch.xpack.core.ssl.SSLService;
-import org.elasticsearch.xpack.security.authc.InternalRealms;
-import org.elasticsearch.xpack.security.authc.support.mapper.NativeRoleMappingStore;
-
-import java.util.Collections;
-import java.util.List;
-
-/**
- * Parse common settings shared by all JwtRealm instances on behalf of InternalRealms.
- * Construct JwtRealm instances on behalf of lambda defined in InternalRealms.
- * Construct AuthenticationToken instances on behalf of JwtRealm instances.
- * @see InternalRealms
- * @see JwtRealm
- */
-public class JwtRealmsService {
-
-    private final List<String> principalClaimNames;
-
-    /**
-     * Parse all xpack settings passed in from {@link InternalRealms#getFactories}
-     * @param settings All xpack settings
-     */
-    public JwtRealmsService(final Settings settings) {
-        this.principalClaimNames = Collections.unmodifiableList(JwtRealmsServiceSettings.PRINCIPAL_CLAIMS_SETTING.get(settings));
-    }
-
-    /**
-     * Return prioritized list of principal claim names to use for computing realm cache keys for all JWT realms.
-     * @return Prioritized list of principal claim names (ex: sub, oid, client_id, azp, appid, client_id, email).
-     */
-    public List<String> getPrincipalClaimNames() {
-        return this.principalClaimNames;
-    }
-
-    /**
-     * Construct JwtRealm instance using settings passed in via lambda defined in {@link InternalRealms#getFactories}
-     * @param config Realm config
-     * @param sslService SSL service settings
-     * @param nativeRoleMappingStore Native role mapping store
-     */
-    public JwtRealm createJwtRealm(
-        final RealmConfig config,
-        final SSLService sslService,
-        final NativeRoleMappingStore nativeRoleMappingStore
-    ) {
-        return new JwtRealm(config, this, sslService, nativeRoleMappingStore);
-    }
-
-    /**
-     * Construct JwtAuthenticationToken instance using request passed in via JwtRealm.token.
-     * @param threadContext Request headers and parameters
-     * @return JwtAuthenticationToken contains mandatory JWT header, optional client secret, and a realm order cache key
-     */
-    AuthenticationToken token(final ThreadContext threadContext) {
-        // extract value from Authorization header with Bearer scheme prefix
-        final SecureString authenticationParameterValue = JwtUtil.getHeaderValue(
-            threadContext,
-            JwtRealm.HEADER_END_USER_AUTHENTICATION,
-            JwtRealm.HEADER_END_USER_AUTHENTICATION_SCHEME,
-            false
-        );
-        if (authenticationParameterValue == null) {
-            return null;
-        }
-        // extract value from ES-Client-Authentication header with SharedSecret scheme prefix
-        final SecureString clientAuthenticationSharedSecretValue = JwtUtil.getHeaderValue(
-            threadContext,
-            JwtRealm.HEADER_CLIENT_AUTHENTICATION,
-            JwtRealm.HEADER_SHARED_SECRET_AUTHENTICATION_SCHEME,
-            true
-        );
-        return new JwtAuthenticationToken(this.principalClaimNames, authenticationParameterValue, clientAuthenticationSharedSecretValue);
-    }
-}

+ 0 - 97
x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticationTokenTests.java

@@ -1,97 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License
- * 2.0; you may not use this file except in compliance with the Elastic License
- * 2.0.
- */
-package org.elasticsearch.xpack.security.authc.jwt;
-
-import com.nimbusds.jose.JOSEObjectType;
-import com.nimbusds.jose.jwk.JWK;
-import com.nimbusds.jwt.SignedJWT;
-
-import org.elasticsearch.common.settings.SecureString;
-import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings;
-import org.junit.Assert;
-
-import java.time.Instant;
-import java.time.temporal.ChronoUnit;
-import java.util.Date;
-import java.util.List;
-import java.util.Map;
-
-import static org.hamcrest.Matchers.equalTo;
-import static org.hamcrest.Matchers.is;
-import static org.hamcrest.Matchers.nullValue;
-
-public class JwtAuthenticationTokenTests extends JwtTestCase {
-
-    public void testJwtAuthenticationTokenParse() throws Exception {
-        final String signatureAlgorithm = randomFrom(JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS);
-        final JWK jwk = JwtTestCase.randomJwk(signatureAlgorithm, randomBoolean());
-
-        final SecureString jwt = JwtTestCase.randomBespokeJwt(jwk, signatureAlgorithm); // bespoke JWT, not tied to any JWT realm
-        final SecureString clientSharedSecret = randomBoolean() ? null : new SecureString(randomAlphaOfLengthBetween(10, 20).toCharArray());
-
-        final List<String> principalClaimNames = List.of(randomAlphaOfLength(4), "sub", randomAlphaOfLength(4));
-        final JwtAuthenticationToken jwtAuthenticationToken = new JwtAuthenticationToken(principalClaimNames, jwt, clientSharedSecret);
-        final SecureString endUserSignedJwt = jwtAuthenticationToken.getEndUserSignedJwt();
-        final SecureString clientAuthenticationSharedSecret = jwtAuthenticationToken.getClientAuthenticationSharedSecret();
-
-        Assert.assertEquals(jwt, endUserSignedJwt);
-        Assert.assertEquals(clientSharedSecret, clientAuthenticationSharedSecret);
-
-        jwtAuthenticationToken.clearCredentials();
-
-        // verify references to SecureString throw exception when calling their methods
-        final Exception exception1 = expectThrows(IllegalStateException.class, endUserSignedJwt::length);
-        assertThat(exception1.getMessage(), equalTo("SecureString has already been closed"));
-        if (clientAuthenticationSharedSecret != null) {
-            final Exception exception2 = expectThrows(IllegalStateException.class, clientAuthenticationSharedSecret::length);
-            assertThat(exception2.getMessage(), equalTo("SecureString has already been closed"));
-        }
-
-        // verify token returns nulls
-        assertThat(jwtAuthenticationToken.principal(), is(nullValue()));
-        assertThat(jwtAuthenticationToken.credentials(), is(nullValue()));
-        assertThat(jwtAuthenticationToken.getEndUserSignedJwt(), is(nullValue()));
-        assertThat(jwtAuthenticationToken.getClientAuthenticationSharedSecret(), is(nullValue()));
-    }
-
-    public void testPrincipalForJwtWithoutSub() throws Exception {
-        final String issuer = randomAlphaOfLengthBetween(8, 24);
-        final String audience = randomAlphaOfLengthBetween(6, 12);
-
-        final String principalClaimName = randomValueOtherThan("sub", () -> randomAlphaOfLength(3));
-        final String principalClaimValue = randomAlphaOfLengthBetween(8, 32);
-
-        final String signatureAlgorithm = randomFrom(JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS);
-        final JWK jwk = JwtTestCase.randomJwk(signatureAlgorithm, randomBoolean());
-
-        final Instant now = Instant.now().truncatedTo(ChronoUnit.SECONDS);
-        final SignedJWT unsignedJwt = JwtTestCase.buildUnsignedJwt(
-            randomBoolean() ? null : JOSEObjectType.JWT.toString(), // kty
-            randomBoolean() ? null : jwk.getKeyID(), // kid
-            signatureAlgorithm, // alg
-            null, // jwtID
-            issuer, // iss
-            List.of(audience), // aud
-            null, // sub claim value
-            principalClaimName, // principal claim name
-            principalClaimValue, // principal claim value
-            null, // groups claim
-            List.of(), // groups
-            Date.from(now.minusSeconds(60 * randomLongBetween(10, 20))), // auth_time
-            Date.from(now.minusSeconds(randomBoolean() ? 0 : 60 * randomLongBetween(5, 10))), // iat
-            Date.from(now), // nbf
-            Date.from(now.plusSeconds(60 * randomLongBetween(3600, 7200))), // exp
-            null, // nonce
-            Map.of() // other claims
-        );
-        final SecureString jwt = JwtValidateUtil.signJwt(jwk, unsignedJwt);
-
-        final List<String> principalClaimNames = List.of(randomAlphaOfLength(4), "sub", principalClaimName);
-        final JwtAuthenticationToken jwtAuthenticationToken = new JwtAuthenticationToken(principalClaimNames, jwt, null);
-        Assert.assertEquals(issuer + "/" + audience + "/" + principalClaimValue, jwtAuthenticationToken.principal());
-    }
-}

+ 10 - 6
x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticatorTests.java

@@ -15,7 +15,6 @@ import com.nimbusds.jwt.SignedJWT;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.common.settings.MockSecureSettings;
-import org.elasticsearch.common.settings.SecureString;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
 import org.elasticsearch.core.Nullable;
@@ -102,7 +101,8 @@ public abstract class JwtAuthenticatorTests extends ESTestCase {
         );
 
         final JwtAuthenticationToken jwtAuthenticationToken = mock(JwtAuthenticationToken.class);
-        when(jwtAuthenticationToken.getEndUserSignedJwt()).thenReturn(new SecureString(signedJWT.serialize().toCharArray()));
+        when(jwtAuthenticationToken.getSignedJWT()).thenReturn(signedJWT);
+        when(jwtAuthenticationToken.getJWTClaimsSet()).thenReturn(signedJWT.getJWTClaimsSet());
 
         final PlainActionFuture<JWTClaimsSet> future = new PlainActionFuture<>();
         final JwtAuthenticator jwtAuthenticator = buildJwtAuthenticator();
@@ -139,7 +139,8 @@ public abstract class JwtAuthenticatorTests extends ESTestCase {
         );
 
         final JwtAuthenticationToken jwtAuthenticationToken = mock(JwtAuthenticationToken.class);
-        when(jwtAuthenticationToken.getEndUserSignedJwt()).thenReturn(new SecureString(signedJWT.serialize().toCharArray()));
+        when(jwtAuthenticationToken.getSignedJWT()).thenReturn(signedJWT);
+        when(jwtAuthenticationToken.getJWTClaimsSet()).thenReturn(signedJWT.getJWTClaimsSet());
 
         final PlainActionFuture<JWTClaimsSet> future = new PlainActionFuture<>();
         final JwtAuthenticator jwtAuthenticator = buildJwtAuthenticator();
@@ -182,7 +183,8 @@ public abstract class JwtAuthenticatorTests extends ESTestCase {
         );
 
         final JwtAuthenticationToken jwtAuthenticationToken = mock(JwtAuthenticationToken.class);
-        when(jwtAuthenticationToken.getEndUserSignedJwt()).thenReturn(new SecureString(signedJWT.serialize().toCharArray()));
+        when(jwtAuthenticationToken.getSignedJWT()).thenReturn(signedJWT);
+        when(jwtAuthenticationToken.getJWTClaimsSet()).thenReturn(signedJWT.getJWTClaimsSet());
 
         // Required claim is mandatory when configured
         final PlainActionFuture<JWTClaimsSet> future1 = new PlainActionFuture<>();
@@ -204,7 +206,8 @@ public abstract class JwtAuthenticatorTests extends ESTestCase {
             Base64URL.encode("signature")
         );
         final JwtAuthenticationToken jwtAuthenticationToken = mock(JwtAuthenticationToken.class);
-        when(jwtAuthenticationToken.getEndUserSignedJwt()).thenReturn(new SecureString(signedJWT.serialize().toCharArray()));
+        when(jwtAuthenticationToken.getSignedJWT()).thenReturn(signedJWT);
+        when(jwtAuthenticationToken.getJWTClaimsSet()).thenReturn(signedJWT.getJWTClaimsSet());
 
         final PlainActionFuture<JWTClaimsSet> future = new PlainActionFuture<>();
         jwtAuthenticator.authenticate(jwtAuthenticationToken, future);
@@ -221,7 +224,8 @@ public abstract class JwtAuthenticatorTests extends ESTestCase {
             Base64URL.encode("signature")
         );
         final JwtAuthenticationToken jwtAuthenticationToken = mock(JwtAuthenticationToken.class);
-        when(jwtAuthenticationToken.getEndUserSignedJwt()).thenReturn(new SecureString(signedJWT.serialize().toCharArray()));
+        when(jwtAuthenticationToken.getSignedJWT()).thenReturn(signedJWT);
+        when(jwtAuthenticationToken.getJWTClaimsSet()).thenReturn(signedJWT.getJWTClaimsSet());
 
         final PlainActionFuture<JWTClaimsSet> future = new PlainActionFuture<>();
         jwtAuthenticator.authenticate(jwtAuthenticationToken, future);

+ 4 - 3
x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtIssuer.java

@@ -23,7 +23,9 @@ import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 
+import static org.elasticsearch.test.ESTestCase.randomAlphaOfLengthBetween;
 import static org.elasticsearch.test.ESTestCase.randomBoolean;
+import static org.elasticsearch.test.ESTestCase.randomFrom;
 
 /**
  * Test class with settings for a JWT issuer to sign JWTs for users.
@@ -37,7 +39,7 @@ public class JwtIssuer implements Closeable {
     // input parameters
     final String issuerClaimValue; // claim name is hard-coded to `iss` for OIDC ID Token compatibility
     final List<String> audiencesClaimValue; // claim name is hard-coded to `aud` for OIDC ID Token compatibility
-    final String principalClaimName; // claim name is configurable, EX: Users (sub, oid, email, dn, uid), Clients (azp, appid, client_id)
+    final String principalClaimName;
     final Map<String, User> principals; // principals with roles, for sending encoded JWTs into JWT realms for authc/authz verification
     final JwtIssuerHttpsServer httpsServer;
 
@@ -56,15 +58,14 @@ public class JwtIssuer implements Closeable {
     JwtIssuer(
         final String issuerClaimValue,
         final List<String> audiencesClaimValue,
-        final String principalClaimName,
         final Map<String, User> principals,
         final boolean createHttpsServer
     ) throws Exception {
         this.issuerClaimValue = issuerClaimValue;
         this.audiencesClaimValue = audiencesClaimValue;
-        this.principalClaimName = principalClaimName;
         this.principals = principals;
         this.httpsServer = createHttpsServer ? new JwtIssuerHttpsServer(null) : null;
+        this.principalClaimName = randomFrom("sub", "oid", "client_id", "appid", "azp", "email", randomAlphaOfLengthBetween(12, 18));
     }
 
     // The flag areHmacJwksOidcSafe indicates if all provided HMAC JWKs are UTF8, for HMAC OIDC JWK encoding compatibility.

+ 8 - 3
x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmAuthenticateAccessTokenTypeTests.java

@@ -38,7 +38,6 @@ public class JwtRealmAuthenticateAccessTokenTypeTests extends JwtRealmTestCase {
         noFallback();
 
         jwtIssuerAndRealms = generateJwtIssuerRealmPairs(
-            createJwtRealmsSettingsBuilder(),
             randomIntBetween(1, 1), // realms
             randomIntBetween(0, 1), // authz
             randomIntBetween(1, JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS.size()), // algorithms
@@ -60,7 +59,6 @@ public class JwtRealmAuthenticateAccessTokenTypeTests extends JwtRealmTestCase {
         randomFallbacks();
 
         jwtIssuerAndRealms = generateJwtIssuerRealmPairs(
-            createJwtRealmsSettingsBuilder(),
             randomIntBetween(1, 1), // realms
             randomIntBetween(0, 1), // authz
             randomIntBetween(1, JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS.size()), // algorithms
@@ -122,8 +120,15 @@ public class JwtRealmAuthenticateAccessTokenTypeTests extends JwtRealmTestCase {
                 otherClaims.put(fallbackSub, randomValueOtherThan(subClaimValue, () -> randomAlphaOfLength(15)));
             }
         }
-        // TODO: fallback aud
         List<String> audClaimValue = JwtRealmInspector.getAllowedAudiences(jwtIssuerAndRealm.realm());
+        if (fallbackAud != null) {
+            if (randomBoolean()) {
+                otherClaims.put(fallbackAud, audClaimValue);
+                audClaimValue = null;
+            } else {
+                otherClaims.put(fallbackAud, randomValueOtherThanMany(audClaimValue::contains, () -> randomAlphaOfLength(15)));
+            }
+        }
 
         // A bogus auth_time but access_token type does not check it
         if (randomBoolean()) {

+ 7 - 21
x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmAuthenticateTests.java

@@ -23,6 +23,7 @@ import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.settings.SettingsException;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
 import org.elasticsearch.xpack.core.security.authc.AuthenticationResult;
+import org.elasticsearch.xpack.core.security.authc.AuthenticationToken;
 import org.elasticsearch.xpack.core.security.authc.Realm;
 import org.elasticsearch.xpack.core.security.authc.RealmSettings;
 import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings;
@@ -49,7 +50,6 @@ public class JwtRealmAuthenticateTests extends JwtRealmTestCase {
      */
     public void testJwtAuthcRealmAuthcAuthzWithEmptyRoles() throws Exception {
         this.jwtIssuerAndRealms = this.generateJwtIssuerRealmPairs(
-            this.createJwtRealmsSettingsBuilder(),
             randomIntBetween(1, 1), // realmsRange
             randomIntBetween(0, 1), // authzRange
             randomIntBetween(1, JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS.size()), // algsRange
@@ -73,7 +73,6 @@ public class JwtRealmAuthenticateTests extends JwtRealmTestCase {
      */
     public void testJwtAuthcRealmAuthcAuthzWithoutAuthzRealms() throws Exception {
         this.jwtIssuerAndRealms = this.generateJwtIssuerRealmPairs(
-            this.createJwtRealmsSettingsBuilder(),
             randomIntBetween(1, 3), // realmsRange
             randomIntBetween(0, 0), // authzRange
             randomIntBetween(1, JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS.size()), // algsRange
@@ -99,7 +98,6 @@ public class JwtRealmAuthenticateTests extends JwtRealmTestCase {
      */
     public void testJwkSetUpdates() throws Exception {
         this.jwtIssuerAndRealms = this.generateJwtIssuerRealmPairs(
-            this.createJwtRealmsSettingsBuilder(),
             randomIntBetween(1, 3), // realmsRange
             randomIntBetween(0, 0), // authzRange
             randomIntBetween(1, JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS.size()), // algsRange
@@ -255,7 +253,6 @@ public class JwtRealmAuthenticateTests extends JwtRealmTestCase {
      */
     public void testJwtAuthcRealmAuthcAuthzWithAuthzRealms() throws Exception {
         this.jwtIssuerAndRealms = this.generateJwtIssuerRealmPairs(
-            this.createJwtRealmsSettingsBuilder(),
             randomIntBetween(1, 3), // realmsRange
             randomIntBetween(1, 3), // authzRange
             randomIntBetween(1, JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS.size()), // algsRange
@@ -284,11 +281,7 @@ public class JwtRealmAuthenticateTests extends JwtRealmTestCase {
             final User otherUser = new User(otherUsername);
             final SecureString otherJwt = this.randomJwt(jwtIssuerAndRealm, otherUser);
 
-            final JwtAuthenticationToken otherToken = new JwtAuthenticationToken(
-                List.of(JwtRealmInspector.getPrincipalClaimName(jwtIssuerAndRealm.realm())),
-                otherJwt,
-                clientSecret
-            );
+            final AuthenticationToken otherToken = jwtIssuerAndRealm.realm().token(createThreadContext(otherJwt, clientSecret));
             final PlainActionFuture<AuthenticationResult<User>> otherFuture = new PlainActionFuture<>();
             jwtIssuerAndRealm.realm().authenticate(otherToken, otherFuture);
             final AuthenticationResult<User> otherResult = otherFuture.actionGet();
@@ -306,12 +299,9 @@ public class JwtRealmAuthenticateTests extends JwtRealmTestCase {
      * @throws Exception Unexpected test failure
      */
     public void testPkcJwkSetUrlNotFound() throws Exception {
-        final JwtRealmsService jwtRealmsService = this.generateJwtRealmsService(this.createJwtRealmsSettingsBuilder());
-        final String principalClaimName = randomFrom(jwtRealmsService.getPrincipalClaimNames());
-
         final List<Realm> allRealms = new ArrayList<>(); // authc and authz realms
         final boolean createHttpsServer = true; // force issuer to create HTTPS server for its PKC JWKSet
-        final JwtIssuer jwtIssuer = this.createJwtIssuer(0, principalClaimName, 12, 1, 1, 1, createHttpsServer);
+        final JwtIssuer jwtIssuer = this.createJwtIssuer(0, 12, 1, 1, 1, createHttpsServer);
         assertThat(jwtIssuer.httpsServer, is(notNullValue()));
         try {
             final JwtRealmSettingsBuilder jwtRealmSettingsBuilder = this.createJwtRealmSettingsBuilder(jwtIssuer, 0, 0);
@@ -320,7 +310,7 @@ public class JwtRealmAuthenticateTests extends JwtRealmTestCase {
             jwtRealmSettingsBuilder.settingsBuilder().put(configKey, configValue);
             final Exception exception = expectThrows(
                 SettingsException.class,
-                () -> this.createJwtRealm(allRealms, jwtRealmsService, jwtIssuer, jwtRealmSettingsBuilder)
+                () -> this.createJwtRealm(allRealms, jwtIssuer, jwtRealmSettingsBuilder)
             );
             assertThat(exception.getMessage(), equalTo("Can't get contents for setting [" + configKey + "] value [" + configValue + "]."));
             assertThat(exception.getCause().getMessage(), equalTo("Get [" + configValue + "] failed, status [404], reason [Not Found]."));
@@ -335,7 +325,6 @@ public class JwtRealmAuthenticateTests extends JwtRealmTestCase {
      */
     public void testJwtValidationFailures() throws Exception {
         this.jwtIssuerAndRealms = this.generateJwtIssuerRealmPairs(
-            this.createJwtRealmsSettingsBuilder(),
             randomIntBetween(1, 1), // realmsRange
             randomIntBetween(0, 0), // authzRange
             randomIntBetween(1, JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS.size()), // algsRange
@@ -481,12 +470,9 @@ public class JwtRealmAuthenticateTests extends JwtRealmTestCase {
      * @throws Exception Unexpected test failure
      */
     public void testSameIssuerTwoRealmsDifferentClientSecrets() throws Exception {
-        final JwtRealmsService jwtRealmsService = this.generateJwtRealmsService(this.createJwtRealmsSettingsBuilder());
-        final String principalClaimName = randomFrom(jwtRealmsService.getPrincipalClaimNames());
-
         final int realmsCount = 2;
         final List<Realm> allRealms = new ArrayList<>(realmsCount); // two identical realms for same issuer, except different client secret
-        final JwtIssuer jwtIssuer = this.createJwtIssuer(0, principalClaimName, 12, 1, 1, 1, false);
+        final JwtIssuer jwtIssuer = this.createJwtIssuer(0, 12, 1, 1, 1, false);
         super.printJwtIssuer(jwtIssuer);
         this.jwtIssuerAndRealms = new ArrayList<>(realmsCount);
         for (int i = 0; i < realmsCount; i++) {
@@ -501,7 +487,7 @@ public class JwtRealmAuthenticateTests extends JwtRealmTestCase {
                     String.join(",", jwtIssuer.algorithmsAll)
                 )
                 .put(RealmSettings.getFullSettingKey(realmName, JwtRealmSettings.ALLOWED_AUDIENCES), jwtIssuer.audiencesClaimValue.get(0))
-                .put(RealmSettings.getFullSettingKey(realmName, JwtRealmSettings.CLAIMS_PRINCIPAL.getClaim()), principalClaimName)
+                .put(RealmSettings.getFullSettingKey(realmName, JwtRealmSettings.CLAIMS_PRINCIPAL.getClaim()), jwtIssuer.principalClaimName)
                 .put(
                     RealmSettings.getFullSettingKey(realmName, JwtRealmSettings.CLIENT_AUTHENTICATION_TYPE),
                     JwtRealmSettings.ClientAuthenticationType.SHARED_SECRET.value()
@@ -532,7 +518,7 @@ public class JwtRealmAuthenticateTests extends JwtRealmTestCase {
             );
             authcSettings.setSecureSettings(secureSettings);
             final JwtRealmSettingsBuilder jwtRealmSettingsBuilder = new JwtRealmSettingsBuilder(realmName, authcSettings);
-            final JwtRealm jwtRealm = this.createJwtRealm(allRealms, jwtRealmsService, jwtIssuer, jwtRealmSettingsBuilder);
+            final JwtRealm jwtRealm = this.createJwtRealm(allRealms, jwtIssuer, jwtRealmSettingsBuilder);
             jwtRealm.initialize(allRealms, super.licenseState);
             final JwtIssuerAndRealm jwtIssuerAndRealm = new JwtIssuerAndRealm(jwtIssuer, jwtRealm, jwtRealmSettingsBuilder);
             this.jwtIssuerAndRealms.add(jwtIssuerAndRealm); // add them so the test will clean them up

+ 5 - 26
x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmGenerateTests.java

@@ -25,7 +25,6 @@ import org.elasticsearch.env.TestEnvironment;
 import org.elasticsearch.xpack.core.security.authc.RealmConfig;
 import org.elasticsearch.xpack.core.security.authc.RealmSettings;
 import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings;
-import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmsServiceSettings;
 import org.elasticsearch.xpack.core.security.authc.support.DelegatedAuthorizationSettings;
 import org.elasticsearch.xpack.core.security.authc.support.UserRoleMapper;
 import org.elasticsearch.xpack.core.security.user.User;
@@ -67,16 +66,11 @@ public class JwtRealmGenerateTests extends JwtRealmTestCase {
         );
 
         final String principalClaimName = "sub";
-        final Settings.Builder jwtRealmsServiceSettings = Settings.builder()
-            .put(this.globalSettings)
-            .put(JwtRealmsServiceSettings.PRINCIPAL_CLAIMS_SETTING.getKey(), String.join(",", principalClaimName));
-        final JwtRealmsService jwtRealmsService = new JwtRealmsService(jwtRealmsServiceSettings.build());
 
         // Create issuer
         final JwtIssuer jwtIssuer = new JwtIssuer(
             "iss8", // iss
             List.of("aud8"), // aud
-            principalClaimName, // sub
             Collections.singletonMap("security_test_user", new User("security_test_user", "security_test_role")), // users
             false // createHttpsServer
         );
@@ -115,7 +109,7 @@ public class JwtRealmGenerateTests extends JwtRealmTestCase {
         final RealmConfig config = super.buildRealmConfig(JwtRealmSettings.TYPE, realmName, configBuilder.build(), 8);
         final SSLService sslService = new SSLService(TestEnvironment.newEnvironment(configBuilder.build()));
         final UserRoleMapper userRoleMapper = super.buildRoleMapper(jwtIssuer.principals);
-        final JwtRealm jwtRealm = new JwtRealm(config, jwtRealmsService, sslService, userRoleMapper);
+        final JwtRealm jwtRealm = new JwtRealm(config, sslService, userRoleMapper);
         jwtRealm.initialize(Collections.singletonList(jwtRealm), super.licenseState);
         final JwtRealmSettingsBuilder jwtRealmSettingsBuilder = new JwtRealmSettingsBuilder(realmName, configBuilder);
         final JwtIssuerAndRealm jwtIssuerAndRealm = new JwtIssuerAndRealm(jwtIssuer, jwtRealm, jwtRealmSettingsBuilder);
@@ -161,16 +155,11 @@ public class JwtRealmGenerateTests extends JwtRealmTestCase {
         final JwtIssuer.AlgJwkPair algJwkPairPkc = new JwtIssuer.AlgJwkPair("RS256", jwk);
 
         final String principalClaimName = "sub";
-        final Settings.Builder jwtRealmsServiceSettings = Settings.builder()
-            .put(this.globalSettings)
-            .put(JwtRealmsServiceSettings.PRINCIPAL_CLAIMS_SETTING.getKey(), String.join(",", principalClaimName));
-        final JwtRealmsService jwtRealmsService = new JwtRealmsService(jwtRealmsServiceSettings.build());
 
         // Create issuer
         final JwtIssuer jwtIssuer = new JwtIssuer(
             "https://issuer.example.com/", // iss claim value
             List.of("https://audience.example.com/"), // aud claim value
-            principalClaimName, // principal claim name
             Collections.singletonMap("user1", new User("user1", "role1")), // users
             false // createHttpsServer
         );
@@ -207,7 +196,7 @@ public class JwtRealmGenerateTests extends JwtRealmTestCase {
         final RealmConfig config = super.buildRealmConfig(JwtRealmSettings.TYPE, realmName, configBuilder.build(), 2);
         final SSLService sslService = new SSLService(TestEnvironment.newEnvironment(configBuilder.build()));
         final UserRoleMapper userRoleMapper = super.buildRoleMapper(jwtIssuer.principals);
-        final JwtRealm jwtRealm = new JwtRealm(config, jwtRealmsService, sslService, userRoleMapper);
+        final JwtRealm jwtRealm = new JwtRealm(config, sslService, userRoleMapper);
         jwtRealm.initialize(Collections.singletonList(jwtRealm), super.licenseState);
         final JwtRealmSettingsBuilder jwtRealmSettingsBuilder = new JwtRealmSettingsBuilder(realmName, configBuilder);
         final JwtIssuerAndRealm jwtIssuerAndRealm = new JwtIssuerAndRealm(jwtIssuer, jwtRealm, jwtRealmSettingsBuilder);
@@ -256,16 +245,11 @@ public class JwtRealmGenerateTests extends JwtRealmTestCase {
         );
 
         final String principalClaimName = "email";
-        final Settings.Builder jwtRealmsServiceSettings = Settings.builder()
-            .put(this.globalSettings)
-            .put(JwtRealmsServiceSettings.PRINCIPAL_CLAIMS_SETTING.getKey(), String.join(",", principalClaimName));
-        final JwtRealmsService jwtRealmsService = new JwtRealmsService(jwtRealmsServiceSettings.build());
 
         // Create issuer
         final JwtIssuer jwtIssuer = new JwtIssuer(
             "my-issuer", // iss claim value
             List.of("es01", "es02", "es03"), // aud claim value
-            principalClaimName, // principal claim name
             Collections.singletonMap("user2", new User("user2", "role2")), // users
             false // createHttpsServer
         );
@@ -282,7 +266,7 @@ public class JwtRealmGenerateTests extends JwtRealmTestCase {
                 String.join(",", jwtIssuer.audiencesClaimValue)
             )
             .put(RealmSettings.getFullSettingKey(realmName, JwtRealmSettings.ALLOWED_SIGNATURE_ALGORITHMS), "HS256,HS384")
-            .put(RealmSettings.getFullSettingKey(realmName, JwtRealmSettings.CLAIMS_PRINCIPAL.getClaim()), "email")
+            .put(RealmSettings.getFullSettingKey(realmName, JwtRealmSettings.CLAIMS_PRINCIPAL.getClaim()), principalClaimName)
             .put(RealmSettings.getFullSettingKey(realmName, JwtRealmSettings.CLAIMS_PRINCIPAL.getPattern()), "^(.*)@[^.]*[.]example[.]com$")
             .put(RealmSettings.getFullSettingKey(realmName, JwtRealmSettings.CLAIMS_MAIL.getClaim()), "email")
             .put(
@@ -313,7 +297,7 @@ public class JwtRealmGenerateTests extends JwtRealmTestCase {
         final RealmConfig config = super.buildRealmConfig(JwtRealmSettings.TYPE, realmName, configBuilder.build(), 3);
         final SSLService sslService = new SSLService(TestEnvironment.newEnvironment(configBuilder.build()));
         final UserRoleMapper userRoleMapper = super.buildRoleMapper(Map.of()); // authc realm will not do role mapping
-        final JwtRealm jwtRealm = new JwtRealm(config, jwtRealmsService, sslService, userRoleMapper);
+        final JwtRealm jwtRealm = new JwtRealm(config, sslService, userRoleMapper);
         jwtRealm.initialize(List.of(authzRealm, jwtRealm), super.licenseState);
         final JwtRealmSettingsBuilder jwtRealmSettingsBuilder = new JwtRealmSettingsBuilder(realmName, configBuilder);
         final JwtIssuerAndRealm jwtIssuerAndRealm = new JwtIssuerAndRealm(jwtIssuer, jwtRealm, jwtRealmSettingsBuilder);
@@ -361,16 +345,11 @@ public class JwtRealmGenerateTests extends JwtRealmTestCase {
         );
 
         final String principalClaimName = "sub";
-        final Settings.Builder jwtRealmsServiceSettings = Settings.builder()
-            .put(this.globalSettings)
-            .put(JwtRealmsServiceSettings.PRINCIPAL_CLAIMS_SETTING.getKey(), String.join(",", principalClaimName));
-        final JwtRealmsService jwtRealmsService = new JwtRealmsService(jwtRealmsServiceSettings.build());
 
         // Create issuer
         final JwtIssuer jwtIssuer = new JwtIssuer(
             "jwt3-issuer", // iss claim value
             List.of("jwt3-audience"), // aud claim value
-            principalClaimName, // principal claim name
             Collections.singletonMap("user3", new User("user3", "role3")), // users
             false // createHttpsServer
         );
@@ -409,7 +388,7 @@ public class JwtRealmGenerateTests extends JwtRealmTestCase {
         final RealmConfig config = super.buildRealmConfig(JwtRealmSettings.TYPE, realmName, configBuilder.build(), 4);
         final SSLService sslService = new SSLService(TestEnvironment.newEnvironment(configBuilder.build()));
         final UserRoleMapper userRoleMapper = super.buildRoleMapper(jwtIssuer.principals);
-        final JwtRealm jwtRealm = new JwtRealm(config, jwtRealmsService, sslService, userRoleMapper);
+        final JwtRealm jwtRealm = new JwtRealm(config, sslService, userRoleMapper);
         jwtRealm.initialize(List.of(jwtRealm), super.licenseState);
         final JwtRealmSettingsBuilder jwtRealmSettingsBuilder = new JwtRealmSettingsBuilder(realmName, configBuilder);
         final JwtIssuerAndRealm jwtIssuerAndRealm = new JwtIssuerAndRealm(jwtIssuer, jwtRealm, jwtRealmSettingsBuilder);

+ 5 - 44
x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmTestCase.java

@@ -31,7 +31,6 @@ import org.elasticsearch.xpack.core.security.authc.RealmConfig;
 import org.elasticsearch.xpack.core.security.authc.RealmSettings;
 import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings;
 import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings.ClientAuthenticationType;
-import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmsServiceSettings;
 import org.elasticsearch.xpack.core.security.authc.support.DelegatedAuthorizationSettings;
 import org.elasticsearch.xpack.core.security.authc.support.UserRoleMapper;
 import org.elasticsearch.xpack.core.security.user.User;
@@ -73,8 +72,6 @@ import static org.mockito.Mockito.when;
 public abstract class JwtRealmTestCase extends JwtTestCase {
     private static final Logger LOGGER = LogManager.getLogger(JwtRealmTestCase.class);
 
-    record JwtRealmsServiceSettingsBuilder(Settings.Builder settingsBuilder) {}
-
     record JwtRealmSettingsBuilder(String name, Settings.Builder settingsBuilder) {}
 
     record JwtIssuerAndRealm(JwtIssuer issuer, JwtRealm realm, JwtRealmSettingsBuilder realmSettingsBuilder) {}
@@ -117,12 +114,7 @@ public abstract class JwtRealmTestCase extends JwtTestCase {
         assertThat(plainActionFuture.get().isAuthenticated(), is(false));
     }
 
-    protected JwtRealmsService generateJwtRealmsService(final JwtRealmsServiceSettingsBuilder jwtRealmRealmsSettingsBuilder) {
-        return new JwtRealmsService(jwtRealmRealmsSettingsBuilder.settingsBuilder.build());
-    }
-
     protected List<JwtIssuerAndRealm> generateJwtIssuerRealmPairs(
-        final JwtRealmsServiceSettingsBuilder jwtRealmsServiceSettingsBuilder,
         final int realmsCount,
         final int authzCount,
         final int algsCount,
@@ -133,20 +125,11 @@ public abstract class JwtRealmTestCase extends JwtTestCase {
         final boolean createHttpsServer
     ) throws Exception {
         // Create JWT authc realms and mocked authz realms. Initialize each JWT realm, and test ensureInitialized() before and after.
-        final JwtRealmsService jwtRealmsService = this.generateJwtRealmsService(jwtRealmsServiceSettingsBuilder);
         final List<Realm> allRealms = new ArrayList<>(); // authc and authz realms
         this.jwtIssuerAndRealms = new ArrayList<>(realmsCount);
         for (int i = 0; i < realmsCount; i++) {
 
-            final JwtIssuer jwtIssuer = this.createJwtIssuer(
-                i,
-                randomFrom(jwtRealmsService.getPrincipalClaimNames()),
-                algsCount,
-                audiencesCount,
-                usersCount,
-                rolesCount,
-                createHttpsServer
-            );
+            final JwtIssuer jwtIssuer = this.createJwtIssuer(i, algsCount, audiencesCount, usersCount, rolesCount, createHttpsServer);
             // If HTTPS server was created in JWT issuer, any exception after that point requires closing it to avoid a thread pool leak
             try {
                 final JwtRealmSettingsBuilder realmSettingsBuilder = this.createJwtRealmSettingsBuilder(
@@ -154,7 +137,7 @@ public abstract class JwtRealmTestCase extends JwtTestCase {
                     authzCount,
                     jwtCacheSize
                 );
-                final JwtRealm jwtRealm = this.createJwtRealm(allRealms, jwtRealmsService, jwtIssuer, realmSettingsBuilder);
+                final JwtRealm jwtRealm = this.createJwtRealm(allRealms, jwtIssuer, realmSettingsBuilder);
 
                 // verify exception before initialize()
                 final Exception exception = expectThrows(IllegalStateException.class, jwtRealm::ensureInitialized);
@@ -174,7 +157,6 @@ public abstract class JwtRealmTestCase extends JwtTestCase {
 
     protected JwtIssuer createJwtIssuer(
         final int i,
-        final String principalClaimName,
         final int algsCount,
         final int audiencesCount,
         final int userCount,
@@ -185,7 +167,7 @@ public abstract class JwtRealmTestCase extends JwtTestCase {
         final List<String> audiences = IntStream.range(0, audiencesCount).mapToObj(j -> issuer + "_aud" + (j + 1)).toList();
         final Map<String, User> users = JwtTestCase.generateTestUsersWithRoles(userCount, roleCount);
         // Allow algorithm repeats, to cover testing of multiple JWKs for same algorithm
-        final JwtIssuer jwtIssuer = new JwtIssuer(issuer, audiences, principalClaimName, users, createHttpsServer);
+        final JwtIssuer jwtIssuer = new JwtIssuer(issuer, audiences, users, createHttpsServer);
         final List<String> algorithms = randomOfMinMaxNonUnique(algsCount, algsCount, JwtRealmSettings.SUPPORTED_SIGNATURE_ALGORITHMS);
         final boolean areHmacJwksOidcSafe = randomBoolean();
         final List<JwtIssuer.AlgJwkPair> algAndJwks = JwtRealmTestCase.randomJwks(algorithms, areHmacJwksOidcSafe);
@@ -204,21 +186,6 @@ public abstract class JwtRealmTestCase extends JwtTestCase {
         // TODO If x-pack Security plug-in add supports for reloadable settings, update HMAC JWKSet and HMAC OIDC JWK in ES Keystore
     }
 
-    protected JwtRealmsServiceSettingsBuilder createJwtRealmsSettingsBuilder() throws Exception {
-        final List<String> principalClaimNames = randomBoolean()
-            ? List.of("principalClaim_" + randomAlphaOfLength(6))
-            : randomSubsetOf(randomIntBetween(1, 6), JwtRealmsServiceSettings.DEFAULT_PRINCIPAL_CLAIMS);
-
-        final Settings.Builder jwtRealmsServiceSettings = Settings.builder()
-            .put(this.globalSettings)
-            .put(JwtRealmsServiceSettings.PRINCIPAL_CLAIMS_SETTING.getKey(), String.join(",", principalClaimNames));
-
-        final MockSecureSettings secureSettings = new MockSecureSettings(); // none for now, placeholder for future
-        jwtRealmsServiceSettings.setSecureSettings(secureSettings);
-
-        return new JwtRealmsServiceSettingsBuilder(jwtRealmsServiceSettings);
-    }
-
     protected JwtRealmSettingsBuilder createJwtRealmSettingsBuilder(final JwtIssuer jwtIssuer, final int authzCount, final int jwtCacheSize)
         throws Exception {
         final String authcRealmName = "realm_" + jwtIssuer.issuerClaimValue;
@@ -367,7 +334,6 @@ public abstract class JwtRealmTestCase extends JwtTestCase {
 
     protected JwtRealm createJwtRealm(
         final List<Realm> allRealms, // JWT realms and authz realms
-        final JwtRealmsService jwtRealmsService,
         final JwtIssuer jwtIssuer,
         final JwtRealmSettingsBuilder realmSettingsBuilder
     ) {
@@ -381,7 +347,7 @@ public abstract class JwtRealmTestCase extends JwtTestCase {
         final UserRoleMapper userRoleMapper = super.buildRoleMapper(authzRealmNames.isEmpty() ? jwtIssuer.principals : Map.of());
 
         // If authz names is not set, register the users here in the JWT authc realm.
-        final JwtRealm jwtRealm = new JwtRealm(authcConfig, jwtRealmsService, sslService, userRoleMapper);
+        final JwtRealm jwtRealm = new JwtRealm(authcConfig, sslService, userRoleMapper);
         allRealms.add(jwtRealm);
 
         // If authz names is set, register the users here in one of the authz realms.
@@ -446,20 +412,15 @@ public abstract class JwtRealmTestCase extends JwtTestCase {
             }
             assertThat(jwtAuthenticationToken, is(notNullValue()));
             final String tokenPrincipal = jwtAuthenticationToken.principal();
-            final SecureString tokenJwt = jwtAuthenticationToken.getEndUserSignedJwt();
             final SecureString tokenSecret = jwtAuthenticationToken.getClientAuthenticationSharedSecret();
             assertThat(tokenPrincipal, is(notNullValue()));
-            if (tokenJwt.equals(jwt) == false) {
-                assertThat(tokenJwt, is(equalTo(jwt)));
-            }
-            assertThat(tokenJwt, is(equalTo(jwt)));
             if (tokenSecret != null) {
                 if (tokenSecret.equals(sharedSecret) == false) {
                     assertThat(tokenSecret, is(equalTo(sharedSecret)));
                 }
                 assertThat(tokenSecret, is(equalTo(sharedSecret)));
             }
-            LOGGER.info("GOT TOKEN: principal=[" + tokenPrincipal + "], jwt=[" + tokenJwt + "], secret=[" + tokenSecret + "].");
+            LOGGER.info("GOT TOKEN: principal=[" + tokenPrincipal + "], jwt=[" + jwt + "], secret=[" + tokenSecret + "].");
 
             // Loop through all authc/authz realms. Confirm user is returned with expected principal and roles.
             User authenticatedUser = null;