|
@@ -24,6 +24,7 @@ import org.elasticsearch.action.bulk.BulkRequest;
|
|
|
import org.elasticsearch.action.get.GetRequest;
|
|
|
import org.elasticsearch.action.get.GetResponse;
|
|
|
import org.elasticsearch.action.index.IndexRequest;
|
|
|
+import org.elasticsearch.action.ingest.PutPipelineRequest;
|
|
|
import org.elasticsearch.action.support.WriteRequest;
|
|
|
import org.elasticsearch.action.support.master.AcknowledgedResponse;
|
|
|
import org.elasticsearch.action.update.UpdateRequest;
|
|
@@ -77,6 +78,8 @@ import org.elasticsearch.client.ml.GetModelSnapshotsRequest;
|
|
|
import org.elasticsearch.client.ml.GetModelSnapshotsResponse;
|
|
|
import org.elasticsearch.client.ml.GetTrainedModelsRequest;
|
|
|
import org.elasticsearch.client.ml.GetTrainedModelsResponse;
|
|
|
+import org.elasticsearch.client.ml.GetTrainedModelsStatsRequest;
|
|
|
+import org.elasticsearch.client.ml.GetTrainedModelsStatsResponse;
|
|
|
import org.elasticsearch.client.ml.MlInfoRequest;
|
|
|
import org.elasticsearch.client.ml.MlInfoResponse;
|
|
|
import org.elasticsearch.client.ml.OpenJobRequest;
|
|
@@ -148,6 +151,8 @@ import org.elasticsearch.client.ml.filestructurefinder.FileStructure;
|
|
|
import org.elasticsearch.client.ml.inference.TrainedModelConfig;
|
|
|
import org.elasticsearch.client.ml.inference.TrainedModelDefinition;
|
|
|
import org.elasticsearch.client.ml.inference.TrainedModelDefinitionTests;
|
|
|
+import org.elasticsearch.client.ml.inference.TrainedModelStats;
|
|
|
+import org.elasticsearch.client.ml.inference.trainedmodel.TargetType;
|
|
|
import org.elasticsearch.client.ml.job.config.AnalysisConfig;
|
|
|
import org.elasticsearch.client.ml.job.config.DataDescription;
|
|
|
import org.elasticsearch.client.ml.job.config.Detector;
|
|
@@ -157,6 +162,7 @@ import org.elasticsearch.client.ml.job.config.JobUpdate;
|
|
|
import org.elasticsearch.client.ml.job.config.MlFilter;
|
|
|
import org.elasticsearch.client.ml.job.process.ModelSnapshot;
|
|
|
import org.elasticsearch.client.ml.job.stats.JobStats;
|
|
|
+import org.elasticsearch.common.bytes.BytesArray;
|
|
|
import org.elasticsearch.common.bytes.BytesReference;
|
|
|
import org.elasticsearch.common.io.stream.BytesStreamOutput;
|
|
|
import org.elasticsearch.common.unit.ByteSizeUnit;
|
|
@@ -2093,6 +2099,67 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ public void testGetTrainedModelsStats() throws Exception {
|
|
|
+ MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
|
|
|
+ String modelIdPrefix = "get-trained-model-stats-";
|
|
|
+ int numberOfModels = 5;
|
|
|
+ for (int i = 0; i < numberOfModels; ++i) {
|
|
|
+ String modelId = modelIdPrefix + i;
|
|
|
+ putTrainedModel(modelId);
|
|
|
+ }
|
|
|
+
|
|
|
+ String regressionPipeline = "{" +
|
|
|
+ " \"processors\": [\n" +
|
|
|
+ " {\n" +
|
|
|
+ " \"inference\": {\n" +
|
|
|
+ " \"target_field\": \"regression_value\",\n" +
|
|
|
+ " \"model_id\": \"" + modelIdPrefix + 0 + "\",\n" +
|
|
|
+ " \"inference_config\": {\"regression\": {}},\n" +
|
|
|
+ " \"field_mappings\": {\n" +
|
|
|
+ " \"col1\": \"col1\",\n" +
|
|
|
+ " \"col2\": \"col2\",\n" +
|
|
|
+ " \"col3\": \"col3\",\n" +
|
|
|
+ " \"col4\": \"col4\"\n" +
|
|
|
+ " }\n" +
|
|
|
+ " }\n" +
|
|
|
+ " }]}\n";
|
|
|
+
|
|
|
+ highLevelClient().ingest().putPipeline(
|
|
|
+ new PutPipelineRequest("regression-stats-pipeline",
|
|
|
+ new BytesArray(regressionPipeline.getBytes(StandardCharsets.UTF_8)),
|
|
|
+ XContentType.JSON),
|
|
|
+ RequestOptions.DEFAULT);
|
|
|
+ {
|
|
|
+ GetTrainedModelsStatsResponse getTrainedModelsStatsResponse = execute(
|
|
|
+ GetTrainedModelsStatsRequest.getAllTrainedModelStatsRequest(),
|
|
|
+ machineLearningClient::getTrainedModelsStats, machineLearningClient::getTrainedModelsStatsAsync);
|
|
|
+ assertThat(getTrainedModelsStatsResponse.getTrainedModelStats(), hasSize(numberOfModels));
|
|
|
+ assertThat(getTrainedModelsStatsResponse.getCount(), equalTo(5L));
|
|
|
+ assertThat(getTrainedModelsStatsResponse.getTrainedModelStats().get(0).getPipelineCount(), equalTo(1));
|
|
|
+ assertThat(getTrainedModelsStatsResponse.getTrainedModelStats().get(1).getPipelineCount(), equalTo(0));
|
|
|
+ }
|
|
|
+ {
|
|
|
+ GetTrainedModelsStatsResponse getTrainedModelsStatsResponse = execute(
|
|
|
+ new GetTrainedModelsStatsRequest(modelIdPrefix + 4, modelIdPrefix + 2, modelIdPrefix + 3),
|
|
|
+ machineLearningClient::getTrainedModelsStats, machineLearningClient::getTrainedModelsStatsAsync);
|
|
|
+ assertThat(getTrainedModelsStatsResponse.getTrainedModelStats(), hasSize(3));
|
|
|
+ assertThat(getTrainedModelsStatsResponse.getCount(), equalTo(3L));
|
|
|
+ }
|
|
|
+ {
|
|
|
+ GetTrainedModelsStatsResponse getTrainedModelsStatsResponse = execute(
|
|
|
+ new GetTrainedModelsStatsRequest(modelIdPrefix + "*").setPageParams(new PageParams(1, 2)),
|
|
|
+ machineLearningClient::getTrainedModelsStats, machineLearningClient::getTrainedModelsStatsAsync);
|
|
|
+ assertThat(getTrainedModelsStatsResponse.getTrainedModelStats(), hasSize(2));
|
|
|
+ assertThat(getTrainedModelsStatsResponse.getCount(), equalTo(5L));
|
|
|
+ assertThat(
|
|
|
+ getTrainedModelsStatsResponse.getTrainedModelStats()
|
|
|
+ .stream()
|
|
|
+ .map(TrainedModelStats::getModelId)
|
|
|
+ .collect(Collectors.toList()),
|
|
|
+ containsInAnyOrder(modelIdPrefix + 1, modelIdPrefix + 2));
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
public void testDeleteTrainedModel() throws Exception {
|
|
|
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
|
|
|
String modelId = "delete-trained-model-test";
|
|
@@ -2298,7 +2365,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|
|
}
|
|
|
|
|
|
private void putTrainedModel(String modelId) throws IOException {
|
|
|
- TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder().build();
|
|
|
+ TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder(TargetType.REGRESSION).build();
|
|
|
highLevelClient().index(
|
|
|
new IndexRequest(".ml-inference-000001")
|
|
|
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|