Преглед на файлове

Improve robustness of JwkSet reloading (#92081)

Currently the reload future object is reset after the listener gets
invoked. Since the reset is done is a separate (http) thread from the
waiting listener, it is possible that the reset can be delayed while the
listener thread is ready to proceed. If the listener thread tries to
reload the JWKs again, it will see the old future object and incorrectly
skip the reload operation. This is mostly a test issue because the
listener thread never tries to reload JWKs again in production. 

There is however another actual production bug in that the new JwkSet is
*not* returned as part of the listener, only the `isUpdated` flag is
returned. When the listener reads the JwkSet from the JwkSetLoader, it
is possible that the JwkSet is changed again thus the `isUpdate` flag
and actual `JwkSet` value can become inconsistent. 

This PR fixes both of the issues by: (1) resetting the future object
before invoking the listener and (2) returning both `isUpdated` flag and
the new JwkSet value to the listener.

Resolves: #90467 Resolves: #89509
Yang Wang преди 2 години
родител
ревизия
0b7d34e75c

+ 5 - 0
docs/changelog/92081.yaml

@@ -0,0 +1,5 @@
+pr: 92081
+summary: Improve robustness of `JwkSet` reloading
+area: Authentication
+type: bug
+issues: []

+ 17 - 14
x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwkSetLoader.java

@@ -19,6 +19,7 @@ import org.elasticsearch.common.hash.MessageDigests;
 import org.elasticsearch.common.util.concurrent.ListenableFuture;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.Releasable;
+import org.elasticsearch.core.Tuple;
 import org.elasticsearch.xpack.core.security.authc.RealmConfig;
 import org.elasticsearch.xpack.core.security.authc.RealmSettings;
 import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings;
@@ -42,7 +43,7 @@ public class JwkSetLoader implements Releasable {
 
     private static final Logger logger = LogManager.getLogger(JwkSetLoader.class);
 
-    private final AtomicReference<ListenableFuture<Boolean>> reloadFutureRef = new AtomicReference<>();
+    private final AtomicReference<ListenableFuture<Tuple<Boolean, JwksAlgs>>> reloadFutureRef = new AtomicReference<>();
     private final RealmConfig realmConfig;
     private final List<String> allowedJwksAlgsPkc;
     private final String jwkSetPath;
@@ -67,10 +68,10 @@ public class JwkSetLoader implements Releasable {
 
         // Any exception during loading requires closing JwkSetLoader's HTTP client to avoid a thread pool leak
         try {
-            final PlainActionFuture<Boolean> future = new PlainActionFuture<>();
+            final PlainActionFuture<Tuple<Boolean, JwksAlgs>> future = new PlainActionFuture<>();
             reload(future);
             // ASSUME: Blocking read operations are OK during startup
-            final Boolean isUpdated = future.actionGet();
+            final Boolean isUpdated = future.actionGet().v1();
             assert isUpdated : "initial reload should have updated the JWK set";
         } catch (Throwable t) {
             close();
@@ -83,8 +84,8 @@ public class JwkSetLoader implements Releasable {
      * they are different. The listener is called with false if the reloaded content is the same
      * as the existing one or true if they are different.
      */
-    void reload(final ActionListener<Boolean> listener) {
-        final ListenableFuture<Boolean> future = this.getFuture();
+    void reload(final ActionListener<Tuple<Boolean, JwksAlgs>> listener) {
+        final ListenableFuture<Tuple<Boolean, JwksAlgs>> future = this.getFuture();
         future.addListener(listener);
     }
 
@@ -92,17 +93,17 @@ public class JwkSetLoader implements Releasable {
         return contentAndJwksAlgs;
     }
 
-    private ListenableFuture<Boolean> getFuture() {
+    private ListenableFuture<Tuple<Boolean, JwksAlgs>> getFuture() {
         for (;;) {
-            final ListenableFuture<Boolean> existingFuture = this.reloadFutureRef.get();
+            final ListenableFuture<Tuple<Boolean, JwksAlgs>> existingFuture = this.reloadFutureRef.get();
             if (existingFuture != null) {
                 return existingFuture;
             }
 
-            final ListenableFuture<Boolean> newFuture = new ListenableFuture<>();
+            final ListenableFuture<Tuple<Boolean, JwksAlgs>> newFuture = new ListenableFuture<>();
             if (this.reloadFutureRef.compareAndSet(null, newFuture)) {
-                loadInternal(ActionListener.runAfter(newFuture, () -> {
-                    final ListenableFuture<Boolean> oldValue = this.reloadFutureRef.getAndSet(null);
+                loadInternal(ActionListener.runBefore(newFuture, () -> {
+                    final ListenableFuture<Tuple<Boolean, JwksAlgs>> oldValue = this.reloadFutureRef.getAndSet(null);
                     assert oldValue == newFuture : "future reference changed unexpectedly";
                 }));
                 return newFuture;
@@ -111,7 +112,7 @@ public class JwkSetLoader implements Releasable {
         }
     }
 
-    private void loadInternal(final ActionListener<Boolean> listener) {
+    private void loadInternal(final ActionListener<Tuple<Boolean, JwksAlgs>> listener) {
         // PKC JWKSet get contents from local file or remote HTTPS URL
         if (httpClient == null) {
             logger.trace("Loading PKC JWKs from path [{}]", jwkSetPath);
@@ -135,16 +136,18 @@ public class JwkSetLoader implements Releasable {
         }
     }
 
-    private Boolean handleReloadedContentAndJwksAlgs(byte[] bytes) {
+    private Tuple<Boolean, JwksAlgs> handleReloadedContentAndJwksAlgs(byte[] bytes) {
         final ContentAndJwksAlgs newContentAndJwksAlgs = parseContent(bytes);
+        final boolean isUpdated;
         if (contentAndJwksAlgs != null && Arrays.equals(contentAndJwksAlgs.sha256, newContentAndJwksAlgs.sha256)) {
             logger.debug("No change in reloaded JWK set");
-            return false;
+            isUpdated = false;
         } else {
             logger.debug("Reloaded JWK set is different from the existing set");
             contentAndJwksAlgs = newContentAndJwksAlgs;
-            return true;
+            isUpdated = true;
         }
+        return new Tuple<>(isUpdated, contentAndJwksAlgs.jwksAlgs);
     }
 
     private ContentAndJwksAlgs parseContent(final byte[] jwkSetContentBytesPkc) {

+ 11 - 8
x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtSignatureValidator.java

@@ -254,23 +254,25 @@ public interface JwtSignatureValidator extends Releasable {
 
         public void validate(String tokenPrincipal, SignedJWT signedJWT, ActionListener<Void> listener) {
             // TODO: assert algorithm?
+            final JwkSetLoader.ContentAndJwksAlgs contentAndJwksAlgs = jwkSetLoader.getContentAndJwksAlgs();
+            final JwkSetLoader.JwksAlgs jwksAlgs = contentAndJwksAlgs.jwksAlgs();
             try {
-                JwtValidateUtil.validateSignature(signedJWT, jwkSetLoader.getContentAndJwksAlgs().jwksAlgs().jwks());
+                JwtValidateUtil.validateSignature(signedJWT, jwksAlgs.jwks());
                 listener.onResponse(null);
             } catch (Exception primaryException) {
                 logger.debug(
                     () -> org.elasticsearch.core.Strings.format(
                         "Signature verification failed for JWT [%s] reloading JWKSet (was: #[%s] JWKs, #[%s] algs, sha256=[%s])",
                         tokenPrincipal,
-                        jwkSetLoader.getContentAndJwksAlgs().jwksAlgs().jwks().size(),
-                        jwkSetLoader.getContentAndJwksAlgs().jwksAlgs().algs().size(),
-                        MessageDigests.toHexString(jwkSetLoader.getContentAndJwksAlgs().sha256())
+                        jwksAlgs.jwks().size(),
+                        jwksAlgs.algs().size(),
+                        MessageDigests.toHexString(contentAndJwksAlgs.sha256())
                     ),
                     primaryException
                 );
 
-                jwkSetLoader.reload(ActionListener.wrap(isUpdated -> {
-                    if (false == isUpdated) {
+                jwkSetLoader.reload(ActionListener.wrap(reloadResult -> {
+                    if (false == reloadResult.v1()) {
                         // No change in JWKSet
                         logger.debug("Reloaded same PKC JWKs, can't retry verify JWT token [{}]", tokenPrincipal);
                         listener.onFailure(primaryException);
@@ -281,13 +283,14 @@ public interface JwtSignatureValidator extends Releasable {
                     // Enhancement idea: When some JWKs are retained (ex: rotation), only invalidate for removed JWKs.
                     reloadNotifier.reloaded();
 
-                    if (jwkSetLoader.getContentAndJwksAlgs().jwksAlgs().isEmpty()) {
+                    final JwkSetLoader.JwksAlgs reloadedJwksAlgs = reloadResult.v2();
+                    if (reloadedJwksAlgs.isEmpty()) {
                         logger.debug("Reloaded empty PKC JWKs, signature verification will fail for JWT [{}]", tokenPrincipal);
                         // fall through and let try/catch below handle empty JWKs failure log and response
                     }
 
                     try {
-                        JwtValidateUtil.validateSignature(signedJWT, jwkSetLoader.getContentAndJwksAlgs().jwksAlgs().jwks());
+                        JwtValidateUtil.validateSignature(signedJWT, reloadedJwksAlgs.jwks());
                         listener.onResponse(null);
                     } catch (Exception secondaryException) {
                         logger.debug(