Browse Source

[ML] add deployment_stats to trained model stats (#80531)

This commit adds a new field deployment_stats that is optionally set for models that are deployed.

If a model does not have a deployment, it will be null.

Also, removes the get deployment stats API and makes the deployment stats action internal only.
Benjamin Trent 3 years ago
parent
commit
cf5f521fac
18 changed files with 745 additions and 656 deletions
  1. 128 3
      docs/reference/ml/df-analytics/apis/get-trained-models-stats.asciidoc
  2. 0 1
      docs/reference/ml/df-analytics/apis/index.asciidoc
  3. 0 1
      docs/reference/ml/df-analytics/apis/ml-df-analytics-apis.asciidoc
  4. 0 30
      rest-api-spec/src/main/resources/rest-api-spec/api/ml.get_trained_model_deployment_stats.json
  5. 4 353
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDeploymentStatsAction.java
  6. 49 7
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsAction.java
  7. 349 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/AllocationStats.java
  8. 5 49
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDeploymentStatsActionResponseTests.java
  9. 24 1
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsActionResponseTests.java
  10. 74 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/AllocationStatsTests.java
  11. 42 80
      x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java
  12. 0 2
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java
  13. 23 43
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsAction.java
  14. 22 4
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java
  15. 0 53
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelDeploymentStatsAction.java
  16. 7 12
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MachineLearningInfoTransportActionTests.java
  17. 17 16
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsActionTests.java
  18. 1 1
      x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java

+ 128 - 3
docs/reference/ml/df-analytics/apis/get-trained-models-stats.asciidoc

@@ -26,7 +26,7 @@ Retrieves usage information for trained models.
 [[ml-get-trained-models-stats-prereq]]
 == {api-prereq-title}
 
-Requires the `monitor_ml` cluster privilege. This privilege is included in the 
+Requires the `monitor_ml` cluster privilege. This privilege is included in the
 `machine_learning_user` built-in role.
 
 
@@ -78,13 +78,131 @@ in ascending order.
 .Properties of trained model stats
 [%collapsible%open]
 ====
+`deployment_stats`:::
+(list)
+A collection of deployment stats if one of the provided `model_id` values
+is deployed
++
+.Properties of deployment stats
+[%collapsible%open]
+=====
+`allocation_status`:::
+(object)
+The detailed allocation status given the deployment configuration.
++
+.Properties of allocation stats
+[%collapsible%open]
+======
+`allocation_count`:::
+(integer)
+The current number of nodes where the model is allocated.
+
+`state`:::
+(string)
+The detailed allocation state related to the nodes.
++
+--
+* `starting`: Allocations are being attempted but no node currently has the model allocated.
+* `started`: At least one node has the model allocated.
+* `fully_allocated`: The deployment is fully allocated and satisfies the `target_allocation_count`.
+--
+
+`target_allocation_count`:::
+(integer)
+The desired number of nodes for model allocation.
+======
+
 `model_id`:::
 (string)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]
 
-`pipeline_count`:::
+`model_size`:::
+(<<byte-units,byte value>>)
+The size of the loaded model in bytes.
+
+`nodes`:::
+(array of objects)
+The deployment stats for each node that currently has the model allocated.
++
+.Properties of node stats
+[%collapsible%open]
+======
+`average_inference_time_ms`:::
+(double)
+The average time for each inference call to complete on this node.
+
+`inference_count`:::
 (integer)
-The number of ingest pipelines that currently refer to the model.
+The total number of inference calls made against this node for this model.
+
+`last_access`:::
+(long)
+The epoch time stamp of the last inference call for the model on this node.
+
+`node`:::
+(object)
+Information pertaining to the node.
++
+.Properties of node
+[%collapsible%open]
+========
+`attributes`:::
+(object)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=node-attributes]
+
+`ephemeral_id`:::
+(string)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=node-ephemeral-id]
+
+`id`:::
+(string)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=node-id]
+
+`name`:::
+(string) The node name.
+
+`transport_address`:::
+(string)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=node-transport-address]
+========
+
+`reason`:::
+(string)
+The reason for the current state. Usually only populated when the `routing_state` is `failed`.
+
+`routing_state`:::
+(object)
+The current routing state and reason for the current routing state for this allocation.
++
+--
+* `starting`: The model is attempting to allocate on this model, inference calls are not yet accepted.
+* `started`: The model is allocated and ready to accept inference requests.
+* `stopping`: The model is being deallocated from this node.
+* `stopped`: The model is fully deallocated from this node.
+* `failed`: The allocation attempt failed, see `reason` field for the potential cause.
+--
+
+`start_time`:::
+(long)
+The epoch timestamp when the allocation started.
+
+======
+
+`start_time`:::
+(long)
+The epoch timestamp when the deployment started.
+
+`state`:::
+(string)
+The overall state of the deployment. The values may be:
++
+--
+* `starting`: The deployment has recently started but is not yet usable as the model is not allocated on any nodes.
+* `started`: The deployment is usable as at least one node has the model allocated.
+* `stopping`: The deployment is preparing to stop and deallocate the model from the relevant nodes.
+--
+
+=====
 
 `inference_stats`:::
 (object)
@@ -127,6 +245,13 @@ A collection of ingest stats for the model across all nodes. The values are
 summations of the individual node statistics. The format matches the `ingest`
 section in <<cluster-nodes-stats>>.
 
+`model_id`:::
+(string)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]
+
+`pipeline_count`:::
+(integer)
+The number of ingest pipelines that currently refer to the model.
 ====
 
 [[ml-get-trained-models-stats-response-codes]]

+ 0 - 1
docs/reference/ml/df-analytics/apis/index.asciidoc

@@ -19,7 +19,6 @@ include::explain-dfanalytics.asciidoc[leveloffset=+2]
 include::get-dfanalytics.asciidoc[leveloffset=+2]
 include::get-dfanalytics-stats.asciidoc[leveloffset=+2]
 include::get-trained-models.asciidoc[leveloffset=+2]
-include::get-trained-model-deployment-stats.asciidoc[leveloffset=+2]
 include::get-trained-models-stats.asciidoc[leveloffset=+2]
 //INFER
 include::infer-trained-model-deployment.asciidoc[leveloffset=+2]

+ 0 - 1
docs/reference/ml/df-analytics/apis/ml-df-analytics-apis.asciidoc

@@ -25,7 +25,6 @@ You can use the following APIs to perform {infer} operations:
 * <<delete-trained-models-aliases>>
 * <<get-trained-models>>
 * <<get-trained-models-stats>>
-* <<get-trained-model-deployment-stats>>
 
 You can deploy a trained model to make predictions in an ingest pipeline or in
 an aggregation. Refer to the following documentation to learn more:

+ 0 - 30
rest-api-spec/src/main/resources/rest-api-spec/api/ml.get_trained_model_deployment_stats.json

@@ -1,30 +0,0 @@
-{
-  "ml.get_trained_model_deployment_stats":{
-    "documentation":{
-      "url":"https://www.elastic.co/guide/en/elasticsearch/reference/current/ml-get-trained-model-deployment-stats.html",
-      "description":"Get information about trained model deployments."
-    },
-    "stability":"stable",
-    "visibility":"public",
-    "headers":{
-      "accept": [ "application/json"],
-      "content_type": ["application/json"]
-    },
-    "url":{
-      "paths":[
-        {
-          "path":"/_ml/trained_models/{model_id}/deployment/_stats",
-          "methods":[
-            "GET"
-          ],
-          "parts":{
-            "model_id":{
-              "type":"string",
-              "description":"The ID of the trained model deployment stats to fetch"
-            }
-          }
-        }
-      ]
-    }
-  }
-}

+ 4 - 353
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDeploymentStatsAction.java

@@ -12,34 +12,25 @@ import org.elasticsearch.action.ActionType;
 import org.elasticsearch.action.TaskOperationFailure;
 import org.elasticsearch.action.support.tasks.BaseTasksRequest;
 import org.elasticsearch.action.support.tasks.BaseTasksResponse;
-import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
-import org.elasticsearch.common.io.stream.Writeable;
-import org.elasticsearch.common.unit.ByteSizeValue;
-import org.elasticsearch.core.Nullable;
 import org.elasticsearch.tasks.Task;
 import org.elasticsearch.xcontent.ParseField;
 import org.elasticsearch.xcontent.ToXContentObject;
 import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xpack.core.action.util.QueryPage;
-import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationState;
-import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationStatus;
-import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingState;
-import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingStateAndReason;
+import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationStats;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 
 import java.io.IOException;
-import java.time.Instant;
 import java.util.Collections;
 import java.util.List;
 import java.util.Objects;
-import java.util.Optional;
 
 public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsAction.Response> {
 
     public static final GetDeploymentStatsAction INSTANCE = new GetDeploymentStatsAction();
-    public static final String NAME = "cluster:monitor/xpack/ml/trained_models/deployments/stats/get";
+    public static final String NAME = "cluster:internal/xpack/ml/trained_models/deployments/stats/get";
 
     private GetDeploymentStatsAction() {
         super(NAME, GetDeploymentStatsAction.Response::new);
@@ -47,10 +38,7 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
 
     public static class Request extends BaseTasksRequest<GetDeploymentStatsAction.Request> {
 
-        public static final String ALLOW_NO_MATCH = "allow_no_match";
-
         private final String deploymentId;
-        private boolean allowNoMatch = true;
         // used internally this should not be set by the REST request
         private List<String> expandedIds;
 
@@ -62,7 +50,6 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
         public Request(StreamInput in) throws IOException {
             super(in);
             this.deploymentId = in.readString();
-            this.allowNoMatch = in.readBoolean();
             this.expandedIds = in.readStringList();
         }
 
@@ -70,7 +57,6 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
         public void writeTo(StreamOutput out) throws IOException {
             super.writeTo(out);
             out.writeString(deploymentId);
-            out.writeBoolean(allowNoMatch);
             out.writeStringCollection(expandedIds);
         }
 
@@ -82,14 +68,6 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
             this.expandedIds = expandedIds;
         }
 
-        public void setAllowNoMatch(boolean allowNoMatch) {
-            this.allowNoMatch = allowNoMatch;
-        }
-
-        public boolean isAllowNoMatch() {
-            return allowNoMatch;
-        }
-
         @Override
         public boolean match(Task task) {
             return expandedIds.stream().anyMatch(taskId -> StartTrainedModelDeploymentAction.TaskMatcher.match(task, taskId));
@@ -100,14 +78,12 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
             if (this == o) return true;
             if (o == null || getClass() != o.getClass()) return false;
             Request request = (Request) o;
-            return Objects.equals(deploymentId, request.deploymentId)
-                && this.allowNoMatch == request.allowNoMatch
-                && Objects.equals(expandedIds, request.expandedIds);
+            return Objects.equals(deploymentId, request.deploymentId) && Objects.equals(expandedIds, request.expandedIds);
         }
 
         @Override
         public int hashCode() {
-            return Objects.hash(deploymentId, allowNoMatch, expandedIds);
+            return Objects.hash(deploymentId, expandedIds);
         }
     }
 
@@ -115,331 +91,6 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
 
         public static final ParseField DEPLOYMENT_STATS = new ParseField("deployment_stats");
 
-        public static class AllocationStats implements ToXContentObject, Writeable {
-
-            public static class NodeStats implements ToXContentObject, Writeable {
-                private final DiscoveryNode node;
-                private final Long inferenceCount;
-                private final Double avgInferenceTime;
-                private final Instant lastAccess;
-                private final Integer pendingCount;
-                private final RoutingStateAndReason routingState;
-                private final Instant startTime;
-
-                public static NodeStats forStartedState(
-                    DiscoveryNode node,
-                    long inferenceCount,
-                    Double avgInferenceTime,
-                    int pendingCount,
-                    Instant lastAccess,
-                    Instant startTime
-                ) {
-                    return new NodeStats(
-                        node,
-                        inferenceCount,
-                        avgInferenceTime,
-                        lastAccess,
-                        pendingCount,
-                        new RoutingStateAndReason(RoutingState.STARTED, null),
-                        Objects.requireNonNull(startTime)
-                    );
-                }
-
-                public static NodeStats forNotStartedState(DiscoveryNode node, RoutingState state, String reason) {
-                    return new NodeStats(node, null, null, null, null, new RoutingStateAndReason(state, reason), null);
-                }
-
-                private NodeStats(
-                    DiscoveryNode node,
-                    Long inferenceCount,
-                    Double avgInferenceTime,
-                    Instant lastAccess,
-                    Integer pendingCount,
-                    RoutingStateAndReason routingState,
-                    @Nullable Instant startTime
-                ) {
-                    this.node = node;
-                    this.inferenceCount = inferenceCount;
-                    this.avgInferenceTime = avgInferenceTime;
-                    this.lastAccess = lastAccess;
-                    this.pendingCount = pendingCount;
-                    this.routingState = routingState;
-                    this.startTime = startTime;
-
-                    // if lastAccess time is null there have been no inferences
-                    assert this.lastAccess != null || (inferenceCount == null || inferenceCount == 0);
-                }
-
-                public NodeStats(StreamInput in) throws IOException {
-                    this.node = in.readOptionalWriteable(DiscoveryNode::new);
-                    this.inferenceCount = in.readOptionalLong();
-                    this.avgInferenceTime = in.readOptionalDouble();
-                    this.lastAccess = in.readOptionalInstant();
-                    this.pendingCount = in.readOptionalVInt();
-                    this.routingState = in.readOptionalWriteable(RoutingStateAndReason::new);
-                    this.startTime = in.readOptionalInstant();
-                }
-
-                public DiscoveryNode getNode() {
-                    return node;
-                }
-
-                public RoutingStateAndReason getRoutingState() {
-                    return routingState;
-                }
-
-                public Optional<Long> getInferenceCount() {
-                    return Optional.ofNullable(inferenceCount);
-                }
-
-                public Optional<Double> getAvgInferenceTime() {
-                    return Optional.ofNullable(avgInferenceTime);
-                }
-
-                @Override
-                public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
-                    builder.startObject();
-                    if (node != null) {
-                        builder.startObject("node");
-                        node.toXContent(builder, params);
-                        builder.endObject();
-                    }
-                    builder.field("routing_state", routingState);
-                    if (inferenceCount != null) {
-                        builder.field("inference_count", inferenceCount);
-                    }
-                    if (avgInferenceTime != null) {
-                        builder.field("average_inference_time_ms", avgInferenceTime);
-                    }
-                    if (lastAccess != null) {
-                        builder.timeField("last_access", "last_access_string", lastAccess.toEpochMilli());
-                    }
-                    if (pendingCount != null) {
-                        builder.field("number_of_pending_requests", pendingCount);
-                    }
-                    if (startTime != null) {
-                        builder.timeField("start_time", "start_time_string", startTime.toEpochMilli());
-                    }
-                    builder.endObject();
-                    return builder;
-                }
-
-                @Override
-                public void writeTo(StreamOutput out) throws IOException {
-                    out.writeOptionalWriteable(node);
-                    out.writeOptionalLong(inferenceCount);
-                    out.writeOptionalDouble(avgInferenceTime);
-                    out.writeOptionalInstant(lastAccess);
-                    out.writeOptionalVInt(pendingCount);
-                    out.writeOptionalWriteable(routingState);
-                    out.writeOptionalInstant(startTime);
-                }
-
-                @Override
-                public boolean equals(Object o) {
-                    if (this == o) return true;
-                    if (o == null || getClass() != o.getClass()) return false;
-                    NodeStats that = (NodeStats) o;
-                    return Objects.equals(inferenceCount, that.inferenceCount)
-                        && Objects.equals(that.avgInferenceTime, avgInferenceTime)
-                        && Objects.equals(node, that.node)
-                        && Objects.equals(lastAccess, that.lastAccess)
-                        && Objects.equals(pendingCount, that.pendingCount)
-                        && Objects.equals(routingState, that.routingState)
-                        && Objects.equals(startTime, that.startTime);
-                }
-
-                @Override
-                public int hashCode() {
-                    return Objects.hash(node, inferenceCount, avgInferenceTime, lastAccess, pendingCount, routingState, startTime);
-                }
-            }
-
-            private final String modelId;
-            private AllocationState state;
-            private AllocationStatus allocationStatus;
-            private String reason;
-            @Nullable
-            private final ByteSizeValue modelSize;
-            @Nullable
-            private final Integer inferenceThreads;
-            @Nullable
-            private final Integer modelThreads;
-            @Nullable
-            private final Integer queueCapacity;
-            private final Instant startTime;
-            private final List<NodeStats> nodeStats;
-
-            public AllocationStats(
-                String modelId,
-                @Nullable ByteSizeValue modelSize,
-                @Nullable Integer inferenceThreads,
-                @Nullable Integer modelThreads,
-                @Nullable Integer queueCapacity,
-                Instant startTime,
-                List<NodeStats> nodeStats
-            ) {
-                this.modelId = modelId;
-                this.modelSize = modelSize;
-                this.inferenceThreads = inferenceThreads;
-                this.modelThreads = modelThreads;
-                this.queueCapacity = queueCapacity;
-                this.startTime = Objects.requireNonNull(startTime);
-                this.nodeStats = nodeStats;
-                this.state = null;
-                this.reason = null;
-            }
-
-            public AllocationStats(StreamInput in) throws IOException {
-                modelId = in.readString();
-                modelSize = in.readOptionalWriteable(ByteSizeValue::new);
-                inferenceThreads = in.readOptionalVInt();
-                modelThreads = in.readOptionalVInt();
-                queueCapacity = in.readOptionalVInt();
-                startTime = in.readInstant();
-                nodeStats = in.readList(NodeStats::new);
-                state = in.readOptionalEnum(AllocationState.class);
-                reason = in.readOptionalString();
-                allocationStatus = in.readOptionalWriteable(AllocationStatus::new);
-            }
-
-            public String getModelId() {
-                return modelId;
-            }
-
-            public ByteSizeValue getModelSize() {
-                return modelSize;
-            }
-
-            @Nullable
-            public Integer getInferenceThreads() {
-                return inferenceThreads;
-            }
-
-            @Nullable
-            public Integer getModelThreads() {
-                return modelThreads;
-            }
-
-            @Nullable
-            public Integer getQueueCapacity() {
-                return queueCapacity;
-            }
-
-            public Instant getStartTime() {
-                return startTime;
-            }
-
-            public List<NodeStats> getNodeStats() {
-                return nodeStats;
-            }
-
-            public AllocationState getState() {
-                return state;
-            }
-
-            public AllocationStats setState(AllocationState state) {
-                this.state = state;
-                return this;
-            }
-
-            public AllocationStats setAllocationStatus(AllocationStatus allocationStatus) {
-                this.allocationStatus = allocationStatus;
-                return this;
-            }
-
-            public String getReason() {
-                return reason;
-            }
-
-            public AllocationStats setReason(String reason) {
-                this.reason = reason;
-                return this;
-            }
-
-            @Override
-            public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
-                builder.startObject();
-                builder.field("model_id", modelId);
-                if (modelSize != null) {
-                    builder.humanReadableField("model_size_bytes", "model_size", modelSize);
-                }
-                if (inferenceThreads != null) {
-                    builder.field(StartTrainedModelDeploymentAction.TaskParams.INFERENCE_THREADS.getPreferredName(), inferenceThreads);
-                }
-                if (modelThreads != null) {
-                    builder.field(StartTrainedModelDeploymentAction.TaskParams.MODEL_THREADS.getPreferredName(), modelThreads);
-                }
-                if (queueCapacity != null) {
-                    builder.field(StartTrainedModelDeploymentAction.TaskParams.QUEUE_CAPACITY.getPreferredName(), queueCapacity);
-                }
-                if (state != null) {
-                    builder.field("state", state);
-                }
-                if (reason != null) {
-                    builder.field("reason", reason);
-                }
-                if (allocationStatus != null) {
-                    builder.field("allocation_status", allocationStatus);
-                }
-                builder.timeField("start_time", "start_time_string", startTime.toEpochMilli());
-                builder.startArray("nodes");
-                for (NodeStats nodeStat : nodeStats) {
-                    nodeStat.toXContent(builder, params);
-                }
-                builder.endArray();
-                builder.endObject();
-                return builder;
-            }
-
-            @Override
-            public void writeTo(StreamOutput out) throws IOException {
-                out.writeString(modelId);
-                out.writeOptionalWriteable(modelSize);
-                out.writeOptionalVInt(inferenceThreads);
-                out.writeOptionalVInt(modelThreads);
-                out.writeOptionalVInt(queueCapacity);
-                out.writeInstant(startTime);
-                out.writeList(nodeStats);
-                out.writeOptionalEnum(state);
-                out.writeOptionalString(reason);
-                out.writeOptionalWriteable(allocationStatus);
-            }
-
-            @Override
-            public boolean equals(Object o) {
-                if (this == o) return true;
-                if (o == null || getClass() != o.getClass()) return false;
-                AllocationStats that = (AllocationStats) o;
-                return Objects.equals(modelId, that.modelId)
-                    && Objects.equals(modelSize, that.modelSize)
-                    && Objects.equals(inferenceThreads, that.inferenceThreads)
-                    && Objects.equals(modelThreads, that.modelThreads)
-                    && Objects.equals(queueCapacity, that.queueCapacity)
-                    && Objects.equals(startTime, that.startTime)
-                    && Objects.equals(state, that.state)
-                    && Objects.equals(reason, that.reason)
-                    && Objects.equals(allocationStatus, that.allocationStatus)
-                    && Objects.equals(nodeStats, that.nodeStats);
-            }
-
-            @Override
-            public int hashCode() {
-                return Objects.hash(
-                    modelId,
-                    modelSize,
-                    inferenceThreads,
-                    modelThreads,
-                    queueCapacity,
-                    startTime,
-                    nodeStats,
-                    state,
-                    reason,
-                    allocationStatus
-                );
-            }
-        }
-
         private final QueryPage<AllocationStats> stats;
 
         public Response(

+ 49 - 7
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsAction.java

@@ -7,10 +7,12 @@
 package org.elasticsearch.xpack.core.ml.action;
 
 import org.elasticsearch.ElasticsearchException;
+import org.elasticsearch.Version;
 import org.elasticsearch.action.ActionType;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.core.RestApiVersion;
 import org.elasticsearch.ingest.IngestStats;
 import org.elasticsearch.xcontent.ParseField;
 import org.elasticsearch.xcontent.ToXContentObject;
@@ -19,6 +21,7 @@ import org.elasticsearch.xpack.core.action.AbstractGetResourcesRequest;
 import org.elasticsearch.xpack.core.action.AbstractGetResourcesResponse;
 import org.elasticsearch.xpack.core.action.util.QueryPage;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
+import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationStats;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
 
 import java.io.IOException;
@@ -30,6 +33,8 @@ import java.util.Map;
 import java.util.Objects;
 import java.util.Set;
 
+import static org.elasticsearch.core.RestApiVersion.onOrAfter;
+
 public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStatsAction.Response> {
 
     public static final GetTrainedModelsStatsAction INSTANCE = new GetTrainedModelsStatsAction();
@@ -38,6 +43,7 @@ public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStat
     public static final ParseField MODEL_ID = new ParseField("model_id");
     public static final ParseField PIPELINE_COUNT = new ParseField("pipeline_count");
     public static final ParseField INFERENCE_STATS = new ParseField("inference_stats");
+    public static final ParseField DEPLOYMENT_STATS = new ParseField("deployment_stats");
 
     private GetTrainedModelsStatsAction() {
         super(NAME, GetTrainedModelsStatsAction.Response::new);
@@ -73,6 +79,7 @@ public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStat
             private final String modelId;
             private final IngestStats ingestStats;
             private final InferenceStats inferenceStats;
+            private final AllocationStats deploymentStats;
             private final int pipelineCount;
 
             private static final IngestStats EMPTY_INGEST_STATS = new IngestStats(
@@ -81,7 +88,13 @@ public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStat
                 Collections.emptyMap()
             );
 
-            public TrainedModelStats(String modelId, IngestStats ingestStats, int pipelineCount, InferenceStats inferenceStats) {
+            public TrainedModelStats(
+                String modelId,
+                IngestStats ingestStats,
+                int pipelineCount,
+                InferenceStats inferenceStats,
+                AllocationStats deploymentStats
+            ) {
                 this.modelId = Objects.requireNonNull(modelId);
                 this.ingestStats = ingestStats == null ? EMPTY_INGEST_STATS : ingestStats;
                 if (pipelineCount < 0) {
@@ -89,13 +102,19 @@ public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStat
                 }
                 this.pipelineCount = pipelineCount;
                 this.inferenceStats = inferenceStats;
+                this.deploymentStats = deploymentStats;
             }
 
             public TrainedModelStats(StreamInput in) throws IOException {
                 modelId = in.readString();
                 ingestStats = new IngestStats(in);
                 pipelineCount = in.readVInt();
-                this.inferenceStats = in.readOptionalWriteable(InferenceStats::new);
+                inferenceStats = in.readOptionalWriteable(InferenceStats::new);
+                if (in.getVersion().onOrAfter(Version.V_8_0_0)) {
+                    this.deploymentStats = in.readOptionalWriteable(AllocationStats::new);
+                } else {
+                    this.deploymentStats = null;
+                }
             }
 
             public String getModelId() {
@@ -110,6 +129,14 @@ public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStat
                 return pipelineCount;
             }
 
+            public InferenceStats getInferenceStats() {
+                return inferenceStats;
+            }
+
+            public AllocationStats getDeploymentStats() {
+                return deploymentStats;
+            }
+
             @Override
             public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
                 builder.startObject();
@@ -122,6 +149,9 @@ public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStat
                 if (this.inferenceStats != null) {
                     builder.field(INFERENCE_STATS.getPreferredName(), this.inferenceStats);
                 }
+                if (deploymentStats != null && builder.getRestApiVersion().matches(onOrAfter(RestApiVersion.V_8))) {
+                    builder.field(DEPLOYMENT_STATS.getPreferredName(), this.deploymentStats);
+                }
                 builder.endObject();
                 return builder;
             }
@@ -131,12 +161,15 @@ public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStat
                 out.writeString(modelId);
                 ingestStats.writeTo(out);
                 out.writeVInt(pipelineCount);
-                out.writeOptionalWriteable(this.inferenceStats);
+                out.writeOptionalWriteable(inferenceStats);
+                if (out.getVersion().onOrAfter(Version.V_8_0_0)) {
+                    out.writeOptionalWriteable(deploymentStats);
+                }
             }
 
             @Override
             public int hashCode() {
-                return Objects.hash(modelId, ingestStats, pipelineCount, inferenceStats);
+                return Objects.hash(modelId, ingestStats, pipelineCount, inferenceStats, deploymentStats);
             }
 
             @Override
@@ -151,6 +184,7 @@ public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStat
                 return Objects.equals(this.modelId, other.modelId)
                     && Objects.equals(this.ingestStats, other.ingestStats)
                     && Objects.equals(this.pipelineCount, other.pipelineCount)
+                    && Objects.equals(this.deploymentStats, other.deploymentStats)
                     && Objects.equals(this.inferenceStats, other.inferenceStats);
             }
         }
@@ -176,6 +210,7 @@ public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStat
             private Map<String, Set<String>> expandedIdsWithAliases;
             private Map<String, IngestStats> ingestStatsMap;
             private Map<String, InferenceStats> inferenceStatsMap;
+            private Map<String, AllocationStats> allocationStatsMap;
 
             public Builder setTotalModelCount(long totalModelCount) {
                 this.totalModelCount = totalModelCount;
@@ -196,8 +231,13 @@ public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStat
                 return this;
             }
 
-            public Builder setInferenceStatsByModelId(Map<String, InferenceStats> infereceStatsByModelId) {
-                this.inferenceStatsMap = infereceStatsByModelId;
+            public Builder setInferenceStatsByModelId(Map<String, InferenceStats> inferenceStatsByModelId) {
+                this.inferenceStatsMap = inferenceStatsByModelId;
+                return this;
+            }
+
+            public Builder setDeploymentStatsByModelId(Map<String, AllocationStats> allocationStatsByModelId) {
+                this.allocationStatsMap = allocationStatsByModelId;
                 return this;
             }
 
@@ -206,12 +246,14 @@ public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStat
                 expandedIdsWithAliases.keySet().forEach(id -> {
                     IngestStats ingestStats = ingestStatsMap.get(id);
                     InferenceStats inferenceStats = inferenceStatsMap.get(id);
+                    AllocationStats allocationStats = allocationStatsMap.get(id);
                     trainedModelStats.add(
                         new TrainedModelStats(
                             id,
                             ingestStats,
                             ingestStats == null ? 0 : ingestStats.getPipelineStats().size(),
-                            inferenceStats
+                            inferenceStats,
+                            allocationStats
                         )
                     );
                 });

+ 349 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/AllocationStats.java

@@ -0,0 +1,349 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.allocation;
+
+import org.elasticsearch.cluster.node.DiscoveryNode;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.unit.ByteSizeValue;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.xcontent.ToXContentObject;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
+
+import java.io.IOException;
+import java.time.Instant;
+import java.util.List;
+import java.util.Objects;
+import java.util.Optional;
+
+public class AllocationStats implements ToXContentObject, Writeable {
+
+    public static class NodeStats implements ToXContentObject, Writeable {
+        private final DiscoveryNode node;
+        private final Long inferenceCount;
+        private final Double avgInferenceTime;
+        private final Instant lastAccess;
+        private final Integer pendingCount;
+        private final RoutingStateAndReason routingState;
+        private final Instant startTime;
+
+        public static AllocationStats.NodeStats forStartedState(
+            DiscoveryNode node,
+            long inferenceCount,
+            Double avgInferenceTime,
+            int pendingCount,
+            Instant lastAccess,
+            Instant startTime
+        ) {
+            return new AllocationStats.NodeStats(
+                node,
+                inferenceCount,
+                avgInferenceTime,
+                lastAccess,
+                pendingCount,
+                new RoutingStateAndReason(RoutingState.STARTED, null),
+                Objects.requireNonNull(startTime)
+            );
+        }
+
+        public static AllocationStats.NodeStats forNotStartedState(DiscoveryNode node, RoutingState state, String reason) {
+            return new AllocationStats.NodeStats(node, null, null, null, null, new RoutingStateAndReason(state, reason), null);
+        }
+
+        private NodeStats(
+            DiscoveryNode node,
+            Long inferenceCount,
+            Double avgInferenceTime,
+            Instant lastAccess,
+            Integer pendingCount,
+            RoutingStateAndReason routingState,
+            @Nullable Instant startTime
+        ) {
+            this.node = node;
+            this.inferenceCount = inferenceCount;
+            this.avgInferenceTime = avgInferenceTime;
+            this.lastAccess = lastAccess;
+            this.pendingCount = pendingCount;
+            this.routingState = routingState;
+            this.startTime = startTime;
+
+            // if lastAccess time is null there have been no inferences
+            assert this.lastAccess != null || (inferenceCount == null || inferenceCount == 0);
+        }
+
+        public NodeStats(StreamInput in) throws IOException {
+            this.node = in.readOptionalWriteable(DiscoveryNode::new);
+            this.inferenceCount = in.readOptionalLong();
+            this.avgInferenceTime = in.readOptionalDouble();
+            this.lastAccess = in.readOptionalInstant();
+            this.pendingCount = in.readOptionalVInt();
+            this.routingState = in.readOptionalWriteable(RoutingStateAndReason::new);
+            this.startTime = in.readOptionalInstant();
+        }
+
+        public DiscoveryNode getNode() {
+            return node;
+        }
+
+        public RoutingStateAndReason getRoutingState() {
+            return routingState;
+        }
+
+        public Optional<Long> getInferenceCount() {
+            return Optional.ofNullable(inferenceCount);
+        }
+
+        public Optional<Double> getAvgInferenceTime() {
+            return Optional.ofNullable(avgInferenceTime);
+        }
+
+        @Override
+        public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+            builder.startObject();
+            if (node != null) {
+                builder.startObject("node");
+                node.toXContent(builder, params);
+                builder.endObject();
+            }
+            builder.field("routing_state", routingState);
+            if (inferenceCount != null) {
+                builder.field("inference_count", inferenceCount);
+            }
+            if (avgInferenceTime != null) {
+                builder.field("average_inference_time_ms", avgInferenceTime);
+            }
+            if (lastAccess != null) {
+                builder.timeField("last_access", "last_access_string", lastAccess.toEpochMilli());
+            }
+            if (pendingCount != null) {
+                builder.field("number_of_pending_requests", pendingCount);
+            }
+            if (startTime != null) {
+                builder.timeField("start_time", "start_time_string", startTime.toEpochMilli());
+            }
+            builder.endObject();
+            return builder;
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            out.writeOptionalWriteable(node);
+            out.writeOptionalLong(inferenceCount);
+            out.writeOptionalDouble(avgInferenceTime);
+            out.writeOptionalInstant(lastAccess);
+            out.writeOptionalVInt(pendingCount);
+            out.writeOptionalWriteable(routingState);
+            out.writeOptionalInstant(startTime);
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            AllocationStats.NodeStats that = (AllocationStats.NodeStats) o;
+            return Objects.equals(inferenceCount, that.inferenceCount)
+                && Objects.equals(that.avgInferenceTime, avgInferenceTime)
+                && Objects.equals(node, that.node)
+                && Objects.equals(lastAccess, that.lastAccess)
+                && Objects.equals(pendingCount, that.pendingCount)
+                && Objects.equals(routingState, that.routingState)
+                && Objects.equals(startTime, that.startTime);
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(node, inferenceCount, avgInferenceTime, lastAccess, pendingCount, routingState, startTime);
+        }
+    }
+
+    private final String modelId;
+    private AllocationState state;
+    private AllocationStatus allocationStatus;
+    private String reason;
+    @Nullable
+    private final ByteSizeValue modelSize;
+    @Nullable
+    private final Integer inferenceThreads;
+    @Nullable
+    private final Integer modelThreads;
+    @Nullable
+    private final Integer queueCapacity;
+    private final Instant startTime;
+    private final List<AllocationStats.NodeStats> nodeStats;
+
+    public AllocationStats(
+        String modelId,
+        @Nullable ByteSizeValue modelSize,
+        @Nullable Integer inferenceThreads,
+        @Nullable Integer modelThreads,
+        @Nullable Integer queueCapacity,
+        Instant startTime,
+        List<AllocationStats.NodeStats> nodeStats
+    ) {
+        this.modelId = modelId;
+        this.modelSize = modelSize;
+        this.inferenceThreads = inferenceThreads;
+        this.modelThreads = modelThreads;
+        this.queueCapacity = queueCapacity;
+        this.startTime = Objects.requireNonNull(startTime);
+        this.nodeStats = nodeStats;
+        this.state = null;
+        this.reason = null;
+    }
+
+    public AllocationStats(StreamInput in) throws IOException {
+        modelId = in.readString();
+        modelSize = in.readOptionalWriteable(ByteSizeValue::new);
+        inferenceThreads = in.readOptionalVInt();
+        modelThreads = in.readOptionalVInt();
+        queueCapacity = in.readOptionalVInt();
+        startTime = in.readInstant();
+        nodeStats = in.readList(AllocationStats.NodeStats::new);
+        state = in.readOptionalEnum(AllocationState.class);
+        reason = in.readOptionalString();
+        allocationStatus = in.readOptionalWriteable(AllocationStatus::new);
+    }
+
+    public String getModelId() {
+        return modelId;
+    }
+
+    public ByteSizeValue getModelSize() {
+        return modelSize;
+    }
+
+    @Nullable
+    public Integer getInferenceThreads() {
+        return inferenceThreads;
+    }
+
+    @Nullable
+    public Integer getModelThreads() {
+        return modelThreads;
+    }
+
+    @Nullable
+    public Integer getQueueCapacity() {
+        return queueCapacity;
+    }
+
+    public Instant getStartTime() {
+        return startTime;
+    }
+
+    public List<AllocationStats.NodeStats> getNodeStats() {
+        return nodeStats;
+    }
+
+    public AllocationState getState() {
+        return state;
+    }
+
+    public AllocationStats setState(AllocationState state) {
+        this.state = state;
+        return this;
+    }
+
+    public AllocationStats setAllocationStatus(AllocationStatus allocationStatus) {
+        this.allocationStatus = allocationStatus;
+        return this;
+    }
+
+    public String getReason() {
+        return reason;
+    }
+
+    public AllocationStats setReason(String reason) {
+        this.reason = reason;
+        return this;
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        builder.field("model_id", modelId);
+        if (modelSize != null) {
+            builder.humanReadableField("model_size_bytes", "model_size", modelSize);
+        }
+        if (inferenceThreads != null) {
+            builder.field(StartTrainedModelDeploymentAction.TaskParams.INFERENCE_THREADS.getPreferredName(), inferenceThreads);
+        }
+        if (modelThreads != null) {
+            builder.field(StartTrainedModelDeploymentAction.TaskParams.MODEL_THREADS.getPreferredName(), modelThreads);
+        }
+        if (queueCapacity != null) {
+            builder.field(StartTrainedModelDeploymentAction.TaskParams.QUEUE_CAPACITY.getPreferredName(), queueCapacity);
+        }
+        if (state != null) {
+            builder.field("state", state);
+        }
+        if (reason != null) {
+            builder.field("reason", reason);
+        }
+        if (allocationStatus != null) {
+            builder.field("allocation_status", allocationStatus);
+        }
+        builder.timeField("start_time", "start_time_string", startTime.toEpochMilli());
+        builder.startArray("nodes");
+        for (AllocationStats.NodeStats nodeStat : nodeStats) {
+            nodeStat.toXContent(builder, params);
+        }
+        builder.endArray();
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeString(modelId);
+        out.writeOptionalWriteable(modelSize);
+        out.writeOptionalVInt(inferenceThreads);
+        out.writeOptionalVInt(modelThreads);
+        out.writeOptionalVInt(queueCapacity);
+        out.writeInstant(startTime);
+        out.writeList(nodeStats);
+        out.writeOptionalEnum(state);
+        out.writeOptionalString(reason);
+        out.writeOptionalWriteable(allocationStatus);
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        AllocationStats that = (AllocationStats) o;
+        return Objects.equals(modelId, that.modelId)
+            && Objects.equals(modelSize, that.modelSize)
+            && Objects.equals(inferenceThreads, that.inferenceThreads)
+            && Objects.equals(modelThreads, that.modelThreads)
+            && Objects.equals(queueCapacity, that.queueCapacity)
+            && Objects.equals(startTime, that.startTime)
+            && Objects.equals(state, that.state)
+            && Objects.equals(reason, that.reason)
+            && Objects.equals(allocationStatus, that.allocationStatus)
+            && Objects.equals(nodeStats, that.nodeStats);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(
+            modelId,
+            modelSize,
+            inferenceThreads,
+            modelThreads,
+            queueCapacity,
+            startTime,
+            nodeStats,
+            state,
+            reason,
+            allocationStatus
+        );
+    }
+}

+ 5 - 49
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDeploymentStatsActionResponseTests.java

@@ -7,20 +7,15 @@
 
 package org.elasticsearch.xpack.core.ml.action;
 
-import org.elasticsearch.Version;
-import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.common.io.stream.Writeable;
-import org.elasticsearch.common.transport.TransportAddress;
-import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.test.AbstractWireSerializingTestCase;
-import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingState;
+import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationStats;
 
-import java.net.InetAddress;
-import java.time.Instant;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.Comparator;
-import java.util.List;
+
+import static org.elasticsearch.xpack.core.ml.inference.allocation.AllocationStatsTests.randomDeploymentStats;
 
 public class GetDeploymentStatsActionResponseTests extends AbstractWireSerializingTestCase<GetDeploymentStatsAction.Response> {
     @Override
@@ -35,51 +30,12 @@ public class GetDeploymentStatsActionResponseTests extends AbstractWireSerializi
 
     public static GetDeploymentStatsAction.Response createRandom() {
         int numStats = randomIntBetween(0, 2);
-        var stats = new ArrayList<GetDeploymentStatsAction.Response.AllocationStats>(numStats);
+        var stats = new ArrayList<AllocationStats>(numStats);
         for (var i = 0; i < numStats; i++) {
             stats.add(randomDeploymentStats());
         }
-        stats.sort(Comparator.comparing(GetDeploymentStatsAction.Response.AllocationStats::getModelId));
+        stats.sort(Comparator.comparing(AllocationStats::getModelId));
         return new GetDeploymentStatsAction.Response(Collections.emptyList(), Collections.emptyList(), stats, stats.size());
     }
 
-    private static GetDeploymentStatsAction.Response.AllocationStats randomDeploymentStats() {
-        List<GetDeploymentStatsAction.Response.AllocationStats.NodeStats> nodeStatsList = new ArrayList<>();
-        int numNodes = randomIntBetween(1, 4);
-        for (int i = 0; i < numNodes; i++) {
-            var node = new DiscoveryNode("node_" + i, new TransportAddress(InetAddress.getLoopbackAddress(), 9300), Version.CURRENT);
-            if (randomBoolean()) {
-                nodeStatsList.add(
-                    GetDeploymentStatsAction.Response.AllocationStats.NodeStats.forStartedState(
-                        node,
-                        randomNonNegativeLong(),
-                        randomBoolean() ? randomDoubleBetween(0.0, 100.0, true) : null,
-                        randomIntBetween(0, 100),
-                        Instant.now(),
-                        Instant.now()
-                    )
-                );
-            } else {
-                nodeStatsList.add(
-                    GetDeploymentStatsAction.Response.AllocationStats.NodeStats.forNotStartedState(
-                        node,
-                        randomFrom(RoutingState.values()),
-                        randomBoolean() ? null : "a good reason"
-                    )
-                );
-            }
-        }
-
-        nodeStatsList.sort(Comparator.comparing(n -> n.getNode().getId()));
-
-        return new GetDeploymentStatsAction.Response.AllocationStats(
-            randomAlphaOfLength(5),
-            ByteSizeValue.ofBytes(randomNonNegativeLong()),
-            randomBoolean() ? null : randomIntBetween(1, 8),
-            randomBoolean() ? null : randomIntBetween(1, 8),
-            randomBoolean() ? null : randomIntBetween(1, 10000),
-            Instant.now(),
-            nodeStatsList
-        );
-    }
 }

+ 24 - 1
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsActionResponseTests.java

@@ -12,6 +12,7 @@ import org.elasticsearch.ingest.IngestStats;
 import org.elasticsearch.xpack.core.action.util.QueryPage;
 import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
 import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction.Response;
+import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationStatsTests;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStatsTests;
 
 import java.util.List;
@@ -33,7 +34,8 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractBWCWireSer
                     id,
                     randomBoolean() ? randomIngestStats() : null,
                     randomIntBetween(0, 10),
-                    randomBoolean() ? InferenceStatsTests.createTestInstance(id, null) : null
+                    randomBoolean() ? InferenceStatsTests.createTestInstance(id, null) : null,
+                    randomBoolean() ? AllocationStatsTests.randomDeploymentStats() : null
                 )
             )
             .collect(Collectors.toList());
@@ -69,6 +71,27 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractBWCWireSer
 
     @Override
     protected Response mutateInstanceForVersion(Response instance, Version version) {
+        if (version.before(Version.V_8_0_0)) {
+            return new Response(
+                new QueryPage<>(
+                    instance.getResources()
+                        .results()
+                        .stream()
+                        .map(
+                            stats -> new Response.TrainedModelStats(
+                                stats.getModelId(),
+                                stats.getIngestStats(),
+                                stats.getPipelineCount(),
+                                stats.getInferenceStats(),
+                                null
+                            )
+                        )
+                        .collect(Collectors.toList()),
+                    instance.getResources().count(),
+                    RESULTS_FIELD
+                )
+            );
+        }
         return instance;
     }
 

+ 74 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/AllocationStatsTests.java

@@ -0,0 +1,74 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.allocation;
+
+import org.elasticsearch.Version;
+import org.elasticsearch.cluster.node.DiscoveryNode;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.transport.TransportAddress;
+import org.elasticsearch.common.unit.ByteSizeValue;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+
+import java.net.InetAddress;
+import java.time.Instant;
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.List;
+
+public class AllocationStatsTests extends AbstractWireSerializingTestCase<AllocationStats> {
+
+    public static AllocationStats randomDeploymentStats() {
+        List<AllocationStats.NodeStats> nodeStatsList = new ArrayList<>();
+        int numNodes = randomIntBetween(1, 4);
+        for (int i = 0; i < numNodes; i++) {
+            var node = new DiscoveryNode("node_" + i, new TransportAddress(InetAddress.getLoopbackAddress(), 9300), Version.CURRENT);
+            if (randomBoolean()) {
+                nodeStatsList.add(
+                    AllocationStats.NodeStats.forStartedState(
+                        node,
+                        randomNonNegativeLong(),
+                        randomBoolean() ? randomDoubleBetween(0.0, 100.0, true) : null,
+                        randomIntBetween(0, 100),
+                        Instant.now(),
+                        Instant.now()
+                    )
+                );
+            } else {
+                nodeStatsList.add(
+                    AllocationStats.NodeStats.forNotStartedState(
+                        node,
+                        randomFrom(RoutingState.values()),
+                        randomBoolean() ? null : "a good reason"
+                    )
+                );
+            }
+        }
+
+        nodeStatsList.sort(Comparator.comparing(n -> n.getNode().getId()));
+
+        return new AllocationStats(
+            randomAlphaOfLength(5),
+            ByteSizeValue.ofBytes(randomNonNegativeLong()),
+            randomBoolean() ? null : randomIntBetween(1, 8),
+            randomBoolean() ? null : randomIntBetween(1, 8),
+            randomBoolean() ? null : randomIntBetween(1, 10000),
+            Instant.now(),
+            nodeStatsList
+        );
+    }
+
+    @Override
+    protected Writeable.Reader<AllocationStats> instanceReader() {
+        return AllocationStats::new;
+    }
+
+    @Override
+    protected AllocationStats createTestInstance() {
+        return randomDeploymentStats();
+    }
+}

+ 42 - 80
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java

@@ -20,7 +20,6 @@ import org.elasticsearch.test.SecuritySettingsSourceField;
 import org.elasticsearch.test.rest.ESRestTestCase;
 import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationStatus;
 import org.elasticsearch.xpack.core.ml.integration.MlRestTestStateCleaner;
-import org.elasticsearch.xpack.core.ml.utils.MapHelper;
 import org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
 import org.junit.After;
@@ -42,9 +41,9 @@ import java.util.stream.Collectors;
 import static org.elasticsearch.xpack.ml.integration.InferenceIngestIT.putPipeline;
 import static org.elasticsearch.xpack.ml.integration.InferenceIngestIT.simulateRequest;
 import static org.hamcrest.Matchers.containsString;
-import static org.hamcrest.Matchers.empty;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.greaterThanOrEqualTo;
+import static org.hamcrest.Matchers.hasKey;
 import static org.hamcrest.Matchers.hasSize;
 import static org.hamcrest.Matchers.is;
 import static org.hamcrest.Matchers.not;
@@ -234,22 +233,20 @@ public class PyTorchModelIT extends ESRestTestCase {
 
         CheckedBiConsumer<String, AllocationStatus.State, IOException> assertAtLeast = (modelId, state) -> {
             startDeployment(modelId, state.toString());
-            Response response = getDeploymentStats(modelId);
-            List<Map<String, Object>> stats = (List<Map<String, Object>>) entityAsMap(response).get("deployment_stats");
+            Response response = getTrainedModelStats(modelId);
+            List<Map<String, Object>> stats = (List<Map<String, Object>>) entityAsMap(response).get("trained_model_stats");
             assertThat(stats, hasSize(1));
-            String statusState = (String) XContentMapValues.extractValue("allocation_status.state", stats.get(0));
+            String statusState = (String) XContentMapValues.extractValue("deployment_stats.allocation_status.state", stats.get(0));
             assertThat(stats.toString(), statusState, is(not(nullValue())));
             assertThat(AllocationStatus.State.fromString(statusState), greaterThanOrEqualTo(state));
-            Integer byteSize = (Integer) XContentMapValues.extractValue("model_size_bytes", stats.get(0));
+            Integer byteSize = (Integer) XContentMapValues.extractValue("deployment_stats.model_size_bytes", stats.get(0));
             assertThat(byteSize, is(not(nullValue())));
             assertThat(byteSize, equalTo((int) RAW_MODEL_SIZE));
 
-            Response humanResponse = client().performRequest(
-                new Request("GET", "/_ml/trained_models/" + modelId + "/deployment/_stats?human")
-            );
-            stats = (List<Map<String, Object>>) entityAsMap(humanResponse).get("deployment_stats");
+            Response humanResponse = client().performRequest(new Request("GET", "/_ml/trained_models/" + modelId + "/_stats?human"));
+            stats = (List<Map<String, Object>>) entityAsMap(humanResponse).get("trained_model_stats");
             assertThat(stats, hasSize(1));
-            String stringBytes = (String) XContentMapValues.extractValue("model_size", stats.get(0));
+            String stringBytes = (String) XContentMapValues.extractValue("deployment_stats.model_size", stats.get(0));
             assertThat(stringBytes, is(not(nullValue())));
             assertThat(stringBytes, equalTo("1.5kb"));
             stopDeployment(model);
@@ -270,12 +267,15 @@ public class PyTorchModelIT extends ESRestTestCase {
         startDeployment(modelA, AllocationStatus.State.FULLY_ALLOCATED.toString());
         infer("once", modelA);
         infer("twice", modelA);
-        Response response = getDeploymentStats(modelA);
-        List<Map<String, Object>> stats = (List<Map<String, Object>>) entityAsMap(response).get("deployment_stats");
+        Response response = getTrainedModelStats(modelA);
+        List<Map<String, Object>> stats = (List<Map<String, Object>>) entityAsMap(response).get("trained_model_stats");
         assertThat(stats, hasSize(1));
-        assertThat(stats.get(0).get("model_id"), equalTo(modelA));
-        assertThat(stats.get(0).get("model_size_bytes"), equalTo((int) RAW_MODEL_SIZE));
-        List<Map<String, Object>> nodes = (List<Map<String, Object>>) stats.get(0).get("nodes");
+        assertThat(XContentMapValues.extractValue("deployment_stats.model_id", stats.get(0)), equalTo(modelA));
+        assertThat(XContentMapValues.extractValue("deployment_stats.model_size_bytes", stats.get(0)), equalTo((int) RAW_MODEL_SIZE));
+        List<Map<String, Object>> nodes = (List<Map<String, Object>>) XContentMapValues.extractValue(
+            "deployment_stats.nodes",
+            stats.get(0)
+        );
         // 2 of the 3 nodes in the cluster are ML nodes
         assertThat(nodes, hasSize(2));
         int inferenceCount = sumInferenceCountOnNodes(nodes);
@@ -288,13 +288,6 @@ public class PyTorchModelIT extends ESRestTestCase {
 
     @SuppressWarnings("unchecked")
     public void testGetDeploymentStats_WithWildcard() throws IOException {
-
-        {
-            // No deployments is an error when allow_no_match == false
-            expectThrows(ResponseException.class, () -> getDeploymentStats("*", false));
-            getDeploymentStats("*", true);
-        }
-
         String modelFoo = "foo";
         createTrainedModel(modelFoo);
         putVocabulary(List.of("once", "twice"), modelFoo);
@@ -310,54 +303,25 @@ public class PyTorchModelIT extends ESRestTestCase {
         infer("once", modelFoo);
         infer("once", modelBar);
         {
-            Response response = getDeploymentStats("*");
+            Response response = getTrainedModelStats("f*");
             Map<String, Object> map = entityAsMap(response);
-            List<Map<String, Object>> stats = (List<Map<String, Object>>) map.get("deployment_stats");
-            assertThat(stats, hasSize(2));
-            assertThat(stats.get(0).get("model_id"), equalTo(modelBar));
-            assertThat(stats.get(1).get("model_id"), equalTo(modelFoo));
-            List<Map<String, Object>> barNodes = (List<Map<String, Object>>) stats.get(0).get("nodes");
-            // 2 of the 3 nodes in the cluster are ML nodes
-            assertThat(barNodes, hasSize(2));
-            assertThat(sumInferenceCountOnNodes(barNodes), equalTo(1));
-            List<Map<String, Object>> fooNodes = (List<Map<String, Object>>) stats.get(0).get("nodes");
-            assertThat(fooNodes, hasSize(2));
-            assertThat(sumInferenceCountOnNodes(fooNodes), equalTo(1));
-        }
-        {
-            Response response = getDeploymentStats("f*");
-            Map<String, Object> map = entityAsMap(response);
-            List<Map<String, Object>> stats = (List<Map<String, Object>>) map.get("deployment_stats");
+            List<Map<String, Object>> stats = (List<Map<String, Object>>) map.get("trained_model_stats");
             assertThat(stats, hasSize(1));
-            assertThat(stats.get(0).get("model_id"), equalTo(modelFoo));
+            assertThat(XContentMapValues.extractValue("deployment_stats.model_id", stats.get(0)), equalTo(modelFoo));
         }
         {
-            Response response = getDeploymentStats("bar");
+            Response response = getTrainedModelStats("bar");
             Map<String, Object> map = entityAsMap(response);
-            List<Map<String, Object>> stats = (List<Map<String, Object>>) map.get("deployment_stats");
+            List<Map<String, Object>> stats = (List<Map<String, Object>>) map.get("trained_model_stats");
             assertThat(stats, hasSize(1));
-            assertThat(stats.get(0).get("model_id"), equalTo(modelBar));
-        }
-        {
-            ResponseException e = expectThrows(ResponseException.class, () -> getDeploymentStats("c*", false));
-            assertThat(
-                EntityUtils.toString(e.getResponse().getEntity()),
-                containsString("No known trained model with deployment with id [c*]")
-            );
-        }
-        {
-            ResponseException e = expectThrows(ResponseException.class, () -> getDeploymentStats("foo,c*", false));
-            assertThat(
-                EntityUtils.toString(e.getResponse().getEntity()),
-                containsString("No known trained model with deployment with id [c*]")
-            );
+            assertThat(XContentMapValues.extractValue("deployment_stats.model_id", stats.get(0)), equalTo(modelBar));
         }
     }
 
     @SuppressWarnings("unchecked")
     public void testGetDeploymentStats_WithStartedStoppedDeployments() throws IOException {
         String modelFoo = "foo";
-        String modelBar = "bar";
+        String modelBar = "foo-2";
         createTrainedModel(modelFoo);
         putVocabulary(List.of("once", "twice"), modelFoo);
         putModelDefinition(modelFoo);
@@ -371,45 +335,47 @@ public class PyTorchModelIT extends ESRestTestCase {
         infer("once", modelFoo);
         infer("once", modelBar);
 
-        Response response = getDeploymentStats("*");
+        Response response = getTrainedModelStats("foo*");
         Map<String, Object> map = entityAsMap(response);
-        List<Map<String, Object>> stats = (List<Map<String, Object>>) map.get("deployment_stats");
+        List<Map<String, Object>> stats = (List<Map<String, Object>>) map.get("trained_model_stats");
         assertThat(stats, hasSize(2));
 
         // check all nodes are started
         for (int i : new int[] { 0, 1 }) {
-            List<Map<String, Object>> nodes = (List<Map<String, Object>>) stats.get(i).get("nodes");
+            List<Map<String, Object>> nodes = (List<Map<String, Object>>) XContentMapValues.extractValue(
+                "deployment_stats.nodes",
+                stats.get(i)
+            );
             // 2 ml nodes
             assertThat(nodes, hasSize(2));
             for (int j : new int[] { 0, 1 }) {
-                Object state = MapHelper.dig("routing_state.routing_state", nodes.get(j));
+                Object state = XContentMapValues.extractValue("routing_state.routing_state", nodes.get(j));
                 assertEquals("started", state);
             }
         }
 
         stopDeployment(modelFoo);
 
-        response = getDeploymentStats("*");
+        response = getTrainedModelStats("foo*");
         map = entityAsMap(response);
-        stats = (List<Map<String, Object>>) map.get("deployment_stats");
+        stats = (List<Map<String, Object>>) map.get("trained_model_stats");
 
-        assertThat(stats, hasSize(1));
+        assertThat(stats, hasSize(2));
+        assertThat(stats.get(0), not(hasKey("deployment_stats")));
 
-        // check all nodes are started
-        List<Map<String, Object>> nodes = (List<Map<String, Object>>) stats.get(0).get("nodes");
+        // check all nodes are started for the non-stopped deployment
+        List<Map<String, Object>> nodes = (List<Map<String, Object>>) XContentMapValues.extractValue(
+            "deployment_stats.nodes",
+            stats.get(1)
+        );
         // 2 ml nodes
         assertThat(nodes, hasSize(2));
         for (int j : new int[] { 0, 1 }) {
-            Object state = MapHelper.dig("routing_state.routing_state", nodes.get(j));
+            Object state = XContentMapValues.extractValue("routing_state.routing_state", nodes.get(j));
             assertEquals("started", state);
         }
 
         stopDeployment(modelBar);
-
-        response = getDeploymentStats("*");
-        map = entityAsMap(response);
-        stats = (List<Map<String, Object>>) map.get("deployment_stats");
-        assertThat(stats, empty());
     }
 
     public void testInferWithMissingModel() {
@@ -677,12 +643,8 @@ public class PyTorchModelIT extends ESRestTestCase {
         client().performRequest(request);
     }
 
-    private Response getDeploymentStats(String modelId) throws IOException {
-        return getDeploymentStats(modelId, true);
-    }
-
-    private Response getDeploymentStats(String modelId, boolean allowNoMatch) throws IOException {
-        Request request = new Request("GET", "/_ml/trained_models/" + modelId + "/deployment/_stats?allow_no_match=" + allowNoMatch);
+    private Response getTrainedModelStats(String modelId) throws IOException {
+        Request request = new Request("GET", "/_ml/trained_models/" + modelId + "/_stats");
         return client().performRequest(request);
     }
 

+ 0 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java

@@ -374,7 +374,6 @@ import org.elasticsearch.xpack.ml.rest.filter.RestPutFilterAction;
 import org.elasticsearch.xpack.ml.rest.filter.RestUpdateFilterAction;
 import org.elasticsearch.xpack.ml.rest.inference.RestDeleteTrainedModelAction;
 import org.elasticsearch.xpack.ml.rest.inference.RestDeleteTrainedModelAliasAction;
-import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelDeploymentStatsAction;
 import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsAction;
 import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsStatsAction;
 import org.elasticsearch.xpack.ml.rest.inference.RestInferTrainedModelDeploymentAction;
@@ -1191,7 +1190,6 @@ public class MachineLearning extends Plugin
             new RestPutTrainedModelAliasAction(),
             new RestDeleteTrainedModelAliasAction(),
             new RestPreviewDataFrameAnalyticsAction(),
-            new RestGetTrainedModelDeploymentStatsAction(),
             new RestStartTrainedModelDeploymentAction(),
             new RestStopTrainedModelDeploymentAction(),
             new RestInferTrainedModelDeploymentAction(),

+ 23 - 43
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsAction.java

@@ -26,10 +26,10 @@ import org.elasticsearch.xpack.core.action.util.ExpandedIdsMatcher;
 import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
 import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
 import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationState;
+import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationStats;
 import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingState;
 import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingStateAndReason;
 import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocation;
-import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationMetadata;
 import org.elasticsearch.xpack.ml.inference.deployment.ModelStats;
 import org.elasticsearch.xpack.ml.inference.deployment.TrainedModelDeploymentTask;
@@ -52,7 +52,7 @@ public class TransportGetDeploymentStatsAction extends TransportTasksAction<
     TrainedModelDeploymentTask,
     GetDeploymentStatsAction.Request,
     GetDeploymentStatsAction.Response,
-    GetDeploymentStatsAction.Response.AllocationStats> {
+    AllocationStats> {
 
     @Inject
     public TransportGetDeploymentStatsAction(
@@ -67,7 +67,7 @@ public class TransportGetDeploymentStatsAction extends TransportTasksAction<
             actionFilters,
             GetDeploymentStatsAction.Request::new,
             GetDeploymentStatsAction.Response::new,
-            GetDeploymentStatsAction.Response.AllocationStats::new,
+            AllocationStats::new,
             ThreadPool.Names.MANAGEMENT
         );
     }
@@ -75,18 +75,18 @@ public class TransportGetDeploymentStatsAction extends TransportTasksAction<
     @Override
     protected GetDeploymentStatsAction.Response newResponse(
         GetDeploymentStatsAction.Request request,
-        List<GetDeploymentStatsAction.Response.AllocationStats> taskResponse,
+        List<AllocationStats> taskResponse,
         List<TaskOperationFailure> taskOperationFailures,
         List<FailedNodeException> failedNodeExceptions
     ) {
         // group the stats by model and merge individual node stats
         var mergedNodeStatsByModel = taskResponse.stream()
-            .collect(Collectors.toMap(GetDeploymentStatsAction.Response.AllocationStats::getModelId, Function.identity(), (l, r) -> {
+            .collect(Collectors.toMap(AllocationStats::getModelId, Function.identity(), (l, r) -> {
                 l.getNodeStats().addAll(r.getNodeStats());
                 return l;
             }, TreeMap::new));
 
-        List<GetDeploymentStatsAction.Response.AllocationStats> bunchedAndSorted = new ArrayList<>(mergedNodeStatsByModel.values());
+        List<AllocationStats> bunchedAndSorted = new ArrayList<>(mergedNodeStatsByModel.values());
 
         return new GetDeploymentStatsAction.Response(
             taskOperationFailures,
@@ -130,12 +130,8 @@ public class TransportGetDeploymentStatsAction extends TransportTasksAction<
         }
 
         // check request has been satisfied
-        ExpandedIdsMatcher requiredIdsMatcher = new ExpandedIdsMatcher(tokenizedRequestIds, request.isAllowNoMatch());
+        ExpandedIdsMatcher requiredIdsMatcher = new ExpandedIdsMatcher(tokenizedRequestIds, true);
         requiredIdsMatcher.filterMatchedIds(matchedDeploymentIds);
-        if (requiredIdsMatcher.hasUnmatchedIds()) {
-            listener.onFailure(ExceptionsHelper.missingDeployment(requiredIdsMatcher.unmatchedIdsString()));
-            return;
-        }
         if (matchedDeploymentIds.isEmpty()) {
             listener.onResponse(
                 new GetDeploymentStatsAction.Response(Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), 0L)
@@ -157,7 +153,7 @@ public class TransportGetDeploymentStatsAction extends TransportTasksAction<
                 .filter(StartTrainedModelDeploymentAction.TaskParams::mayAllocateToNode)
                 .collect(Collectors.toList());
             // Set the allocation state and reason if we have it
-            for (GetDeploymentStatsAction.Response.AllocationStats stats : updatedResponse.getStats().results()) {
+            for (AllocationStats stats : updatedResponse.getStats().results()) {
                 Optional<TrainedModelAllocation> modelAllocation = Optional.ofNullable(allocation.getModelAllocation(stats.getModelId()));
                 TrainedModelAllocation trainedModelAllocation = modelAllocation.orElse(null);
                 if (trainedModelAllocation != null) {
@@ -195,15 +191,15 @@ public class TransportGetDeploymentStatsAction extends TransportTasksAction<
             .stream()
             .collect(Collectors.toMap(TrainedModelAllocation::getModelId, Function.identity()));
 
-        final List<GetDeploymentStatsAction.Response.AllocationStats> updatedAllocationStats = new ArrayList<>();
+        final List<AllocationStats> updatedAllocationStats = new ArrayList<>();
 
-        for (GetDeploymentStatsAction.Response.AllocationStats stat : tasksResponse.getStats().results()) {
+        for (AllocationStats stat : tasksResponse.getStats().results()) {
             if (modelToAllocationWithNonStartedRoutes.containsKey(stat.getModelId())) {
                 // there is merging to be done
                 Map<String, RoutingStateAndReason> nodeToRoutingStates = allocationNonStartedRoutes.get(
                     modelToAllocationWithNonStartedRoutes.get(stat.getModelId())
                 );
-                List<GetDeploymentStatsAction.Response.AllocationStats.NodeStats> updatedNodeStats = new ArrayList<>();
+                List<AllocationStats.NodeStats> updatedNodeStats = new ArrayList<>();
 
                 Set<String> visitedNodes = new HashSet<>();
                 for (var nodeStat : stat.getNodeStats()) {
@@ -214,7 +210,7 @@ public class TransportGetDeploymentStatsAction extends TransportTasksAction<
                         // of the state of the task - it may be starting, started, stopping, or stopped.
                         RoutingStateAndReason stateAndReason = nodeToRoutingStates.get(nodeStat.getNode().getId());
                         updatedNodeStats.add(
-                            GetDeploymentStatsAction.Response.AllocationStats.NodeStats.forNotStartedState(
+                            AllocationStats.NodeStats.forNotStartedState(
                                 nodeStat.getNode(),
                                 stateAndReason.getState(),
                                 stateAndReason.getReason()
@@ -231,7 +227,7 @@ public class TransportGetDeploymentStatsAction extends TransportTasksAction<
                 for (var nodeRoutingState : nodeToRoutingStates.entrySet()) {
                     if (visitedNodes.contains(nodeRoutingState.getKey()) == false) {
                         updatedNodeStats.add(
-                            GetDeploymentStatsAction.Response.AllocationStats.NodeStats.forNotStartedState(
+                            AllocationStats.NodeStats.forNotStartedState(
                                 nodes.get(nodeRoutingState.getKey()),
                                 nodeRoutingState.getValue().getState(),
                                 nodeRoutingState.getValue().getReason()
@@ -242,7 +238,7 @@ public class TransportGetDeploymentStatsAction extends TransportTasksAction<
 
                 updatedNodeStats.sort(Comparator.comparing(n -> n.getNode().getId()));
                 updatedAllocationStats.add(
-                    new GetDeploymentStatsAction.Response.AllocationStats(
+                    new AllocationStats(
                         stat.getModelId(),
                         stat.getModelSize(),
                         stat.getInferenceThreads(),
@@ -264,11 +260,11 @@ public class TransportGetDeploymentStatsAction extends TransportTasksAction<
             if (tasksResponse.getStats().results().stream().anyMatch(e -> modelId.equals(e.getModelId())) == false) {
 
                 // no tasks for this model so build the allocation stats from the non-started states
-                List<GetDeploymentStatsAction.Response.AllocationStats.NodeStats> nodeStats = new ArrayList<>();
+                List<AllocationStats.NodeStats> nodeStats = new ArrayList<>();
 
                 for (var routingEntry : nonStartedEntries.getValue().entrySet()) {
                     nodeStats.add(
-                        GetDeploymentStatsAction.Response.AllocationStats.NodeStats.forNotStartedState(
+                        AllocationStats.NodeStats.forNotStartedState(
                             nodes.get(routingEntry.getKey()),
                             routingEntry.getValue().getState(),
                             routingEntry.getValue().getReason()
@@ -278,21 +274,11 @@ public class TransportGetDeploymentStatsAction extends TransportTasksAction<
 
                 nodeStats.sort(Comparator.comparing(n -> n.getNode().getId()));
 
-                updatedAllocationStats.add(
-                    new GetDeploymentStatsAction.Response.AllocationStats(
-                        modelId,
-                        null,
-                        null,
-                        null,
-                        null,
-                        allocation.getStartTime(),
-                        nodeStats
-                    )
-                );
+                updatedAllocationStats.add(new AllocationStats(modelId, null, null, null, null, allocation.getStartTime(), nodeStats));
             }
         }
 
-        updatedAllocationStats.sort(Comparator.comparing(GetDeploymentStatsAction.Response.AllocationStats::getModelId));
+        updatedAllocationStats.sort(Comparator.comparing(AllocationStats::getModelId));
 
         return new GetDeploymentStatsAction.Response(
             tasksResponse.getTaskFailures(),
@@ -306,15 +292,15 @@ public class TransportGetDeploymentStatsAction extends TransportTasksAction<
     protected void taskOperation(
         GetDeploymentStatsAction.Request request,
         TrainedModelDeploymentTask task,
-        ActionListener<GetDeploymentStatsAction.Response.AllocationStats> listener
+        ActionListener<AllocationStats> listener
     ) {
         Optional<ModelStats> stats = task.modelStats();
 
-        List<GetDeploymentStatsAction.Response.AllocationStats.NodeStats> nodeStats = new ArrayList<>();
+        List<AllocationStats.NodeStats> nodeStats = new ArrayList<>();
 
         if (stats.isPresent()) {
             nodeStats.add(
-                GetDeploymentStatsAction.Response.AllocationStats.NodeStats.forStartedState(
+                AllocationStats.NodeStats.forStartedState(
                     clusterService.localNode(),
                     stats.get().getTimingStats().getCount(),
                     // avoid reporting the average time as 0 if count < 1
@@ -327,17 +313,11 @@ public class TransportGetDeploymentStatsAction extends TransportTasksAction<
         } else {
             // if there are no stats the process is missing.
             // Either because it is starting or stopped
-            nodeStats.add(
-                GetDeploymentStatsAction.Response.AllocationStats.NodeStats.forNotStartedState(
-                    clusterService.localNode(),
-                    RoutingState.STOPPED,
-                    ""
-                )
-            );
+            nodeStats.add(AllocationStats.NodeStats.forNotStartedState(clusterService.localNode(), RoutingState.STOPPED, ""));
         }
 
         listener.onResponse(
-            new GetDeploymentStatsAction.Response.AllocationStats(
+            new AllocationStats(
                 task.getModelId(),
                 ByteSizeValue.ofBytes(task.getParams().getModelBytes()),
                 task.getParams().getInferenceThreads(),

+ 22 - 4
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java

@@ -27,7 +27,9 @@ import org.elasticsearch.ingest.IngestStats;
 import org.elasticsearch.ingest.Pipeline;
 import org.elasticsearch.tasks.Task;
 import org.elasticsearch.transport.TransportService;
+import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
 import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction;
+import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationStats;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
 import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata;
 import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
@@ -83,15 +85,31 @@ public class TransportGetTrainedModelsStatsAction extends HandledTransportAction
         final ModelAliasMetadata currentMetadata = ModelAliasMetadata.fromState(clusterService.state());
         GetTrainedModelsStatsAction.Response.Builder responseBuilder = new GetTrainedModelsStatsAction.Response.Builder();
 
-        ActionListener<List<InferenceStats>> inferenceStatsListener = ActionListener.wrap(
-            inferenceStats -> listener.onResponse(
-                responseBuilder.setInferenceStatsByModelId(
-                    inferenceStats.stream().collect(Collectors.toMap(InferenceStats::getModelId, Function.identity()))
+        ActionListener<GetDeploymentStatsAction.Response> getDeploymentStats = ActionListener.wrap(
+            deploymentStats -> listener.onResponse(
+                responseBuilder.setDeploymentStatsByModelId(
+                    deploymentStats.getStats()
+                        .results()
+                        .stream()
+                        .collect(Collectors.toMap(AllocationStats::getModelId, Function.identity()))
                 ).build()
             ),
             listener::onFailure
         );
 
+        ActionListener<List<InferenceStats>> inferenceStatsListener = ActionListener.wrap(inferenceStats -> {
+            responseBuilder.setInferenceStatsByModelId(
+                inferenceStats.stream().collect(Collectors.toMap(InferenceStats::getModelId, Function.identity()))
+            );
+            executeAsyncWithOrigin(
+                client,
+                ML_ORIGIN,
+                GetDeploymentStatsAction.INSTANCE,
+                new GetDeploymentStatsAction.Request(request.getResourceId()),
+                getDeploymentStats
+            );
+        }, listener::onFailure);
+
         ActionListener<NodesStatsResponse> nodesStatsListener = ActionListener.wrap(nodesStatsResponse -> {
             Set<String> allPossiblePipelineReferences = responseBuilder.getExpandedIdsWithAliases()
                 .entrySet()

+ 0 - 53
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelDeploymentStatsAction.java

@@ -1,53 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License
- * 2.0; you may not use this file except in compliance with the Elastic License
- * 2.0.
- */
-
-package org.elasticsearch.xpack.ml.rest.inference;
-
-import org.elasticsearch.client.node.NodeClient;
-import org.elasticsearch.rest.BaseRestHandler;
-import org.elasticsearch.rest.RestRequest;
-import org.elasticsearch.rest.action.RestToXContentListener;
-import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
-import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
-import org.elasticsearch.xpack.ml.MachineLearning;
-
-import java.io.IOException;
-import java.util.Collections;
-import java.util.List;
-
-import static org.elasticsearch.rest.RestRequest.Method.GET;
-
-public class RestGetTrainedModelDeploymentStatsAction extends BaseRestHandler {
-
-    @Override
-    public String getName() {
-        return "ml_get_trained_models_deployment_stats_action";
-    }
-
-    @Override
-    public List<Route> routes() {
-        return Collections.singletonList(
-            new Route(
-                GET,
-                MachineLearning.BASE_PATH
-                    + "trained_models/{"
-                    + StartTrainedModelDeploymentAction.Request.MODEL_ID.getPreferredName()
-                    + "}/deployment/_stats"
-            )
-        );
-    }
-
-    @Override
-    protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException {
-        String modelId = restRequest.param(StartTrainedModelDeploymentAction.Request.MODEL_ID.getPreferredName());
-        GetDeploymentStatsAction.Request request = new GetDeploymentStatsAction.Request(modelId);
-
-        request.setAllowNoMatch(restRequest.paramAsBoolean(GetDeploymentStatsAction.Request.ALLOW_NO_MATCH, request.isAllowNoMatch()));
-
-        return channel -> client.execute(GetDeploymentStatsAction.INSTANCE, request, new RestToXContentListener<>(channel));
-    }
-}

+ 7 - 12
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MachineLearningInfoTransportActionTests.java

@@ -59,6 +59,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.stats.common.MemoryUsage;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfigTests;
 import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationState;
+import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationStats;
 import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationStatus;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfig;
@@ -339,16 +340,10 @@ public class MachineLearningInfoTransportActionTests extends ESTestCase {
                 List.of(),
                 List.of(),
                 List.of(
-                    new GetDeploymentStatsAction.Response.AllocationStats(
-                        "model_3",
-                        ByteSizeValue.ofMb(100),
-                        null,
-                        null,
-                        null,
-                        Instant.now(),
-                        List.of()
-                    ).setState(AllocationState.STOPPING),
-                    new GetDeploymentStatsAction.Response.AllocationStats(
+                    new AllocationStats("model_3", ByteSizeValue.ofMb(100), null, null, null, Instant.now(), List.of()).setState(
+                        AllocationState.STOPPING
+                    ),
+                    new AllocationStats(
                         "model_4",
                         ByteSizeValue.ofMb(200),
                         2,
@@ -356,7 +351,7 @@ public class MachineLearningInfoTransportActionTests extends ESTestCase {
                         1000,
                         Instant.now(),
                         List.of(
-                            GetDeploymentStatsAction.Response.AllocationStats.NodeStats.forStartedState(
+                            AllocationStats.NodeStats.forStartedState(
                                 new DiscoveryNode("foo", new TransportAddress(TransportAddress.META_ADDRESS, 2), Version.CURRENT),
                                 5,
                                 42.0,
@@ -364,7 +359,7 @@ public class MachineLearningInfoTransportActionTests extends ESTestCase {
                                 Instant.now(),
                                 Instant.now()
                             ),
-                            GetDeploymentStatsAction.Response.AllocationStats.NodeStats.forStartedState(
+                            AllocationStats.NodeStats.forStartedState(
                                 new DiscoveryNode("bar", new TransportAddress(TransportAddress.META_ADDRESS, 3), Version.CURRENT),
                                 4,
                                 50.0,

+ 17 - 16
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsActionTests.java

@@ -16,6 +16,7 @@ import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
 import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsActionResponseTests;
 import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
+import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationStats;
 import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingState;
 import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingStateAndReason;
 import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocation;
@@ -60,7 +61,7 @@ public class TransportGetDeploymentStatsActionTests extends ESTestCase {
 
         DiscoveryNodes nodes = buildNodes("nodeA", "nodeB");
         var modified = TransportGetDeploymentStatsAction.addFailedRoutes(emptyResponse, badRoutes, nodes);
-        List<GetDeploymentStatsAction.Response.AllocationStats> results = modified.getStats().results();
+        List<AllocationStats> results = modified.getStats().results();
         assertThat(results, hasSize(2));
         assertEquals("model1", results.get(0).getModelId());
         assertThat(results.get(0).getNodeStats(), hasSize(2));
@@ -73,9 +74,9 @@ public class TransportGetDeploymentStatsActionTests extends ESTestCase {
     public void testAddFailedRoutes_GivenMixedResponses() throws UnknownHostException {
         DiscoveryNodes nodes = buildNodes("node1", "node2", "node3");
 
-        List<GetDeploymentStatsAction.Response.AllocationStats.NodeStats> nodeStatsList = new ArrayList<>();
+        List<AllocationStats.NodeStats> nodeStatsList = new ArrayList<>();
         nodeStatsList.add(
-            GetDeploymentStatsAction.Response.AllocationStats.NodeStats.forStartedState(
+            AllocationStats.NodeStats.forStartedState(
                 nodes.get("node1"),
                 randomNonNegativeLong(),
                 randomDoubleBetween(0.0, 100.0, true),
@@ -85,7 +86,7 @@ public class TransportGetDeploymentStatsActionTests extends ESTestCase {
             )
         );
         nodeStatsList.add(
-            GetDeploymentStatsAction.Response.AllocationStats.NodeStats.forStartedState(
+            AllocationStats.NodeStats.forStartedState(
                 nodes.get("node2"),
                 randomNonNegativeLong(),
                 randomDoubleBetween(0.0, 100.0, true),
@@ -95,7 +96,7 @@ public class TransportGetDeploymentStatsActionTests extends ESTestCase {
             )
         );
 
-        var model1 = new GetDeploymentStatsAction.Response.AllocationStats(
+        var model1 = new AllocationStats(
             "model1",
             ByteSizeValue.ofBytes(randomNonNegativeLong()),
             randomBoolean() ? null : randomIntBetween(1, 8),
@@ -113,7 +114,7 @@ public class TransportGetDeploymentStatsActionTests extends ESTestCase {
         var response = new GetDeploymentStatsAction.Response(Collections.emptyList(), Collections.emptyList(), List.of(model1), 1);
 
         var modified = TransportGetDeploymentStatsAction.addFailedRoutes(response, badRoutes, nodes);
-        List<GetDeploymentStatsAction.Response.AllocationStats> results = modified.getStats().results();
+        List<AllocationStats> results = modified.getStats().results();
         assertThat(results, hasSize(1));
         assertThat(results.get(0).getNodeStats(), hasSize(3));
         assertEquals("node1", results.get(0).getNodeStats().get(0).getNode().getId());
@@ -127,9 +128,9 @@ public class TransportGetDeploymentStatsActionTests extends ESTestCase {
     public void testAddFailedRoutes_TaskResultIsOverwritten() throws UnknownHostException {
         DiscoveryNodes nodes = buildNodes("node1", "node2");
 
-        List<GetDeploymentStatsAction.Response.AllocationStats.NodeStats> nodeStatsList = new ArrayList<>();
+        List<AllocationStats.NodeStats> nodeStatsList = new ArrayList<>();
         nodeStatsList.add(
-            GetDeploymentStatsAction.Response.AllocationStats.NodeStats.forStartedState(
+            AllocationStats.NodeStats.forStartedState(
                 nodes.get("node1"),
                 randomNonNegativeLong(),
                 randomDoubleBetween(0.0, 100.0, true),
@@ -139,7 +140,7 @@ public class TransportGetDeploymentStatsActionTests extends ESTestCase {
             )
         );
         nodeStatsList.add(
-            GetDeploymentStatsAction.Response.AllocationStats.NodeStats.forStartedState(
+            AllocationStats.NodeStats.forStartedState(
                 nodes.get("node2"),
                 randomNonNegativeLong(),
                 randomDoubleBetween(0.0, 100.0, true),
@@ -149,7 +150,7 @@ public class TransportGetDeploymentStatsActionTests extends ESTestCase {
             )
         );
 
-        var model1 = new GetDeploymentStatsAction.Response.AllocationStats(
+        var model1 = new AllocationStats(
             "model1",
             ByteSizeValue.ofBytes(randomNonNegativeLong()),
             randomBoolean() ? null : randomIntBetween(1, 8),
@@ -167,7 +168,7 @@ public class TransportGetDeploymentStatsActionTests extends ESTestCase {
         badRoutes.put(createAllocation("model1"), nodeRoutes);
 
         var modified = TransportGetDeploymentStatsAction.addFailedRoutes(response, badRoutes, nodes);
-        List<GetDeploymentStatsAction.Response.AllocationStats> results = modified.getStats().results();
+        List<AllocationStats> results = modified.getStats().results();
         assertThat(results, hasSize(1));
         assertThat(results.get(0).getNodeStats(), hasSize(2));
         assertEquals("node1", results.get(0).getNodeStats().get(0).getNode().getId());
@@ -187,14 +188,14 @@ public class TransportGetDeploymentStatsActionTests extends ESTestCase {
         return builder.build();
     }
 
-    private GetDeploymentStatsAction.Response.AllocationStats randomDeploymentStats() {
-        List<GetDeploymentStatsAction.Response.AllocationStats.NodeStats> nodeStatsList = new ArrayList<>();
+    private AllocationStats randomDeploymentStats() {
+        List<AllocationStats.NodeStats> nodeStatsList = new ArrayList<>();
         int numNodes = randomIntBetween(1, 4);
         for (int i = 0; i < numNodes; i++) {
             var node = new DiscoveryNode("node_" + i, new TransportAddress(InetAddress.getLoopbackAddress(), 9300), Version.CURRENT);
             if (randomBoolean()) {
                 nodeStatsList.add(
-                    GetDeploymentStatsAction.Response.AllocationStats.NodeStats.forStartedState(
+                    AllocationStats.NodeStats.forStartedState(
                         node,
                         randomNonNegativeLong(),
                         randomDoubleBetween(0.0, 100.0, true),
@@ -205,7 +206,7 @@ public class TransportGetDeploymentStatsActionTests extends ESTestCase {
                 );
             } else {
                 nodeStatsList.add(
-                    GetDeploymentStatsAction.Response.AllocationStats.NodeStats.forNotStartedState(
+                    AllocationStats.NodeStats.forNotStartedState(
                         node,
                         randomFrom(RoutingState.values()),
                         randomBoolean() ? null : "a good reason"
@@ -216,7 +217,7 @@ public class TransportGetDeploymentStatsActionTests extends ESTestCase {
 
         nodeStatsList.sort(Comparator.comparing(n -> n.getNode().getId()));
 
-        return new GetDeploymentStatsAction.Response.AllocationStats(
+        return new AllocationStats(
             randomAlphaOfLength(5),
             ByteSizeValue.ofBytes(randomNonNegativeLong()),
             randomBoolean() ? null : randomIntBetween(1, 8),

+ 1 - 1
x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java

@@ -304,7 +304,7 @@ public class Constants {
         "cluster:monitor/xpack/ml/job/results/records/get",
         "cluster:monitor/xpack/ml/job/stats/get",
         "cluster:monitor/xpack/ml/trained_models/deployment/infer",
-        "cluster:monitor/xpack/ml/trained_models/deployments/stats/get",
+        "cluster:internal/xpack/ml/trained_models/deployments/stats/get",
         "cluster:monitor/xpack/repositories_metering/clear_metering_archive",
         "cluster:monitor/xpack/repositories_metering/get_metrics",
         "cluster:monitor/xpack/rollup/get",