Răsfoiți Sursa

[ML] Report start_time for trained model deployments and allocations (#80188)

Adds `start_time` to the get deployment stats API for the deployment
and each allocation.

Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
Dimitris Athanasiou 4 ani în urmă
părinte
comite
d13baade69

+ 8 - 0
docs/reference/ml/df-analytics/apis/get-trained-model-deployment-stats.asciidoc

@@ -73,6 +73,10 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]
 (<<byte-units,byte value>>)
 The size of the loaded model in bytes.
 
+`start_time`:::
+(long)
+The epoch timestamp when the deployment started.
+
 `state`:::
 (string)
 The overall state of the deployment. The values may be:
@@ -171,6 +175,10 @@ The current routing state and reason for the current routing state for this allo
 (string)
 The reason for the current state. Usually only populated when the `routing_state` is `failed`.
 
+`start_time`:::
+(long)
+The epoch timestamp when the allocation started.
+
 =====
 ====
 

+ 2 - 1
docs/reference/ml/df-analytics/apis/start-trained-model-deployment.asciidoc

@@ -96,7 +96,8 @@ The API returns the following results:
                 "reason": ""
             }
         },
-        "allocation_state": "started"
+        "allocation_state": "started",
+        "start_time": "2021-11-02T11:50:34.766591Z"
     }
 }
 ----

+ 29 - 125
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDeploymentStatsAction.java

@@ -13,7 +13,6 @@ import org.elasticsearch.action.TaskOperationFailure;
 import org.elasticsearch.action.support.tasks.BaseTasksRequest;
 import org.elasticsearch.action.support.tasks.BaseTasksResponse;
 import org.elasticsearch.cluster.node.DiscoveryNode;
-import org.elasticsearch.cluster.node.DiscoveryNodes;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.io.stream.Writeable;
@@ -32,14 +31,9 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 
 import java.io.IOException;
 import java.time.Instant;
-import java.util.ArrayList;
 import java.util.Collections;
-import java.util.Comparator;
-import java.util.HashSet;
 import java.util.List;
-import java.util.Map;
 import java.util.Objects;
-import java.util.Set;
 
 public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsAction.Response> {
 
@@ -129,13 +123,15 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
                 private final Instant lastAccess;
                 private final Integer pendingCount;
                 private final RoutingStateAndReason routingState;
+                private final Instant startTime;
 
                 public static NodeStats forStartedState(
                     DiscoveryNode node,
                     long inferenceCount,
                     Double avgInferenceTime,
                     int pendingCount,
-                    Instant lastAccess
+                    Instant lastAccess,
+                    Instant startTime
                 ) {
                     return new NodeStats(
                         node,
@@ -143,12 +139,13 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
                         avgInferenceTime,
                         lastAccess,
                         pendingCount,
-                        new RoutingStateAndReason(RoutingState.STARTED, null)
+                        new RoutingStateAndReason(RoutingState.STARTED, null),
+                        Objects.requireNonNull(startTime)
                     );
                 }
 
                 public static NodeStats forNotStartedState(DiscoveryNode node, RoutingState state, String reason) {
-                    return new NodeStats(node, null, null, null, null, new RoutingStateAndReason(state, reason));
+                    return new NodeStats(node, null, null, null, null, new RoutingStateAndReason(state, reason), null);
                 }
 
                 private NodeStats(
@@ -157,7 +154,8 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
                     Double avgInferenceTime,
                     Instant lastAccess,
                     Integer pendingCount,
-                    RoutingStateAndReason routingState
+                    RoutingStateAndReason routingState,
+                    @Nullable Instant startTime
                 ) {
                     this.node = node;
                     this.inferenceCount = inferenceCount;
@@ -165,6 +163,7 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
                     this.lastAccess = lastAccess;
                     this.pendingCount = pendingCount;
                     this.routingState = routingState;
+                    this.startTime = startTime;
 
                     // if lastAccess time is null there have been no inferences
                     assert this.lastAccess != null || (inferenceCount == null || inferenceCount == 0);
@@ -177,6 +176,7 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
                     this.lastAccess = in.readOptionalInstant();
                     this.pendingCount = in.readOptionalVInt();
                     this.routingState = in.readOptionalWriteable(RoutingStateAndReason::new);
+                    this.startTime = in.readOptionalInstant();
                 }
 
                 public DiscoveryNode getNode() {
@@ -208,6 +208,9 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
                     if (pendingCount != null) {
                         builder.field("number_of_pending_requests", pendingCount);
                     }
+                    if (startTime != null) {
+                        builder.timeField("start_time", "start_time_string", startTime.toEpochMilli());
+                    }
                     builder.endObject();
                     return builder;
                 }
@@ -220,6 +223,7 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
                     out.writeOptionalInstant(lastAccess);
                     out.writeOptionalVInt(pendingCount);
                     out.writeOptionalWriteable(routingState);
+                    out.writeOptionalInstant(startTime);
                 }
 
                 @Override
@@ -232,12 +236,13 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
                         && Objects.equals(node, that.node)
                         && Objects.equals(lastAccess, that.lastAccess)
                         && Objects.equals(pendingCount, that.pendingCount)
-                        && Objects.equals(routingState, that.routingState);
+                        && Objects.equals(routingState, that.routingState)
+                        && Objects.equals(startTime, that.startTime);
                 }
 
                 @Override
                 public int hashCode() {
-                    return Objects.hash(node, inferenceCount, avgInferenceTime, lastAccess, pendingCount, routingState);
+                    return Objects.hash(node, inferenceCount, avgInferenceTime, lastAccess, pendingCount, routingState, startTime);
                 }
             }
 
@@ -253,6 +258,7 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
             private final Integer modelThreads;
             @Nullable
             private final Integer queueCapacity;
+            private final Instant startTime;
             private final List<NodeStats> nodeStats;
 
             public AllocationStats(
@@ -261,6 +267,7 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
                 @Nullable Integer inferenceThreads,
                 @Nullable Integer modelThreads,
                 @Nullable Integer queueCapacity,
+                Instant startTime,
                 List<NodeStats> nodeStats
             ) {
                 this.modelId = modelId;
@@ -268,6 +275,7 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
                 this.inferenceThreads = inferenceThreads;
                 this.modelThreads = modelThreads;
                 this.queueCapacity = queueCapacity;
+                this.startTime = Objects.requireNonNull(startTime);
                 this.nodeStats = nodeStats;
                 this.state = null;
                 this.reason = null;
@@ -279,6 +287,7 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
                 inferenceThreads = in.readOptionalVInt();
                 modelThreads = in.readOptionalVInt();
                 queueCapacity = in.readOptionalVInt();
+                startTime = in.readInstant();
                 nodeStats = in.readList(NodeStats::new);
                 state = in.readOptionalEnum(AllocationState.class);
                 reason = in.readOptionalString();
@@ -308,6 +317,10 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
                 return queueCapacity;
             }
 
+            public Instant getStartTime() {
+                return startTime;
+            }
+
             public List<NodeStats> getNodeStats() {
                 return nodeStats;
             }
@@ -360,6 +373,7 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
                 if (allocationStatus != null) {
                     builder.field("allocation_status", allocationStatus);
                 }
+                builder.timeField("start_time", "start_time_string", startTime.toEpochMilli());
                 builder.startArray("nodes");
                 for (NodeStats nodeStat : nodeStats) {
                     nodeStat.toXContent(builder, params);
@@ -376,6 +390,7 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
                 out.writeOptionalVInt(inferenceThreads);
                 out.writeOptionalVInt(modelThreads);
                 out.writeOptionalVInt(queueCapacity);
+                out.writeInstant(startTime);
                 out.writeList(nodeStats);
                 out.writeOptionalEnum(state);
                 out.writeOptionalString(reason);
@@ -392,6 +407,7 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
                     && Objects.equals(inferenceThreads, that.inferenceThreads)
                     && Objects.equals(modelThreads, that.modelThreads)
                     && Objects.equals(queueCapacity, that.queueCapacity)
+                    && Objects.equals(startTime, that.startTime)
                     && Objects.equals(state, that.state)
                     && Objects.equals(reason, that.reason)
                     && Objects.equals(allocationStatus, that.allocationStatus)
@@ -406,6 +422,7 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
                     inferenceThreads,
                     modelThreads,
                     queueCapacity,
+                    startTime,
                     nodeStats,
                     state,
                     reason,
@@ -464,118 +481,5 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
             return Objects.hash(super.hashCode(), stats);
         }
 
-        /**
-         * Update the collected task responses with the non-started
-         * allocation information. The result is the task responses
-         * merged with the non-started model allocations.
-         *
-         * Where there is a merge collision for the pair {@code <model_id, node_id>}
-         * the non-started allocations are used.
-         *
-         * @param tasksResponse All the responses from the tasks
-         * @param nonStartedModelRoutes Non-started routes
-         * @param nodes current cluster nodes
-         * @return The result of merging tasksResponse and the non-started routes
-         */
-        public static GetDeploymentStatsAction.Response addFailedRoutes(
-            GetDeploymentStatsAction.Response tasksResponse,
-            Map<String, Map<String, RoutingStateAndReason>> nonStartedModelRoutes,
-            DiscoveryNodes nodes
-        ) {
-
-            List<GetDeploymentStatsAction.Response.AllocationStats> updatedAllocationStats = new ArrayList<>();
-
-            for (GetDeploymentStatsAction.Response.AllocationStats stat : tasksResponse.getStats().results()) {
-                if (nonStartedModelRoutes.containsKey(stat.getModelId())) {
-                    // there is merging to be done
-                    Map<String, RoutingStateAndReason> nodeToRoutingStates = nonStartedModelRoutes.get(stat.getModelId());
-                    List<GetDeploymentStatsAction.Response.AllocationStats.NodeStats> updatedNodeStats = new ArrayList<>();
-
-                    Set<String> visitedNodes = new HashSet<>();
-                    for (var nodeStat : stat.getNodeStats()) {
-                        if (nodeToRoutingStates.containsKey(nodeStat.getNode().getId())) {
-                            // conflict as there is both a task response for the model/node pair
-                            // and we have a non-started routing entry.
-                            // Prefer the entry from nonStartedModelRoutes as we cannot be sure
-                            // of the state of the task - it may be starting, started, stopping, or stopped.
-                            RoutingStateAndReason stateAndReason = nodeToRoutingStates.get(nodeStat.getNode().getId());
-                            updatedNodeStats.add(
-                                GetDeploymentStatsAction.Response.AllocationStats.NodeStats.forNotStartedState(
-                                    nodeStat.getNode(),
-                                    stateAndReason.getState(),
-                                    stateAndReason.getReason()
-                                )
-                            );
-                        } else {
-                            updatedNodeStats.add(nodeStat);
-                        }
-
-                        visitedNodes.add(nodeStat.node.getId());
-                    }
-
-                    // add nodes from the failures that were not in the task responses
-                    for (var nodeRoutingState : nodeToRoutingStates.entrySet()) {
-                        if (visitedNodes.contains(nodeRoutingState.getKey()) == false) {
-                            updatedNodeStats.add(
-                                GetDeploymentStatsAction.Response.AllocationStats.NodeStats.forNotStartedState(
-                                    nodes.get(nodeRoutingState.getKey()),
-                                    nodeRoutingState.getValue().getState(),
-                                    nodeRoutingState.getValue().getReason()
-                                )
-                            );
-                        }
-                    }
-
-                    updatedNodeStats.sort(Comparator.comparing(n -> n.getNode().getId()));
-                    updatedAllocationStats.add(
-                        new GetDeploymentStatsAction.Response.AllocationStats(
-                            stat.getModelId(),
-                            stat.getModelSize(),
-                            stat.getInferenceThreads(),
-                            stat.getModelThreads(),
-                            stat.getQueueCapacity(),
-                            updatedNodeStats
-                        )
-                    );
-                } else {
-                    updatedAllocationStats.add(stat);
-                }
-            }
-
-            // Merge any models in the non-started that were not in the task responses
-            for (var nonStartedEntries : nonStartedModelRoutes.entrySet()) {
-                String modelId = nonStartedEntries.getKey();
-                if (tasksResponse.getStats().results().stream().anyMatch(e -> modelId.equals(e.getModelId())) == false) {
-
-                    // no tasks for this model so build the allocation stats from the non-started states
-                    List<GetDeploymentStatsAction.Response.AllocationStats.NodeStats> nodeStats = new ArrayList<>();
-
-                    for (var routingEntry : nonStartedEntries.getValue().entrySet()) {
-                        nodeStats.add(
-                            AllocationStats.NodeStats.forNotStartedState(
-                                nodes.get(routingEntry.getKey()),
-                                routingEntry.getValue().getState(),
-                                routingEntry.getValue().getReason()
-                            )
-                        );
-                    }
-
-                    nodeStats.sort(Comparator.comparing(n -> n.getNode().getId()));
-
-                    updatedAllocationStats.add(
-                        new GetDeploymentStatsAction.Response.AllocationStats(modelId, null, null, null, null, nodeStats)
-                    );
-                }
-            }
-
-            updatedAllocationStats.sort(Comparator.comparing(GetDeploymentStatsAction.Response.AllocationStats::getModelId));
-
-            return new GetDeploymentStatsAction.Response(
-                tasksResponse.getTaskFailures(),
-                tasksResponse.getNodeFailures(),
-                updatedAllocationStats,
-                updatedAllocationStats.size()
-            );
-        }
     }
 }

+ 43 - 10
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/TrainedModelAllocation.java

@@ -15,14 +15,17 @@ import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.xcontent.ConstructingObjectParser;
+import org.elasticsearch.xcontent.ObjectParser;
 import org.elasticsearch.xcontent.ParseField;
 import org.elasticsearch.xcontent.ToXContentObject;
 import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xcontent.XContentParser;
+import org.elasticsearch.xpack.core.common.time.TimeUtils;
 import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 
 import java.io.IOException;
+import java.time.Instant;
 import java.util.Collections;
 import java.util.LinkedHashMap;
 import java.util.List;
@@ -43,6 +46,7 @@ public class TrainedModelAllocation extends AbstractDiffable<TrainedModelAllocat
     private static final ParseField ALLOCATION_STATE = new ParseField("allocation_state");
     private static final ParseField ROUTING_TABLE = new ParseField("routing_table");
     private static final ParseField TASK_PARAMETERS = new ParseField("task_parameters");
+    private static final ParseField START_TIME = new ParseField("start_time");
 
     @SuppressWarnings("unchecked")
     private static final ConstructingObjectParser<TrainedModelAllocation, Void> PARSER = new ConstructingObjectParser<>(
@@ -52,7 +56,8 @@ public class TrainedModelAllocation extends AbstractDiffable<TrainedModelAllocat
             (StartTrainedModelDeploymentAction.TaskParams) a[0],
             (Map<String, RoutingStateAndReason>) a[1],
             AllocationState.fromString((String) a[2]),
-            (String) a[3]
+            (String) a[3],
+            (Instant) a[4]
         )
     );
     static {
@@ -68,12 +73,19 @@ public class TrainedModelAllocation extends AbstractDiffable<TrainedModelAllocat
         );
         PARSER.declareString(ConstructingObjectParser.constructorArg(), ALLOCATION_STATE);
         PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), REASON);
+        PARSER.declareField(
+            ConstructingObjectParser.constructorArg(),
+            p -> TimeUtils.parseTimeFieldToInstant(p, START_TIME.getPreferredName()),
+            START_TIME,
+            ObjectParser.ValueType.VALUE
+        );
     }
 
     private final StartTrainedModelDeploymentAction.TaskParams taskParams;
     private final Map<String, RoutingStateAndReason> nodeRoutingTable;
     private final AllocationState allocationState;
     private final String reason;
+    private final Instant startTime;
 
     public static TrainedModelAllocation fromXContent(XContentParser parser) throws IOException {
         return PARSER.apply(parser, null);
@@ -83,12 +95,14 @@ public class TrainedModelAllocation extends AbstractDiffable<TrainedModelAllocat
         StartTrainedModelDeploymentAction.TaskParams taskParams,
         Map<String, RoutingStateAndReason> nodeRoutingTable,
         AllocationState allocationState,
-        String reason
+        String reason,
+        Instant startTime
     ) {
         this.taskParams = ExceptionsHelper.requireNonNull(taskParams, TASK_PARAMETERS);
         this.nodeRoutingTable = ExceptionsHelper.requireNonNull(nodeRoutingTable, ROUTING_TABLE);
         this.allocationState = ExceptionsHelper.requireNonNull(allocationState, ALLOCATION_STATE);
         this.reason = reason;
+        this.startTime = ExceptionsHelper.requireNonNull(startTime, START_TIME);
     }
 
     public TrainedModelAllocation(StreamInput in) throws IOException {
@@ -96,6 +110,7 @@ public class TrainedModelAllocation extends AbstractDiffable<TrainedModelAllocat
         this.nodeRoutingTable = in.readOrderedMap(StreamInput::readString, RoutingStateAndReason::new);
         this.allocationState = in.readEnum(AllocationState.class);
         this.reason = in.readOptionalString();
+        this.startTime = in.readInstant();
     }
 
     public boolean isRoutedToNode(String nodeId) {
@@ -106,6 +121,10 @@ public class TrainedModelAllocation extends AbstractDiffable<TrainedModelAllocat
         return Collections.unmodifiableMap(nodeRoutingTable);
     }
 
+    public String getModelId() {
+        return taskParams.getModelId();
+    }
+
     public StartTrainedModelDeploymentAction.TaskParams getTaskParams() {
         return taskParams;
     }
@@ -126,6 +145,10 @@ public class TrainedModelAllocation extends AbstractDiffable<TrainedModelAllocat
         return Optional.ofNullable(reason);
     }
 
+    public Instant getStartTime() {
+        return startTime;
+    }
+
     @Override
     public boolean equals(Object o) {
         if (this == o) return true;
@@ -134,12 +157,13 @@ public class TrainedModelAllocation extends AbstractDiffable<TrainedModelAllocat
         return Objects.equals(nodeRoutingTable, that.nodeRoutingTable)
             && Objects.equals(taskParams, that.taskParams)
             && Objects.equals(reason, that.reason)
-            && Objects.equals(allocationState, that.allocationState);
+            && Objects.equals(allocationState, that.allocationState)
+            && Objects.equals(startTime, that.startTime);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(nodeRoutingTable, taskParams, allocationState, reason);
+        return Objects.hash(nodeRoutingTable, taskParams, allocationState, reason, startTime);
     }
 
     @Override
@@ -151,6 +175,7 @@ public class TrainedModelAllocation extends AbstractDiffable<TrainedModelAllocat
         if (reason != null) {
             builder.field(REASON.getPreferredName(), reason);
         }
+        builder.timeField(START_TIME.getPreferredName(), startTime);
         builder.endObject();
         return builder;
     }
@@ -161,6 +186,7 @@ public class TrainedModelAllocation extends AbstractDiffable<TrainedModelAllocat
         out.writeMap(nodeRoutingTable, StreamOutput::writeString, (o, w) -> w.writeTo(o));
         out.writeEnum(allocationState);
         out.writeOptionalString(reason);
+        out.writeInstant(startTime);
     }
 
     public Optional<AllocationStatus> calculateAllocationStatus(List<DiscoveryNode> allocatableNodes) {
@@ -189,9 +215,16 @@ public class TrainedModelAllocation extends AbstractDiffable<TrainedModelAllocat
         private AllocationState allocationState;
         private boolean isChanged;
         private String reason;
+        private Instant startTime;
 
         public static Builder fromAllocation(TrainedModelAllocation allocation) {
-            return new Builder(allocation.taskParams, allocation.nodeRoutingTable, allocation.allocationState, allocation.reason);
+            return new Builder(
+                allocation.taskParams,
+                allocation.nodeRoutingTable,
+                allocation.allocationState,
+                allocation.reason,
+                allocation.startTime
+            );
         }
 
         public static Builder empty(StartTrainedModelDeploymentAction.TaskParams taskParams) {
@@ -202,18 +235,18 @@ public class TrainedModelAllocation extends AbstractDiffable<TrainedModelAllocat
             StartTrainedModelDeploymentAction.TaskParams taskParams,
             Map<String, RoutingStateAndReason> nodeRoutingTable,
             AllocationState allocationState,
-            String reason
+            String reason,
+            Instant startTime
         ) {
             this.taskParams = taskParams;
             this.nodeRoutingTable = new LinkedHashMap<>(nodeRoutingTable);
             this.allocationState = allocationState;
             this.reason = reason;
+            this.startTime = startTime;
         }
 
         private Builder(StartTrainedModelDeploymentAction.TaskParams taskParams) {
-            this.nodeRoutingTable = new LinkedHashMap<>();
-            this.taskParams = taskParams;
-            this.allocationState = AllocationState.STARTING;
+            this(taskParams, new LinkedHashMap<>(), AllocationState.STARTING, null, Instant.now());
         }
 
         public Builder addNewRoutingEntry(String nodeId) {
@@ -331,7 +364,7 @@ public class TrainedModelAllocation extends AbstractDiffable<TrainedModelAllocat
         }
 
         public TrainedModelAllocation build() {
-            return new TrainedModelAllocation(taskParams, nodeRoutingTable, allocationState, reason);
+            return new TrainedModelAllocation(taskParams, nodeRoutingTable, allocationState, reason, startTime);
         }
     }
 

+ 7 - 154
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDeploymentStatsActionResponseTests.java

@@ -9,25 +9,18 @@ package org.elasticsearch.xpack.core.ml.action;
 
 import org.elasticsearch.Version;
 import org.elasticsearch.cluster.node.DiscoveryNode;
-import org.elasticsearch.cluster.node.DiscoveryNodes;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.transport.TransportAddress;
 import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.test.AbstractWireSerializingTestCase;
 import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingState;
-import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingStateAndReason;
 
 import java.net.InetAddress;
-import java.net.UnknownHostException;
 import java.time.Instant;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.Comparator;
-import java.util.HashMap;
 import java.util.List;
-import java.util.Map;
-
-import static org.hamcrest.Matchers.hasSize;
 
 public class GetDeploymentStatsActionResponseTests extends AbstractWireSerializingTestCase<GetDeploymentStatsAction.Response> {
     @Override
@@ -37,6 +30,10 @@ public class GetDeploymentStatsActionResponseTests extends AbstractWireSerializi
 
     @Override
     protected GetDeploymentStatsAction.Response createTestInstance() {
+        return createRandom();
+    }
+
+    public static GetDeploymentStatsAction.Response createRandom() {
         int numStats = randomIntBetween(0, 2);
         var stats = new ArrayList<GetDeploymentStatsAction.Response.AllocationStats>(numStats);
         for (var i = 0; i < numStats; i++) {
@@ -46,153 +43,7 @@ public class GetDeploymentStatsActionResponseTests extends AbstractWireSerializi
         return new GetDeploymentStatsAction.Response(Collections.emptyList(), Collections.emptyList(), stats, stats.size());
     }
 
-    public void testAddFailedRoutes_GivenNoFailures() throws UnknownHostException {
-        var response = createTestInstance();
-        var modifed = GetDeploymentStatsAction.Response.addFailedRoutes(response, Collections.emptyMap(), buildNodes("node_foo"));
-        assertEquals(response, modifed);
-    }
-
-    public void testAddFailedRoutes_GivenNoTaskResponses() throws UnknownHostException {
-        var emptyResponse = new GetDeploymentStatsAction.Response(
-            Collections.emptyList(),
-            Collections.emptyList(),
-            Collections.emptyList(),
-            0
-        );
-
-        Map<String, Map<String, RoutingStateAndReason>> badRoutes = new HashMap<>();
-        for (var modelId : new String[] { "model1", "model2" }) {
-            Map<String, RoutingStateAndReason> nodeRoutes = new HashMap<>();
-            for (var nodeId : new String[] { "nodeA", "nodeB" }) {
-                nodeRoutes.put(nodeId, new RoutingStateAndReason(RoutingState.FAILED, "failure reason"));
-            }
-            badRoutes.put(modelId, nodeRoutes);
-        }
-
-        DiscoveryNodes nodes = buildNodes("nodeA", "nodeB");
-        var modified = GetDeploymentStatsAction.Response.addFailedRoutes(emptyResponse, badRoutes, nodes);
-        List<GetDeploymentStatsAction.Response.AllocationStats> results = modified.getStats().results();
-        assertThat(results, hasSize(2));
-        assertEquals("model1", results.get(0).getModelId());
-        assertThat(results.get(0).getNodeStats(), hasSize(2));
-        assertEquals("nodeA", results.get(0).getNodeStats().get(0).getNode().getId());
-        assertEquals("nodeB", results.get(0).getNodeStats().get(1).getNode().getId());
-        assertEquals("nodeA", results.get(1).getNodeStats().get(0).getNode().getId());
-        assertEquals("nodeB", results.get(1).getNodeStats().get(1).getNode().getId());
-    }
-
-    public void testAddFailedRoutes_GivenMixedResponses() throws UnknownHostException {
-        DiscoveryNodes nodes = buildNodes("node1", "node2", "node3");
-
-        List<GetDeploymentStatsAction.Response.AllocationStats.NodeStats> nodeStatsList = new ArrayList<>();
-        nodeStatsList.add(
-            GetDeploymentStatsAction.Response.AllocationStats.NodeStats.forStartedState(
-                nodes.get("node1"),
-                randomNonNegativeLong(),
-                randomDoubleBetween(0.0, 100.0, true),
-                randomIntBetween(0, 100),
-                Instant.now()
-            )
-        );
-        nodeStatsList.add(
-            GetDeploymentStatsAction.Response.AllocationStats.NodeStats.forStartedState(
-                nodes.get("node2"),
-                randomNonNegativeLong(),
-                randomDoubleBetween(0.0, 100.0, true),
-                randomIntBetween(0, 100),
-                Instant.now()
-            )
-        );
-
-        var model1 = new GetDeploymentStatsAction.Response.AllocationStats(
-            "model1",
-            ByteSizeValue.ofBytes(randomNonNegativeLong()),
-            randomBoolean() ? null : randomIntBetween(1, 8),
-            randomBoolean() ? null : randomIntBetween(1, 8),
-            randomBoolean() ? null : randomIntBetween(1, 10000),
-            nodeStatsList
-        );
-
-        Map<String, Map<String, RoutingStateAndReason>> badRoutes = new HashMap<>();
-        Map<String, RoutingStateAndReason> nodeRoutes = new HashMap<>();
-        nodeRoutes.put("node3", new RoutingStateAndReason(RoutingState.FAILED, "failed on node3"));
-        badRoutes.put("model1", nodeRoutes);
-
-        var response = new GetDeploymentStatsAction.Response(Collections.emptyList(), Collections.emptyList(), List.of(model1), 1);
-
-        var modified = GetDeploymentStatsAction.Response.addFailedRoutes(response, badRoutes, nodes);
-        List<GetDeploymentStatsAction.Response.AllocationStats> results = modified.getStats().results();
-        assertThat(results, hasSize(1));
-        assertThat(results.get(0).getNodeStats(), hasSize(3));
-        assertEquals("node1", results.get(0).getNodeStats().get(0).getNode().getId());
-        assertEquals(RoutingState.STARTED, results.get(0).getNodeStats().get(0).getRoutingState().getState());
-        assertEquals("node2", results.get(0).getNodeStats().get(1).getNode().getId());
-        assertEquals(RoutingState.STARTED, results.get(0).getNodeStats().get(1).getRoutingState().getState());
-        assertEquals("node3", results.get(0).getNodeStats().get(2).getNode().getId());
-        assertEquals(RoutingState.FAILED, results.get(0).getNodeStats().get(2).getRoutingState().getState());
-    }
-
-    public void testAddFailedRoutes_TaskResultIsOverwritten() throws UnknownHostException {
-        DiscoveryNodes nodes = buildNodes("node1", "node2");
-
-        List<GetDeploymentStatsAction.Response.AllocationStats.NodeStats> nodeStatsList = new ArrayList<>();
-        nodeStatsList.add(
-            GetDeploymentStatsAction.Response.AllocationStats.NodeStats.forStartedState(
-                nodes.get("node1"),
-                randomNonNegativeLong(),
-                randomDoubleBetween(0.0, 100.0, true),
-                randomIntBetween(0, 100),
-                Instant.now()
-            )
-        );
-        nodeStatsList.add(
-            GetDeploymentStatsAction.Response.AllocationStats.NodeStats.forStartedState(
-                nodes.get("node2"),
-                randomNonNegativeLong(),
-                randomDoubleBetween(0.0, 100.0, true),
-                randomIntBetween(0, 100),
-                Instant.now()
-            )
-        );
-
-        var model1 = new GetDeploymentStatsAction.Response.AllocationStats(
-            "model1",
-            ByteSizeValue.ofBytes(randomNonNegativeLong()),
-            randomBoolean() ? null : randomIntBetween(1, 8),
-            randomBoolean() ? null : randomIntBetween(1, 8),
-            randomBoolean() ? null : randomIntBetween(1, 10000),
-            nodeStatsList
-        );
-        var response = new GetDeploymentStatsAction.Response(Collections.emptyList(), Collections.emptyList(), List.of(model1), 1);
-
-        // failed state for node 2 conflicts with the task response
-        Map<String, Map<String, RoutingStateAndReason>> badRoutes = new HashMap<>();
-        Map<String, RoutingStateAndReason> nodeRoutes = new HashMap<>();
-        nodeRoutes.put("node2", new RoutingStateAndReason(RoutingState.FAILED, "failed on node3"));
-        badRoutes.put("model1", nodeRoutes);
-
-        var modified = GetDeploymentStatsAction.Response.addFailedRoutes(response, badRoutes, nodes);
-        List<GetDeploymentStatsAction.Response.AllocationStats> results = modified.getStats().results();
-        assertThat(results, hasSize(1));
-        assertThat(results.get(0).getNodeStats(), hasSize(2));
-        assertEquals("node1", results.get(0).getNodeStats().get(0).getNode().getId());
-        assertEquals(RoutingState.STARTED, results.get(0).getNodeStats().get(0).getRoutingState().getState());
-        assertEquals("node2", results.get(0).getNodeStats().get(1).getNode().getId());
-        // routing state from the bad routes map is chosen to resolve teh conflict
-        assertEquals(RoutingState.FAILED, results.get(0).getNodeStats().get(1).getRoutingState().getState());
-    }
-
-    private DiscoveryNodes buildNodes(String... nodeIds) throws UnknownHostException {
-        InetAddress inetAddress = InetAddress.getByAddress(new byte[] { (byte) 192, (byte) 168, (byte) 0, (byte) 1 });
-        DiscoveryNodes.Builder builder = DiscoveryNodes.builder();
-        int port = 9200;
-        for (String nodeId : nodeIds) {
-            builder.add(new DiscoveryNode(nodeId, new TransportAddress(inetAddress, port++), Version.CURRENT));
-        }
-        return builder.build();
-    }
-
-    private GetDeploymentStatsAction.Response.AllocationStats randomDeploymentStats() {
+    private static GetDeploymentStatsAction.Response.AllocationStats randomDeploymentStats() {
         List<GetDeploymentStatsAction.Response.AllocationStats.NodeStats> nodeStatsList = new ArrayList<>();
         int numNodes = randomIntBetween(1, 4);
         for (int i = 0; i < numNodes; i++) {
@@ -204,6 +55,7 @@ public class GetDeploymentStatsActionResponseTests extends AbstractWireSerializi
                         randomNonNegativeLong(),
                         randomBoolean() ? randomDoubleBetween(0.0, 100.0, true) : null,
                         randomIntBetween(0, 100),
+                        Instant.now(),
                         Instant.now()
                     )
                 );
@@ -226,6 +78,7 @@ public class GetDeploymentStatsActionResponseTests extends AbstractWireSerializi
             randomBoolean() ? null : randomIntBetween(1, 8),
             randomBoolean() ? null : randomIntBetween(1, 8),
             randomBoolean() ? null : randomIntBetween(1, 10000),
+            Instant.now(),
             nodeStatsList
         );
     }

+ 139 - 9
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsAction.java

@@ -14,6 +14,7 @@ import org.elasticsearch.action.support.ActionFilters;
 import org.elasticsearch.action.support.tasks.TransportTasksAction;
 import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.cluster.node.DiscoveryNode;
+import org.elasticsearch.cluster.node.DiscoveryNodes;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.inject.Inject;
@@ -36,6 +37,7 @@ import org.elasticsearch.xpack.ml.inference.deployment.TrainedModelDeploymentTas
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
+import java.util.Comparator;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
@@ -100,14 +102,15 @@ public class TransportGetDeploymentStatsAction extends TransportTasksAction<
         GetDeploymentStatsAction.Request request,
         ActionListener<GetDeploymentStatsAction.Response> listener
     ) {
+        final ClusterState clusterState = clusterService.state();
+        final TrainedModelAllocationMetadata allocation = TrainedModelAllocationMetadata.fromState(clusterState);
 
         String[] tokenizedRequestIds = Strings.tokenizeToStringArray(request.getDeploymentId(), ",");
         ExpandedIdsMatcher.SimpleIdsMatcher idsMatcher = new ExpandedIdsMatcher.SimpleIdsMatcher(tokenizedRequestIds);
 
-        TrainedModelAllocationMetadata allocation = TrainedModelAllocationMetadata.fromState(clusterService.state());
         List<String> matchedDeploymentIds = new ArrayList<>();
         Set<String> taskNodes = new HashSet<>();
-        Map<String, Map<String, RoutingStateAndReason>> nonStartedAllocationsForModel = new HashMap<>();
+        Map<TrainedModelAllocation, Map<String, RoutingStateAndReason>> allocationNonStartedRoutes = new HashMap<>();
         for (var allocationEntry : allocation.modelAllocations().entrySet()) {
             String modelId = allocationEntry.getKey();
             if (idsMatcher.idMatches(modelId)) {
@@ -122,7 +125,7 @@ public class TransportGetDeploymentStatsAction extends TransportTasksAction<
                     .filter(routingEntry -> RoutingState.STARTED.equals(routingEntry.getValue().getState()) == false)
                     .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
 
-                nonStartedAllocationsForModel.put(modelId, routings);
+                allocationNonStartedRoutes.put(allocationEntry.getValue(), routings);
             }
         }
 
@@ -144,11 +147,7 @@ public class TransportGetDeploymentStatsAction extends TransportTasksAction<
         request.setExpandedIds(matchedDeploymentIds);
 
         ActionListener<GetDeploymentStatsAction.Response> addFailedListener = listener.delegateFailure((l, response) -> {
-            var updatedResponse = GetDeploymentStatsAction.Response.addFailedRoutes(
-                response,
-                nonStartedAllocationsForModel,
-                clusterService.state().nodes()
-            );
+            var updatedResponse = addFailedRoutes(response, allocationNonStartedRoutes, clusterState.nodes());
             ClusterState latestState = clusterService.state();
             Set<String> nodesShuttingDown = TransportStartTrainedModelDeploymentAction.nodesShuttingDown(latestState);
             List<DiscoveryNode> nodes = latestState.getNodes()
@@ -174,6 +173,135 @@ public class TransportGetDeploymentStatsAction extends TransportTasksAction<
         super.doExecute(task, request, addFailedListener);
     }
 
+    /**
+     * Update the collected task responses with the non-started
+     * allocation information. The result is the task responses
+     * merged with the non-started model allocations.
+     *
+     * Where there is a merge collision for the pair {@code <model_id, node_id>}
+     * the non-started allocations are used.
+     *
+     * @param tasksResponse All the responses from the tasks
+     * @param allocationNonStartedRoutes Non-started routes
+     * @param nodes current cluster nodes
+     * @return The result of merging tasksResponse and the non-started routes
+     */
+    static GetDeploymentStatsAction.Response addFailedRoutes(
+        GetDeploymentStatsAction.Response tasksResponse,
+        Map<TrainedModelAllocation, Map<String, RoutingStateAndReason>> allocationNonStartedRoutes,
+        DiscoveryNodes nodes
+    ) {
+        final Map<String, TrainedModelAllocation> modelToAllocationWithNonStartedRoutes = allocationNonStartedRoutes.keySet()
+            .stream()
+            .collect(Collectors.toMap(TrainedModelAllocation::getModelId, Function.identity()));
+
+        final List<GetDeploymentStatsAction.Response.AllocationStats> updatedAllocationStats = new ArrayList<>();
+
+        for (GetDeploymentStatsAction.Response.AllocationStats stat : tasksResponse.getStats().results()) {
+            if (modelToAllocationWithNonStartedRoutes.containsKey(stat.getModelId())) {
+                // there is merging to be done
+                Map<String, RoutingStateAndReason> nodeToRoutingStates = allocationNonStartedRoutes.get(
+                    modelToAllocationWithNonStartedRoutes.get(stat.getModelId())
+                );
+                List<GetDeploymentStatsAction.Response.AllocationStats.NodeStats> updatedNodeStats = new ArrayList<>();
+
+                Set<String> visitedNodes = new HashSet<>();
+                for (var nodeStat : stat.getNodeStats()) {
+                    if (nodeToRoutingStates.containsKey(nodeStat.getNode().getId())) {
+                        // conflict as there is both a task response for the model/node pair
+                        // and we have a non-started routing entry.
+                        // Prefer the entry from allocationNonStartedRoutes as we cannot be sure
+                        // of the state of the task - it may be starting, started, stopping, or stopped.
+                        RoutingStateAndReason stateAndReason = nodeToRoutingStates.get(nodeStat.getNode().getId());
+                        updatedNodeStats.add(
+                            GetDeploymentStatsAction.Response.AllocationStats.NodeStats.forNotStartedState(
+                                nodeStat.getNode(),
+                                stateAndReason.getState(),
+                                stateAndReason.getReason()
+                            )
+                        );
+                    } else {
+                        updatedNodeStats.add(nodeStat);
+                    }
+
+                    visitedNodes.add(nodeStat.getNode().getId());
+                }
+
+                // add nodes from the failures that were not in the task responses
+                for (var nodeRoutingState : nodeToRoutingStates.entrySet()) {
+                    if (visitedNodes.contains(nodeRoutingState.getKey()) == false) {
+                        updatedNodeStats.add(
+                            GetDeploymentStatsAction.Response.AllocationStats.NodeStats.forNotStartedState(
+                                nodes.get(nodeRoutingState.getKey()),
+                                nodeRoutingState.getValue().getState(),
+                                nodeRoutingState.getValue().getReason()
+                            )
+                        );
+                    }
+                }
+
+                updatedNodeStats.sort(Comparator.comparing(n -> n.getNode().getId()));
+                updatedAllocationStats.add(
+                    new GetDeploymentStatsAction.Response.AllocationStats(
+                        stat.getModelId(),
+                        stat.getModelSize(),
+                        stat.getInferenceThreads(),
+                        stat.getModelThreads(),
+                        stat.getQueueCapacity(),
+                        stat.getStartTime(),
+                        updatedNodeStats
+                    )
+                );
+            } else {
+                updatedAllocationStats.add(stat);
+            }
+        }
+
+        // Merge any models in the non-started that were not in the task responses
+        for (var nonStartedEntries : allocationNonStartedRoutes.entrySet()) {
+            final TrainedModelAllocation allocation = nonStartedEntries.getKey();
+            final String modelId = allocation.getTaskParams().getModelId();
+            if (tasksResponse.getStats().results().stream().anyMatch(e -> modelId.equals(e.getModelId())) == false) {
+
+                // no tasks for this model so build the allocation stats from the non-started states
+                List<GetDeploymentStatsAction.Response.AllocationStats.NodeStats> nodeStats = new ArrayList<>();
+
+                for (var routingEntry : nonStartedEntries.getValue().entrySet()) {
+                    nodeStats.add(
+                        GetDeploymentStatsAction.Response.AllocationStats.NodeStats.forNotStartedState(
+                            nodes.get(routingEntry.getKey()),
+                            routingEntry.getValue().getState(),
+                            routingEntry.getValue().getReason()
+                        )
+                    );
+                }
+
+                nodeStats.sort(Comparator.comparing(n -> n.getNode().getId()));
+
+                updatedAllocationStats.add(
+                    new GetDeploymentStatsAction.Response.AllocationStats(
+                        modelId,
+                        null,
+                        null,
+                        null,
+                        null,
+                        allocation.getStartTime(),
+                        nodeStats
+                    )
+                );
+            }
+        }
+
+        updatedAllocationStats.sort(Comparator.comparing(GetDeploymentStatsAction.Response.AllocationStats::getModelId));
+
+        return new GetDeploymentStatsAction.Response(
+            tasksResponse.getTaskFailures(),
+            tasksResponse.getNodeFailures(),
+            updatedAllocationStats,
+            updatedAllocationStats.size()
+        );
+    }
+
     @Override
     protected void taskOperation(
         GetDeploymentStatsAction.Request request,
@@ -192,7 +320,8 @@ public class TransportGetDeploymentStatsAction extends TransportTasksAction<
                     // avoid reporting the average time as 0 if count < 1
                     (stats.get().getTimingStats().getCount() > 0) ? stats.get().getTimingStats().getAverage() : null,
                     stats.get().getPendingCount(),
-                    stats.get().getLastUsed()
+                    stats.get().getLastUsed(),
+                    stats.get().getStartTime()
                 )
             );
         } else {
@@ -214,6 +343,7 @@ public class TransportGetDeploymentStatsAction extends TransportTasksAction<
                 task.getParams().getInferenceThreads(),
                 task.getParams().getModelThreads(),
                 task.getParams().getQueueCapacity(),
+                TrainedModelAllocationMetadata.fromState(clusterService.state()).getModelAllocation(task.getModelId()).getStartTime(),
                 nodeStats
             )
         );

+ 4 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java

@@ -52,6 +52,7 @@ import org.elasticsearch.xpack.ml.job.process.ProcessWorkerExecutorService;
 
 import java.io.IOException;
 import java.io.InputStream;
+import java.time.Instant;
 import java.util.Collections;
 import java.util.List;
 import java.util.Map;
@@ -102,6 +103,7 @@ public class DeploymentManager {
         return Optional.ofNullable(processContextByAllocation.get(task.getId()))
             .map(
                 processContext -> new ModelStats(
+                    processContext.startTime,
                     processContext.getResultProcessor().getTimingStats(),
                     processContext.getResultProcessor().getLastUsed(),
                     processContext.executorService.queueSize() + processContext.getResultProcessor().numberOfPendingResults()
@@ -415,6 +417,7 @@ public class DeploymentManager {
         private final PyTorchResultProcessor resultProcessor;
         private final PyTorchStateStreamer stateStreamer;
         private final ProcessWorkerExecutorService executorService;
+        private volatile Instant startTime;
 
         ProcessContext(TrainedModelDeploymentTask task, ExecutorService executorService) {
             this.task = Objects.requireNonNull(task);
@@ -433,6 +436,7 @@ public class DeploymentManager {
 
         synchronized void startProcess() {
             process.set(pyTorchProcessFactory.createProcess(task, executorServiceForProcess, onProcessCrash()));
+            startTime = Instant.now();
             executorServiceForProcess.submit(executorService::start);
         }
 

+ 7 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/ModelStats.java

@@ -12,16 +12,22 @@ import java.util.LongSummaryStatistics;
 
 public class ModelStats {
 
+    private final Instant startTime;
     private final LongSummaryStatistics timingStats;
     private final Instant lastUsed;
     private final int pendingCount;
 
-    ModelStats(LongSummaryStatistics timingStats, Instant lastUsed, int pendingCount) {
+    ModelStats(Instant startTime, LongSummaryStatistics timingStats, Instant lastUsed, int pendingCount) {
+        this.startTime = startTime;
         this.timingStats = timingStats;
         this.lastUsed = lastUsed;
         this.pendingCount = pendingCount;
     }
 
+    public Instant getStartTime() {
+        return startTime;
+    }
+
     public LongSummaryStatistics getTimingStats() {
         return timingStats;
     }

+ 233 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsActionTests.java

@@ -0,0 +1,233 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.action;
+
+import org.elasticsearch.Version;
+import org.elasticsearch.cluster.node.DiscoveryNode;
+import org.elasticsearch.cluster.node.DiscoveryNodes;
+import org.elasticsearch.common.transport.TransportAddress;
+import org.elasticsearch.common.unit.ByteSizeValue;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
+import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsActionResponseTests;
+import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
+import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingState;
+import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingStateAndReason;
+import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocation;
+
+import java.net.InetAddress;
+import java.net.UnknownHostException;
+import java.time.Instant;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.hamcrest.Matchers.hasSize;
+
+public class TransportGetDeploymentStatsActionTests extends ESTestCase {
+
+    public void testAddFailedRoutes_GivenNoFailures() throws UnknownHostException {
+        var response = GetDeploymentStatsActionResponseTests.createRandom();
+        var modified = TransportGetDeploymentStatsAction.addFailedRoutes(response, Collections.emptyMap(), buildNodes("node_foo"));
+        assertEquals(response, modified);
+    }
+
+    public void testAddFailedRoutes_GivenNoTaskResponses() throws UnknownHostException {
+        var emptyResponse = new GetDeploymentStatsAction.Response(
+            Collections.emptyList(),
+            Collections.emptyList(),
+            Collections.emptyList(),
+            0
+        );
+
+        Map<TrainedModelAllocation, Map<String, RoutingStateAndReason>> badRoutes = new HashMap<>();
+        for (var modelId : new String[] { "model1", "model2" }) {
+            TrainedModelAllocation allocation = createAllocation(modelId);
+            Map<String, RoutingStateAndReason> nodeRoutes = new HashMap<>();
+            for (var nodeId : new String[] { "nodeA", "nodeB" }) {
+                nodeRoutes.put(nodeId, new RoutingStateAndReason(RoutingState.FAILED, "failure reason"));
+            }
+            badRoutes.put(allocation, nodeRoutes);
+        }
+
+        DiscoveryNodes nodes = buildNodes("nodeA", "nodeB");
+        var modified = TransportGetDeploymentStatsAction.addFailedRoutes(emptyResponse, badRoutes, nodes);
+        List<GetDeploymentStatsAction.Response.AllocationStats> results = modified.getStats().results();
+        assertThat(results, hasSize(2));
+        assertEquals("model1", results.get(0).getModelId());
+        assertThat(results.get(0).getNodeStats(), hasSize(2));
+        assertEquals("nodeA", results.get(0).getNodeStats().get(0).getNode().getId());
+        assertEquals("nodeB", results.get(0).getNodeStats().get(1).getNode().getId());
+        assertEquals("nodeA", results.get(1).getNodeStats().get(0).getNode().getId());
+        assertEquals("nodeB", results.get(1).getNodeStats().get(1).getNode().getId());
+    }
+
+    public void testAddFailedRoutes_GivenMixedResponses() throws UnknownHostException {
+        DiscoveryNodes nodes = buildNodes("node1", "node2", "node3");
+
+        List<GetDeploymentStatsAction.Response.AllocationStats.NodeStats> nodeStatsList = new ArrayList<>();
+        nodeStatsList.add(
+            GetDeploymentStatsAction.Response.AllocationStats.NodeStats.forStartedState(
+                nodes.get("node1"),
+                randomNonNegativeLong(),
+                randomDoubleBetween(0.0, 100.0, true),
+                randomIntBetween(1, 100),
+                Instant.now(),
+                Instant.now()
+            )
+        );
+        nodeStatsList.add(
+            GetDeploymentStatsAction.Response.AllocationStats.NodeStats.forStartedState(
+                nodes.get("node2"),
+                randomNonNegativeLong(),
+                randomDoubleBetween(0.0, 100.0, true),
+                randomIntBetween(1, 100),
+                Instant.now(),
+                Instant.now()
+            )
+        );
+
+        var model1 = new GetDeploymentStatsAction.Response.AllocationStats(
+            "model1",
+            ByteSizeValue.ofBytes(randomNonNegativeLong()),
+            randomBoolean() ? null : randomIntBetween(1, 8),
+            randomBoolean() ? null : randomIntBetween(1, 8),
+            randomBoolean() ? null : randomIntBetween(1, 10000),
+            Instant.now(),
+            nodeStatsList
+        );
+
+        Map<TrainedModelAllocation, Map<String, RoutingStateAndReason>> badRoutes = new HashMap<>();
+        Map<String, RoutingStateAndReason> nodeRoutes = new HashMap<>();
+        nodeRoutes.put("node3", new RoutingStateAndReason(RoutingState.FAILED, "failed on node3"));
+        badRoutes.put(createAllocation("model1"), nodeRoutes);
+
+        var response = new GetDeploymentStatsAction.Response(Collections.emptyList(), Collections.emptyList(), List.of(model1), 1);
+
+        var modified = TransportGetDeploymentStatsAction.addFailedRoutes(response, badRoutes, nodes);
+        List<GetDeploymentStatsAction.Response.AllocationStats> results = modified.getStats().results();
+        assertThat(results, hasSize(1));
+        assertThat(results.get(0).getNodeStats(), hasSize(3));
+        assertEquals("node1", results.get(0).getNodeStats().get(0).getNode().getId());
+        assertEquals(RoutingState.STARTED, results.get(0).getNodeStats().get(0).getRoutingState().getState());
+        assertEquals("node2", results.get(0).getNodeStats().get(1).getNode().getId());
+        assertEquals(RoutingState.STARTED, results.get(0).getNodeStats().get(1).getRoutingState().getState());
+        assertEquals("node3", results.get(0).getNodeStats().get(2).getNode().getId());
+        assertEquals(RoutingState.FAILED, results.get(0).getNodeStats().get(2).getRoutingState().getState());
+    }
+
+    public void testAddFailedRoutes_TaskResultIsOverwritten() throws UnknownHostException {
+        DiscoveryNodes nodes = buildNodes("node1", "node2");
+
+        List<GetDeploymentStatsAction.Response.AllocationStats.NodeStats> nodeStatsList = new ArrayList<>();
+        nodeStatsList.add(
+            GetDeploymentStatsAction.Response.AllocationStats.NodeStats.forStartedState(
+                nodes.get("node1"),
+                randomNonNegativeLong(),
+                randomDoubleBetween(0.0, 100.0, true),
+                randomIntBetween(1, 100),
+                Instant.now(),
+                Instant.now()
+            )
+        );
+        nodeStatsList.add(
+            GetDeploymentStatsAction.Response.AllocationStats.NodeStats.forStartedState(
+                nodes.get("node2"),
+                randomNonNegativeLong(),
+                randomDoubleBetween(0.0, 100.0, true),
+                randomIntBetween(1, 100),
+                Instant.now(),
+                Instant.now()
+            )
+        );
+
+        var model1 = new GetDeploymentStatsAction.Response.AllocationStats(
+            "model1",
+            ByteSizeValue.ofBytes(randomNonNegativeLong()),
+            randomBoolean() ? null : randomIntBetween(1, 8),
+            randomBoolean() ? null : randomIntBetween(1, 8),
+            randomBoolean() ? null : randomIntBetween(1, 10000),
+            Instant.now(),
+            nodeStatsList
+        );
+        var response = new GetDeploymentStatsAction.Response(Collections.emptyList(), Collections.emptyList(), List.of(model1), 1);
+
+        // failed state for node 2 conflicts with the task response
+        Map<TrainedModelAllocation, Map<String, RoutingStateAndReason>> badRoutes = new HashMap<>();
+        Map<String, RoutingStateAndReason> nodeRoutes = new HashMap<>();
+        nodeRoutes.put("node2", new RoutingStateAndReason(RoutingState.FAILED, "failed on node3"));
+        badRoutes.put(createAllocation("model1"), nodeRoutes);
+
+        var modified = TransportGetDeploymentStatsAction.addFailedRoutes(response, badRoutes, nodes);
+        List<GetDeploymentStatsAction.Response.AllocationStats> results = modified.getStats().results();
+        assertThat(results, hasSize(1));
+        assertThat(results.get(0).getNodeStats(), hasSize(2));
+        assertEquals("node1", results.get(0).getNodeStats().get(0).getNode().getId());
+        assertEquals(RoutingState.STARTED, results.get(0).getNodeStats().get(0).getRoutingState().getState());
+        assertEquals("node2", results.get(0).getNodeStats().get(1).getNode().getId());
+        // routing state from the bad routes map is chosen to resolve teh conflict
+        assertEquals(RoutingState.FAILED, results.get(0).getNodeStats().get(1).getRoutingState().getState());
+    }
+
+    private DiscoveryNodes buildNodes(String... nodeIds) throws UnknownHostException {
+        InetAddress inetAddress = InetAddress.getByAddress(new byte[] { (byte) 192, (byte) 168, (byte) 0, (byte) 1 });
+        DiscoveryNodes.Builder builder = DiscoveryNodes.builder();
+        int port = 9200;
+        for (String nodeId : nodeIds) {
+            builder.add(new DiscoveryNode(nodeId, new TransportAddress(inetAddress, port++), Version.CURRENT));
+        }
+        return builder.build();
+    }
+
+    private GetDeploymentStatsAction.Response.AllocationStats randomDeploymentStats() {
+        List<GetDeploymentStatsAction.Response.AllocationStats.NodeStats> nodeStatsList = new ArrayList<>();
+        int numNodes = randomIntBetween(1, 4);
+        for (int i = 0; i < numNodes; i++) {
+            var node = new DiscoveryNode("node_" + i, new TransportAddress(InetAddress.getLoopbackAddress(), 9300), Version.CURRENT);
+            if (randomBoolean()) {
+                nodeStatsList.add(
+                    GetDeploymentStatsAction.Response.AllocationStats.NodeStats.forStartedState(
+                        node,
+                        randomNonNegativeLong(),
+                        randomDoubleBetween(0.0, 100.0, true),
+                        randomIntBetween(1, 100),
+                        Instant.now(),
+                        Instant.now()
+                    )
+                );
+            } else {
+                nodeStatsList.add(
+                    GetDeploymentStatsAction.Response.AllocationStats.NodeStats.forNotStartedState(
+                        node,
+                        randomFrom(RoutingState.values()),
+                        randomBoolean() ? null : "a good reason"
+                    )
+                );
+            }
+        }
+
+        nodeStatsList.sort(Comparator.comparing(n -> n.getNode().getId()));
+
+        return new GetDeploymentStatsAction.Response.AllocationStats(
+            randomAlphaOfLength(5),
+            ByteSizeValue.ofBytes(randomNonNegativeLong()),
+            randomBoolean() ? null : randomIntBetween(1, 8),
+            randomBoolean() ? null : randomIntBetween(1, 8),
+            randomBoolean() ? null : randomIntBetween(1, 10000),
+            Instant.now(),
+            nodeStatsList
+        );
+    }
+
+    private static TrainedModelAllocation createAllocation(String modelId) {
+        return TrainedModelAllocation.Builder.empty(new StartTrainedModelDeploymentAction.TaskParams(modelId, 1024, 1, 1, 1)).build();
+    }
+}