|
@@ -5,14 +5,18 @@
|
|
|
*/
|
|
|
package org.elasticsearch.license;
|
|
|
|
|
|
+import org.apache.lucene.util.SetOnce;
|
|
|
import org.elasticsearch.ElasticsearchSecurityException;
|
|
|
+import org.elasticsearch.action.index.IndexRequest;
|
|
|
import org.elasticsearch.action.ingest.PutPipelineAction;
|
|
|
import org.elasticsearch.action.ingest.PutPipelineRequest;
|
|
|
import org.elasticsearch.action.ingest.SimulateDocumentBaseResult;
|
|
|
import org.elasticsearch.action.ingest.SimulatePipelineAction;
|
|
|
import org.elasticsearch.action.ingest.SimulatePipelineRequest;
|
|
|
import org.elasticsearch.action.ingest.SimulatePipelineResponse;
|
|
|
+import org.elasticsearch.action.search.SearchRequest;
|
|
|
import org.elasticsearch.action.support.PlainActionFuture;
|
|
|
+import org.elasticsearch.action.support.WriteRequest;
|
|
|
import org.elasticsearch.action.support.master.AcknowledgedResponse;
|
|
|
import org.elasticsearch.cluster.ClusterState;
|
|
|
import org.elasticsearch.common.bytes.BytesArray;
|
|
@@ -21,6 +25,8 @@ import org.elasticsearch.common.xcontent.XContentType;
|
|
|
import org.elasticsearch.license.License.OperationMode;
|
|
|
import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
|
|
|
import org.elasticsearch.rest.RestStatus;
|
|
|
+import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
|
|
|
+import org.elasticsearch.search.aggregations.metrics.AvgAggregationBuilder;
|
|
|
import org.elasticsearch.xpack.core.XPackField;
|
|
|
import org.elasticsearch.xpack.core.ml.MlConfigIndex;
|
|
|
import org.elasticsearch.xpack.core.ml.action.CloseJobAction;
|
|
@@ -46,12 +52,15 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode;
|
|
|
import org.elasticsearch.xpack.core.ml.job.config.JobState;
|
|
|
+import org.elasticsearch.xpack.ml.inference.aggs.InferencePipelineAggregationBuilder;
|
|
|
+import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
|
|
|
import org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase;
|
|
|
import org.junit.Before;
|
|
|
|
|
|
import java.nio.charset.StandardCharsets;
|
|
|
-import java.util.Arrays;
|
|
|
import java.util.Collections;
|
|
|
+import java.util.HashMap;
|
|
|
+import java.util.Map;
|
|
|
|
|
|
import static org.hamcrest.Matchers.containsString;
|
|
|
import static org.hamcrest.Matchers.empty;
|
|
@@ -140,7 +149,7 @@ public class MachineLearningLicensingIT extends BaseMlIntegTestCase {
|
|
|
assertNotNull(response2);
|
|
|
}
|
|
|
|
|
|
- public void testMachineLearningPutDatafeedActionRestricted() throws Exception {
|
|
|
+ public void testMachineLearningPutDatafeedActionRestricted() {
|
|
|
String jobId = "testmachinelearningputdatafeedactionrestricted";
|
|
|
String datafeedId = jobId + "-datafeed";
|
|
|
assertMLAllowed(true);
|
|
@@ -431,7 +440,7 @@ public class MachineLearningLicensingIT extends BaseMlIntegTestCase {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- public void testMachineLearningDeleteJobActionNotRestricted() throws Exception {
|
|
|
+ public void testMachineLearningDeleteJobActionNotRestricted() {
|
|
|
String jobId = "testmachinelearningclosejobactionnotrestricted";
|
|
|
assertMLAllowed(true);
|
|
|
// test that license restricted apis do now work
|
|
@@ -449,7 +458,7 @@ public class MachineLearningLicensingIT extends BaseMlIntegTestCase {
|
|
|
listener.actionGet();
|
|
|
}
|
|
|
|
|
|
- public void testMachineLearningDeleteDatafeedActionNotRestricted() throws Exception {
|
|
|
+ public void testMachineLearningDeleteDatafeedActionNotRestricted() {
|
|
|
String jobId = "testmachinelearningdeletedatafeedactionnotrestricted";
|
|
|
String datafeedId = jobId + "-datafeed";
|
|
|
assertMLAllowed(true);
|
|
@@ -474,7 +483,7 @@ public class MachineLearningLicensingIT extends BaseMlIntegTestCase {
|
|
|
listener.actionGet();
|
|
|
}
|
|
|
|
|
|
- public void testMachineLearningCreateInferenceProcessorRestricted() throws Exception {
|
|
|
+ public void testMachineLearningCreateInferenceProcessorRestricted() {
|
|
|
String modelId = "modelprocessorlicensetest";
|
|
|
assertMLAllowed(true);
|
|
|
putInferenceModel(modelId);
|
|
@@ -606,7 +615,7 @@ public class MachineLearningLicensingIT extends BaseMlIntegTestCase {
|
|
|
.actionGet();
|
|
|
}
|
|
|
|
|
|
- public void testMachineLearningInferModelRestricted() throws Exception {
|
|
|
+ public void testMachineLearningInferModelRestricted() {
|
|
|
String modelId = "modelinfermodellicensetest";
|
|
|
assertMLAllowed(true);
|
|
|
putInferenceModel(modelId);
|
|
@@ -668,6 +677,57 @@ public class MachineLearningLicensingIT extends BaseMlIntegTestCase {
|
|
|
assertThat(listener.actionGet().getInferenceResults(), is(not(empty())));
|
|
|
}
|
|
|
|
|
|
+ public void testInferenceAggRestricted() {
|
|
|
+ String modelId = "inference-agg-restricted";
|
|
|
+ assertMLAllowed(true);
|
|
|
+ putInferenceModel(modelId);
|
|
|
+
|
|
|
+ // index some data
|
|
|
+ String index = "inference-agg-licence-test";
|
|
|
+ client().admin().indices().prepareCreate(index).setMapping("feature1", "type=double", "feature2", "type=keyword").get();
|
|
|
+ client().prepareBulk(index)
|
|
|
+ .add(new IndexRequest().source("feature1", "10.0", "feature2", "foo"))
|
|
|
+ .add(new IndexRequest().source("feature1", "20.0", "feature2", "foo"))
|
|
|
+ .add(new IndexRequest().source("feature1", "20.0", "feature2", "bar"))
|
|
|
+ .add(new IndexRequest().source("feature1", "20.0", "feature2", "bar"))
|
|
|
+ .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
|
|
+ .get();
|
|
|
+
|
|
|
+ TermsAggregationBuilder termsAgg = new TermsAggregationBuilder("foobar").field("feature2");
|
|
|
+ AvgAggregationBuilder avgAgg = new AvgAggregationBuilder("avg_feature1").field("feature1");
|
|
|
+ termsAgg.subAggregation(avgAgg);
|
|
|
+
|
|
|
+ XPackLicenseState licenseState = internalCluster().getInstance(XPackLicenseState.class);
|
|
|
+ ModelLoadingService modelLoading = internalCluster().getInstance(ModelLoadingService.class);
|
|
|
+
|
|
|
+ Map<String, String> bucketPaths = new HashMap<>();
|
|
|
+ bucketPaths.put("feature1", "avg_feature1");
|
|
|
+ InferencePipelineAggregationBuilder inferenceAgg =
|
|
|
+ new InferencePipelineAggregationBuilder("infer_agg", new SetOnce<>(modelLoading), licenseState, bucketPaths);
|
|
|
+ inferenceAgg.setModelId(modelId);
|
|
|
+
|
|
|
+ termsAgg.subAggregation(inferenceAgg);
|
|
|
+
|
|
|
+ SearchRequest search = new SearchRequest(index);
|
|
|
+ search.source().aggregation(termsAgg);
|
|
|
+ client().search(search).actionGet();
|
|
|
+
|
|
|
+ // Pick a license that does not allow machine learning
|
|
|
+ License.OperationMode mode = randomInvalidLicenseType();
|
|
|
+ enableLicensing(mode);
|
|
|
+ assertMLAllowed(false);
|
|
|
+
|
|
|
+ // inferring against a model should now fail
|
|
|
+ SearchRequest invalidSearch = new SearchRequest(index);
|
|
|
+ invalidSearch.source().aggregation(termsAgg);
|
|
|
+ ElasticsearchSecurityException e = expectThrows(ElasticsearchSecurityException.class,
|
|
|
+ () -> client().search(invalidSearch).actionGet());
|
|
|
+
|
|
|
+ assertThat(e.status(), is(RestStatus.FORBIDDEN));
|
|
|
+ assertThat(e.getMessage(), containsString("current license is non-compliant for [ml]"));
|
|
|
+ assertThat(e.getMetadata(LicenseUtils.EXPIRED_FEATURE_METADATA), hasItem(XPackField.MACHINE_LEARNING));
|
|
|
+ }
|
|
|
+
|
|
|
private void putInferenceModel(String modelId) {
|
|
|
TrainedModelConfig config = TrainedModelConfig.builder()
|
|
|
.setParsedDefinition(
|
|
@@ -675,13 +735,13 @@ public class MachineLearningLicensingIT extends BaseMlIntegTestCase {
|
|
|
.setTrainedModel(
|
|
|
Tree.builder()
|
|
|
.setTargetType(TargetType.REGRESSION)
|
|
|
- .setFeatureNames(Arrays.asList("feature1"))
|
|
|
+ .setFeatureNames(Collections.singletonList("feature1"))
|
|
|
.setNodes(TreeNode.builder(0).setLeafValue(1.0))
|
|
|
.build())
|
|
|
.setPreProcessors(Collections.emptyList()))
|
|
|
.setModelId(modelId)
|
|
|
.setDescription("test model for classification")
|
|
|
- .setInput(new TrainedModelInput(Arrays.asList("feature1")))
|
|
|
+ .setInput(new TrainedModelInput(Collections.singletonList("feature1")))
|
|
|
.setInferenceConfig(RegressionConfig.EMPTY_PARAMS)
|
|
|
.build();
|
|
|
client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(config)).actionGet();
|