Browse Source

Move SAML credential comparison to metadata resolver (#91172)

Justin Cranford 2 years ago
parent
commit
45a42f1884

+ 41 - 29
x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/saml/SamlRealm.java

@@ -74,6 +74,8 @@ import org.opensaml.security.criteria.UsageCriterion;
 import org.opensaml.security.x509.X509Credential;
 import org.opensaml.security.x509.impl.X509KeyManagerX509CredentialAdapter;
 import org.opensaml.xmlsec.keyinfo.impl.BasicProviderKeyInfoCredentialResolver;
+import org.opensaml.xmlsec.keyinfo.impl.KeyInfoProvider;
+import org.opensaml.xmlsec.keyinfo.impl.KeyInfoResolutionContext;
 import org.opensaml.xmlsec.keyinfo.impl.provider.InlineX509DataProvider;
 
 import java.io.IOException;
@@ -95,7 +97,6 @@ import java.util.Map;
 import java.util.Objects;
 import java.util.Set;
 import java.util.concurrent.atomic.AtomicReference;
-import java.util.function.Consumer;
 import java.util.function.Function;
 import java.util.function.Supplier;
 import java.util.regex.Matcher;
@@ -305,8 +306,18 @@ public final class SamlRealm extends Realm implements Releasable {
         final PredicateRoleDescriptorResolver roleDescriptorResolver = new PredicateRoleDescriptorResolver(metadataResolver);
         resolver.setRoleDescriptorResolver(roleDescriptorResolver);
 
-        final InlineX509DataProvider keyInfoProvider = new InlineX509DataProvider();
-        resolver.setKeyInfoCredentialResolver(new BasicProviderKeyInfoCredentialResolver(Collections.singletonList(keyInfoProvider)));
+        final List<KeyInfoProvider> keyInfoProviders = Collections.singletonList(new InlineX509DataProvider());
+        final BasicProviderKeyInfoCredentialResolver credentialsResolver = new BasicProviderKeyInfoCredentialResolver(keyInfoProviders) {
+            final AtomicReference<Set<PublicKey>> previousCredentialsRef = new AtomicReference<>();
+
+            @Override
+            protected void postProcess(KeyInfoResolutionContext kiContext, CriteriaSet criteriaSet, List<Credential> credentials)
+                throws ResolverException {
+                SamlRealm.logDiff(credentials, this.previousCredentialsRef);
+                super.postProcess(kiContext, criteriaSet, credentials);
+            }
+        };
+        resolver.setKeyInfoCredentialResolver(credentialsResolver);
 
         try {
             roleDescriptorResolver.initialize();
@@ -315,7 +326,6 @@ public final class SamlRealm extends Realm implements Releasable {
             throw new IllegalStateException("Cannot initialise SAML IDP resolvers for realm " + config.name(), e);
         }
 
-        final Consumer<List<Credential>> diffLogger = SamlRealm.diffLogger();
         final String entityID = idpDescriptor.get().getEntityID();
         return new IdpConfiguration(entityID, () -> {
             try {
@@ -326,38 +336,40 @@ public final class SamlRealm extends Realm implements Releasable {
                         new UsageCriterion(UsageType.SIGNING)
                     )
                 );
-                final List<Credential> list = CollectionUtils.iterableAsArrayList(credentials);
-                diffLogger.accept(list);
-                return list;
+                return CollectionUtils.iterableAsArrayList(credentials);
             } catch (ResolverException e) {
                 throw new IllegalStateException("Cannot resolve SAML IDP credentials resolver for realm " + config.name(), e);
             }
         });
     }
 
-    private static Consumer<List<Credential>> diffLogger() {
-        final AtomicReference<Set<PublicKey>> previousCredentialsRef = new AtomicReference<>(null);
-        return new Consumer<>() {
-            private static final Logger LOGGER = LogManager.getLogger(IdpConfiguration.class);
-
-            @Override
-            public void accept(final List<Credential> newCredentials) {
-                final Set<PublicKey> newPublicKeys = newCredentials.stream().map(cert -> cert.getPublicKey()).collect(Collectors.toSet());
-                final Set<PublicKey> previousPublicKeys = previousCredentialsRef.get();
-                if (previousPublicKeys == null) {
-                    LOGGER.trace("Signing credentials initialized, added: [{}]", newCredentials.size());
-                } else {
-                    final Set<PublicKey> added = Sets.difference(newPublicKeys, previousPublicKeys);
-                    final Set<PublicKey> removed = Sets.difference(previousPublicKeys, newPublicKeys);
-                    if (added.isEmpty() && removed.isEmpty()) {
-                        LOGGER.debug("Signing credentials did not change, current: [{}]", newCredentials.size());
-                    } else {
-                        LOGGER.info("Signing credentials changed, added: [{}], removed: [{}]", added.size(), removed.size());
-                    }
-                }
-                previousCredentialsRef.set(newPublicKeys);
+    private static void logDiff(final List<Credential> newCredentials, final AtomicReference<Set<PublicKey>> previousCredentialsRef) {
+        if (newCredentials == null) {
+            logger.warn("Signing credentials missing, null");
+            return;
+        } else if (newCredentials.isEmpty()) {
+            logger.warn("Signing credentials missing, empty");
+            return;
+        }
+        final Set<PublicKey> newPublicKeys = newCredentials.stream().map(Credential::getPublicKey).collect(Collectors.toSet());
+        final Set<PublicKey> previousPublicKeys = previousCredentialsRef.getAndSet(newPublicKeys);
+        if (previousPublicKeys == null) {
+            logger.trace("Signing credentials initialized, added: [{}]", newCredentials.size());
+        } else {
+            final Set<PublicKey> added = Sets.difference(newPublicKeys, previousPublicKeys);
+            final Set<PublicKey> removed = Sets.difference(previousPublicKeys, newPublicKeys);
+            if (added.isEmpty() && removed.isEmpty()) {
+                logger.debug("Signing credentials did not change, current: [{}]", newCredentials.size());
+            } else {
+                logger.info(
+                    "Signing credentials changed, new: [{}], previous: [{}], added: [{}], removed: [{}]",
+                    newCredentials.size(),
+                    previousPublicKeys.size(),
+                    added.size(),
+                    removed.size()
+                );
             }
-        };
+        }
     }
 
     static SpConfiguration getSpConfiguration(RealmConfig config) throws IOException, GeneralSecurityException {