Browse Source

Add ml licence check to the pipeline inference agg. (#59213)

Ensures the licence is sufficient for the model used in inference
David Kyle 5 years ago
parent
commit
3202f46e3b

+ 68 - 8
x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/license/MachineLearningLicensingIT.java

@@ -5,14 +5,18 @@
  */
 package org.elasticsearch.license;
 
+import org.apache.lucene.util.SetOnce;
 import org.elasticsearch.ElasticsearchSecurityException;
+import org.elasticsearch.action.index.IndexRequest;
 import org.elasticsearch.action.ingest.PutPipelineAction;
 import org.elasticsearch.action.ingest.PutPipelineRequest;
 import org.elasticsearch.action.ingest.SimulateDocumentBaseResult;
 import org.elasticsearch.action.ingest.SimulatePipelineAction;
 import org.elasticsearch.action.ingest.SimulatePipelineRequest;
 import org.elasticsearch.action.ingest.SimulatePipelineResponse;
+import org.elasticsearch.action.search.SearchRequest;
 import org.elasticsearch.action.support.PlainActionFuture;
+import org.elasticsearch.action.support.WriteRequest;
 import org.elasticsearch.action.support.master.AcknowledgedResponse;
 import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.common.bytes.BytesArray;
@@ -21,6 +25,8 @@ import org.elasticsearch.common.xcontent.XContentType;
 import org.elasticsearch.license.License.OperationMode;
 import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
 import org.elasticsearch.rest.RestStatus;
+import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
+import org.elasticsearch.search.aggregations.metrics.AvgAggregationBuilder;
 import org.elasticsearch.xpack.core.XPackField;
 import org.elasticsearch.xpack.core.ml.MlConfigIndex;
 import org.elasticsearch.xpack.core.ml.action.CloseJobAction;
@@ -46,12 +52,15 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode;
 import org.elasticsearch.xpack.core.ml.job.config.JobState;
+import org.elasticsearch.xpack.ml.inference.aggs.InferencePipelineAggregationBuilder;
+import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
 import org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase;
 import org.junit.Before;
 
 import java.nio.charset.StandardCharsets;
-import java.util.Arrays;
 import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
 
 import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.empty;
@@ -140,7 +149,7 @@ public class MachineLearningLicensingIT extends BaseMlIntegTestCase {
         assertNotNull(response2);
     }
 
-    public void testMachineLearningPutDatafeedActionRestricted() throws Exception {
+    public void testMachineLearningPutDatafeedActionRestricted() {
         String jobId = "testmachinelearningputdatafeedactionrestricted";
         String datafeedId = jobId + "-datafeed";
         assertMLAllowed(true);
@@ -431,7 +440,7 @@ public class MachineLearningLicensingIT extends BaseMlIntegTestCase {
         }
     }
 
-    public void testMachineLearningDeleteJobActionNotRestricted() throws Exception {
+    public void testMachineLearningDeleteJobActionNotRestricted() {
         String jobId = "testmachinelearningclosejobactionnotrestricted";
         assertMLAllowed(true);
         // test that license restricted apis do now work
@@ -449,7 +458,7 @@ public class MachineLearningLicensingIT extends BaseMlIntegTestCase {
         listener.actionGet();
     }
 
-    public void testMachineLearningDeleteDatafeedActionNotRestricted() throws Exception {
+    public void testMachineLearningDeleteDatafeedActionNotRestricted() {
         String jobId = "testmachinelearningdeletedatafeedactionnotrestricted";
         String datafeedId = jobId + "-datafeed";
         assertMLAllowed(true);
@@ -474,7 +483,7 @@ public class MachineLearningLicensingIT extends BaseMlIntegTestCase {
         listener.actionGet();
     }
 
-    public void testMachineLearningCreateInferenceProcessorRestricted() throws Exception {
+    public void testMachineLearningCreateInferenceProcessorRestricted() {
         String modelId = "modelprocessorlicensetest";
         assertMLAllowed(true);
         putInferenceModel(modelId);
@@ -606,7 +615,7 @@ public class MachineLearningLicensingIT extends BaseMlIntegTestCase {
             .actionGet();
     }
 
-    public void testMachineLearningInferModelRestricted() throws Exception {
+    public void testMachineLearningInferModelRestricted() {
         String modelId = "modelinfermodellicensetest";
         assertMLAllowed(true);
         putInferenceModel(modelId);
@@ -668,6 +677,57 @@ public class MachineLearningLicensingIT extends BaseMlIntegTestCase {
         assertThat(listener.actionGet().getInferenceResults(), is(not(empty())));
     }
 
+    public void testInferenceAggRestricted() {
+        String modelId = "inference-agg-restricted";
+        assertMLAllowed(true);
+        putInferenceModel(modelId);
+
+        // index some data
+        String index = "inference-agg-licence-test";
+        client().admin().indices().prepareCreate(index).setMapping("feature1", "type=double", "feature2", "type=keyword").get();
+        client().prepareBulk(index)
+            .add(new IndexRequest().source("feature1", "10.0", "feature2", "foo"))
+            .add(new IndexRequest().source("feature1", "20.0", "feature2", "foo"))
+            .add(new IndexRequest().source("feature1", "20.0", "feature2", "bar"))
+            .add(new IndexRequest().source("feature1", "20.0", "feature2", "bar"))
+            .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
+            .get();
+
+        TermsAggregationBuilder termsAgg = new TermsAggregationBuilder("foobar").field("feature2");
+        AvgAggregationBuilder avgAgg = new AvgAggregationBuilder("avg_feature1").field("feature1");
+        termsAgg.subAggregation(avgAgg);
+
+        XPackLicenseState licenseState = internalCluster().getInstance(XPackLicenseState.class);
+        ModelLoadingService modelLoading = internalCluster().getInstance(ModelLoadingService.class);
+
+        Map<String, String> bucketPaths = new HashMap<>();
+        bucketPaths.put("feature1", "avg_feature1");
+        InferencePipelineAggregationBuilder inferenceAgg =
+            new InferencePipelineAggregationBuilder("infer_agg", new SetOnce<>(modelLoading), licenseState, bucketPaths);
+        inferenceAgg.setModelId(modelId);
+
+        termsAgg.subAggregation(inferenceAgg);
+
+        SearchRequest search = new SearchRequest(index);
+        search.source().aggregation(termsAgg);
+        client().search(search).actionGet();
+
+        // Pick a license that does not allow machine learning
+        License.OperationMode mode = randomInvalidLicenseType();
+        enableLicensing(mode);
+        assertMLAllowed(false);
+
+        // inferring against a model should now fail
+        SearchRequest invalidSearch = new SearchRequest(index);
+        invalidSearch.source().aggregation(termsAgg);
+        ElasticsearchSecurityException e = expectThrows(ElasticsearchSecurityException.class,
+            () -> client().search(invalidSearch).actionGet());
+
+        assertThat(e.status(), is(RestStatus.FORBIDDEN));
+        assertThat(e.getMessage(), containsString("current license is non-compliant for [ml]"));
+        assertThat(e.getMetadata(LicenseUtils.EXPIRED_FEATURE_METADATA), hasItem(XPackField.MACHINE_LEARNING));
+    }
+
     private void putInferenceModel(String modelId) {
         TrainedModelConfig config = TrainedModelConfig.builder()
             .setParsedDefinition(
@@ -675,13 +735,13 @@ public class MachineLearningLicensingIT extends BaseMlIntegTestCase {
             .setTrainedModel(
             Tree.builder()
                 .setTargetType(TargetType.REGRESSION)
-                .setFeatureNames(Arrays.asList("feature1"))
+                .setFeatureNames(Collections.singletonList("feature1"))
                 .setNodes(TreeNode.builder(0).setLeafValue(1.0))
                 .build())
             .setPreProcessors(Collections.emptyList()))
             .setModelId(modelId)
             .setDescription("test model for classification")
-            .setInput(new TrainedModelInput(Arrays.asList("feature1")))
+            .setInput(new TrainedModelInput(Collections.singletonList("feature1")))
             .setInferenceConfig(RegressionConfig.EMPTY_PARAMS)
             .build();
         client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(config)).actionGet();

+ 2 - 3
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java

@@ -988,10 +988,9 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
     @Override
     public List<PipelineAggregationSpec> getPipelineAggregations() {
         PipelineAggregationSpec spec = new PipelineAggregationSpec(InferencePipelineAggregationBuilder.NAME,
-            in -> new InferencePipelineAggregationBuilder(in, modelLoadingService),
+            in -> new InferencePipelineAggregationBuilder(in, getLicenseState(), modelLoadingService),
             (ContextParser<String, ? extends PipelineAggregationBuilder>)
-                (parser, name) -> InferencePipelineAggregationBuilder.parse(modelLoadingService, name, parser
-                ));
+                (parser, name) -> InferencePipelineAggregationBuilder.parse(modelLoadingService, getLicenseState(), name, parser));
         spec.addResultReader(InternalInferenceAggregation::new);
 
         return Collections.singletonList(spec);

+ 44 - 15
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/aggs/InferencePipelineAggregationBuilder.java

@@ -10,15 +10,17 @@ import org.apache.lucene.util.SetOnce;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.common.ParseField;
 import org.elasticsearch.common.Strings;
-import org.elasticsearch.common.collect.Tuple;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.xcontent.ConstructingObjectParser;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.index.query.QueryRewriteContext;
+import org.elasticsearch.license.LicenseUtils;
+import org.elasticsearch.license.XPackLicenseState;
 import org.elasticsearch.search.aggregations.pipeline.AbstractPipelineAggregationBuilder;
 import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
+import org.elasticsearch.xpack.core.XPackField;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
@@ -44,10 +46,10 @@ public class InferencePipelineAggregationBuilder extends AbstractPipelineAggrega
     static final String AGGREGATIONS_RESULTS_FIELD = "value";
 
     @SuppressWarnings("unchecked")
-    private static final ConstructingObjectParser<InferencePipelineAggregationBuilder,
-        Tuple<SetOnce<ModelLoadingService>, String>> PARSER = new ConstructingObjectParser<>(
-        NAME, false,
-        (args, context) -> new InferencePipelineAggregationBuilder(context.v2(), context.v1(), (Map<String, String>) args[0])
+    private static final ConstructingObjectParser<InferencePipelineAggregationBuilder, ParserSupplement> PARSER =
+        new ConstructingObjectParser<>(NAME, false,
+        (args, context) -> new InferencePipelineAggregationBuilder(context.name, context.modelLoadingService,
+            context.licenseState, (Map<String, String>) args[0])
     );
 
     static {
@@ -60,34 +62,52 @@ public class InferencePipelineAggregationBuilder extends AbstractPipelineAggrega
     private final Map<String, String> bucketPathMap;
     private String modelId;
     private InferenceConfigUpdate inferenceConfig;
+    private final XPackLicenseState licenseState;
     private final SetOnce<ModelLoadingService> modelLoadingService;
     /**
      * The model. Set to a non-null value during the rewrite phase.
      */
     private final Supplier<LocalModel> model;
 
+    private static class ParserSupplement {
+        final XPackLicenseState licenseState;
+        final SetOnce<ModelLoadingService> modelLoadingService;
+        final String name;
+
+        ParserSupplement(String name, XPackLicenseState licenseState, SetOnce<ModelLoadingService> modelLoadingService) {
+            this.name = name;
+            this.licenseState = licenseState;
+            this.modelLoadingService = modelLoadingService;
+        }
+    }
     public static InferencePipelineAggregationBuilder parse(SetOnce<ModelLoadingService> modelLoadingService,
+                                                            XPackLicenseState licenseState,
                                                             String pipelineAggregatorName,
                                                             XContentParser parser) {
-        Tuple<SetOnce<ModelLoadingService>, String> context = new Tuple<>(modelLoadingService, pipelineAggregatorName);
-        return PARSER.apply(parser, context);
+        return PARSER.apply(parser, new ParserSupplement(pipelineAggregatorName, licenseState, modelLoadingService));
     }
 
-    public InferencePipelineAggregationBuilder(String name, SetOnce<ModelLoadingService> modelLoadingService,
+    public InferencePipelineAggregationBuilder(String name,
+                                               SetOnce<ModelLoadingService> modelLoadingService,
+                                               XPackLicenseState licenseState,
                                                Map<String, String> bucketsPath) {
         super(name, NAME, new TreeMap<>(bucketsPath).values().toArray(new String[] {}));
         this.modelLoadingService = modelLoadingService;
         this.bucketPathMap = bucketsPath;
         this.model = null;
+        this.licenseState = licenseState;
     }
 
-    public InferencePipelineAggregationBuilder(StreamInput in, SetOnce<ModelLoadingService> modelLoadingService) throws IOException {
+    public InferencePipelineAggregationBuilder(StreamInput in,
+                                               XPackLicenseState licenseState,
+                                               SetOnce<ModelLoadingService> modelLoadingService) throws IOException {
         super(in, NAME);
         modelId = in.readString();
         bucketPathMap = in.readMap(StreamInput::readString, StreamInput::readString);
         inferenceConfig = in.readOptionalNamedWriteable(InferenceConfigUpdate.class);
         this.modelLoadingService = modelLoadingService;
         this.model = null;
+        this.licenseState = licenseState;
     }
 
     /**
@@ -98,7 +118,8 @@ public class InferencePipelineAggregationBuilder extends AbstractPipelineAggrega
         Map<String, String> bucketsPath,
         Supplier<LocalModel> model,
         String modelId,
-        InferenceConfigUpdate inferenceConfig
+        InferenceConfigUpdate inferenceConfig,
+        XPackLicenseState licenseState
     ) {
         super(name, NAME, new TreeMap<>(bucketsPath).values().toArray(new String[] {}));
         modelLoadingService = null;
@@ -113,13 +134,14 @@ public class InferencePipelineAggregationBuilder extends AbstractPipelineAggrega
          */
         this.modelId = modelId;
         this.inferenceConfig = inferenceConfig;
+        this.licenseState = licenseState;
     }
 
-    void setModelId(String modelId) {
+    public void setModelId(String modelId) {
         this.modelId = modelId;
     }
 
-    void setInferenceConfig(InferenceConfigUpdate inferenceConfig) {
+    public void setInferenceConfig(InferenceConfigUpdate inferenceConfig) {
         this.inferenceConfig = inferenceConfig;
     }
 
@@ -160,7 +182,7 @@ public class InferencePipelineAggregationBuilder extends AbstractPipelineAggrega
     }
 
     @Override
-    public InferencePipelineAggregationBuilder rewrite(QueryRewriteContext context) throws IOException {
+    public InferencePipelineAggregationBuilder rewrite(QueryRewriteContext context) {
         if (model != null) {
             return this;
         }
@@ -168,10 +190,17 @@ public class InferencePipelineAggregationBuilder extends AbstractPipelineAggrega
         context.registerAsyncAction((client, listener) -> {
             modelLoadingService.get().getModelForSearch(modelId, ActionListener.delegateFailure(listener, (delegate, model) -> {
                 loadedModel.set(model);
-                delegate.onResponse(null);
+
+                boolean isLicensed = licenseState.checkFeature(XPackLicenseState.Feature.MACHINE_LEARNING) ||
+                    licenseState.isAllowedByLicense(model.getLicenseLevel());
+                if (isLicensed) {
+                    delegate.onResponse(null);
+                } else {
+                    delegate.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING));
+                }
             }));
         });
-        return new InferencePipelineAggregationBuilder(name, bucketPathMap, loadedModel::get, modelId, inferenceConfig);
+        return new InferencePipelineAggregationBuilder(name, bucketPathMap, loadedModel::get, modelId, inferenceConfig, licenseState);
     }
 
     @Override

+ 8 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java

@@ -6,6 +6,7 @@
 package org.elasticsearch.xpack.ml.inference.loadingservice;
 
 import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.license.License;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
@@ -38,6 +39,7 @@ public class LocalModel {
     private volatile long persistenceQuotient = 100;
     private final LongAdder currentInferenceCount;
     private final InferenceConfig inferenceConfig;
+    private final License.OperationMode licenseLevel;
 
     public LocalModel(String modelId,
                       String nodeId,
@@ -45,6 +47,7 @@ public class LocalModel {
                       TrainedModelInput input,
                       Map<String, String> defaultFieldMap,
                       InferenceConfig modelInferenceConfig,
+                      License.OperationMode licenseLevel,
                       TrainedModelStatsService trainedModelStatsService) {
         this.trainedModelDefinition = trainedModelDefinition;
         this.modelId = modelId;
@@ -56,6 +59,7 @@ public class LocalModel {
         this.defaultFieldMap = defaultFieldMap == null ? null : new HashMap<>(defaultFieldMap);
         this.currentInferenceCount = new LongAdder();
         this.inferenceConfig = modelInferenceConfig;
+        this.licenseLevel = licenseLevel;
     }
 
     long ramBytesUsed() {
@@ -66,6 +70,10 @@ public class LocalModel {
         return modelId;
     }
 
+    public License.OperationMode getLicenseLevel() {
+        return licenseLevel;
+    }
+
     public InferenceStats getLatestStatsAndReset() {
         return statsAccumulator.currentStatsAndReset();
     }

+ 2 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java

@@ -309,6 +309,7 @@ public class ModelLoadingService implements ClusterStateListener {
                             trainedModelConfig.getInput(),
                             trainedModelConfig.getDefaultFieldMap(),
                             inferenceConfig,
+                            trainedModelConfig.getLicenseLevel(),
                             modelStatsService));
                     },
                     // Failure getting the definition, remove the initial estimation value
@@ -337,6 +338,7 @@ public class ModelLoadingService implements ClusterStateListener {
             trainedModelConfig.getInput(),
             trainedModelConfig.getDefaultFieldMap(),
             inferenceConfig,
+            trainedModelConfig.getLicenseLevel(),
             modelStatsService);
         synchronized (loadingListeners) {
             listeners = loadingListeners.remove(modelId);

+ 3 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/aggs/InferencePipelineAggregationBuilderTests.java

@@ -10,6 +10,7 @@ import org.apache.lucene.util.SetOnce;
 import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.license.XPackLicenseState;
 import org.elasticsearch.plugins.SearchPlugin;
 import org.elasticsearch.search.aggregations.AggregationBuilder;
 import org.elasticsearch.search.aggregations.BasePipelineAggregationTestCase;
@@ -61,7 +62,8 @@ public class InferencePipelineAggregationBuilderTests extends BasePipelineAggreg
             .collect(Collectors.toMap(Function.identity(), (t) -> randomAlphaOfLength(5)));
 
         InferencePipelineAggregationBuilder builder =
-            new InferencePipelineAggregationBuilder(NAME, new SetOnce<>(mock(ModelLoadingService.class)), bucketPaths);
+            new InferencePipelineAggregationBuilder(NAME, new SetOnce<>(mock(ModelLoadingService.class)),
+                mock(XPackLicenseState.class), bucketPaths);
         builder.setModelId(randomAlphaOfLength(6));
 
         if (randomBoolean()) {

+ 7 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java

@@ -7,6 +7,7 @@ package org.elasticsearch.xpack.ml.inference.loadingservice;
 
 import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.ingest.IngestDocument;
+import org.elasticsearch.license.License;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
 import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
@@ -73,6 +74,7 @@ public class LocalModelTests extends ESTestCase {
             new TrainedModelInput(inputFields),
             Collections.singletonMap("field.foo", "field.foo.keyword"),
             ClassificationConfig.EMPTY_PARAMS,
+            randomFrom(License.OperationMode.values()),
             modelStatsService);
         Map<String, Object> fields = new HashMap<>() {{
             put("field.foo", 1.0);
@@ -102,6 +104,7 @@ public class LocalModelTests extends ESTestCase {
             new TrainedModelInput(inputFields),
             Collections.singletonMap("field.foo", "field.foo.keyword"),
             ClassificationConfig.EMPTY_PARAMS,
+            License.OperationMode.PLATINUM,
             modelStatsService);
         result = getSingleValue(model, fields, ClassificationConfigUpdate.EMPTY_PARAMS);
         assertThat(result.value(), equalTo(0.0));
@@ -144,6 +147,7 @@ public class LocalModelTests extends ESTestCase {
             new TrainedModelInput(inputFields),
             Collections.singletonMap("field.foo", "field.foo.keyword"),
             ClassificationConfig.EMPTY_PARAMS,
+            License.OperationMode.PLATINUM,
             modelStatsService);
         Map<String, Object> fields = new HashMap<>() {{
             put("field.foo", 1.0);
@@ -199,6 +203,7 @@ public class LocalModelTests extends ESTestCase {
             new TrainedModelInput(inputFields),
             Collections.singletonMap("bar", "bar.keyword"),
             RegressionConfig.EMPTY_PARAMS,
+            License.OperationMode.PLATINUM,
             modelStatsService);
 
         Map<String, Object> fields = new HashMap<>() {{
@@ -226,6 +231,7 @@ public class LocalModelTests extends ESTestCase {
             new TrainedModelInput(inputFields),
             null,
             RegressionConfig.EMPTY_PARAMS,
+            License.OperationMode.PLATINUM,
             modelStatsService);
 
         Map<String, Object> fields = new HashMap<>() {{
@@ -256,6 +262,7 @@ public class LocalModelTests extends ESTestCase {
             new TrainedModelInput(inputFields),
             null,
             ClassificationConfig.EMPTY_PARAMS,
+            License.OperationMode.PLATINUM,
             modelStatsService
         );
         Map<String, Object> fields = new HashMap<>() {{