Преглед изворни кода

Recompute active realms when license changes (#76592)

This commit changes the implementation of the Realms class to listen
for license changes, and recompute the set of actively licensed realms
only when the license changes rather than each time the "asList" method
is called.

This is primarily a performance optimisation, but it also allows us to
turn off the "in use" license tracking for realms when they are
disabled by a change in license.

Relates: #76476
Tim Vernum пре 4 година
родитељ
комит
c1a32447a7

+ 5 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/license/MockLicenseState.java

@@ -25,4 +25,9 @@ public class MockLicenseState extends XPackLicenseState {
     public void enableUsageTracking(LicensedFeature feature, String contextName) {
         super.enableUsageTracking(feature, contextName);
     }
+
+    @Override
+    public void disableUsageTracking(LicensedFeature feature, String contextName) {
+        super.disableUsageTracking(feature, contextName);
+    }
 }

+ 1 - 1
x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java

@@ -663,7 +663,7 @@ public class Security extends Plugin implements SystemIndexPlugin, IngestPlugin,
             logger.debug("Using default authentication failure handler");
             Supplier<Map<String, List<String>>> headersSupplier = () -> {
                 final Map<String, List<String>> defaultFailureResponseHeaders = new HashMap<>();
-                realms.asList().stream().forEach((realm) -> {
+                realms.getActiveRealms().stream().forEach((realm) -> {
                     Map<String, List<String>> realmFailureHeaders = realm.getAuthenticationFailureHeaders();
                     realmFailureHeaders.entrySet().stream().forEach((e) -> {
                         String key = e.getKey();

+ 1 - 1
x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/AuthenticationService.java

@@ -305,7 +305,7 @@ public class AuthenticationService {
             this.request = auditableRequest;
             this.fallbackUser = fallbackUser;
             this.fallbackToAnonymous = fallbackToAnonymous;
-            this.defaultOrderedRealmList = realms.asList();
+            this.defaultOrderedRealmList = realms.getActiveRealms();
             this.listener = listener;
         }
 

+ 63 - 62
x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/Realms.java

@@ -8,7 +8,6 @@ package org.elasticsearch.xpack.security.authc;
 
 import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
-import org.elasticsearch.Assertions;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.collect.MapBuilder;
@@ -30,7 +29,6 @@ import org.elasticsearch.xpack.security.Security;
 import org.elasticsearch.xpack.security.authc.esnative.ReservedRealm;
 
 import java.util.ArrayList;
-import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
@@ -59,12 +57,11 @@ public class Realms implements Iterable<Realm> {
     private final ThreadContext threadContext;
     private final ReservedRealm reservedRealm;
 
-    protected List<Realm> realms;
-    // a list of realms that are considered standard in that they are provided by x-pack and
-    // interact with a 3rd party source on a limited basis
-    List<Realm> standardRealmsOnly;
-    // a list of realms that are considered native, that is they only interact with x-pack and no 3rd party auth sources
-    List<Realm> nativeRealmsOnly;
+    // All realms that were configured from the node settings, some of these may not be enabled due to licensing
+    private final List<Realm> allConfiguredRealms;
+
+    // the realms in current use. This list will change dynamically as the license changes
+    private volatile List<Realm> activeRealms;
 
     public Realms(Settings settings, Environment env, Map<String, Realm.Factory> factories, XPackLicenseState licenseState,
                   ThreadContext threadContext, ReservedRealm reservedRealm) throws Exception {
@@ -74,77 +71,78 @@ public class Realms implements Iterable<Realm> {
         this.licenseState = licenseState;
         this.threadContext = threadContext;
         this.reservedRealm = reservedRealm;
+
         assert XPackSettings.SECURITY_ENABLED.get(settings) : "security must be enabled";
         assert factories.get(ReservedRealm.TYPE) == null;
+
         final List<RealmConfig> realmConfigs = buildRealmConfigs();
-        this.realms = initRealms(realmConfigs);
-        assert realms.get(0) == reservedRealm : "the first realm must be reserved realm";
-        // pre-computing a list of internal only realms allows us to have much cheaper iteration than a custom iterator
-        // and is also simpler in terms of logic. These lists are small, so the duplication should not be a real issue here
-        List<Realm> standardRealms = new ArrayList<>(List.of(reservedRealm));
-        List<Realm> basicRealms = new ArrayList<>(List.of(reservedRealm));
-        for (Realm realm : realms) {
-            // don't add the reserved realm here otherwise we end up with only this realm...
-            if (InternalRealms.isStandardRealm(realm.type())) {
-                standardRealms.add(realm);
-            }
+        this.allConfiguredRealms = initRealms(realmConfigs);
+        this.allConfiguredRealms.forEach(r -> r.initialize(allConfiguredRealms, licenseState));
+        assert allConfiguredRealms.get(0) == reservedRealm : "the first realm must be reserved realm";
 
-            if (InternalRealms.isBuiltinRealm(realm.type())) {
-                basicRealms.add(realm);
-            }
-        }
+        recomputeActiveRealms();
+        licenseState.addListener(this::recomputeActiveRealms);
+    }
 
-        if (Assertions.ENABLED) {
-            for (List<Realm> realmList : Arrays.asList(standardRealms, basicRealms)) {
-                assert realmList.get(0) == reservedRealm : "the first realm must be reserved realm";
-            }
+    protected void recomputeActiveRealms() {
+        final XPackLicenseState licenseStateSnapshot = licenseState.copyCurrentLicenseState();
+        final List<Realm> licensedRealms = calculateLicensedRealms(licenseStateSnapshot);
+        logger.info(
+            "license mode is [{}], currently licensed security realms are [{}]",
+            licenseStateSnapshot.getOperationMode().description(),
+            Strings.collectionToCommaDelimitedString(licensedRealms)
+        );
+
+        // Stop license-tracking for any previously-active realms that are no longer allowed
+        if (activeRealms != null) {
+            activeRealms.stream().filter(r -> licensedRealms.contains(r) == false).forEach(realm -> {
+                if (InternalRealms.isStandardRealm(realm.type())) {
+                    Security.STANDARD_REALMS_FEATURE.stopTracking(licenseStateSnapshot, realm.name());
+                } else {
+                    Security.ALL_REALMS_FEATURE.stopTracking(licenseStateSnapshot, realm.name());
+                }
+            });
         }
 
-        this.standardRealmsOnly = Collections.unmodifiableList(standardRealms);
-        this.nativeRealmsOnly = Collections.unmodifiableList(basicRealms);
-        realms.forEach(r -> r.initialize(this, licenseState));
+        activeRealms = licensedRealms;
     }
 
     @Override
     public Iterator<Realm> iterator() {
-        return asList().iterator();
+        return getActiveRealms().iterator();
     }
 
     /**
      * Returns a list of realms that are configured, but are not permitted under the current license.
      */
     public List<Realm> getUnlicensedRealms() {
-        final XPackLicenseState licenseStateSnapshot = licenseState.copyCurrentLicenseState();
-
-        // If all realms are allowed, then nothing is unlicensed
-        if (Security.ALL_REALMS_FEATURE.checkWithoutTracking(licenseStateSnapshot)) {
-            return Collections.emptyList();
-        }
-
-        final List<Realm> allowedRealms = this.asList();
-        // Shortcut for the typical case, all the configured realms are allowed
-        if (allowedRealms.equals(this.realms)) {
+        final List<Realm> activeSnapshot = activeRealms;
+        if (activeSnapshot.equals(allConfiguredRealms)) {
             return Collections.emptyList();
         }
 
         // Otherwise, we return anything in "all realms" that is not in the allowed realm list
-        return realms.stream().filter(r -> allowedRealms.contains(r) == false).collect(Collectors.toUnmodifiableList());
+        return allConfiguredRealms.stream().filter(r -> activeSnapshot.contains(r) == false).collect(Collectors.toUnmodifiableList());
     }
 
     public Stream<Realm> stream() {
         return StreamSupport.stream(this.spliterator(), false);
     }
 
-    public List<Realm> asList() {
-        // TODO : Recalculate this when the license changes rather than on every call
-        return realms.stream().filter(r -> checkLicense(r, licenseState)).collect(Collectors.toUnmodifiableList());
+    public List<Realm> getActiveRealms() {
+        assert activeRealms != null : "Active realms not configured";
+        return activeRealms;
+    }
+
+    // Protected for testing
+    protected List<Realm> calculateLicensedRealms(XPackLicenseState licenseStateSnapshot) {
+        return allConfiguredRealms.stream()
+            .filter(r -> checkLicense(r, licenseStateSnapshot))
+            .collect(Collectors.toUnmodifiableList());
     }
 
     private static boolean checkLicense(Realm realm, XPackLicenseState licenseState) {
-        if (ReservedRealm.TYPE.equals(realm.type())) {
-            return true;
-        }
-        if (InternalRealms.isBuiltinRealm(realm.type())) {
+        if (isBasicLicensedRealm(realm.type())) {
             return true;
         }
         if (InternalRealms.isStandardRealm(realm.type())) {
@@ -153,8 +151,22 @@ public class Realms implements Iterable<Realm> {
         return Security.ALL_REALMS_FEATURE.checkAndStartTracking(licenseState, realm.name());
     }
 
+    public static boolean isRealmTypeAvailable(XPackLicenseState licenseState, String type) {
+        if (Security.ALL_REALMS_FEATURE.checkWithoutTracking(licenseState)) {
+            return true;
+        } else if (Security.STANDARD_REALMS_FEATURE.checkWithoutTracking(licenseState)) {
+            return InternalRealms.isStandardRealm(type) || ReservedRealm.TYPE.equals(type);
+        } else {
+            return isBasicLicensedRealm(type);
+        }
+    }
+
+    private static boolean isBasicLicensedRealm(String type) {
+        return ReservedRealm.TYPE.equals(type) || InternalRealms.isBuiltinRealm(type);
+    }
+
     public Realm realm(String name) {
-        for (Realm realm : realms) {
+        for (Realm realm : activeRealms) {
             if (name.equals(realm.name())) {
                 return realm;
             }
@@ -213,7 +225,7 @@ public class Realms implements Iterable<Realm> {
         final XPackLicenseState licenseStateSnapshot = licenseState.copyCurrentLicenseState();
         Map<String, Object> realmMap = new HashMap<>();
         final AtomicBoolean failed = new AtomicBoolean(false);
-        final List<Realm> realmList = asList().stream()
+        final List<Realm> realmList = getActiveRealms().stream()
             .filter(r -> ReservedRealm.TYPE.equals(r.type()) == false)
             .collect(Collectors.toList());
         final Set<String> realmTypes = realmList.stream().map(Realm::type).collect(Collectors.toSet());
@@ -382,15 +394,4 @@ public class Realms implements Iterable<Realm> {
         }
         return converted;
     }
-
-    public static boolean isRealmTypeAvailable(XPackLicenseState licenseState, String type) {
-        if (Security.ALL_REALMS_FEATURE.checkWithoutTracking(licenseState)) {
-            return true;
-        } else if (Security.STANDARD_REALMS_FEATURE.checkWithoutTracking(licenseState)) {
-            return InternalRealms.isStandardRealm(type) || ReservedRealm.TYPE.equals(type);
-        } else {
-            return InternalRealms.isBuiltinRealm(type);
-        }
-    }
-
 }

+ 52 - 22
x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/AuthenticationServiceTests.java

@@ -51,6 +51,7 @@ import org.elasticsearch.env.TestEnvironment;
 import org.elasticsearch.index.get.GetResult;
 import org.elasticsearch.index.seqno.SequenceNumbers;
 import org.elasticsearch.index.shard.ShardId;
+import org.elasticsearch.license.License;
 import org.elasticsearch.license.MockLicenseState;
 import org.elasticsearch.license.XPackLicenseState;
 import org.elasticsearch.license.XPackLicenseState.Feature;
@@ -220,16 +221,23 @@ public class AuthenticationServiceTests extends ESTestCase {
             .build();
         MockLicenseState licenseState = mock(MockLicenseState.class);
         when(licenseState.isAllowed(Security.ALL_REALMS_FEATURE)).thenReturn(true);
+        when(licenseState.isAllowed(Security.STANDARD_REALMS_FEATURE)).thenReturn(true);
         when(licenseState.checkFeature(Feature.SECURITY_TOKEN_SERVICE)).thenReturn(true);
         when(licenseState.copyCurrentLicenseState()).thenReturn(licenseState);
         when(licenseState.checkFeature(Feature.SECURITY_AUDITING)).thenReturn(true);
+        when(licenseState.getOperationMode()).thenReturn(randomFrom(License.OperationMode.ENTERPRISE, License.OperationMode.PLATINUM));
+
         ReservedRealm reservedRealm = mock(ReservedRealm.class);
         when(reservedRealm.type()).thenReturn("reserved");
         when(reservedRealm.name()).thenReturn("reserved_realm");
         realms = spy(new TestRealms(Settings.EMPTY, TestEnvironment.newEnvironment(settings),
             Map.of(FileRealmSettings.TYPE, config -> mock(FileRealm.class), NativeRealmSettings.TYPE, config -> mock(NativeRealm.class)),
             licenseState, threadContext, reservedRealm, Arrays.asList(firstRealm, secondRealm),
-            Collections.singletonList(firstRealm)));
+            Arrays.asList(firstRealm)));
+
+        // Needed because this is calculated in the constructor, which means the override doesn't get called correctly
+        realms.recomputeActiveRealms();
+        assertThat(realms.getActiveRealms(), contains(firstRealm, secondRealm));
 
         auditTrail = mock(AuditTrail.class);
         auditTrailService = new AuditTrailService(Collections.singletonList(auditTrail), licenseState);
@@ -325,7 +333,7 @@ public class AuthenticationServiceTests extends ESTestCase {
             ));
 
             Mockito.doReturn(List.of(secondRealm)).when(realms).getUnlicensedRealms();
-            Mockito.doReturn(List.of(firstRealm)).when(realms).asList();
+            Mockito.doReturn(List.of(firstRealm)).when(realms).getActiveRealms();
             boolean requestIdAlreadyPresent = randomBoolean();
             SetOnce<String> reqId = new SetOnce<>();
             if (requestIdAlreadyPresent) {
@@ -388,7 +396,10 @@ public class AuthenticationServiceTests extends ESTestCase {
         }, this::logAndFail));
         assertTrue(completed.get());
         verify(auditTrail).authenticationFailed(reqId.get(), firstRealm.name(), token, "_action", transportRequest);
-        verify(realms).asList();
+        verify(realms, atLeastOnce()).recomputeActiveRealms();
+        verify(realms, atLeastOnce()).calculateLicensedRealms(any(XPackLicenseState.class));
+        verify(realms, atLeastOnce()).getActiveRealms();
+        // ^^ We don't care how many times these methods are called, we just check it here so that we can verify no more interactions below.
         verifyNoMoreInteractions(realms);
     }
 
@@ -447,9 +458,8 @@ public class AuthenticationServiceTests extends ESTestCase {
 
         verify(auditTrail).authenticationFailed(reqId.get(), firstRealm.name(), token, "_action", transportRequest);
         verify(firstRealm, times(2)).name(); // used above one time
-        verify(firstRealm, atLeastOnce()).type();
-        verify(secondRealm, Mockito.atLeast(3)).name(); // also used in license tracking
-        verify(secondRealm, Mockito.atLeast(3)).type(); // used to create realm ref, and license tracking
+        verify(secondRealm, Mockito.atLeast(2)).name(); // also used in license tracking
+        verify(secondRealm, Mockito.atLeast(2)).type(); // used to create realm ref, and license tracking
         verify(firstRealm, times(2)).token(threadContext);
         verify(secondRealm, times(2)).token(threadContext);
         verify(firstRealm).supports(token);
@@ -573,9 +583,8 @@ public class AuthenticationServiceTests extends ESTestCase {
         }, this::logAndFail));
         verify(auditTrail, times(2)).authenticationFailed(reqId.get(), firstRealm.name(), token, "_action", transportRequest);
         verify(firstRealm, times(3)).name(); // used above one time
-        verify(firstRealm, atLeastOnce()).type();
-        verify(secondRealm, Mockito.atLeast(3)).name();
-        verify(secondRealm, Mockito.atLeast(3)).type(); // used to create realm ref
+        verify(secondRealm, Mockito.atLeast(2)).name();
+        verify(secondRealm, Mockito.atLeast(2)).type(); // used to create realm ref
         verify(firstRealm, times(2)).token(threadContext);
         verify(secondRealm, times(2)).token(threadContext);
         verify(firstRealm, times(2)).supports(token);
@@ -638,10 +647,7 @@ public class AuthenticationServiceTests extends ESTestCase {
         assertThat(result.v1(), is(authentication));
         assertThat(result.v1().getAuthenticationType(), is(AuthenticationType.REALM));
         verifyZeroInteractions(auditTrail);
-        verify(firstRealm, atLeastOnce()).type();
-        verify(secondRealm, atLeastOnce()).type();
-        verify(secondRealm, atLeastOnce()).name(); // This realm is license-tracked, which uses the name
-        verifyNoMoreInteractions(firstRealm, secondRealm);
+        verifyZeroInteractions(firstRealm, secondRealm);
         verifyZeroInteractions(operatorPrivilegesService);
     }
 
@@ -920,8 +926,6 @@ public class AuthenticationServiceTests extends ESTestCase {
                     verifyZeroInteractions(operatorPrivilegesService);
                 }, this::logAndFail));
             assertTrue(completed.compareAndSet(true, false));
-            verify(firstRealm, atLeastOnce()).type();
-            verify(firstRealm, atLeastOnce()).name();
             verifyNoMoreInteractions(firstRealm);
             reset(firstRealm);
         } finally {
@@ -970,8 +974,6 @@ public class AuthenticationServiceTests extends ESTestCase {
                     verifyZeroInteractions(operatorPrivilegesService);
                 }, this::logAndFail));
             assertTrue(completed.get());
-            verify(firstRealm, atLeastOnce()).type();
-            verify(firstRealm, atLeastOnce()).name();
             verifyNoMoreInteractions(firstRealm);
         } finally {
             terminate(threadPool2);
@@ -2106,12 +2108,40 @@ public class AuthenticationServiceTests extends ESTestCase {
 
     static class TestRealms extends Realms {
 
-        TestRealms(Settings settings, Environment env, Map<String, Factory> factories, XPackLicenseState licenseState,
-                   ThreadContext threadContext, ReservedRealm reservedRealm, List<Realm> realms, List<Realm> internalRealms)
-                throws Exception {
+        private final List<Realm> allRealms;
+        private final List<Realm> internalRealms;
+
+        TestRealms(
+            Settings settings,
+            Environment env,
+            Map<String, Factory> factories,
+            XPackLicenseState licenseState,
+            ThreadContext threadContext,
+            ReservedRealm reservedRealm,
+            List<Realm> realms,
+            List<Realm> internalRealms
+        ) throws Exception {
             super(settings, env, factories, licenseState, threadContext, reservedRealm);
-            this.realms = realms;
-            this.standardRealmsOnly = internalRealms;
+            this.allRealms = realms;
+            this.internalRealms = internalRealms;
+        }
+
+        @Override
+        protected List<Realm> calculateLicensedRealms(XPackLicenseState licenseState) {
+            if (allRealms == null) {
+                // This can happen because the realms are recalculated during construction
+                return super.calculateLicensedRealms(licenseState);
+            }
+            if (Security.STANDARD_REALMS_FEATURE.checkWithoutTracking(licenseState)) {
+                return allRealms;
+            } else {
+                return internalRealms;
+            }
+        }
+
+        // Make public for testing
+        public void recomputeActiveRealms() {
+            super.recomputeActiveRealms();
         }
     }
 

+ 97 - 1
x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/RealmsTests.java

@@ -11,8 +11,11 @@ import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
+import org.elasticsearch.common.util.set.Sets;
 import org.elasticsearch.env.Environment;
 import org.elasticsearch.env.TestEnvironment;
+import org.elasticsearch.license.License;
+import org.elasticsearch.license.LicenseStateListener;
 import org.elasticsearch.license.MockLicenseState;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xpack.core.security.authc.AuthenticationResult;
@@ -29,19 +32,24 @@ import org.elasticsearch.xpack.core.security.user.User;
 import org.elasticsearch.xpack.security.Security;
 import org.elasticsearch.xpack.security.authc.esnative.ReservedRealm;
 import org.junit.Before;
+import org.mockito.Mockito;
 
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
 import java.util.Map.Entry;
+import java.util.Set;
 import java.util.TreeMap;
+import java.util.function.Consumer;
 import java.util.stream.Collectors;
 import java.util.stream.IntStream;
 
+import static org.hamcrest.Matchers.arrayWithSize;
 import static org.hamcrest.Matchers.contains;
 import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.empty;
@@ -53,8 +61,12 @@ import static org.hamcrest.Matchers.is;
 import static org.hamcrest.Matchers.iterableWithSize;
 import static org.hamcrest.Matchers.notNullValue;
 import static org.hamcrest.Matchers.sameInstance;
+import static org.mockito.Mockito.atLeastOnce;
+import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoMoreInteractions;
 import static org.mockito.Mockito.when;
 
 public class RealmsTests extends ESTestCase {
@@ -63,6 +75,7 @@ public class RealmsTests extends ESTestCase {
     private ThreadContext threadContext;
     private ReservedRealm reservedRealm;
     private int randomRealmTypesCount;
+    private List<LicenseStateListener> licenseStateListeners;
 
     @Before
     public void init() throws Exception {
@@ -76,7 +89,17 @@ public class RealmsTests extends ESTestCase {
             factories.put(name, config -> new DummyRealm(name, config));
         }
         licenseState = mock(MockLicenseState.class);
+        licenseStateListeners = new ArrayList<>();
         when(licenseState.copyCurrentLicenseState()).thenReturn(licenseState);
+        when(licenseState.getOperationMode()).thenReturn(randomFrom(License.OperationMode.values()));
+        doAnswer(inv -> {
+            assertThat(inv.getArguments(), arrayWithSize(1));
+            Object arg0 = inv.getArguments()[0];
+            assertThat(arg0, instanceOf(LicenseStateListener.class));
+            this.licenseStateListeners.add((LicenseStateListener) arg0);
+            return null;
+        }).when(licenseState).addListener(Mockito.any(LicenseStateListener.class));
+
         threadContext = new ThreadContext(Settings.EMPTY);
         reservedRealm = mock(ReservedRealm.class);
         allowAllRealms();
@@ -87,16 +110,47 @@ public class RealmsTests extends ESTestCase {
     private void allowAllRealms() {
         when(licenseState.isAllowed(Security.ALL_REALMS_FEATURE)).thenReturn(true);
         when(licenseState.isAllowed(Security.STANDARD_REALMS_FEATURE)).thenReturn(true);
+        licenseStateListeners.forEach(LicenseStateListener::licenseStateChanged);
     }
 
     private void allowOnlyStandardRealms() {
         when(licenseState.isAllowed(Security.ALL_REALMS_FEATURE)).thenReturn(false);
         when(licenseState.isAllowed(Security.STANDARD_REALMS_FEATURE)).thenReturn(true);
+        licenseStateListeners.forEach(LicenseStateListener::licenseStateChanged);
     }
 
     private void allowOnlyNativeRealms() {
         when(licenseState.isAllowed(Security.ALL_REALMS_FEATURE)).thenReturn(false);
         when(licenseState.isAllowed(Security.STANDARD_REALMS_FEATURE)).thenReturn(false);
+        licenseStateListeners.forEach(LicenseStateListener::licenseStateChanged);
+    }
+
+    public void testRealmTypeAvailable() {
+        final Set<String> basicRealmTypes = Sets.newHashSet("file", "native", "reserved");
+        final Set<String> goldRealmTypes = Sets.newHashSet("ldap", "active_directory", "pki");
+
+        final Set<String> platinumRealmTypes = new HashSet<>(InternalRealms.getConfigurableRealmsTypes());
+        platinumRealmTypes.addAll(this.factories.keySet());
+        platinumRealmTypes.removeAll(basicRealmTypes);
+        platinumRealmTypes.removeAll(goldRealmTypes);
+
+        Consumer<String> checkAllowed = type -> assertThat("Type: " + type, Realms.isRealmTypeAvailable(licenseState, type), is(true));
+        Consumer<String> checkNotAllowed = type -> assertThat("Type: " + type, Realms.isRealmTypeAvailable(licenseState, type), is(false));
+
+        allowAllRealms();
+        platinumRealmTypes.forEach(checkAllowed);
+        goldRealmTypes.forEach(checkAllowed);
+        basicRealmTypes.forEach(checkAllowed);
+
+        allowOnlyStandardRealms();
+        platinumRealmTypes.forEach(checkNotAllowed);
+        goldRealmTypes.forEach(checkAllowed);
+        basicRealmTypes.forEach(checkAllowed);
+
+        allowOnlyNativeRealms();
+        platinumRealmTypes.forEach(checkNotAllowed);
+        goldRealmTypes.forEach(checkNotAllowed);
+        basicRealmTypes.forEach(checkAllowed);
     }
 
     public void testWithSettings() throws Exception {
@@ -118,6 +172,16 @@ public class RealmsTests extends ESTestCase {
         Settings settings = builder.build();
         Environment env = TestEnvironment.newEnvironment(settings);
         Realms realms = new Realms(settings, env, factories, licenseState, threadContext, reservedRealm);
+        verify(licenseState, times(1)).addListener(Mockito.any(LicenseStateListener.class));
+        verify(licenseState, times(1)).copyCurrentLicenseState();
+        verify(licenseState, times(1)).getOperationMode();
+
+        // Verify that we recorded licensed-feature use for each realm (this is trigger on license load during node startup)
+        verify(licenseState, Mockito.atLeast(randomRealmTypesCount)).isAllowed(Security.ALL_REALMS_FEATURE);
+        for (int i = 0; i < randomRealmTypesCount; i++) {
+            verify(licenseState, atLeastOnce()).enableUsageTracking(Security.ALL_REALMS_FEATURE, "realm_" + i);
+        }
+        verifyNoMoreInteractions(licenseState);
 
         Iterator<Realm> iterator = realms.iterator();
         assertThat(iterator.hasNext(), is(true));
@@ -372,9 +436,23 @@ public class RealmsTests extends ESTestCase {
         assertThat(realm.type(), is(type));
         assertThat(iter.hasNext(), is(false));
         assertThat(realms.getUnlicensedRealms(), empty());
+
+        // during init only
+        verify(licenseState, times(1)).addListener(Mockito.any(LicenseStateListener.class));
+        // each time the license state changes
+        verify(licenseState, times(1)).copyCurrentLicenseState();
+        verify(licenseState, times(1)).getOperationMode();
+
+        // Verify that we recorded licensed-feature use for each licensed realm (this is trigger on license load/change)
+        verify(licenseState, times(1)).isAllowed(Security.STANDARD_REALMS_FEATURE);
         verify(licenseState).enableUsageTracking(Security.STANDARD_REALMS_FEATURE, "foo");
+        verifyNoMoreInteractions(licenseState);
 
         allowOnlyNativeRealms();
+        // because the license state changed ...
+        verify(licenseState, times(2)).copyCurrentLicenseState();
+        verify(licenseState, times(2)).getOperationMode();
+
         iter = realms.iterator();
         assertThat(iter.hasNext(), is(true));
         realm = iter.next();
@@ -389,6 +467,12 @@ public class RealmsTests extends ESTestCase {
         assertThat(realm.type(), is(type));
         assertThat(iter.hasNext(), is(false));
 
+        // Verify that we checked (a 2nd time) the license for the non-basic realm
+        verify(licenseState, times(2)).isAllowed(Security.STANDARD_REALMS_FEATURE);
+        // Verify that we stopped tracking  use for realms which are no longer licensed
+        verify(licenseState).disableUsageTracking(Security.STANDARD_REALMS_FEATURE, "foo");
+        verifyNoMoreInteractions(licenseState);
+
         assertThat(realms.getUnlicensedRealms(), iterableWithSize(1));
         realm = realms.getUnlicensedRealms().get(0);
         assertThat(realm.type(), equalTo("ldap"));
@@ -416,7 +500,8 @@ public class RealmsTests extends ESTestCase {
         realm = iter.next();
         assertThat(realm.type(), is(selectedRealmType));
         assertThat(realms.getUnlicensedRealms(), empty());
-        verify(licenseState).enableUsageTracking(Security.ALL_REALMS_FEATURE, realmName);
+        verify(licenseState, times(1)).isAllowed(Security.ALL_REALMS_FEATURE);
+        verify(licenseState, times(1)).enableUsageTracking(Security.ALL_REALMS_FEATURE, realmName);
 
         allowOnlyStandardRealms();
         iter = realms.iterator();
@@ -430,6 +515,11 @@ public class RealmsTests extends ESTestCase {
         assertThat(realm.type(), equalTo(selectedRealmType));
         assertThat(realm.name(), equalTo(realmName));
 
+        verify(licenseState, times(2)).isAllowed(Security.ALL_REALMS_FEATURE);
+        verify(licenseState, times(1)).disableUsageTracking(Security.ALL_REALMS_FEATURE, realmName);
+        // this happened when the realm was allowed. Check it's still only 1 call
+        verify(licenseState, times(1)).enableUsageTracking(Security.ALL_REALMS_FEATURE, realmName);
+
         allowOnlyNativeRealms();
         iter = realms.iterator();
         assertThat(iter.hasNext(), is(true));
@@ -441,6 +531,12 @@ public class RealmsTests extends ESTestCase {
         realm = realms.getUnlicensedRealms().get(0);
         assertThat(realm.type(), equalTo(selectedRealmType));
         assertThat(realm.name(), equalTo(realmName));
+
+        verify(licenseState, times(3)).isAllowed(Security.ALL_REALMS_FEATURE);
+        // this doesn't get called a second time because it didn't change
+        verify(licenseState, times(1)).disableUsageTracking(Security.ALL_REALMS_FEATURE, realmName);
+        // this happened when the realm was allowed. Check it's still only 1 call
+        verify(licenseState, times(1)).enableUsageTracking(Security.ALL_REALMS_FEATURE, realmName);
     }
 
     public void testDisabledRealmsAreNotAdded() throws Exception {

+ 1 - 1
x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/SecondaryAuthenticatorTests.java

@@ -101,7 +101,7 @@ public class SecondaryAuthenticatorTests extends ESTestCase {
         final Environment env = TestEnvironment.newEnvironment(settings);
 
         realm = new DummyUsernamePasswordRealm(new RealmConfig(new RealmIdentifier("dummy", "test_realm"), settings, env, threadContext));
-        when(realms.asList()).thenReturn(List.of(realm));
+        when(realms.getActiveRealms()).thenReturn(List.of(realm));
         when(realms.getUnlicensedRealms()).thenReturn(List.of());
 
         final AuditTrailService auditTrail = new AuditTrailService(Collections.emptyList(), null);