瀏覽代碼

Consistent response for starting basic license (#86272)

The start basic license action returns an inconsistent response when a
basic license exists but a new basic license with different maxNodes
or expiryDate fields is started. In this case, the new license is
created but the response indicates that it was not. This PR makes the
response consistent.

Closes #86244
Nikolaj Volgushev 3 年之前
父節點
當前提交
76d6cfc278

+ 6 - 0
docs/changelog/86272.yaml

@@ -0,0 +1,6 @@
+pr: 86272
+summary: Consistent response for starting basic license
+area: License
+type: bug
+issues:
+ - 86244

+ 4 - 3
x-pack/plugin/core/src/main/java/org/elasticsearch/license/StartBasicClusterTask.java

@@ -92,9 +92,10 @@ public class StartBasicClusterTask implements ClusterStateTaskListener {
         } else {
         } else {
             updatedLicensesMetadata = currentLicensesMetadata;
             updatedLicensesMetadata = currentLicensesMetadata;
         }
         }
-        final var responseStatus = currentLicense != null && License.LicenseType.isBasic(currentLicense.type())
-            ? PostStartBasicResponse.Status.ALREADY_USING_BASIC
-            : PostStartBasicResponse.Status.GENERATED_BASIC;
+        final var newLicenseGenerated = updatedLicensesMetadata != currentLicensesMetadata;
+        final var responseStatus = newLicenseGenerated
+            ? PostStartBasicResponse.Status.GENERATED_BASIC
+            : PostStartBasicResponse.Status.ALREADY_USING_BASIC;
         taskContext.success(listener.delegateFailure((l, s) -> l.onResponse(new PostStartBasicResponse(responseStatus))));
         taskContext.success(listener.delegateFailure((l, s) -> l.onResponse(new PostStartBasicResponse(responseStatus))));
         return updatedLicensesMetadata;
         return updatedLicensesMetadata;
     }
     }

+ 101 - 7
x-pack/plugin/core/src/test/java/org/elasticsearch/license/LicenseServiceTests.java

@@ -6,9 +6,12 @@
  */
  */
 package org.elasticsearch.license;
 package org.elasticsearch.license;
 
 
+import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.cluster.AckedClusterStateUpdateTask;
 import org.elasticsearch.cluster.AckedClusterStateUpdateTask;
+import org.elasticsearch.cluster.ClusterName;
 import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.cluster.ClusterState;
+import org.elasticsearch.cluster.ClusterStateTaskExecutor;
 import org.elasticsearch.cluster.ClusterStateUpdateTask;
 import org.elasticsearch.cluster.ClusterStateUpdateTask;
 import org.elasticsearch.cluster.metadata.Metadata;
 import org.elasticsearch.cluster.metadata.Metadata;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.cluster.service.ClusterService;
@@ -29,8 +32,12 @@ import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xcontent.XContentFactory;
 import org.elasticsearch.xcontent.XContentFactory;
 import org.elasticsearch.xcontent.XContentType;
 import org.elasticsearch.xcontent.XContentType;
 import org.hamcrest.Matchers;
 import org.hamcrest.Matchers;
+import org.junit.After;
+import org.junit.Before;
 import org.mockito.ArgumentCaptor;
 import org.mockito.ArgumentCaptor;
+import org.mockito.Mock;
 import org.mockito.Mockito;
 import org.mockito.Mockito;
+import org.mockito.MockitoAnnotations;
 
 
 import java.io.IOException;
 import java.io.IOException;
 import java.nio.file.Path;
 import java.nio.file.Path;
@@ -52,9 +59,11 @@ import static org.hamcrest.Matchers.is;
 import static org.hamcrest.Matchers.nullValue;
 import static org.hamcrest.Matchers.nullValue;
 import static org.hamcrest.Matchers.startsWith;
 import static org.hamcrest.Matchers.startsWith;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.doNothing;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
 
 
 /**
 /**
  * Due to changes in JDK9 where locale data is used from CLDR, the licence message will differ in jdk 8 and jdk9+
  * Due to changes in JDK9 where locale data is used from CLDR, the licence message will differ in jdk 8 and jdk9+
@@ -63,6 +72,22 @@ import static org.mockito.Mockito.verify;
  */
  */
 public class LicenseServiceTests extends ESTestCase {
 public class LicenseServiceTests extends ESTestCase {
 
 
+    // must use member mock for generic
+    @Mock
+    private ClusterStateTaskExecutor.TaskContext<StartBasicClusterTask> taskContext;
+
+    private AutoCloseable closeable;
+
+    @Before
+    public void init() {
+        closeable = MockitoAnnotations.openMocks(this);
+    }
+
+    @After
+    public void releaseMocks() throws Exception {
+        closeable.close();
+    }
+
     public void testLogExpirationWarning() {
     public void testLogExpirationWarning() {
         long time = LocalDate.of(2018, 11, 15).atStartOfDay(ZoneOffset.UTC).toInstant().toEpochMilli();
         long time = LocalDate.of(2018, 11, 15).atStartOfDay(ZoneOffset.UTC).toInstant().toEpochMilli();
         final boolean expired = randomBoolean();
         final boolean expired = randomBoolean();
@@ -141,6 +166,70 @@ public class LicenseServiceTests extends ESTestCase {
         });
         });
     }
     }
 
 
+    public void testStartBasicStartsNewLicenseIfFieldsDifferent() throws Exception {
+        final Settings settings = Settings.builder()
+            .put("path.home", createTempDir())
+            .put(DISCOVERY_TYPE_SETTING.getKey(), SINGLE_NODE_DISCOVERY_TYPE) // So we skip TLS checks
+            .build();
+
+        final ClusterService clusterService = mockDefaultClusterService();
+        final Clock clock = randomBoolean() ? Clock.systemUTC() : Clock.systemDefaultZone();
+        final LicenseService service = new LicenseService(
+            settings,
+            mock(ThreadPool.class),
+            clusterService,
+            clock,
+            TestEnvironment.newEnvironment(settings),
+            mock(ResourceWatcherService.class),
+            mock(XPackLicenseState.class)
+        );
+
+        final Consumer<PlainActionFuture<PostStartBasicResponse>> assertion = future -> {
+            PostStartBasicResponse response = future.actionGet();
+            assertThat(response.getStatus(), equalTo(PostStartBasicResponse.Status.GENERATED_BASIC));
+        };
+        final PlainActionFuture<PostStartBasicResponse> future = new PlainActionFuture<>();
+        service.startBasicLicense(new PostStartBasicRequest(), future);
+
+        if (future.isDone()) {
+            // If validation failed, the future might be done without calling the updater task.
+            assertion.accept(future);
+        } else {
+            final var taskCaptor = ArgumentCaptor.forClass(StartBasicClusterTask.class);
+            final var taskExecutorCaptor = ArgumentCaptor.forClass(StartBasicClusterTask.Executor.class);
+            @SuppressWarnings("unchecked")
+            final ArgumentCaptor<ActionListener<ClusterState>> listenerCaptor = ArgumentCaptor.forClass(ActionListener.class);
+            doNothing().when(taskContext).success(listenerCaptor.capture());
+            verify(clusterService).submitStateUpdateTask(any(), taskCaptor.capture(), any(), taskExecutorCaptor.capture());
+            when(taskContext.getTask()).thenReturn(taskCaptor.getValue());
+
+            int maxNodes = randomValueOtherThan(
+                LicenseService.SELF_GENERATED_LICENSE_MAX_NODES,
+                () -> randomIntBetween(1, LicenseService.SELF_GENERATED_LICENSE_MAX_NODES)
+            );
+            License oldLicense = sign(buildLicense(License.LicenseType.BASIC, TimeValue.timeValueDays(randomIntBetween(1, 100)), maxNodes));
+            ClusterState oldState = ClusterState.EMPTY_STATE.copyAndUpdateMetadata(
+                m -> m.putCustom(LicensesMetadata.TYPE, new LicensesMetadata(oldLicense, null))
+            );
+
+            ClusterState updatedState = taskExecutorCaptor.getValue().execute(oldState, List.of(taskContext));
+            // Pass updated state to listener to trigger onResponse call to wrapped `future`
+            listenerCaptor.getValue().onResponse(updatedState);
+            assertion.accept(future);
+        }
+    }
+
+    private ClusterService mockDefaultClusterService() {
+        final ClusterState clusterState = mock(ClusterState.class);
+        Mockito.when(clusterState.metadata()).thenReturn(Metadata.EMPTY_METADATA);
+        Mockito.when(clusterState.getClusterName()).thenReturn(ClusterName.DEFAULT);
+
+        final ClusterService clusterService = mock(ClusterService.class);
+        Mockito.when(clusterService.state()).thenReturn(clusterState);
+        Mockito.when(clusterService.getClusterName()).thenReturn(ClusterName.DEFAULT);
+        return clusterService;
+    }
+
     private void assertRegisterValidLicense(Settings baseSettings, License.LicenseType licenseType) throws IOException {
     private void assertRegisterValidLicense(Settings baseSettings, License.LicenseType licenseType) throws IOException {
         tryRegisterLicense(baseSettings, licenseType, future -> assertThat(future.actionGet().status(), equalTo(LicensesStatus.VALID)));
         tryRegisterLicense(baseSettings, licenseType, future -> assertThat(future.actionGet().status(), equalTo(LicensesStatus.VALID)));
     }
     }
@@ -173,12 +262,7 @@ public class LicenseServiceTests extends ESTestCase {
             .put(DISCOVERY_TYPE_SETTING.getKey(), SINGLE_NODE_DISCOVERY_TYPE) // So we skip TLS checks
             .put(DISCOVERY_TYPE_SETTING.getKey(), SINGLE_NODE_DISCOVERY_TYPE) // So we skip TLS checks
             .build();
             .build();
 
 
-        final ClusterState clusterState = mock(ClusterState.class);
-        Mockito.when(clusterState.metadata()).thenReturn(Metadata.EMPTY_METADATA);
-
-        final ClusterService clusterService = mock(ClusterService.class);
-        Mockito.when(clusterService.state()).thenReturn(clusterState);
-
+        final ClusterService clusterService = mockDefaultClusterService();
         final Clock clock = randomBoolean() ? Clock.systemUTC() : Clock.systemDefaultZone();
         final Clock clock = randomBoolean() ? Clock.systemUTC() : Clock.systemDefaultZone();
         final Environment env = TestEnvironment.newEnvironment(settings);
         final Environment env = TestEnvironment.newEnvironment(settings);
         final ResourceWatcherService resourceWatcherService = mock(ResourceWatcherService.class);
         final ResourceWatcherService resourceWatcherService = mock(ResourceWatcherService.class);
@@ -233,11 +317,21 @@ public class LicenseServiceTests extends ESTestCase {
         return signer.sign(license);
         return signer.sign(license);
     }
     }
 
 
+    private License buildLicense(License.LicenseType type, TimeValue expires, int maxNodes) {
+        return buildLicense(new UUID(randomLong(), randomLong()), type, expires.millis());
+    }
+
     private License buildLicense(License.LicenseType type, TimeValue expires) {
     private License buildLicense(License.LicenseType type, TimeValue expires) {
         return buildLicense(new UUID(randomLong(), randomLong()), type, expires.millis());
         return buildLicense(new UUID(randomLong(), randomLong()), type, expires.millis());
     }
     }
 
 
     private License buildLicense(UUID licenseId, License.LicenseType type, long expires) {
     private License buildLicense(UUID licenseId, License.LicenseType type, long expires) {
+        int maxNodes = type == License.LicenseType.ENTERPRISE ? -1 : randomIntBetween(1, 500);
+        return buildLicense(licenseId, type, expires, maxNodes);
+    }
+
+    private License buildLicense(UUID licenseId, License.LicenseType type, long expires, int maxNodes) {
+        assert (type == License.LicenseType.ENTERPRISE && maxNodes != -1) == false : "enterprise license must have unlimited nodes";
         return License.builder()
         return License.builder()
             .uid(licenseId.toString())
             .uid(licenseId.toString())
             .type(type)
             .type(type)
@@ -245,7 +339,7 @@ public class LicenseServiceTests extends ESTestCase {
             .issuer(randomAlphaOfLengthBetween(5, 60))
             .issuer(randomAlphaOfLengthBetween(5, 60))
             .issuedTo(randomAlphaOfLengthBetween(5, 60))
             .issuedTo(randomAlphaOfLengthBetween(5, 60))
             .issueDate(System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(randomLongBetween(1, 5000)))
             .issueDate(System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(randomLongBetween(1, 5000)))
-            .maxNodes(type == License.LicenseType.ENTERPRISE ? -1 : randomIntBetween(1, 500))
+            .maxNodes(maxNodes)
             .maxResourceUnits(type == License.LicenseType.ENTERPRISE ? randomIntBetween(10, 500) : -1)
             .maxResourceUnits(type == License.LicenseType.ENTERPRISE ? randomIntBetween(10, 500) : -1)
             .signature(null)
             .signature(null)
             .build();
             .build();