|
@@ -114,6 +114,8 @@ import org.elasticsearch.client.ml.PutFilterRequest;
|
|
|
import org.elasticsearch.client.ml.PutFilterResponse;
|
|
|
import org.elasticsearch.client.ml.PutJobRequest;
|
|
|
import org.elasticsearch.client.ml.PutJobResponse;
|
|
|
+import org.elasticsearch.client.ml.PutTrainedModelRequest;
|
|
|
+import org.elasticsearch.client.ml.PutTrainedModelResponse;
|
|
|
import org.elasticsearch.client.ml.RevertModelSnapshotRequest;
|
|
|
import org.elasticsearch.client.ml.RevertModelSnapshotResponse;
|
|
|
import org.elasticsearch.client.ml.SetUpgradeModeRequest;
|
|
@@ -162,10 +164,14 @@ import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.Recal
|
|
|
import org.elasticsearch.client.ml.dataframe.explain.FieldSelection;
|
|
|
import org.elasticsearch.client.ml.dataframe.explain.MemoryEstimation;
|
|
|
import org.elasticsearch.client.ml.filestructurefinder.FileStructure;
|
|
|
+import org.elasticsearch.client.ml.inference.InferenceToXContentCompressor;
|
|
|
+import org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider;
|
|
|
import org.elasticsearch.client.ml.inference.TrainedModelConfig;
|
|
|
import org.elasticsearch.client.ml.inference.TrainedModelDefinition;
|
|
|
import org.elasticsearch.client.ml.inference.TrainedModelDefinitionTests;
|
|
|
+import org.elasticsearch.client.ml.inference.TrainedModelInput;
|
|
|
import org.elasticsearch.client.ml.inference.TrainedModelStats;
|
|
|
+import org.elasticsearch.client.ml.inference.trainedmodel.TargetType;
|
|
|
import org.elasticsearch.client.ml.job.config.AnalysisConfig;
|
|
|
import org.elasticsearch.client.ml.job.config.AnalysisLimits;
|
|
|
import org.elasticsearch.client.ml.job.config.DataDescription;
|
|
@@ -186,12 +192,11 @@ import org.elasticsearch.client.ml.job.results.Influencer;
|
|
|
import org.elasticsearch.client.ml.job.results.OverallBucket;
|
|
|
import org.elasticsearch.client.ml.job.stats.JobStats;
|
|
|
import org.elasticsearch.common.bytes.BytesReference;
|
|
|
-import org.elasticsearch.common.io.stream.BytesStreamOutput;
|
|
|
import org.elasticsearch.common.unit.ByteSizeUnit;
|
|
|
import org.elasticsearch.common.unit.ByteSizeValue;
|
|
|
import org.elasticsearch.common.unit.TimeValue;
|
|
|
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
|
|
import org.elasticsearch.common.xcontent.XContentFactory;
|
|
|
-import org.elasticsearch.common.xcontent.XContentHelper;
|
|
|
import org.elasticsearch.common.xcontent.XContentType;
|
|
|
import org.elasticsearch.index.query.MatchAllQueryBuilder;
|
|
|
import org.elasticsearch.index.query.QueryBuilders;
|
|
@@ -202,12 +207,10 @@ import org.elasticsearch.tasks.TaskId;
|
|
|
import org.junit.After;
|
|
|
|
|
|
import java.io.IOException;
|
|
|
-import java.io.OutputStream;
|
|
|
import java.nio.charset.StandardCharsets;
|
|
|
import java.nio.file.Files;
|
|
|
import java.nio.file.Path;
|
|
|
import java.util.Arrays;
|
|
|
-import java.util.Base64;
|
|
|
import java.util.Collections;
|
|
|
import java.util.Date;
|
|
|
import java.util.HashMap;
|
|
@@ -216,7 +219,6 @@ import java.util.Map;
|
|
|
import java.util.concurrent.CountDownLatch;
|
|
|
import java.util.concurrent.TimeUnit;
|
|
|
import java.util.stream.Collectors;
|
|
|
-import java.util.zip.GZIPOutputStream;
|
|
|
|
|
|
import static org.hamcrest.Matchers.allOf;
|
|
|
import static org.hamcrest.Matchers.closeTo;
|
|
@@ -3625,6 +3627,79 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ public void testPutTrainedModel() throws Exception {
|
|
|
+ TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder(TargetType.REGRESSION).build();
|
|
|
+ // tag::put-trained-model-config
|
|
|
+ TrainedModelConfig trainedModelConfig = TrainedModelConfig.builder()
|
|
|
+ .setDefinition(definition) // <1>
|
|
|
+ .setCompressedDefinition(InferenceToXContentCompressor.deflate(definition)) // <2>
|
|
|
+ .setModelId("my-new-trained-model") // <3>
|
|
|
+ .setInput(new TrainedModelInput("col1", "col2", "col3", "col4")) // <4>
|
|
|
+ .setDescription("test model") // <5>
|
|
|
+ .setMetadata(new HashMap<>()) // <6>
|
|
|
+ .setTags("my_regression_models") // <7>
|
|
|
+ .build();
|
|
|
+ // end::put-trained-model-config
|
|
|
+
|
|
|
+ trainedModelConfig = TrainedModelConfig.builder()
|
|
|
+ .setDefinition(definition)
|
|
|
+ .setModelId("my-new-trained-model")
|
|
|
+ .setInput(new TrainedModelInput("col1", "col2", "col3", "col4"))
|
|
|
+ .setDescription("test model")
|
|
|
+ .setMetadata(new HashMap<>())
|
|
|
+ .setTags("my_regression_models")
|
|
|
+ .build();
|
|
|
+
|
|
|
+ RestHighLevelClient client = highLevelClient();
|
|
|
+ {
|
|
|
+ // tag::put-trained-model-request
|
|
|
+ PutTrainedModelRequest request = new PutTrainedModelRequest(trainedModelConfig); // <1>
|
|
|
+ // end::put-trained-model-request
|
|
|
+
|
|
|
+ // tag::put-trained-model-execute
|
|
|
+ PutTrainedModelResponse response = client.machineLearning().putTrainedModel(request, RequestOptions.DEFAULT);
|
|
|
+ // end::put-trained-model-execute
|
|
|
+
|
|
|
+ // tag::put-trained-model-response
|
|
|
+ TrainedModelConfig model = response.getResponse();
|
|
|
+ // end::put-trained-model-response
|
|
|
+
|
|
|
+ assertThat(model.getModelId(), equalTo(trainedModelConfig.getModelId()));
|
|
|
+ highLevelClient().machineLearning()
|
|
|
+ .deleteTrainedModel(new DeleteTrainedModelRequest("my-new-trained-model"), RequestOptions.DEFAULT);
|
|
|
+ }
|
|
|
+ {
|
|
|
+ PutTrainedModelRequest request = new PutTrainedModelRequest(trainedModelConfig);
|
|
|
+
|
|
|
+ // tag::put-trained-model-execute-listener
|
|
|
+ ActionListener<PutTrainedModelResponse> listener = new ActionListener<>() {
|
|
|
+ @Override
|
|
|
+ public void onResponse(PutTrainedModelResponse response) {
|
|
|
+ // <1>
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void onFailure(Exception e) {
|
|
|
+ // <2>
|
|
|
+ }
|
|
|
+ };
|
|
|
+ // end::put-trained-model-execute-listener
|
|
|
+
|
|
|
+ // Replace the empty listener by a blocking listener in test
|
|
|
+ CountDownLatch latch = new CountDownLatch(1);
|
|
|
+ listener = new LatchedActionListener<>(listener, latch);
|
|
|
+
|
|
|
+ // tag::put-trained-model-execute-async
|
|
|
+ client.machineLearning().putTrainedModelAsync(request, RequestOptions.DEFAULT, listener); // <1>
|
|
|
+ // end::put-trained-model-execute-async
|
|
|
+
|
|
|
+ assertTrue(latch.await(30L, TimeUnit.SECONDS));
|
|
|
+
|
|
|
+ highLevelClient().machineLearning()
|
|
|
+ .deleteTrainedModel(new DeleteTrainedModelRequest("my-new-trained-model"), RequestOptions.DEFAULT);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
public void testGetTrainedModelsStats() throws Exception {
|
|
|
putTrainedModel("my-trained-model");
|
|
|
RestHighLevelClient client = highLevelClient();
|
|
@@ -4088,57 +4163,19 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
|
|
}
|
|
|
|
|
|
private void putTrainedModel(String modelId) throws IOException {
|
|
|
- TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder().build();
|
|
|
- highLevelClient().index(
|
|
|
- new IndexRequest(".ml-inference-000001")
|
|
|
- .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
|
|
- .source(modelConfigString(modelId), XContentType.JSON)
|
|
|
- .id(modelId),
|
|
|
- RequestOptions.DEFAULT);
|
|
|
-
|
|
|
- highLevelClient().index(
|
|
|
- new IndexRequest(".ml-inference-000001")
|
|
|
- .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
|
|
- .source(modelDocString(compressDefinition(definition), modelId), XContentType.JSON)
|
|
|
- .id("trained_model_definition_doc-" + modelId + "-0"),
|
|
|
- RequestOptions.DEFAULT);
|
|
|
- }
|
|
|
-
|
|
|
- private String compressDefinition(TrainedModelDefinition definition) throws IOException {
|
|
|
- BytesReference reference = XContentHelper.toXContent(definition, XContentType.JSON, false);
|
|
|
- BytesStreamOutput out = new BytesStreamOutput();
|
|
|
- try (OutputStream compressedOutput = new GZIPOutputStream(out, 4096)) {
|
|
|
- reference.writeTo(compressedOutput);
|
|
|
- }
|
|
|
- return new String(Base64.getEncoder().encode(BytesReference.toBytes(out.bytes())), StandardCharsets.UTF_8);
|
|
|
- }
|
|
|
-
|
|
|
- private static String modelConfigString(String modelId) {
|
|
|
- return "{\n" +
|
|
|
- " \"doc_type\": \"trained_model_config\",\n" +
|
|
|
- " \"model_id\": \"" + modelId + "\",\n" +
|
|
|
- " \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," +
|
|
|
- " \"description\": \"test model for\",\n" +
|
|
|
- " \"version\": \"7.6.0\",\n" +
|
|
|
- " \"license_level\": \"platinum\",\n" +
|
|
|
- " \"created_by\": \"ml_test\",\n" +
|
|
|
- " \"estimated_heap_memory_usage_bytes\": 0," +
|
|
|
- " \"estimated_operations\": 0," +
|
|
|
- " \"created_time\": 0\n" +
|
|
|
- "}";
|
|
|
+ TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder(TargetType.REGRESSION).build();
|
|
|
+ TrainedModelConfig trainedModelConfig = TrainedModelConfig.builder()
|
|
|
+ .setDefinition(definition)
|
|
|
+ .setModelId(modelId)
|
|
|
+ .setInput(new TrainedModelInput(Arrays.asList("col1", "col2", "col3", "col4")))
|
|
|
+ .setDescription("test model")
|
|
|
+ .build();
|
|
|
+ highLevelClient().machineLearning().putTrainedModel(new PutTrainedModelRequest(trainedModelConfig), RequestOptions.DEFAULT);
|
|
|
}
|
|
|
|
|
|
- private static String modelDocString(String compressedDefinition, String modelId) {
|
|
|
- return "" +
|
|
|
- "{" +
|
|
|
- "\"model_id\": \"" + modelId + "\",\n" +
|
|
|
- "\"doc_num\": 0,\n" +
|
|
|
- "\"doc_type\": \"trained_model_definition_doc\",\n" +
|
|
|
- " \"compression_version\": " + 1 + ",\n" +
|
|
|
- " \"total_definition_length\": " + compressedDefinition.length() + ",\n" +
|
|
|
- " \"definition_length\": " + compressedDefinition.length() + ",\n" +
|
|
|
- "\"definition\": \"" + compressedDefinition + "\"\n" +
|
|
|
- "}";
|
|
|
+ @Override
|
|
|
+ protected NamedXContentRegistry xContentRegistry() {
|
|
|
+ return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
|
|
}
|
|
|
|
|
|
private static final DataFrameAnalyticsConfig DF_ANALYTICS_CONFIG =
|