|
@@ -11,8 +11,11 @@ import org.elasticsearch.action.support.PlainActionFuture;
|
|
|
import org.elasticsearch.common.Strings;
|
|
|
import org.elasticsearch.common.settings.Settings;
|
|
|
import org.elasticsearch.common.util.concurrent.ThreadContext;
|
|
|
+import org.elasticsearch.common.util.set.Sets;
|
|
|
import org.elasticsearch.env.Environment;
|
|
|
import org.elasticsearch.env.TestEnvironment;
|
|
|
+import org.elasticsearch.license.License;
|
|
|
+import org.elasticsearch.license.LicenseStateListener;
|
|
|
import org.elasticsearch.license.MockLicenseState;
|
|
|
import org.elasticsearch.test.ESTestCase;
|
|
|
import org.elasticsearch.xpack.core.security.authc.AuthenticationResult;
|
|
@@ -29,19 +32,24 @@ import org.elasticsearch.xpack.core.security.user.User;
|
|
|
import org.elasticsearch.xpack.security.Security;
|
|
|
import org.elasticsearch.xpack.security.authc.esnative.ReservedRealm;
|
|
|
import org.junit.Before;
|
|
|
+import org.mockito.Mockito;
|
|
|
|
|
|
import java.io.IOException;
|
|
|
import java.util.ArrayList;
|
|
|
import java.util.Collections;
|
|
|
import java.util.HashMap;
|
|
|
+import java.util.HashSet;
|
|
|
import java.util.Iterator;
|
|
|
import java.util.List;
|
|
|
import java.util.Map;
|
|
|
import java.util.Map.Entry;
|
|
|
+import java.util.Set;
|
|
|
import java.util.TreeMap;
|
|
|
+import java.util.function.Consumer;
|
|
|
import java.util.stream.Collectors;
|
|
|
import java.util.stream.IntStream;
|
|
|
|
|
|
+import static org.hamcrest.Matchers.arrayWithSize;
|
|
|
import static org.hamcrest.Matchers.contains;
|
|
|
import static org.hamcrest.Matchers.containsString;
|
|
|
import static org.hamcrest.Matchers.empty;
|
|
@@ -53,8 +61,12 @@ import static org.hamcrest.Matchers.is;
|
|
|
import static org.hamcrest.Matchers.iterableWithSize;
|
|
|
import static org.hamcrest.Matchers.notNullValue;
|
|
|
import static org.hamcrest.Matchers.sameInstance;
|
|
|
+import static org.mockito.Mockito.atLeastOnce;
|
|
|
+import static org.mockito.Mockito.doAnswer;
|
|
|
import static org.mockito.Mockito.mock;
|
|
|
+import static org.mockito.Mockito.times;
|
|
|
import static org.mockito.Mockito.verify;
|
|
|
+import static org.mockito.Mockito.verifyNoMoreInteractions;
|
|
|
import static org.mockito.Mockito.when;
|
|
|
|
|
|
public class RealmsTests extends ESTestCase {
|
|
@@ -63,6 +75,7 @@ public class RealmsTests extends ESTestCase {
|
|
|
private ThreadContext threadContext;
|
|
|
private ReservedRealm reservedRealm;
|
|
|
private int randomRealmTypesCount;
|
|
|
+ private List<LicenseStateListener> licenseStateListeners;
|
|
|
|
|
|
@Before
|
|
|
public void init() throws Exception {
|
|
@@ -76,7 +89,17 @@ public class RealmsTests extends ESTestCase {
|
|
|
factories.put(name, config -> new DummyRealm(name, config));
|
|
|
}
|
|
|
licenseState = mock(MockLicenseState.class);
|
|
|
+ licenseStateListeners = new ArrayList<>();
|
|
|
when(licenseState.copyCurrentLicenseState()).thenReturn(licenseState);
|
|
|
+ when(licenseState.getOperationMode()).thenReturn(randomFrom(License.OperationMode.values()));
|
|
|
+ doAnswer(inv -> {
|
|
|
+ assertThat(inv.getArguments(), arrayWithSize(1));
|
|
|
+ Object arg0 = inv.getArguments()[0];
|
|
|
+ assertThat(arg0, instanceOf(LicenseStateListener.class));
|
|
|
+ this.licenseStateListeners.add((LicenseStateListener) arg0);
|
|
|
+ return null;
|
|
|
+ }).when(licenseState).addListener(Mockito.any(LicenseStateListener.class));
|
|
|
+
|
|
|
threadContext = new ThreadContext(Settings.EMPTY);
|
|
|
reservedRealm = mock(ReservedRealm.class);
|
|
|
allowAllRealms();
|
|
@@ -87,16 +110,47 @@ public class RealmsTests extends ESTestCase {
|
|
|
private void allowAllRealms() {
|
|
|
when(licenseState.isAllowed(Security.ALL_REALMS_FEATURE)).thenReturn(true);
|
|
|
when(licenseState.isAllowed(Security.STANDARD_REALMS_FEATURE)).thenReturn(true);
|
|
|
+ licenseStateListeners.forEach(LicenseStateListener::licenseStateChanged);
|
|
|
}
|
|
|
|
|
|
private void allowOnlyStandardRealms() {
|
|
|
when(licenseState.isAllowed(Security.ALL_REALMS_FEATURE)).thenReturn(false);
|
|
|
when(licenseState.isAllowed(Security.STANDARD_REALMS_FEATURE)).thenReturn(true);
|
|
|
+ licenseStateListeners.forEach(LicenseStateListener::licenseStateChanged);
|
|
|
}
|
|
|
|
|
|
private void allowOnlyNativeRealms() {
|
|
|
when(licenseState.isAllowed(Security.ALL_REALMS_FEATURE)).thenReturn(false);
|
|
|
when(licenseState.isAllowed(Security.STANDARD_REALMS_FEATURE)).thenReturn(false);
|
|
|
+ licenseStateListeners.forEach(LicenseStateListener::licenseStateChanged);
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testRealmTypeAvailable() {
|
|
|
+ final Set<String> basicRealmTypes = Sets.newHashSet("file", "native", "reserved");
|
|
|
+ final Set<String> goldRealmTypes = Sets.newHashSet("ldap", "active_directory", "pki");
|
|
|
+
|
|
|
+ final Set<String> platinumRealmTypes = new HashSet<>(InternalRealms.getConfigurableRealmsTypes());
|
|
|
+ platinumRealmTypes.addAll(this.factories.keySet());
|
|
|
+ platinumRealmTypes.removeAll(basicRealmTypes);
|
|
|
+ platinumRealmTypes.removeAll(goldRealmTypes);
|
|
|
+
|
|
|
+ Consumer<String> checkAllowed = type -> assertThat("Type: " + type, Realms.isRealmTypeAvailable(licenseState, type), is(true));
|
|
|
+ Consumer<String> checkNotAllowed = type -> assertThat("Type: " + type, Realms.isRealmTypeAvailable(licenseState, type), is(false));
|
|
|
+
|
|
|
+ allowAllRealms();
|
|
|
+ platinumRealmTypes.forEach(checkAllowed);
|
|
|
+ goldRealmTypes.forEach(checkAllowed);
|
|
|
+ basicRealmTypes.forEach(checkAllowed);
|
|
|
+
|
|
|
+ allowOnlyStandardRealms();
|
|
|
+ platinumRealmTypes.forEach(checkNotAllowed);
|
|
|
+ goldRealmTypes.forEach(checkAllowed);
|
|
|
+ basicRealmTypes.forEach(checkAllowed);
|
|
|
+
|
|
|
+ allowOnlyNativeRealms();
|
|
|
+ platinumRealmTypes.forEach(checkNotAllowed);
|
|
|
+ goldRealmTypes.forEach(checkNotAllowed);
|
|
|
+ basicRealmTypes.forEach(checkAllowed);
|
|
|
}
|
|
|
|
|
|
public void testWithSettings() throws Exception {
|
|
@@ -118,6 +172,16 @@ public class RealmsTests extends ESTestCase {
|
|
|
Settings settings = builder.build();
|
|
|
Environment env = TestEnvironment.newEnvironment(settings);
|
|
|
Realms realms = new Realms(settings, env, factories, licenseState, threadContext, reservedRealm);
|
|
|
+ verify(licenseState, times(1)).addListener(Mockito.any(LicenseStateListener.class));
|
|
|
+ verify(licenseState, times(1)).copyCurrentLicenseState();
|
|
|
+ verify(licenseState, times(1)).getOperationMode();
|
|
|
+
|
|
|
+ // Verify that we recorded licensed-feature use for each realm (this is trigger on license load during node startup)
|
|
|
+ verify(licenseState, Mockito.atLeast(randomRealmTypesCount)).isAllowed(Security.ALL_REALMS_FEATURE);
|
|
|
+ for (int i = 0; i < randomRealmTypesCount; i++) {
|
|
|
+ verify(licenseState, atLeastOnce()).enableUsageTracking(Security.ALL_REALMS_FEATURE, "realm_" + i);
|
|
|
+ }
|
|
|
+ verifyNoMoreInteractions(licenseState);
|
|
|
|
|
|
Iterator<Realm> iterator = realms.iterator();
|
|
|
assertThat(iterator.hasNext(), is(true));
|
|
@@ -372,9 +436,23 @@ public class RealmsTests extends ESTestCase {
|
|
|
assertThat(realm.type(), is(type));
|
|
|
assertThat(iter.hasNext(), is(false));
|
|
|
assertThat(realms.getUnlicensedRealms(), empty());
|
|
|
+
|
|
|
+ // during init only
|
|
|
+ verify(licenseState, times(1)).addListener(Mockito.any(LicenseStateListener.class));
|
|
|
+ // each time the license state changes
|
|
|
+ verify(licenseState, times(1)).copyCurrentLicenseState();
|
|
|
+ verify(licenseState, times(1)).getOperationMode();
|
|
|
+
|
|
|
+ // Verify that we recorded licensed-feature use for each licensed realm (this is trigger on license load/change)
|
|
|
+ verify(licenseState, times(1)).isAllowed(Security.STANDARD_REALMS_FEATURE);
|
|
|
verify(licenseState).enableUsageTracking(Security.STANDARD_REALMS_FEATURE, "foo");
|
|
|
+ verifyNoMoreInteractions(licenseState);
|
|
|
|
|
|
allowOnlyNativeRealms();
|
|
|
+ // because the license state changed ...
|
|
|
+ verify(licenseState, times(2)).copyCurrentLicenseState();
|
|
|
+ verify(licenseState, times(2)).getOperationMode();
|
|
|
+
|
|
|
iter = realms.iterator();
|
|
|
assertThat(iter.hasNext(), is(true));
|
|
|
realm = iter.next();
|
|
@@ -389,6 +467,12 @@ public class RealmsTests extends ESTestCase {
|
|
|
assertThat(realm.type(), is(type));
|
|
|
assertThat(iter.hasNext(), is(false));
|
|
|
|
|
|
+ // Verify that we checked (a 2nd time) the license for the non-basic realm
|
|
|
+ verify(licenseState, times(2)).isAllowed(Security.STANDARD_REALMS_FEATURE);
|
|
|
+ // Verify that we stopped tracking use for realms which are no longer licensed
|
|
|
+ verify(licenseState).disableUsageTracking(Security.STANDARD_REALMS_FEATURE, "foo");
|
|
|
+ verifyNoMoreInteractions(licenseState);
|
|
|
+
|
|
|
assertThat(realms.getUnlicensedRealms(), iterableWithSize(1));
|
|
|
realm = realms.getUnlicensedRealms().get(0);
|
|
|
assertThat(realm.type(), equalTo("ldap"));
|
|
@@ -416,7 +500,8 @@ public class RealmsTests extends ESTestCase {
|
|
|
realm = iter.next();
|
|
|
assertThat(realm.type(), is(selectedRealmType));
|
|
|
assertThat(realms.getUnlicensedRealms(), empty());
|
|
|
- verify(licenseState).enableUsageTracking(Security.ALL_REALMS_FEATURE, realmName);
|
|
|
+ verify(licenseState, times(1)).isAllowed(Security.ALL_REALMS_FEATURE);
|
|
|
+ verify(licenseState, times(1)).enableUsageTracking(Security.ALL_REALMS_FEATURE, realmName);
|
|
|
|
|
|
allowOnlyStandardRealms();
|
|
|
iter = realms.iterator();
|
|
@@ -430,6 +515,11 @@ public class RealmsTests extends ESTestCase {
|
|
|
assertThat(realm.type(), equalTo(selectedRealmType));
|
|
|
assertThat(realm.name(), equalTo(realmName));
|
|
|
|
|
|
+ verify(licenseState, times(2)).isAllowed(Security.ALL_REALMS_FEATURE);
|
|
|
+ verify(licenseState, times(1)).disableUsageTracking(Security.ALL_REALMS_FEATURE, realmName);
|
|
|
+ // this happened when the realm was allowed. Check it's still only 1 call
|
|
|
+ verify(licenseState, times(1)).enableUsageTracking(Security.ALL_REALMS_FEATURE, realmName);
|
|
|
+
|
|
|
allowOnlyNativeRealms();
|
|
|
iter = realms.iterator();
|
|
|
assertThat(iter.hasNext(), is(true));
|
|
@@ -441,6 +531,12 @@ public class RealmsTests extends ESTestCase {
|
|
|
realm = realms.getUnlicensedRealms().get(0);
|
|
|
assertThat(realm.type(), equalTo(selectedRealmType));
|
|
|
assertThat(realm.name(), equalTo(realmName));
|
|
|
+
|
|
|
+ verify(licenseState, times(3)).isAllowed(Security.ALL_REALMS_FEATURE);
|
|
|
+ // this doesn't get called a second time because it didn't change
|
|
|
+ verify(licenseState, times(1)).disableUsageTracking(Security.ALL_REALMS_FEATURE, realmName);
|
|
|
+ // this happened when the realm was allowed. Check it's still only 1 call
|
|
|
+ verify(licenseState, times(1)).enableUsageTracking(Security.ALL_REALMS_FEATURE, realmName);
|
|
|
}
|
|
|
|
|
|
public void testDisabledRealmsAreNotAdded() throws Exception {
|