Browse Source

[ML] Confirm platinum license for experimental ML aggregations (#89117)

We have 5 experimental aggregations in the ML plugin that
are not confirming a platinum license when they are used.
Before they are made generally available we should plug
this hole.

The 5 aggregations are:

1. change_point
2. bucket_correlation
3. bucket_count_ks_test
4. frequent_items
5. categorize_text

This PR also touches a sixth aggregation, namely inference.
This was already platinum licensed but is given its own
unique feature like the other five aggregations. In the
future this will allow us to tell how popular the different
aggregations are. (The principle of separate features per
licensed aggregation is taken from the spatial plugin.)
David Roberts 3 years ago
parent
commit
80eeca74e4

+ 62 - 6
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java

@@ -53,6 +53,7 @@ import org.elasticsearch.indices.analysis.AnalysisModule.AnalysisProvider;
 import org.elasticsearch.indices.breaker.BreakerSettings;
 import org.elasticsearch.ingest.Processor;
 import org.elasticsearch.license.License;
+import org.elasticsearch.license.LicenseUtils;
 import org.elasticsearch.license.LicensedFeature;
 import org.elasticsearch.license.XPackLicenseState;
 import org.elasticsearch.monitor.jvm.JvmInfo;
@@ -77,6 +78,7 @@ import org.elasticsearch.threadpool.ScalingExecutorBuilder;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.tracing.Tracer;
 import org.elasticsearch.watcher.ResourceWatcherService;
+import org.elasticsearch.xcontent.ContextParser;
 import org.elasticsearch.xcontent.NamedXContentRegistry;
 import org.elasticsearch.xcontent.ParseField;
 import org.elasticsearch.xpack.autoscaling.capacity.AutoscalingDeciderService;
@@ -279,6 +281,7 @@ import org.elasticsearch.xpack.ml.aggs.categorization.CategorizeTextAggregationB
 import org.elasticsearch.xpack.ml.aggs.categorization.InternalCategorizationAggregation;
 import org.elasticsearch.xpack.ml.aggs.changepoint.ChangePointAggregationBuilder;
 import org.elasticsearch.xpack.ml.aggs.changepoint.ChangePointNamedContentProvider;
+import org.elasticsearch.xpack.ml.aggs.changepoint.InternalChangePointAggregation;
 import org.elasticsearch.xpack.ml.aggs.correlation.BucketCorrelationAggregationBuilder;
 import org.elasticsearch.xpack.ml.aggs.correlation.CorrelationNamedContentProvider;
 import org.elasticsearch.xpack.ml.aggs.frequentitemsets.FrequentItemSetsAggregationBuilder;
@@ -286,6 +289,7 @@ import org.elasticsearch.xpack.ml.aggs.frequentitemsets.FrequentItemSetsAggregat
 import org.elasticsearch.xpack.ml.aggs.heuristic.PValueScore;
 import org.elasticsearch.xpack.ml.aggs.inference.InferencePipelineAggregationBuilder;
 import org.elasticsearch.xpack.ml.aggs.kstest.BucketCountKSTestAggregationBuilder;
+import org.elasticsearch.xpack.ml.aggs.kstest.InternalKSTestAggregation;
 import org.elasticsearch.xpack.ml.annotations.AnnotationPersister;
 import org.elasticsearch.xpack.ml.autoscaling.MlAutoscalingDeciderService;
 import org.elasticsearch.xpack.ml.autoscaling.MlAutoscalingNamedWritableProvider;
@@ -501,6 +505,37 @@ public class MachineLearning extends Plugin
         License.OperationMode.PLATINUM
     );
 
+    private static final LicensedFeature.Momentary CATEGORIZE_TEXT_AGG_FEATURE = LicensedFeature.momentary(
+        MachineLearningField.ML_FEATURE_FAMILY,
+        "categorize-text-agg",
+        License.OperationMode.PLATINUM
+    );
+    private static final LicensedFeature.Momentary FREQUENT_ITEM_SETS_AGG_FEATURE = LicensedFeature.momentary(
+        MachineLearningField.ML_FEATURE_FAMILY,
+        "frequent-items-agg",
+        License.OperationMode.PLATINUM
+    );
+    public static final LicensedFeature.Momentary INFERENCE_AGG_FEATURE = LicensedFeature.momentary(
+        MachineLearningField.ML_FEATURE_FAMILY,
+        "inference-agg",
+        License.OperationMode.PLATINUM
+    );
+    private static final LicensedFeature.Momentary CHANGE_POINT_AGG_FEATURE = LicensedFeature.momentary(
+        MachineLearningField.ML_FEATURE_FAMILY,
+        "change-point-agg",
+        License.OperationMode.PLATINUM
+    );
+    private static final LicensedFeature.Momentary BUCKET_CORRELATION_AGG_FEATURE = LicensedFeature.momentary(
+        MachineLearningField.ML_FEATURE_FAMILY,
+        "bucket-correlation-agg",
+        License.OperationMode.PLATINUM
+    );
+    private static final LicensedFeature.Momentary BUCKET_COUNT_KS_TEST_AGG_FEATURE = LicensedFeature.momentary(
+        MachineLearningField.ML_FEATURE_FAMILY,
+        "bucket-count-ks-test-agg",
+        License.OperationMode.PLATINUM
+    );
+
     @Override
     public Map<String, Processor.Factory> getProcessors(Processor.Parameters parameters) {
         if (this.enabled == false) {
@@ -1454,11 +1489,23 @@ public class MachineLearning extends Plugin
 
     @Override
     public List<PipelineAggregationSpec> getPipelineAggregations() {
-        return Arrays.asList(
+        return List.of(
             InferencePipelineAggregationBuilder.buildSpec(modelLoadingService, getLicenseState(), settings),
-            BucketCorrelationAggregationBuilder.buildSpec(),
-            BucketCountKSTestAggregationBuilder.buildSpec(),
-            ChangePointAggregationBuilder.buildSpec()
+            new SearchPlugin.PipelineAggregationSpec(
+                BucketCorrelationAggregationBuilder.NAME,
+                BucketCorrelationAggregationBuilder::new,
+                checkAggLicense(BucketCorrelationAggregationBuilder.PARSER, BUCKET_CORRELATION_AGG_FEATURE)
+            ),
+            new SearchPlugin.PipelineAggregationSpec(
+                BucketCountKSTestAggregationBuilder.NAME,
+                BucketCountKSTestAggregationBuilder::new,
+                checkAggLicense(BucketCountKSTestAggregationBuilder.PARSER, BUCKET_COUNT_KS_TEST_AGG_FEATURE)
+            ).addResultReader(InternalKSTestAggregation::new),
+            new SearchPlugin.PipelineAggregationSpec(
+                ChangePointAggregationBuilder.NAME,
+                ChangePointAggregationBuilder::new,
+                checkAggLicense(ChangePointAggregationBuilder.PARSER, CHANGE_POINT_AGG_FEATURE)
+            ).addResultReader(InternalChangePointAggregation::new)
         );
     }
 
@@ -1467,19 +1514,28 @@ public class MachineLearning extends Plugin
         return List.of(new SignificanceHeuristicSpec<>(PValueScore.NAME, PValueScore::new, PValueScore.PARSER));
     }
 
+    private <T> ContextParser<String, T> checkAggLicense(ContextParser<String, T> realParser, LicensedFeature.Momentary feature) {
+        return (parser, name) -> {
+            if (feature.check(getLicenseState()) == false) {
+                throw LicenseUtils.newComplianceException(feature.getName());
+            }
+            return realParser.parse(parser, name);
+        };
+    }
+
     @Override
     public List<AggregationSpec> getAggregations() {
         return List.of(
             new AggregationSpec(
                 CategorizeTextAggregationBuilder.NAME,
                 CategorizeTextAggregationBuilder::new,
-                CategorizeTextAggregationBuilder.PARSER
+                checkAggLicense(CategorizeTextAggregationBuilder.PARSER, CATEGORIZE_TEXT_AGG_FEATURE)
             ).addResultReader(InternalCategorizationAggregation::new)
                 .setAggregatorRegistrar(s -> s.registerUsage(CategorizeTextAggregationBuilder.NAME)),
             new AggregationSpec(
                 FrequentItemSetsAggregationBuilder.NAME,
                 FrequentItemSetsAggregationBuilder::new,
-                FrequentItemSetsAggregationBuilder.PARSER
+                checkAggLicense(FrequentItemSetsAggregationBuilder.PARSER, FREQUENT_ITEM_SETS_AGG_FEATURE)
             ).addResultReader(FrequentItemSetsAggregatorFactory.getResultReader())
                 .setAggregatorRegistrar(FrequentItemSetsAggregationBuilder::registerAggregators)
         );

+ 2 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java

@@ -23,7 +23,6 @@ import org.elasticsearch.tasks.TaskId;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.TransportService;
 import org.elasticsearch.xpack.core.XPackField;
-import org.elasticsearch.xpack.core.ml.MachineLearningField;
 import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
 import org.elasticsearch.xpack.core.ml.action.InferModelAction;
 import org.elasticsearch.xpack.core.ml.action.InferModelAction.Request;
@@ -31,6 +30,7 @@ import org.elasticsearch.xpack.core.ml.action.InferModelAction.Response;
 import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
+import org.elasticsearch.xpack.ml.MachineLearning;
 import org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentMetadata;
 import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel;
 import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
@@ -98,7 +98,7 @@ public class TransportInternalInferModelAction extends HandledTransportAction<Re
         Response.Builder responseBuilder = Response.builder();
         TaskId parentTaskId = new TaskId(clusterService.localNode().getId(), task.getId());
 
-        if (MachineLearningField.ML_API_FEATURE.check(licenseState)) {
+        if (MachineLearning.INFERENCE_AGG_FEATURE.check(licenseState)) {
             responseBuilder.setLicensed(true);
             doInfer(task, request, responseBuilder, parentTaskId, listener);
         } else {

+ 0 - 6
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/changepoint/ChangePointAggregationBuilder.java

@@ -10,7 +10,6 @@ package org.elasticsearch.xpack.ml.aggs.changepoint;
 import org.elasticsearch.Version;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
-import org.elasticsearch.plugins.SearchPlugin;
 import org.elasticsearch.search.aggregations.pipeline.BucketHelpers;
 import org.elasticsearch.search.aggregations.pipeline.BucketMetricsPipelineAggregationBuilder;
 import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
@@ -53,11 +52,6 @@ public class ChangePointAggregationBuilder extends BucketMetricsPipelineAggregat
         super(in, NAME.getPreferredName());
     }
 
-    public static SearchPlugin.PipelineAggregationSpec buildSpec() {
-        return new SearchPlugin.PipelineAggregationSpec(NAME, ChangePointAggregationBuilder::new, ChangePointAggregationBuilder.PARSER)
-            .addResultReader(InternalChangePointAggregation::new);
-    }
-
     @Override
     public String getWriteableName() {
         return NAME.getPreferredName();

+ 0 - 9
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/correlation/BucketCorrelationAggregationBuilder.java

@@ -10,7 +10,6 @@ package org.elasticsearch.xpack.ml.aggs.correlation;
 import org.elasticsearch.Version;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
-import org.elasticsearch.plugins.SearchPlugin;
 import org.elasticsearch.search.aggregations.pipeline.BucketHelpers;
 import org.elasticsearch.search.aggregations.pipeline.BucketMetricsPipelineAggregationBuilder;
 import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
@@ -59,14 +58,6 @@ public class BucketCorrelationAggregationBuilder extends BucketMetricsPipelineAg
         }, GAP_POLICY, ObjectParser.ValueType.STRING);
     }
 
-    public static SearchPlugin.PipelineAggregationSpec buildSpec() {
-        return new SearchPlugin.PipelineAggregationSpec(
-            NAME,
-            BucketCorrelationAggregationBuilder::new,
-            BucketCorrelationAggregationBuilder.PARSER
-        );
-    }
-
     private final CorrelationFunction correlationFunction;
 
     public BucketCorrelationAggregationBuilder(String name, String bucketsPath, CorrelationFunction correlationFunction) {

+ 2 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/inference/InferencePipelineAggregationBuilder.java

@@ -30,7 +30,6 @@ import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xcontent.XContentParser;
 import org.elasticsearch.xpack.core.XPackField;
 import org.elasticsearch.xpack.core.XPackSettings;
-import org.elasticsearch.xpack.core.ml.MachineLearningField;
 import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate;
@@ -42,6 +41,7 @@ import org.elasticsearch.xpack.core.security.action.user.HasPrivilegesRequest;
 import org.elasticsearch.xpack.core.security.action.user.HasPrivilegesResponse;
 import org.elasticsearch.xpack.core.security.authz.RoleDescriptor;
 import org.elasticsearch.xpack.core.security.support.Exceptions;
+import org.elasticsearch.xpack.ml.MachineLearning;
 import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel;
 import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
 
@@ -268,7 +268,7 @@ public class InferencePipelineAggregationBuilder extends AbstractPipelineAggrega
                 loadedModel.set(localModel);
 
                 boolean isLicensed = localModel.getLicenseLevel() == License.OperationMode.BASIC
-                    || MachineLearningField.ML_API_FEATURE.check(licenseState);
+                    || MachineLearning.INFERENCE_AGG_FEATURE.check(licenseState);
                 if (isLicensed) {
                     delegate.onResponse(null);
                 } else {

+ 0 - 9
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/kstest/BucketCountKSTestAggregationBuilder.java

@@ -11,7 +11,6 @@ import org.elasticsearch.Version;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.core.Nullable;
-import org.elasticsearch.plugins.SearchPlugin;
 import org.elasticsearch.search.aggregations.pipeline.BucketHelpers;
 import org.elasticsearch.search.aggregations.pipeline.BucketMetricsPipelineAggregationBuilder;
 import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
@@ -124,14 +123,6 @@ public class BucketCountKSTestAggregationBuilder extends BucketMetricsPipelineAg
         this.samplingMethod = SamplingMethod.fromStream(in);
     }
 
-    public static SearchPlugin.PipelineAggregationSpec buildSpec() {
-        return new SearchPlugin.PipelineAggregationSpec(
-            NAME,
-            BucketCountKSTestAggregationBuilder::new,
-            BucketCountKSTestAggregationBuilder.PARSER
-        ).addResultReader(InternalKSTestAggregation::new);
-    }
-
     @Override
     public String getWriteableName() {
         return NAME.getPreferredName();

+ 88 - 58
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MachineLearningTests.java

@@ -22,6 +22,7 @@ import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.xpack.core.ml.MlMetadata;
 import org.elasticsearch.xpack.core.ml.action.SetUpgradeModeAction;
 
+import java.io.IOException;
 import java.util.Collections;
 import java.util.Map;
 
@@ -40,7 +41,7 @@ import static org.mockito.Mockito.when;
 public class MachineLearningTests extends ESTestCase {
 
     @SuppressWarnings("unchecked")
-    public void testPrePostSystemIndexUpgrade_givenNotInUpgradeMode() {
+    public void testPrePostSystemIndexUpgrade_givenNotInUpgradeMode() throws IOException {
         ThreadPool threadpool = new TestThreadPool("test");
         ClusterService clusterService = mock(ClusterService.class);
         when(clusterService.state()).thenReturn(ClusterState.EMPTY_STATE);
@@ -52,27 +53,40 @@ public class MachineLearningTests extends ESTestCase {
             return null;
         }).when(client).execute(same(SetUpgradeModeAction.INSTANCE), any(SetUpgradeModeAction.Request.class), any(ActionListener.class));
 
-        MachineLearning machineLearning = createMachineLearning(Settings.EMPTY);
-
-        SetOnce<Map<String, Object>> response = new SetOnce<>();
-        machineLearning.prepareForIndicesMigration(clusterService, client, ActionListener.wrap(response::set, e -> fail(e.getMessage())));
-
-        assertThat(response.get(), equalTo(Collections.singletonMap("already_in_upgrade_mode", false)));
-        verify(client).execute(same(SetUpgradeModeAction.INSTANCE), eq(new SetUpgradeModeAction.Request(true)), any(ActionListener.class));
-
-        machineLearning.indicesMigrationComplete(
-            response.get(),
-            clusterService,
-            client,
-            ActionListener.wrap(ESTestCase::assertTrue, e -> fail(e.getMessage()))
-        );
-
-        verify(client).execute(same(SetUpgradeModeAction.INSTANCE), eq(new SetUpgradeModeAction.Request(false)), any(ActionListener.class));
-
-        threadpool.shutdown();
+        try (MachineLearning machineLearning = createTrialLicensedMachineLearning(Settings.EMPTY)) {
+
+            SetOnce<Map<String, Object>> response = new SetOnce<>();
+            machineLearning.prepareForIndicesMigration(
+                clusterService,
+                client,
+                ActionListener.wrap(response::set, e -> fail(e.getMessage()))
+            );
+
+            assertThat(response.get(), equalTo(Collections.singletonMap("already_in_upgrade_mode", false)));
+            verify(client).execute(
+                same(SetUpgradeModeAction.INSTANCE),
+                eq(new SetUpgradeModeAction.Request(true)),
+                any(ActionListener.class)
+            );
+
+            machineLearning.indicesMigrationComplete(
+                response.get(),
+                clusterService,
+                client,
+                ActionListener.wrap(ESTestCase::assertTrue, e -> fail(e.getMessage()))
+            );
+
+            verify(client).execute(
+                same(SetUpgradeModeAction.INSTANCE),
+                eq(new SetUpgradeModeAction.Request(false)),
+                any(ActionListener.class)
+            );
+        } finally {
+            threadpool.shutdown();
+        }
     }
 
-    public void testPrePostSystemIndexUpgrade_givenAlreadyInUpgradeMode() {
+    public void testPrePostSystemIndexUpgrade_givenAlreadyInUpgradeMode() throws IOException {
         ClusterService clusterService = mock(ClusterService.class);
         when(clusterService.state()).thenReturn(
             ClusterState.builder(ClusterName.DEFAULT)
@@ -81,23 +95,28 @@ public class MachineLearningTests extends ESTestCase {
         );
         Client client = mock(Client.class);
 
-        MachineLearning machineLearning = createMachineLearning(Settings.EMPTY);
+        try (MachineLearning machineLearning = createTrialLicensedMachineLearning(Settings.EMPTY)) {
 
-        SetOnce<Map<String, Object>> response = new SetOnce<>();
-        machineLearning.prepareForIndicesMigration(clusterService, client, ActionListener.wrap(response::set, e -> fail(e.getMessage())));
+            SetOnce<Map<String, Object>> response = new SetOnce<>();
+            machineLearning.prepareForIndicesMigration(
+                clusterService,
+                client,
+                ActionListener.wrap(response::set, e -> fail(e.getMessage()))
+            );
 
-        assertThat(response.get(), equalTo(Collections.singletonMap("already_in_upgrade_mode", true)));
-        verifyNoMoreInteractions(client);
+            assertThat(response.get(), equalTo(Collections.singletonMap("already_in_upgrade_mode", true)));
+            verifyNoMoreInteractions(client);
 
-        machineLearning.indicesMigrationComplete(
-            response.get(),
-            clusterService,
-            client,
-            ActionListener.wrap(ESTestCase::assertTrue, e -> fail(e.getMessage()))
-        );
+            machineLearning.indicesMigrationComplete(
+                response.get(),
+                clusterService,
+                client,
+                ActionListener.wrap(ESTestCase::assertTrue, e -> fail(e.getMessage()))
+            );
 
-        // Neither pre nor post should have called any action
-        verifyNoMoreInteractions(client);
+            // Neither pre nor post should have called any action
+            verifyNoMoreInteractions(client);
+        }
     }
 
     public void testMaxOpenWorkersSetting_givenDefault() {
@@ -141,7 +160,7 @@ public class MachineLearningTests extends ESTestCase {
         );
     }
 
-    public void testNoAttributes_givenNoClash() {
+    public void testNoAttributes_givenNoClash() throws IOException {
         Settings.Builder builder = Settings.builder();
         if (randomBoolean()) {
             builder.put("xpack.ml.enabled", randomBoolean());
@@ -151,11 +170,12 @@ public class MachineLearningTests extends ESTestCase {
         }
         builder.put("node.attr.foo", "abc");
         builder.put("node.attr.ml.bar", "def");
-        MachineLearning machineLearning = createMachineLearning(builder.put("path.home", createTempDir()).build());
-        assertNotNull(machineLearning.additionalSettings());
+        try (MachineLearning machineLearning = createTrialLicensedMachineLearning(builder.put("path.home", createTempDir()).build())) {
+            assertNotNull(machineLearning.additionalSettings());
+        }
     }
 
-    public void testNoAttributes_givenSameAndMlEnabled() {
+    public void testNoAttributes_givenSameAndMlEnabled() throws IOException {
         Settings.Builder builder = Settings.builder();
         if (randomBoolean()) {
             builder.put("xpack.ml.enabled", randomBoolean());
@@ -164,33 +184,43 @@ public class MachineLearningTests extends ESTestCase {
             int maxOpenJobs = randomIntBetween(5, 15);
             builder.put("xpack.ml.max_open_jobs", maxOpenJobs);
         }
-        MachineLearning machineLearning = createMachineLearning(builder.put("path.home", createTempDir()).build());
-        assertNotNull(machineLearning.additionalSettings());
+        try (MachineLearning machineLearning = createTrialLicensedMachineLearning(builder.put("path.home", createTempDir()).build())) {
+            assertNotNull(machineLearning.additionalSettings());
+        }
     }
 
-    public void testNoAttributes_givenClash() {
+    public void testNoAttributes_givenClash() throws IOException {
         Settings.Builder builder = Settings.builder();
         builder.put("node.attr.ml.max_open_jobs", randomIntBetween(13, 15));
-        MachineLearning machineLearning = createMachineLearning(builder.put("path.home", createTempDir()).build());
-        IllegalArgumentException e = expectThrows(IllegalArgumentException.class, machineLearning::additionalSettings);
-        assertThat(e.getMessage(), startsWith("Directly setting [node.attr.ml."));
-        assertThat(
-            e.getMessage(),
-            containsString(
-                "] is not permitted - "
-                    + "it is reserved for machine learning. If your intention was to customize machine learning, set the [xpack.ml."
-            )
-        );
+        try (MachineLearning machineLearning = createTrialLicensedMachineLearning(builder.put("path.home", createTempDir()).build())) {
+            IllegalArgumentException e = expectThrows(IllegalArgumentException.class, machineLearning::additionalSettings);
+            assertThat(e.getMessage(), startsWith("Directly setting [node.attr.ml."));
+            assertThat(
+                e.getMessage(),
+                containsString(
+                    "] is not permitted - "
+                        + "it is reserved for machine learning. If your intention was to customize machine learning, set the [xpack.ml."
+                )
+            );
+        }
     }
 
-    private MachineLearning createMachineLearning(Settings settings) {
-        XPackLicenseState licenseState = mock(XPackLicenseState.class);
+    public static class TrialLicensedMachineLearning extends MachineLearning {
+
+        // A license state constructed like this is considered a trial license
+        XPackLicenseState licenseState = new XPackLicenseState(() -> 0L);
+
+        public TrialLicensedMachineLearning(Settings settings) {
+            super(settings);
+        }
+
+        @Override
+        protected XPackLicenseState getLicenseState() {
+            return licenseState;
+        }
+    }
 
-        return new MachineLearning(settings) {
-            @Override
-            protected XPackLicenseState getLicenseState() {
-                return licenseState;
-            }
-        };
+    public static MachineLearning createTrialLicensedMachineLearning(Settings settings) {
+        return new TrialLicensedMachineLearning(settings);
     }
 }

+ 3 - 3
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/categorization/CategorizeTextAggregationBuilderTests.java

@@ -9,11 +9,11 @@ package org.elasticsearch.xpack.ml.aggs.categorization;
 
 import org.elasticsearch.plugins.Plugin;
 import org.elasticsearch.search.aggregations.BaseAggregationTestCase;
-import org.elasticsearch.xpack.ml.MachineLearning;
+import org.elasticsearch.xpack.ml.MachineLearningTests;
 import org.elasticsearch.xpack.ml.job.config.CategorizationAnalyzerConfigTests;
 
 import java.util.Collection;
-import java.util.Collections;
+import java.util.List;
 import java.util.stream.Collectors;
 import java.util.stream.Stream;
 
@@ -21,7 +21,7 @@ public class CategorizeTextAggregationBuilderTests extends BaseAggregationTestCa
 
     @Override
     protected Collection<Class<? extends Plugin>> getExtraPlugins() {
-        return Collections.singletonList(MachineLearning.class);
+        return List.of(MachineLearningTests.TrialLicensedMachineLearning.class);
     }
 
     @Override

+ 2 - 3
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/changepoint/ChangePointAggregationBuilderTests.java

@@ -10,15 +10,14 @@ package org.elasticsearch.xpack.ml.aggs.changepoint;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.plugins.SearchPlugin;
 import org.elasticsearch.search.aggregations.BasePipelineAggregationTestCase;
-import org.elasticsearch.xpack.ml.MachineLearning;
+import org.elasticsearch.xpack.ml.MachineLearningTests;
 
-import java.util.Collections;
 import java.util.List;
 
 public class ChangePointAggregationBuilderTests extends BasePipelineAggregationTestCase<ChangePointAggregationBuilder> {
     @Override
     protected List<SearchPlugin> plugins() {
-        return Collections.singletonList(new MachineLearning(Settings.EMPTY));
+        return List.of(MachineLearningTests.createTrialLicensedMachineLearning(Settings.EMPTY));
     }
 
     @Override

+ 2 - 3
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/correlation/BucketCorrelationAggregationBuilderTests.java

@@ -16,9 +16,8 @@ import org.elasticsearch.search.aggregations.bucket.global.GlobalAggregationBuil
 import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
 import org.elasticsearch.search.aggregations.support.ValueType;
 import org.elasticsearch.xcontent.NamedXContentRegistry;
-import org.elasticsearch.xpack.ml.MachineLearning;
+import org.elasticsearch.xpack.ml.MachineLearningTests;
 
-import java.util.Collections;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Set;
@@ -31,7 +30,7 @@ public class BucketCorrelationAggregationBuilderTests extends BasePipelineAggreg
 
     @Override
     protected List<SearchPlugin> plugins() {
-        return Collections.singletonList(new MachineLearning(Settings.EMPTY));
+        return List.of(MachineLearningTests.createTrialLicensedMachineLearning(Settings.EMPTY));
     }
 
     @Override

+ 3 - 3
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/FrequentItemSetsTests.java

@@ -9,10 +9,10 @@ package org.elasticsearch.xpack.ml.aggs.frequentitemsets;
 
 import org.elasticsearch.plugins.Plugin;
 import org.elasticsearch.search.aggregations.BaseAggregationTestCase;
-import org.elasticsearch.xpack.ml.MachineLearning;
+import org.elasticsearch.xpack.ml.MachineLearningTests;
 
 import java.util.Collection;
-import java.util.Collections;
+import java.util.List;
 
 import static org.elasticsearch.xpack.ml.aggs.frequentitemsets.FrequentItemSetsAggregationBuilderTests.randomFrequentItemsSetsAggregationBuilder;
 
@@ -20,7 +20,7 @@ public class FrequentItemSetsTests extends BaseAggregationTestCase<FrequentItemS
 
     @Override
     protected Collection<Class<? extends Plugin>> getExtraPlugins() {
-        return Collections.singletonList(MachineLearning.class);
+        return List.of(MachineLearningTests.TrialLicensedMachineLearning.class);
     }
 
     @Override

+ 2 - 3
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/kstest/BucketCountKSTestAggregationBuilderTests.java

@@ -15,9 +15,8 @@ import org.elasticsearch.search.aggregations.bucket.global.GlobalAggregationBuil
 import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
 import org.elasticsearch.search.aggregations.support.ValueType;
 import org.elasticsearch.test.ESTestCase;
-import org.elasticsearch.xpack.ml.MachineLearning;
+import org.elasticsearch.xpack.ml.MachineLearningTests;
 
-import java.util.Collections;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Set;
@@ -32,7 +31,7 @@ public class BucketCountKSTestAggregationBuilderTests extends BasePipelineAggreg
 
     @Override
     protected List<SearchPlugin> plugins() {
-        return Collections.singletonList(new MachineLearning(Settings.EMPTY));
+        return List.of(MachineLearningTests.createTrialLicensedMachineLearning(Settings.EMPTY));
     }
 
     @Override