浏览代码

[ML][Inference][HLRC] add GET _stats (#49562)

Benjamin Trent 5 年之前
父节点
当前提交
ba914453be
共有 15 个文件被更改,包括 722 次插入8 次删除
  1. 26 0
      client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java
  2. 45 0
      client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java
  3. 103 0
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsStatsRequest.java
  4. 86 0
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsStatsResponse.java
  5. 123 0
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelStats.java
  6. 21 1
      client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java
  7. 68 1
      client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java
  8. 55 0
      client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java
  9. 39 0
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/GetTrainedModelsStatsRequestTests.java
  10. 7 1
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelDefinitionTests.java
  11. 96 0
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelStatsTests.java
  12. 6 2
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/EnsembleTests.java
  13. 3 3
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeTests.java
  14. 42 0
      docs/java-rest/high-level/ml/get-trained-models-stats.asciidoc
  15. 2 0
      docs/java-rest/high-level/supported-apis.asciidoc

+ 26 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java

@@ -61,6 +61,7 @@ import org.elasticsearch.client.ml.GetModelSnapshotsRequest;
 import org.elasticsearch.client.ml.GetOverallBucketsRequest;
 import org.elasticsearch.client.ml.GetRecordsRequest;
 import org.elasticsearch.client.ml.GetTrainedModelsRequest;
+import org.elasticsearch.client.ml.GetTrainedModelsStatsRequest;
 import org.elasticsearch.client.ml.MlInfoRequest;
 import org.elasticsearch.client.ml.OpenJobRequest;
 import org.elasticsearch.client.ml.PostCalendarEventRequest;
@@ -749,6 +750,31 @@ final class MLRequestConverters {
         return request;
     }
 
+    static Request getTrainedModelsStats(GetTrainedModelsStatsRequest getTrainedModelsStatsRequest) {
+        String endpoint = new EndpointBuilder()
+            .addPathPartAsIs("_ml", "inference")
+            .addPathPart(Strings.collectionToCommaDelimitedString(getTrainedModelsStatsRequest.getIds()))
+            .addPathPart("_stats")
+            .build();
+        RequestConverters.Params params = new RequestConverters.Params();
+        if (getTrainedModelsStatsRequest.getPageParams() != null) {
+            PageParams pageParams = getTrainedModelsStatsRequest.getPageParams();
+            if (pageParams.getFrom() != null) {
+                params.putParam(PageParams.FROM.getPreferredName(), pageParams.getFrom().toString());
+            }
+            if (pageParams.getSize() != null) {
+                params.putParam(PageParams.SIZE.getPreferredName(), pageParams.getSize().toString());
+            }
+        }
+        if (getTrainedModelsStatsRequest.getAllowNoMatch() != null) {
+            params.putParam(GetTrainedModelsStatsRequest.ALLOW_NO_MATCH,
+                Boolean.toString(getTrainedModelsStatsRequest.getAllowNoMatch()));
+        }
+        Request request = new Request(HttpGet.METHOD_NAME, endpoint);
+        request.addParameters(params.asMap());
+        return request;
+    }
+
     static Request deleteTrainedModel(DeleteTrainedModelRequest deleteRequest) {
         String endpoint = new EndpointBuilder()
             .addPathPartAsIs("_ml", "inference")

+ 45 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java

@@ -77,6 +77,8 @@ import org.elasticsearch.client.ml.GetRecordsRequest;
 import org.elasticsearch.client.ml.GetRecordsResponse;
 import org.elasticsearch.client.ml.GetTrainedModelsRequest;
 import org.elasticsearch.client.ml.GetTrainedModelsResponse;
+import org.elasticsearch.client.ml.GetTrainedModelsStatsRequest;
+import org.elasticsearch.client.ml.GetTrainedModelsStatsResponse;
 import org.elasticsearch.client.ml.MlInfoRequest;
 import org.elasticsearch.client.ml.MlInfoResponse;
 import org.elasticsearch.client.ml.OpenJobRequest;
@@ -2338,6 +2340,49 @@ public final class MachineLearningClient {
             Collections.emptySet());
     }
 
+    /**
+     * Gets trained model stats
+     * <p>
+     * For additional info
+     * see <a href="TODO">
+     *     GET Trained Model Stats documentation</a>
+     *
+     * @param request The {@link GetTrainedModelsStatsRequest}
+     * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized
+     * @return {@link GetTrainedModelsStatsResponse} response object
+     */
+    public GetTrainedModelsStatsResponse getTrainedModelsStats(GetTrainedModelsStatsRequest request,
+                                                               RequestOptions options) throws IOException {
+        return restHighLevelClient.performRequestAndParseEntity(request,
+            MLRequestConverters::getTrainedModelsStats,
+            options,
+            GetTrainedModelsStatsResponse::fromXContent,
+            Collections.emptySet());
+    }
+
+    /**
+     * Gets trained model stats asynchronously and notifies listener upon completion
+     * <p>
+     * For additional info
+     * see <a href="TODO">
+     *     GET Trained Model Stats documentation</a>
+     *
+     * @param request The {@link GetTrainedModelsStatsRequest}
+     * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized
+     * @param listener Listener to be notified upon request completion
+     * @return cancellable that may be used to cancel the request
+     */
+    public Cancellable getTrainedModelsStatsAsync(GetTrainedModelsStatsRequest request,
+                                                  RequestOptions options,
+                                                  ActionListener<GetTrainedModelsStatsResponse> listener) {
+        return restHighLevelClient.performRequestAsyncAndParseEntity(request,
+            MLRequestConverters::getTrainedModelsStats,
+            options,
+            GetTrainedModelsStatsResponse::fromXContent,
+            listener,
+            Collections.emptySet());
+    }
+
     /**
      * Deletes the given Trained Model
      * <p>

+ 103 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsStatsRequest.java

@@ -0,0 +1,103 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.client.ml;
+
+import org.elasticsearch.client.Validatable;
+import org.elasticsearch.client.ValidationException;
+import org.elasticsearch.client.core.PageParams;
+import org.elasticsearch.common.Nullable;
+
+import java.util.Arrays;
+import java.util.List;
+import java.util.Objects;
+import java.util.Optional;
+
+public class GetTrainedModelsStatsRequest implements Validatable {
+
+    public static final String ALLOW_NO_MATCH = "allow_no_match";
+
+    private final List<String> ids;
+    private Boolean allowNoMatch;
+    private PageParams pageParams;
+
+    /**
+     * Helper method to create a request that will get ALL TrainedModelStats
+     * @return new {@link GetTrainedModelsStatsRequest} object for the id "_all"
+     */
+    public static GetTrainedModelsStatsRequest getAllTrainedModelStatsRequest() {
+        return new GetTrainedModelsStatsRequest("_all");
+    }
+
+    public GetTrainedModelsStatsRequest(String... ids) {
+        this.ids = Arrays.asList(ids);
+    }
+
+    public List<String> getIds() {
+        return ids;
+    }
+
+    public Boolean getAllowNoMatch() {
+        return allowNoMatch;
+    }
+
+    /**
+     * Whether to ignore if a wildcard expression matches no trained models.
+     *
+     * @param allowNoMatch If this is {@code false}, then an error is returned when a wildcard (or {@code _all})
+     *                    does not match any trained models
+     */
+    public GetTrainedModelsStatsRequest setAllowNoMatch(boolean allowNoMatch) {
+        this.allowNoMatch = allowNoMatch;
+        return this;
+    }
+
+    public PageParams getPageParams() {
+        return pageParams;
+    }
+
+    public GetTrainedModelsStatsRequest setPageParams(@Nullable PageParams pageParams) {
+        this.pageParams = pageParams;
+        return this;
+    }
+
+    @Override
+    public Optional<ValidationException> validate() {
+        if (ids == null || ids.isEmpty()) {
+            return Optional.of(ValidationException.withError("trained model id must not be null"));
+        }
+        return Optional.empty();
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+
+        GetTrainedModelsStatsRequest other = (GetTrainedModelsStatsRequest) o;
+        return Objects.equals(ids, other.ids)
+            && Objects.equals(allowNoMatch, other.allowNoMatch)
+            && Objects.equals(pageParams, other.pageParams);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(ids, allowNoMatch, pageParams);
+    }
+}

+ 86 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsStatsResponse.java

@@ -0,0 +1,86 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.client.ml;
+
+import org.elasticsearch.client.ml.inference.TrainedModelStats;
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.XContentParser;
+
+import java.util.List;
+import java.util.Objects;
+
+import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
+
+public class GetTrainedModelsStatsResponse {
+
+    public static final ParseField TRAINED_MODEL_STATS = new ParseField("trained_model_stats");
+    public static final ParseField COUNT = new ParseField("count");
+
+    @SuppressWarnings("unchecked")
+    static final ConstructingObjectParser<GetTrainedModelsStatsResponse, Void> PARSER =
+        new ConstructingObjectParser<>(
+            "get_trained_model_stats",
+            true,
+            args -> new GetTrainedModelsStatsResponse((List<TrainedModelStats>) args[0], (Long) args[1]));
+
+    static {
+        PARSER.declareObjectArray(constructorArg(), (p, c) -> TrainedModelStats.fromXContent(p), TRAINED_MODEL_STATS);
+        PARSER.declareLong(constructorArg(), COUNT);
+    }
+
+    public static GetTrainedModelsStatsResponse fromXContent(final XContentParser parser) {
+        return PARSER.apply(parser, null);
+    }
+
+    private final List<TrainedModelStats> trainedModelStats;
+    private final Long count;
+
+
+    public GetTrainedModelsStatsResponse(List<TrainedModelStats> trainedModelStats, Long count) {
+        this.trainedModelStats = trainedModelStats;
+        this.count = count;
+    }
+
+    public List<TrainedModelStats> getTrainedModelStats() {
+        return trainedModelStats;
+    }
+
+    /**
+     * @return The total count of the trained models that matched the ID pattern.
+     */
+    public Long getCount() {
+        return count;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+
+        GetTrainedModelsStatsResponse other = (GetTrainedModelsStatsResponse) o;
+        return Objects.equals(this.trainedModelStats, other.trainedModelStats) && Objects.equals(this.count, other.count);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(trainedModelStats, count);
+    }
+}

+ 123 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelStats.java

@@ -0,0 +1,123 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.inference;
+
+import org.elasticsearch.common.Nullable;
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ToXContentObject;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.ingest.IngestStats;
+
+import java.io.IOException;
+import java.util.Map;
+import java.util.Objects;
+
+import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
+import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
+
+public class TrainedModelStats implements ToXContentObject {
+
+    public static final ParseField MODEL_ID = new ParseField("model_id");
+    public static final ParseField PIPELINE_COUNT = new ParseField("pipeline_count");
+    public static final ParseField INGEST_STATS = new ParseField("ingest");
+
+    private final String modelId;
+    private final Map<String, Object> ingestStats;
+    private final int pipelineCount;
+
+    @SuppressWarnings("unchecked")
+    static final ConstructingObjectParser<TrainedModelStats, Void> PARSER =
+        new ConstructingObjectParser<>(
+            "trained_model_stats",
+            true,
+            args -> new TrainedModelStats((String) args[0], (Map<String, Object>) args[1], (Integer) args[2]));
+
+    static {
+        PARSER.declareString(constructorArg(), MODEL_ID);
+        PARSER.declareObject(optionalConstructorArg(), (p, c) -> p.mapOrdered(), INGEST_STATS);
+        PARSER.declareInt(constructorArg(), PIPELINE_COUNT);
+    }
+
+    public static TrainedModelStats fromXContent(XContentParser parser) {
+        return PARSER.apply(parser, null);
+    }
+
+    public TrainedModelStats(String modelId, Map<String, Object> ingestStats, int pipelineCount) {
+        this.modelId = modelId;
+        this.ingestStats = ingestStats;
+        this.pipelineCount = pipelineCount;
+    }
+
+    /**
+     * The model id for which the stats apply
+     */
+    public String getModelId() {
+        return modelId;
+    }
+
+    /**
+     * Ingest level statistics. See {@link IngestStats#toXContent(XContentBuilder, Params)} for fields and format
+     * If there are no ingest pipelines referencing the model, then the ingest statistics could be null.
+     */
+    @Nullable
+    public Map<String, Object> getIngestStats() {
+        return ingestStats;
+    }
+
+    /**
+     * The total number of pipelines that reference the trained model
+     */
+    public int getPipelineCount() {
+        return pipelineCount;
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        builder.field(MODEL_ID.getPreferredName(), modelId);
+        builder.field(PIPELINE_COUNT.getPreferredName(), pipelineCount);
+        if (ingestStats != null) {
+            builder.field(INGEST_STATS.getPreferredName(), ingestStats);
+        }
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(modelId, ingestStats, pipelineCount);
+    }
+
+    @Override
+    public boolean equals(Object obj) {
+        if (obj == null) {
+            return false;
+        }
+        if (getClass() != obj.getClass()) {
+            return false;
+        }
+        TrainedModelStats other = (TrainedModelStats) obj;
+        return Objects.equals(this.modelId, other.modelId)
+            && Objects.equals(this.ingestStats, other.ingestStats)
+            && Objects.equals(this.pipelineCount, other.pipelineCount);
+    }
+
+}

+ 21 - 1
client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java

@@ -59,6 +59,7 @@ import org.elasticsearch.client.ml.GetModelSnapshotsRequest;
 import org.elasticsearch.client.ml.GetOverallBucketsRequest;
 import org.elasticsearch.client.ml.GetRecordsRequest;
 import org.elasticsearch.client.ml.GetTrainedModelsRequest;
+import org.elasticsearch.client.ml.GetTrainedModelsStatsRequest;
 import org.elasticsearch.client.ml.MlInfoRequest;
 import org.elasticsearch.client.ml.OpenJobRequest;
 import org.elasticsearch.client.ml.PostCalendarEventRequest;
@@ -825,7 +826,6 @@ public class MLRequestConvertersTests extends ESTestCase {
         Request request = MLRequestConverters.getTrainedModels(getRequest);
         assertEquals(HttpGet.METHOD_NAME, request.getMethod());
         assertEquals("/_ml/inference/" + modelId1 + "," + modelId2 + "," + modelId3, request.getEndpoint());
-        assertThat(request.getParameters(), allOf(hasEntry("from", "100"), hasEntry("size", "300"), hasEntry("allow_no_match", "false")));
         assertThat(request.getParameters(),
             allOf(
                 hasEntry("from", "100"),
@@ -837,6 +837,26 @@ public class MLRequestConvertersTests extends ESTestCase {
         assertNull(request.getEntity());
     }
 
+    public void testGetTrainedModelsStats() {
+        String modelId1 = randomAlphaOfLength(10);
+        String modelId2 = randomAlphaOfLength(10);
+        String modelId3 = randomAlphaOfLength(10);
+        GetTrainedModelsStatsRequest getRequest = new GetTrainedModelsStatsRequest(modelId1, modelId2, modelId3)
+            .setAllowNoMatch(false)
+            .setPageParams(new PageParams(100, 300));
+
+        Request request = MLRequestConverters.getTrainedModelsStats(getRequest);
+        assertEquals(HttpGet.METHOD_NAME, request.getMethod());
+        assertEquals("/_ml/inference/" + modelId1 + "," + modelId2 + "," + modelId3 + "/_stats", request.getEndpoint());
+        assertThat(request.getParameters(),
+            allOf(
+                hasEntry("from", "100"),
+                hasEntry("size", "300"),
+                hasEntry("allow_no_match", "false")
+            ));
+        assertNull(request.getEntity());
+    }
+
     public void testDeleteTrainedModel() {
         DeleteTrainedModelRequest deleteRequest = new DeleteTrainedModelRequest(randomAlphaOfLength(10));
         Request request = MLRequestConverters.deleteTrainedModel(deleteRequest);

+ 68 - 1
client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java

@@ -24,6 +24,7 @@ import org.elasticsearch.action.bulk.BulkRequest;
 import org.elasticsearch.action.get.GetRequest;
 import org.elasticsearch.action.get.GetResponse;
 import org.elasticsearch.action.index.IndexRequest;
+import org.elasticsearch.action.ingest.PutPipelineRequest;
 import org.elasticsearch.action.support.WriteRequest;
 import org.elasticsearch.action.support.master.AcknowledgedResponse;
 import org.elasticsearch.action.update.UpdateRequest;
@@ -77,6 +78,8 @@ import org.elasticsearch.client.ml.GetModelSnapshotsRequest;
 import org.elasticsearch.client.ml.GetModelSnapshotsResponse;
 import org.elasticsearch.client.ml.GetTrainedModelsRequest;
 import org.elasticsearch.client.ml.GetTrainedModelsResponse;
+import org.elasticsearch.client.ml.GetTrainedModelsStatsRequest;
+import org.elasticsearch.client.ml.GetTrainedModelsStatsResponse;
 import org.elasticsearch.client.ml.MlInfoRequest;
 import org.elasticsearch.client.ml.MlInfoResponse;
 import org.elasticsearch.client.ml.OpenJobRequest;
@@ -148,6 +151,8 @@ import org.elasticsearch.client.ml.filestructurefinder.FileStructure;
 import org.elasticsearch.client.ml.inference.TrainedModelConfig;
 import org.elasticsearch.client.ml.inference.TrainedModelDefinition;
 import org.elasticsearch.client.ml.inference.TrainedModelDefinitionTests;
+import org.elasticsearch.client.ml.inference.TrainedModelStats;
+import org.elasticsearch.client.ml.inference.trainedmodel.TargetType;
 import org.elasticsearch.client.ml.job.config.AnalysisConfig;
 import org.elasticsearch.client.ml.job.config.DataDescription;
 import org.elasticsearch.client.ml.job.config.Detector;
@@ -157,6 +162,7 @@ import org.elasticsearch.client.ml.job.config.JobUpdate;
 import org.elasticsearch.client.ml.job.config.MlFilter;
 import org.elasticsearch.client.ml.job.process.ModelSnapshot;
 import org.elasticsearch.client.ml.job.stats.JobStats;
+import org.elasticsearch.common.bytes.BytesArray;
 import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.common.io.stream.BytesStreamOutput;
 import org.elasticsearch.common.unit.ByteSizeUnit;
@@ -2093,6 +2099,67 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
         }
     }
 
+    public void testGetTrainedModelsStats() throws Exception {
+        MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
+        String modelIdPrefix = "get-trained-model-stats-";
+        int numberOfModels = 5;
+        for (int i = 0; i < numberOfModels; ++i) {
+            String modelId = modelIdPrefix + i;
+            putTrainedModel(modelId);
+        }
+
+        String regressionPipeline = "{" +
+            "    \"processors\": [\n" +
+            "      {\n" +
+            "        \"inference\": {\n" +
+            "          \"target_field\": \"regression_value\",\n" +
+            "          \"model_id\": \"" + modelIdPrefix + 0 + "\",\n" +
+            "          \"inference_config\": {\"regression\": {}},\n" +
+            "          \"field_mappings\": {\n" +
+            "            \"col1\": \"col1\",\n" +
+            "            \"col2\": \"col2\",\n" +
+            "            \"col3\": \"col3\",\n" +
+            "            \"col4\": \"col4\"\n" +
+            "          }\n" +
+            "        }\n" +
+            "      }]}\n";
+
+        highLevelClient().ingest().putPipeline(
+            new PutPipelineRequest("regression-stats-pipeline",
+                new BytesArray(regressionPipeline.getBytes(StandardCharsets.UTF_8)),
+                XContentType.JSON),
+            RequestOptions.DEFAULT);
+        {
+            GetTrainedModelsStatsResponse getTrainedModelsStatsResponse = execute(
+                GetTrainedModelsStatsRequest.getAllTrainedModelStatsRequest(),
+                machineLearningClient::getTrainedModelsStats, machineLearningClient::getTrainedModelsStatsAsync);
+            assertThat(getTrainedModelsStatsResponse.getTrainedModelStats(), hasSize(numberOfModels));
+            assertThat(getTrainedModelsStatsResponse.getCount(), equalTo(5L));
+            assertThat(getTrainedModelsStatsResponse.getTrainedModelStats().get(0).getPipelineCount(), equalTo(1));
+            assertThat(getTrainedModelsStatsResponse.getTrainedModelStats().get(1).getPipelineCount(), equalTo(0));
+        }
+        {
+            GetTrainedModelsStatsResponse getTrainedModelsStatsResponse = execute(
+                new GetTrainedModelsStatsRequest(modelIdPrefix + 4, modelIdPrefix + 2, modelIdPrefix + 3),
+                machineLearningClient::getTrainedModelsStats, machineLearningClient::getTrainedModelsStatsAsync);
+            assertThat(getTrainedModelsStatsResponse.getTrainedModelStats(), hasSize(3));
+            assertThat(getTrainedModelsStatsResponse.getCount(), equalTo(3L));
+        }
+        {
+            GetTrainedModelsStatsResponse getTrainedModelsStatsResponse = execute(
+                new GetTrainedModelsStatsRequest(modelIdPrefix + "*").setPageParams(new PageParams(1, 2)),
+                machineLearningClient::getTrainedModelsStats, machineLearningClient::getTrainedModelsStatsAsync);
+            assertThat(getTrainedModelsStatsResponse.getTrainedModelStats(), hasSize(2));
+            assertThat(getTrainedModelsStatsResponse.getCount(), equalTo(5L));
+            assertThat(
+                getTrainedModelsStatsResponse.getTrainedModelStats()
+                    .stream()
+                    .map(TrainedModelStats::getModelId)
+                    .collect(Collectors.toList()),
+                containsInAnyOrder(modelIdPrefix + 1, modelIdPrefix + 2));
+        }
+    }
+
     public void testDeleteTrainedModel() throws Exception {
         MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
         String modelId = "delete-trained-model-test";
@@ -2298,7 +2365,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
     }
 
     private void putTrainedModel(String modelId) throws IOException {
-        TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder().build();
+        TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder(TargetType.REGRESSION).build();
         highLevelClient().index(
             new IndexRequest(".ml-inference-000001")
                 .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)

+ 55 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java

@@ -91,6 +91,8 @@ import org.elasticsearch.client.ml.GetRecordsRequest;
 import org.elasticsearch.client.ml.GetRecordsResponse;
 import org.elasticsearch.client.ml.GetTrainedModelsRequest;
 import org.elasticsearch.client.ml.GetTrainedModelsResponse;
+import org.elasticsearch.client.ml.GetTrainedModelsStatsRequest;
+import org.elasticsearch.client.ml.GetTrainedModelsStatsResponse;
 import org.elasticsearch.client.ml.MlInfoRequest;
 import org.elasticsearch.client.ml.MlInfoResponse;
 import org.elasticsearch.client.ml.OpenJobRequest;
@@ -163,6 +165,7 @@ import org.elasticsearch.client.ml.filestructurefinder.FileStructure;
 import org.elasticsearch.client.ml.inference.TrainedModelConfig;
 import org.elasticsearch.client.ml.inference.TrainedModelDefinition;
 import org.elasticsearch.client.ml.inference.TrainedModelDefinitionTests;
+import org.elasticsearch.client.ml.inference.TrainedModelStats;
 import org.elasticsearch.client.ml.job.config.AnalysisConfig;
 import org.elasticsearch.client.ml.job.config.AnalysisLimits;
 import org.elasticsearch.client.ml.job.config.DataDescription;
@@ -3593,6 +3596,58 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
         }
     }
 
+    public void testGetTrainedModelsStats() throws Exception {
+        putTrainedModel("my-trained-model");
+        RestHighLevelClient client = highLevelClient();
+        {
+            // tag::get-trained-models-stats-request
+            GetTrainedModelsStatsRequest request =
+                new GetTrainedModelsStatsRequest("my-trained-model") // <1>
+                    .setPageParams(new PageParams(0, 1)) // <2>
+                    .setAllowNoMatch(true); // <3>
+            // end::get-trained-models-stats-request
+
+            // tag::get-trained-models-stats-execute
+            GetTrainedModelsStatsResponse response =
+                client.machineLearning().getTrainedModelsStats(request, RequestOptions.DEFAULT);
+            // end::get-trained-models-stats-execute
+
+            // tag::get-trained-models-stats-response
+            List<TrainedModelStats> models = response.getTrainedModelStats();
+            // end::get-trained-models-stats-response
+
+            assertThat(models, hasSize(1));
+        }
+        {
+            GetTrainedModelsStatsRequest request = new GetTrainedModelsStatsRequest("my-trained-model");
+
+            // tag::get-trained-models-stats-execute-listener
+            ActionListener<GetTrainedModelsStatsResponse> listener = new ActionListener<>() {
+                @Override
+                public void onResponse(GetTrainedModelsStatsResponse response) {
+                    // <1>
+                }
+
+                @Override
+                public void onFailure(Exception e) {
+                    // <2>
+                }
+            };
+            // end::get-trained-models-stats-execute-listener
+
+            // Replace the empty listener by a blocking listener in test
+            CountDownLatch latch = new CountDownLatch(1);
+            listener = new LatchedActionListener<>(listener, latch);
+
+            // tag::get-trained-models-stats-execute-async
+            client.machineLearning()
+                .getTrainedModelsStatsAsync(request, RequestOptions.DEFAULT, listener); // <1>
+            // end::get-trained-models-stats-execute-async
+
+            assertTrue(latch.await(30L, TimeUnit.SECONDS));
+        }
+    }
+
     public void testDeleteTrainedModel() throws Exception {
         RestHighLevelClient client = highLevelClient();
         {

+ 39 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/GetTrainedModelsStatsRequestTests.java

@@ -0,0 +1,39 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.client.ml;
+
+import org.elasticsearch.test.ESTestCase;
+
+import java.util.Optional;
+
+import static org.hamcrest.Matchers.containsString;
+
+public class GetTrainedModelsStatsRequestTests extends ESTestCase {
+
+    public void testValidate_Ok() {
+        assertEquals(Optional.empty(), new GetTrainedModelsStatsRequest("valid-id").validate());
+        assertEquals(Optional.empty(), new GetTrainedModelsStatsRequest("").validate());
+    }
+
+    public void testValidate_Failure() {
+        assertThat(new GetTrainedModelsStatsRequest(new String[0]).validate().get().getMessage(),
+            containsString("trained model id must not be null"));
+    }
+}

+ 7 - 1
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelDefinitionTests.java

@@ -21,6 +21,7 @@ package org.elasticsearch.client.ml.inference;
 import org.elasticsearch.client.ml.inference.preprocessing.FrequencyEncodingTests;
 import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncodingTests;
 import org.elasticsearch.client.ml.inference.preprocessing.TargetMeanEncodingTests;
+import org.elasticsearch.client.ml.inference.trainedmodel.TargetType;
 import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.EnsembleTests;
 import org.elasticsearch.client.ml.inference.trainedmodel.tree.TreeTests;
 import org.elasticsearch.common.settings.Settings;
@@ -56,6 +57,10 @@ public class TrainedModelDefinitionTests extends AbstractXContentTestCase<Traine
     }
 
     public static TrainedModelDefinition.Builder createRandomBuilder() {
+        return createRandomBuilder(randomFrom(TargetType.values()));
+    }
+
+    public static TrainedModelDefinition.Builder createRandomBuilder(TargetType targetType) {
         int numberOfProcessors = randomIntBetween(1, 10);
         return new TrainedModelDefinition.Builder()
             .setPreProcessors(
@@ -65,7 +70,8 @@ public class TrainedModelDefinitionTests extends AbstractXContentTestCase<Traine
                         TargetMeanEncodingTests.createRandom()))
                         .limit(numberOfProcessors)
                         .collect(Collectors.toList()))
-            .setTrainedModel(randomFrom(TreeTests.createRandom(), EnsembleTests.createRandom()));
+            .setTrainedModel(randomFrom(TreeTests.buildRandomTree(Collections.emptyList(), 6, targetType),
+                EnsembleTests.createRandom(targetType)));
     }
 
     @Override

+ 96 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelStatsTests.java

@@ -0,0 +1,96 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.inference;
+
+import org.elasticsearch.common.bytes.BytesReference;
+import org.elasticsearch.common.xcontent.ToXContent;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentFactory;
+import org.elasticsearch.common.xcontent.XContentHelper;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.ingest.IngestStats;
+import org.elasticsearch.test.AbstractXContentTestCase;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+import java.util.function.Predicate;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+
+public class TrainedModelStatsTests extends AbstractXContentTestCase<TrainedModelStats> {
+
+    @Override
+    protected TrainedModelStats doParseInstance(XContentParser parser) throws IOException {
+        return TrainedModelStats.fromXContent(parser);
+    }
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return true;
+    }
+
+    @Override
+    protected Predicate<String> getRandomFieldsExcludeFilter() {
+        return field -> !field.isEmpty();
+    }
+
+    @Override
+    protected TrainedModelStats createTestInstance() {
+        return new TrainedModelStats(
+            randomAlphaOfLength(10),
+            randomBoolean() ? null : randomIngestStats(),
+            randomInt());
+    }
+
+    private Map<String, Object> randomIngestStats() {
+        try {
+            List<String> pipelineIds = Stream.generate(()-> randomAlphaOfLength(10))
+                .limit(randomIntBetween(0, 10))
+                .collect(Collectors.toList());
+            IngestStats stats = new IngestStats(
+                new IngestStats.Stats(randomNonNegativeLong(), randomNonNegativeLong(), randomNonNegativeLong(), randomNonNegativeLong()),
+                pipelineIds.stream().map(id -> new IngestStats.PipelineStat(id, randomStats())).collect(Collectors.toList()),
+                pipelineIds.stream().collect(Collectors.toMap(Function.identity(), (v) -> randomProcessorStats())));
+            try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
+                builder.startObject();
+                stats.toXContent(builder, ToXContent.EMPTY_PARAMS);
+                builder.endObject();
+                return XContentHelper.convertToMap(BytesReference.bytes(builder), false, builder.contentType()).v2();
+            }
+        } catch (IOException ex) {
+            fail(ex.getMessage());
+            return null;
+        }
+    }
+
+    private IngestStats.Stats randomStats(){
+        return new IngestStats.Stats(randomNonNegativeLong(), randomNonNegativeLong(), randomNonNegativeLong(), randomNonNegativeLong());
+    }
+
+    private List<IngestStats.ProcessorStat> randomProcessorStats() {
+        return Stream.generate(() -> randomAlphaOfLength(10))
+            .limit(randomIntBetween(0, 10))
+            .map(name -> new IngestStats.ProcessorStat(name, "inference", randomStats()))
+            .collect(Collectors.toList());
+    }
+
+}

+ 6 - 2
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/EnsembleTests.java

@@ -57,12 +57,16 @@ public class EnsembleTests extends AbstractXContentTestCase<Ensemble> {
     }
 
     public static Ensemble createRandom() {
+        return createRandom(randomFrom(TargetType.values()));
+    }
+
+    public static Ensemble createRandom(TargetType targetType) {
         int numberOfFeatures = randomIntBetween(1, 10);
         List<String> featureNames = Stream.generate(() -> randomAlphaOfLength(10))
             .limit(numberOfFeatures)
             .collect(Collectors.toList());
         int numberOfModels = randomIntBetween(1, 10);
-        List<TrainedModel> models = Stream.generate(() -> TreeTests.buildRandomTree(featureNames, 6))
+        List<TrainedModel> models = Stream.generate(() -> TreeTests.buildRandomTree(featureNames, 6, targetType))
             .limit(numberOfFeatures)
             .collect(Collectors.toList());
         OutputAggregator outputAggregator = null;
@@ -77,7 +81,7 @@ public class EnsembleTests extends AbstractXContentTestCase<Ensemble> {
         return new Ensemble(featureNames,
             models,
             outputAggregator,
-            randomFrom(TargetType.values()),
+            targetType,
             categoryLabels);
     }
 

+ 3 - 3
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeTests.java

@@ -57,10 +57,10 @@ public class TreeTests extends AbstractXContentTestCase<Tree> {
         for (int i = 0; i < numberOfFeatures; i++) {
             featureNames.add(randomAlphaOfLength(10));
         }
-        return buildRandomTree(featureNames,  6);
+        return buildRandomTree(featureNames,  6, randomFrom(TargetType.values()));
     }
 
-    public static Tree buildRandomTree(List<String> featureNames, int depth) {
+    public static Tree buildRandomTree(List<String> featureNames, int depth, TargetType targetType) {
         int numFeatures = featureNames.size();
         Tree.Builder builder = Tree.builder();
         builder.setFeatureNames(featureNames);
@@ -88,7 +88,7 @@ public class TreeTests extends AbstractXContentTestCase<Tree> {
             categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false));
         }
         return builder.setClassificationLabels(categoryLabels)
-            .setTargetType(randomFrom(TargetType.values()))
+            .setTargetType(targetType)
             .build();
     }
 

+ 42 - 0
docs/java-rest/high-level/ml/get-trained-models-stats.asciidoc

@@ -0,0 +1,42 @@
+--
+:api: get-trained-models-stats
+:request: GetTrainedModelsStatsRequest
+:response: GetTrainedModelsStatsResponse
+--
+[role="xpack"]
+[id="{upid}-{api}"]
+=== Get Trained Models Stats API
+
+experimental[]
+
+Retrieves one or more Trained Model statistics.
+The API accepts a +{request}+ object and returns a +{response}+.
+
+[id="{upid}-{api}-request"]
+==== Get Trained Models Stats request
+
+A +{request}+ requires either a Trained Model ID, a comma-separated list of
+IDs, or the special wildcard `_all` to get stats for all Trained Models.
+
+["source","java",subs="attributes,callouts,macros"]
+--------------------------------------------------
+include-tagged::{doc-tests-file}[{api}-request]
+--------------------------------------------------
+<1> Constructing a new GET request referencing an existing Trained Model
+<2> Set the paging parameters
+<3> Allow empty response if no Trained Models match the provided ID patterns.
+    If false, an error will be thrown if no Trained Models match the
+    ID patterns.
+
+include::../execution.asciidoc[]
+
+[id="{upid}-{api}-response"]
+==== Response
+
+The returned +{response}+ contains the statistics
+for the requested Trained Model.
+
+["source","java",subs="attributes,callouts,macros"]
+--------------------------------------------------
+include-tagged::{doc-tests-file}[{api}-response]
+--------------------------------------------------

+ 2 - 0
docs/java-rest/high-level/supported-apis.asciidoc

@@ -302,6 +302,7 @@ The Java High Level REST Client supports the following Machine Learning APIs:
 * <<{upid}-evaluate-data-frame>>
 * <<{upid}-explain-data-frame-analytics>>
 * <<{upid}-get-trained-models>>
+* <<{upid}-get-trained-models-stats>>
 * <<{upid}-delete-trained-model>>
 * <<{upid}-put-filter>>
 * <<{upid}-get-filters>>
@@ -356,6 +357,7 @@ include::ml/stop-data-frame-analytics.asciidoc[]
 include::ml/evaluate-data-frame.asciidoc[]
 include::ml/explain-data-frame-analytics.asciidoc[]
 include::ml/get-trained-models.asciidoc[]
+include::ml/get-trained-models-stats.asciidoc[]
 include::ml/delete-trained-model.asciidoc[]
 include::ml/put-filter.asciidoc[]
 include::ml/get-filters.asciidoc[]