Browse Source

[ML] add allocation state reason and support for partial model allocations (#76925)

Previously, if a model failed to be allocated on any node, the deployment failed.

This commit allows for an allocation to be partially_started and indicates its
current state via a new state value in the deployment stats API.

Additionally, when starting a deployment, the user may specify to wait_for
starting, partially_started, started and the API will block (as long as timeout doesn't expire) until that state is reached.
Benjamin Trent 4 years ago
parent
commit
708491d0d3
18 changed files with 811 additions and 119 deletions
  1. 11 2
      docs/reference/ml/df-analytics/apis/start-trained-model-deployment.asciidoc
  2. 7 0
      rest-api-spec/src/main/resources/rest-api-spec/api/ml.start_trained_model_deployment.json
  3. 57 3
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDeploymentStatsAction.java
  4. 30 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java
  5. 6 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/AllocationState.java
  6. 121 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/AllocationStatus.java
  7. 102 8
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/TrainedModelAllocation.java
  8. 58 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/AllocationStatusTests.java
  9. 123 3
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/TrainedModelAllocationTests.java
  10. 49 6
      x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java
  11. 31 4
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsAction.java
  12. 66 9
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java
  13. 69 50
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationClusterService.java
  14. 1 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationMetadata.java
  15. 3 6
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationService.java
  16. 7 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestStartTrainedModelDeploymentAction.java
  17. 56 26
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationClusterServiceTests.java
  18. 14 0
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationMetadataTests.java

+ 11 - 2
docs/reference/ml/df-analytics/apis/start-trained-model-deployment.asciidoc

@@ -10,7 +10,7 @@
 [[start-trained-model-deployment-request]]
 == {api-request-title}
 
-`POST _ml/trained_models/<model_id>/deployent/_start` 
+`POST _ml/trained_models/<model_id>/deployent/_start`
 ////
 [[start-trained-model-deployment-prereq]]
 == {api-prereq-title}
@@ -37,6 +37,14 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]
 Controls the amount of time to wait for the model to deploy. Defaults
 to 20 seconds.
 
+`wait_for`::
+(Optional, string)
+Which allocation status to wait for before returning. Defaults to "started".
+Valid values are: "starting", "started", and "fully_allocated". Each
+indicating, respectively, deployment is starting but not yet on any
+node, the model has started on at least one node, the deployment has
+started on all valid nodes.
+
 ////
 [role="child_attributes"]
 [[start-trained-model-deployment-results]]
@@ -51,4 +59,5 @@ to 20 seconds.
 ////
 [[start-trained-model-deployment-example]]
 == {api-examples-title}
-////
+
+////

+ 7 - 0
rest-api-spec/src/main/resources/rest-api-spec/api/ml.start_trained_model_deployment.json

@@ -33,6 +33,13 @@
         "required":false,
         "description":"Controls the amount of time to wait for the model to deploy.",
         "default": "20s"
+      },
+      "wait_for":{
+        "type":"string",
+        "required":false,
+        "description":"The allocation status for which to wait",
+        "options": ["starting", "started", "fully_allocated"],
+        "default": "started"
       }
     }
   }

+ 57 - 3
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDeploymentStatsAction.java

@@ -23,6 +23,8 @@ import org.elasticsearch.common.xcontent.ToXContentObject;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.tasks.Task;
 import org.elasticsearch.xpack.core.action.util.QueryPage;
+import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationStatus;
+import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationState;
 import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingState;
 import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingStateAndReason;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
@@ -224,19 +226,31 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
 
 
             private final String modelId;
+            private AllocationState state;
+            private AllocationStatus allocationStatus;
+            private String reason;
             private final ByteSizeValue modelSize;
             private final List<NodeStats> nodeStats;
 
-            public AllocationStats(String modelId, ByteSizeValue modelSize, List<NodeStats> nodeStats) {
+            public AllocationStats(
+                String modelId,
+                ByteSizeValue modelSize,
+                List<NodeStats> nodeStats
+            ) {
                 this.modelId = modelId;
                 this.modelSize = modelSize;
                 this.nodeStats = nodeStats;
+                this.state = null;
+                this.reason = null;
             }
 
             public AllocationStats(StreamInput in) throws IOException {
                 modelId = in.readString();
                 modelSize = in.readOptionalWriteable(ByteSizeValue::new);
                 nodeStats = in.readList(NodeStats::new);
+                state = in.readOptionalEnum(AllocationState.class);
+                reason = in.readOptionalString();
+                allocationStatus = in.readOptionalWriteable(AllocationStatus::new);
             }
 
             public String getModelId() {
@@ -251,6 +265,29 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
                 return nodeStats;
             }
 
+            public AllocationState getState() {
+                return state;
+            }
+
+            public AllocationStats setState(AllocationState state) {
+                this.state = state;
+                return this;
+            }
+
+            public AllocationStats setAllocationStatus(AllocationStatus allocationStatus) {
+                this.allocationStatus = allocationStatus;
+                return this;
+            }
+
+            public String getReason() {
+                return reason;
+            }
+
+            public AllocationStats setReason(String reason) {
+                this.reason = reason;
+                return this;
+            }
+
             @Override
             public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
                 builder.startObject();
@@ -258,6 +295,15 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
                 if (modelSize != null) {
                     builder.field("model_size", modelSize);
                 }
+                if (state != null) {
+                    builder.field("state", state);
+                }
+                if (reason != null) {
+                    builder.field("reason", reason);
+                }
+                if (allocationStatus != null) {
+                    builder.field("allocation_status", allocationStatus);
+                }
                 builder.startArray("nodes");
                 for (NodeStats nodeStat : nodeStats){
                     nodeStat.toXContent(builder, params);
@@ -272,6 +318,9 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
                 out.writeString(modelId);
                 out.writeOptionalWriteable(modelSize);
                 out.writeList(nodeStats);
+                out.writeOptionalEnum(state);
+                out.writeOptionalString(reason);
+                out.writeOptionalWriteable(allocationStatus);
             }
 
             @Override
@@ -281,12 +330,15 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
                 AllocationStats that = (AllocationStats) o;
                 return Objects.equals(modelId, that.modelId) &&
                     Objects.equals(modelSize, that.modelSize) &&
+                    Objects.equals(state, that.state) &&
+                    Objects.equals(reason, that.reason) &&
+                    Objects.equals(allocationStatus, that.allocationStatus) &&
                     Objects.equals(nodeStats, that.nodeStats);
             }
 
             @Override
             public int hashCode() {
-                return Objects.hash(modelId, modelSize, nodeStats);
+                return Objects.hash(modelId, modelSize, nodeStats, state, reason, allocationStatus);
             }
         }
 
@@ -346,12 +398,14 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
          *
          * @param tasksResponse All the responses from the tasks
          * @param nonStartedModelRoutes Non-started routes
+         * @param nodes current cluster nodes
          * @return The result of merging tasksResponse and the non-started routes
          */
         public static GetDeploymentStatsAction.Response addFailedRoutes(
             GetDeploymentStatsAction.Response tasksResponse,
             Map<String, Map<String, RoutingStateAndReason>> nonStartedModelRoutes,
-            DiscoveryNodes nodes) {
+            DiscoveryNodes nodes
+        ) {
 
             List<GetDeploymentStatsAction.Response.AllocationStats> updatedAllocationStats = new ArrayList<>();
 

+ 30 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java

@@ -26,6 +26,7 @@ import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.tasks.Task;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
+import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationStatus;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.core.ml.utils.MlTaskParams;
 
@@ -48,11 +49,17 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
 
     public static class Request extends MasterNodeRequest<Request> implements ToXContentObject {
 
+        private static final AllocationStatus.State[] VALID_WAIT_STATES = new AllocationStatus.State[] {
+            AllocationStatus.State.STARTED,
+            AllocationStatus.State.STARTING,
+            AllocationStatus.State.FULLY_ALLOCATED };
         public static final ParseField MODEL_ID = new ParseField("model_id");
         public static final ParseField TIMEOUT = new ParseField("timeout");
+        public static final ParseField WAIT_FOR = new ParseField("wait_for");
 
         private String modelId;
         private TimeValue timeout = DEFAULT_TIMEOUT;
+        private AllocationStatus.State waitForState = AllocationStatus.State.STARTED;
 
         public Request(String modelId) {
             setModelId(modelId);
@@ -62,6 +69,7 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
             super(in);
             modelId = in.readString();
             timeout = in.readTimeValue();
+            waitForState = in.readEnum(AllocationStatus.State.class);
         }
 
         public final void setModelId(String modelId) {
@@ -80,23 +88,44 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
             return timeout;
         }
 
+        public AllocationStatus.State getWaitForState() {
+            return waitForState;
+        }
+
+        public Request setWaitForState(AllocationStatus.State waitForState) {
+            this.waitForState = ExceptionsHelper.requireNonNull(waitForState, WAIT_FOR);
+            return this;
+        }
+
         @Override
         public void writeTo(StreamOutput out) throws IOException {
             super.writeTo(out);
             out.writeString(modelId);
             out.writeTimeValue(timeout);
+            out.writeEnum(waitForState);
         }
 
         @Override
         public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
             builder.field(MODEL_ID.getPreferredName(), modelId);
             builder.field(TIMEOUT.getPreferredName(), timeout.getStringRep());
+            builder.field(WAIT_FOR.getPreferredName(), waitForState);
             return builder;
         }
 
         @Override
         public ActionRequestValidationException validate() {
-            return null;
+            if (waitForState.isAnyOf(VALID_WAIT_STATES)) {
+                return null;
+            }
+            ActionRequestValidationException validationException = new ActionRequestValidationException();
+            validationException.addValidationError(
+                "invalid [wait_for] state ["
+                    + waitForState
+                    + "]; must be one of ["
+                    + Strings.arrayToCommaDelimitedString(VALID_WAIT_STATES)
+            );
+            return validationException;
         }
 
         @Override

+ 6 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/AllocationState.java

@@ -7,9 +7,11 @@
 
 package org.elasticsearch.xpack.core.ml.inference.allocation;
 
+import java.util.Arrays;
 import java.util.Locale;
 
 public enum AllocationState {
+    STARTING,
     STARTED,
     STOPPING;
 
@@ -17,6 +19,10 @@ public enum AllocationState {
         return valueOf(value.toUpperCase(Locale.ROOT));
     }
 
+    public boolean isAnyOf(AllocationState... candidates) {
+        return Arrays.stream(candidates).anyMatch(candidate -> this == candidate);
+    }
+
     @Override
     public String toString() {
         return name().toLowerCase(Locale.ROOT);

+ 121 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/AllocationStatus.java

@@ -0,0 +1,121 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.allocation;
+
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ParseField;
+import org.elasticsearch.common.xcontent.ToXContentObject;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Locale;
+import java.util.Objects;
+
+public class AllocationStatus implements Writeable, ToXContentObject {
+
+    public enum State {
+        STARTING,
+        STARTED,
+        FULLY_ALLOCATED;
+
+        public static State fromString(String value) {
+            return valueOf(value.toUpperCase(Locale.ROOT));
+        }
+
+        public boolean isAnyOf(State... candidates) {
+            return Arrays.stream(candidates).anyMatch(candidate -> this == candidate);
+        }
+
+        @Override
+        public String toString() {
+            return name().toLowerCase(Locale.ROOT);
+        }
+    }
+
+    public static ParseField ALLOCATION_COUNT = new ParseField("allocation_count");
+    public static ParseField TARGET_ALLOCATION_COUNT = new ParseField("target_allocation_count");
+    public static ParseField STATE = new ParseField("state");
+
+    private static final ConstructingObjectParser<AllocationStatus, Void> PARSER = new ConstructingObjectParser<>(
+        "allocation_health",
+        a -> new AllocationStatus((int)a[0], (int)a[1])
+    );
+    static {
+        PARSER.declareInt(ConstructingObjectParser.constructorArg(), ALLOCATION_COUNT);
+        PARSER.declareInt(ConstructingObjectParser.constructorArg(), TARGET_ALLOCATION_COUNT);
+        // NOTE: We ignore this as we calculate it given allocation and target allocation counts
+        PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), STATE);
+    }
+
+    public static AllocationStatus fromXContent(XContentParser parser) {
+        return PARSER.apply(parser, null);
+    }
+
+    private final int allocationCount;
+    private final int targetAllocationCount;
+
+    public AllocationStatus(int allocationCount, int targetAllocationCount) {
+        this.allocationCount = allocationCount;
+        this.targetAllocationCount = targetAllocationCount;
+        if (allocationCount < 0) {
+            throw new IllegalArgumentException("[" + ALLOCATION_COUNT.getPreferredName() + "] must be greater than or equal to 0");
+        }
+        if (targetAllocationCount < 0) {
+            throw new IllegalArgumentException("[" + TARGET_ALLOCATION_COUNT.getPreferredName() + "] must be greater than or equal to 0");
+        }
+    }
+
+    public AllocationStatus(StreamInput in) throws IOException {
+        this.allocationCount = in.readVInt();
+        this.targetAllocationCount = in.readVInt();
+    }
+
+    public State calculateState() {
+        if (allocationCount == 0) {
+            return State.STARTING;
+        }
+        if (allocationCount < targetAllocationCount) {
+            return State.STARTED;
+        }
+        return State.FULLY_ALLOCATED;
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        builder.field(ALLOCATION_COUNT.getPreferredName(), allocationCount);
+        builder.field(TARGET_ALLOCATION_COUNT.getPreferredName(), targetAllocationCount);
+        builder.field(STATE.getPreferredName(), calculateState());
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeVInt(allocationCount);
+        out.writeVInt(targetAllocationCount);
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        AllocationStatus that = (AllocationStatus) o;
+        return allocationCount == that.allocationCount && targetAllocationCount == that.targetAllocationCount;
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(allocationCount, targetAllocationCount);
+    }
+}

+ 102 - 8
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/TrainedModelAllocation.java

@@ -11,6 +11,7 @@ import org.elasticsearch.ResourceAlreadyExistsException;
 import org.elasticsearch.ResourceNotFoundException;
 import org.elasticsearch.cluster.AbstractDiffable;
 import org.elasticsearch.cluster.Diffable;
+import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.xcontent.ConstructingObjectParser;
@@ -24,8 +25,10 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import java.io.IOException;
 import java.util.Collections;
 import java.util.LinkedHashMap;
+import java.util.List;
 import java.util.Map;
 import java.util.Objects;
+import java.util.Optional;
 
 // TODO implement better diffable logic so that whole diff does not need to be serialized if only one part changes
 /**
@@ -36,6 +39,7 @@ public class TrainedModelAllocation extends AbstractDiffable<TrainedModelAllocat
         Diffable<TrainedModelAllocation>,
         ToXContentObject {
 
+    private static final ParseField REASON = new ParseField("reason");
     private static final ParseField ALLOCATION_STATE = new ParseField("allocation_state");
     private static final ParseField ROUTING_TABLE = new ParseField("routing_table");
     private static final ParseField TASK_PARAMETERS = new ParseField("task_parameters");
@@ -47,7 +51,8 @@ public class TrainedModelAllocation extends AbstractDiffable<TrainedModelAllocat
         a -> new TrainedModelAllocation(
             (StartTrainedModelDeploymentAction.TaskParams) a[0],
             (Map<String, RoutingStateAndReason>) a[1],
-            AllocationState.fromString((String)a[2])
+            AllocationState.fromString((String)a[2]),
+            (String) a[3]
         )
     );
     static {
@@ -62,11 +67,13 @@ public class TrainedModelAllocation extends AbstractDiffable<TrainedModelAllocat
             ROUTING_TABLE
         );
         PARSER.declareString(ConstructingObjectParser.constructorArg(), ALLOCATION_STATE);
+        PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), REASON);
     }
 
     private final StartTrainedModelDeploymentAction.TaskParams taskParams;
     private final Map<String, RoutingStateAndReason> nodeRoutingTable;
     private final AllocationState allocationState;
+    private final String reason;
 
     public static TrainedModelAllocation fromXContent(XContentParser parser) throws IOException {
         return PARSER.apply(parser, null);
@@ -75,17 +82,20 @@ public class TrainedModelAllocation extends AbstractDiffable<TrainedModelAllocat
     TrainedModelAllocation(
         StartTrainedModelDeploymentAction.TaskParams taskParams,
         Map<String, RoutingStateAndReason> nodeRoutingTable,
-        AllocationState allocationState
+        AllocationState allocationState,
+        String reason
     ) {
         this.taskParams = ExceptionsHelper.requireNonNull(taskParams, TASK_PARAMETERS);
         this.nodeRoutingTable = ExceptionsHelper.requireNonNull(nodeRoutingTable, ROUTING_TABLE);
         this.allocationState = ExceptionsHelper.requireNonNull(allocationState, ALLOCATION_STATE);
+        this.reason = reason;
     }
 
     public TrainedModelAllocation(StreamInput in) throws IOException {
         this.taskParams = new StartTrainedModelDeploymentAction.TaskParams(in);
         this.nodeRoutingTable = in.readOrderedMap(StreamInput::readString, RoutingStateAndReason::new);
         this.allocationState = in.readEnum(AllocationState.class);
+        this.reason = in.readOptionalString();
     }
 
     public boolean isRoutedToNode(String nodeId) {
@@ -113,6 +123,10 @@ public class TrainedModelAllocation extends AbstractDiffable<TrainedModelAllocat
             .toArray(String[]::new);
     }
 
+    public Optional<String> getReason() {
+        return Optional.ofNullable(reason);
+    }
+
     @Override
     public boolean equals(Object o) {
         if (this == o) return true;
@@ -120,12 +134,13 @@ public class TrainedModelAllocation extends AbstractDiffable<TrainedModelAllocat
         TrainedModelAllocation that = (TrainedModelAllocation) o;
         return Objects.equals(nodeRoutingTable, that.nodeRoutingTable)
             && Objects.equals(taskParams, that.taskParams)
+            && Objects.equals(reason, that.reason)
             && Objects.equals(allocationState, that.allocationState);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(nodeRoutingTable, taskParams, allocationState);
+        return Objects.hash(nodeRoutingTable, taskParams, allocationState, reason);
     }
 
     @Override
@@ -134,6 +149,9 @@ public class TrainedModelAllocation extends AbstractDiffable<TrainedModelAllocat
         builder.field(TASK_PARAMETERS.getPreferredName(), taskParams);
         builder.field(ROUTING_TABLE.getPreferredName(), nodeRoutingTable);
         builder.field(ALLOCATION_STATE.getPreferredName(), allocationState);
+        if (reason != null) {
+            builder.field(REASON.getPreferredName(), reason);
+        }
         builder.endObject();
         return builder;
     }
@@ -143,16 +161,39 @@ public class TrainedModelAllocation extends AbstractDiffable<TrainedModelAllocat
         taskParams.writeTo(out);
         out.writeMap(nodeRoutingTable, StreamOutput::writeString, (o, w) -> w.writeTo(o));
         out.writeEnum(allocationState);
+        out.writeOptionalString(reason);
+    }
+
+    public Optional<AllocationStatus> calculateAllocationStatus(List<DiscoveryNode> allocatableNodes) {
+        if (allocationState.equals(AllocationState.STOPPING)) {
+            return Optional.empty();
+        }
+        int numAllocatableNodes = 0;
+        int numStarted = 0;
+        for (DiscoveryNode node : allocatableNodes) {
+            if (StartTrainedModelDeploymentAction.TaskParams.mayAllocateToNode(node)) {
+                RoutingState nodeState = Optional.ofNullable(nodeRoutingTable.get(node.getId()))
+                    .map(RoutingStateAndReason::getState)
+                    .orElse(RoutingState.STOPPED);
+                numAllocatableNodes++;
+                if (nodeState.equals(RoutingState.STARTED)) {
+                    numStarted++;
+                }
+            }
+        }
+        return Optional.of(new AllocationStatus(numStarted, numAllocatableNodes));
     }
 
+
     public static class Builder {
         private final Map<String, RoutingStateAndReason> nodeRoutingTable;
         private final StartTrainedModelDeploymentAction.TaskParams taskParams;
         private AllocationState allocationState;
         private boolean isChanged;
+        private String reason;
 
         public static Builder fromAllocation(TrainedModelAllocation allocation) {
-            return new Builder(allocation.taskParams, allocation.nodeRoutingTable, allocation.allocationState);
+            return new Builder(allocation.taskParams, allocation.nodeRoutingTable, allocation.allocationState, allocation.reason);
         }
 
         public static Builder empty(StartTrainedModelDeploymentAction.TaskParams taskParams) {
@@ -162,17 +203,19 @@ public class TrainedModelAllocation extends AbstractDiffable<TrainedModelAllocat
         private Builder(
             StartTrainedModelDeploymentAction.TaskParams taskParams,
             Map<String, RoutingStateAndReason> nodeRoutingTable,
-            AllocationState allocationState
+            AllocationState allocationState,
+            String reason
         ) {
             this.taskParams = taskParams;
             this.nodeRoutingTable = new LinkedHashMap<>(nodeRoutingTable);
             this.allocationState = allocationState;
+            this.reason = reason;
         }
 
         private Builder(StartTrainedModelDeploymentAction.TaskParams taskParams) {
             this.nodeRoutingTable = new LinkedHashMap<>();
             this.taskParams = taskParams;
-            this.allocationState = AllocationState.STARTED;
+            this.allocationState = AllocationState.STARTING;
         }
 
         public Builder addNewRoutingEntry(String nodeId) {
@@ -186,6 +229,12 @@ public class TrainedModelAllocation extends AbstractDiffable<TrainedModelAllocat
             return this;
         }
 
+        // For testing purposes
+        Builder addRoutingEntry(String nodeId, RoutingState state) {
+            nodeRoutingTable.put(nodeId, new RoutingStateAndReason(state, ""));
+            return this;
+        }
+
         public Builder addNewFailedRoutingEntry(String nodeId, String reason) {
             if (nodeRoutingTable.containsKey(nodeId)) {
                 throw new ResourceAlreadyExistsException(
@@ -219,21 +268,66 @@ public class TrainedModelAllocation extends AbstractDiffable<TrainedModelAllocat
             return this;
         }
 
-        public Builder stopAllocation() {
+        public Builder setReason(String reason) {
+            if (Objects.equals(reason, this.reason)) {
+                return this;
+            }
+            isChanged = true;
+            this.reason = reason;
+            return this;
+        }
+
+        public Builder stopAllocation(String reason) {
             if (allocationState.equals(AllocationState.STOPPING)) {
                 return this;
             }
             isChanged = true;
+            this.reason = reason;
             allocationState = AllocationState.STOPPING;
             return this;
         }
 
+        public AllocationState calculateAllocationState() {
+            if (allocationState.equals(AllocationState.STOPPING)) {
+                return allocationState;
+            }
+            if (nodeRoutingTable.values().stream().anyMatch(r -> r.getState().equals(RoutingState.STARTED))) {
+                return AllocationState.STARTED;
+            }
+            return AllocationState.STARTING;
+        }
+
+        public Builder calculateAndSetAllocationState() {
+            return setAllocationState(calculateAllocationState());
+        }
+
+        public Builder setAllocationState(AllocationState state) {
+            if (allocationState.equals(AllocationState.STOPPING)) {
+                return this;
+            }
+            if (allocationState.equals(state)) {
+                return this;
+            }
+            isChanged = true;
+            allocationState = state;
+            return this;
+        }
+
+        public Builder clearReason() {
+            if (this.reason == null) {
+                return this;
+            }
+            isChanged = true;
+            reason = null;
+            return this;
+        }
+
         public boolean isChanged() {
             return isChanged;
         }
 
         public TrainedModelAllocation build() {
-            return new TrainedModelAllocation(taskParams, nodeRoutingTable, allocationState);
+            return new TrainedModelAllocation(taskParams, nodeRoutingTable, allocationState, reason);
         }
     }
 

+ 58 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/AllocationStatusTests.java

@@ -0,0 +1,58 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.allocation;
+
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.test.AbstractSerializingTestCase;
+
+import java.io.IOException;
+
+import static org.hamcrest.Matchers.equalTo;
+
+public class AllocationStatusTests extends AbstractSerializingTestCase<AllocationStatus> {
+
+    public static AllocationStatus randomInstance() {
+        return new AllocationStatus(randomInt(10), randomIntBetween(1, 10));
+    }
+
+    @Override
+    protected AllocationStatus doParseInstance(XContentParser parser) throws IOException {
+        return AllocationStatus.fromXContent(parser);
+    }
+
+    @Override
+    protected Writeable.Reader<AllocationStatus> instanceReader() {
+        return AllocationStatus::new;
+    }
+
+    @Override
+    protected AllocationStatus createTestInstance() {
+        return randomInstance();
+    }
+
+    public void testCalculateState() {
+        int targetAllocation = randomIntBetween(2, 10);
+
+        assertThat(
+            new AllocationStatus(randomIntBetween(1, targetAllocation - 1), targetAllocation).calculateState(),
+            equalTo(AllocationStatus.State.STARTED)
+        );
+
+        assertThat(
+            new AllocationStatus(0, targetAllocation).calculateState(),
+            equalTo(AllocationStatus.State.STARTING)
+        );
+
+        assertThat(
+            new AllocationStatus(targetAllocation, targetAllocation).calculateState(),
+            equalTo(AllocationStatus.State.FULLY_ALLOCATED)
+        );
+    }
+
+}

+ 123 - 3
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/TrainedModelAllocationTests.java

@@ -9,6 +9,9 @@ package org.elasticsearch.xpack.core.ml.inference.allocation;
 
 import org.elasticsearch.ResourceAlreadyExistsException;
 import org.elasticsearch.ResourceNotFoundException;
+import org.elasticsearch.Version;
+import org.elasticsearch.cluster.node.DiscoveryNode;
+import org.elasticsearch.cluster.node.DiscoveryNodeRole;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.test.AbstractSerializingTestCase;
@@ -16,11 +19,13 @@ import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
 
 import java.io.IOException;
 import java.util.List;
+import java.util.Map;
 import java.util.function.Function;
 import java.util.stream.Collectors;
 import java.util.stream.Stream;
 
 import static org.hamcrest.Matchers.arrayContainingInAnyOrder;
+import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.is;
 
 public class TrainedModelAllocationTests extends AbstractSerializingTestCase<TrainedModelAllocation> {
@@ -37,6 +42,10 @@ public class TrainedModelAllocationTests extends AbstractSerializingTestCase<Tra
                 builder.addNewRoutingEntry(node);
             }
         }
+        builder.setAllocationState(randomFrom(AllocationState.values()));
+        if (randomBoolean()) {
+            builder.setReason(randomAlphaOfLength(10));
+        }
         return builder.build();
     }
 
@@ -105,9 +114,7 @@ public class TrainedModelAllocationTests extends AbstractSerializingTestCase<Tra
         String startedNode2 = "started-node-2";
         String nodeInAnotherState1 = "another-state-node-1";
         String nodeInAnotherState2 = "another-state-node-2";
-        TrainedModelAllocation allocation = TrainedModelAllocation.Builder.empty(
-            new StartTrainedModelDeploymentAction.TaskParams(randomAlphaOfLength(10), randomNonNegativeLong())
-        )
+        TrainedModelAllocation allocation = TrainedModelAllocation.Builder.empty(randomParams())
             .addNewRoutingEntry(startedNode1)
             .addNewRoutingEntry(startedNode2)
             .addNewRoutingEntry(nodeInAnotherState1)
@@ -132,6 +139,119 @@ public class TrainedModelAllocationTests extends AbstractSerializingTestCase<Tra
         assertThat(allocation.getStartedNodes(), arrayContainingInAnyOrder(startedNode1, startedNode2));
     }
 
+    public void testCalculateAllocationStatus() {
+        List<DiscoveryNode> nodes = Stream.generate(TrainedModelAllocationTests::buildNode).limit(5).collect(Collectors.toList());
+        final boolean includeNodes = randomBoolean();
+        assertThat(
+            TrainedModelAllocation.Builder.empty(randomParams())
+                .build()
+                .calculateAllocationStatus(includeNodes ? nodes : List.of())
+                .orElseThrow(),
+            equalTo(new AllocationStatus(0, includeNodes ? 5 : 0))
+        );
+        assertThat(
+            TrainedModelAllocation.Builder.empty(randomParams())
+                .stopAllocation("test")
+                .build()
+                .calculateAllocationStatus(includeNodes ? nodes : List.of())
+                .isPresent(),
+            is(false)
+        );
+
+        {
+            TrainedModelAllocation.Builder builder = TrainedModelAllocation.Builder.empty(randomParams());
+            int count = randomInt(4);
+            for (int i = 0; i < count; i++) {
+                builder.addRoutingEntry(nodes.get(i).getId(), RoutingState.STARTED);
+            }
+            assertThat(builder.build().calculateAllocationStatus(nodes).orElseThrow(), equalTo(new AllocationStatus(count, 5)));
+        }
+        {
+            TrainedModelAllocation.Builder builder = TrainedModelAllocation.Builder.empty(randomParams());
+            for (DiscoveryNode node : nodes) {
+                builder.addRoutingEntry(
+                    node.getId(),
+                    randomFrom(RoutingState.FAILED, RoutingState.STOPPED, RoutingState.STARTING, RoutingState.STOPPING)
+                );
+            }
+            int count = randomIntBetween(1, 4);
+            for (int i = 0; i < count; i++) {
+                builder.addRoutingEntry(nodes.get(i).getId(), RoutingState.STARTED);
+            }
+            assertThat(builder.build().calculateAllocationStatus(nodes).orElseThrow(), equalTo(new AllocationStatus(count, 5)));
+        }
+        {
+            TrainedModelAllocation.Builder builder = TrainedModelAllocation.Builder.empty(randomParams());
+            for (DiscoveryNode node : nodes) {
+                builder.addRoutingEntry(node.getId(), RoutingState.STARTED);
+            }
+            assertThat(builder.build().calculateAllocationStatus(nodes).orElseThrow(), equalTo(new AllocationStatus(5, 5)));
+        }
+    }
+
+    public void testCalculateAllocationState() {
+        List<DiscoveryNode> nodes = Stream.generate(TrainedModelAllocationTests::buildNode).limit(5).collect(Collectors.toList());
+        assertThat(
+            TrainedModelAllocation.Builder.empty(randomParams()).calculateAllocationState(),
+            equalTo(AllocationState.STARTING)
+        );
+        assertThat(
+            TrainedModelAllocation.Builder.empty(randomParams())
+                .stopAllocation("test")
+                .calculateAllocationState(),
+            equalTo(AllocationState.STOPPING)
+        );
+
+        {
+            TrainedModelAllocation.Builder builder = TrainedModelAllocation.Builder.empty(randomParams());
+            int count = randomInt(4);
+            for (int i = 0; i < count; i++) {
+                builder.addRoutingEntry(
+                    nodes.get(i).getId(),
+                    randomFrom(RoutingState.FAILED, RoutingState.STOPPED, RoutingState.STARTING, RoutingState.STOPPING)
+                );
+            }
+            assertThat(builder.calculateAllocationState(), equalTo(AllocationState.STARTING));
+        }
+        {
+            TrainedModelAllocation.Builder builder = TrainedModelAllocation.Builder.empty(randomParams());
+            for (DiscoveryNode node : nodes) {
+                builder.addRoutingEntry(
+                    node.getId(),
+                    randomFrom(RoutingState.FAILED, RoutingState.STOPPED, RoutingState.STARTING, RoutingState.STOPPING)
+                );
+            }
+            int count = randomIntBetween(1, 4);
+            for (int i = 0; i < count; i++) {
+                builder.addRoutingEntry(nodes.get(i).getId(), RoutingState.STARTED);
+            }
+            assertThat(builder.calculateAllocationState(), equalTo(AllocationState.STARTED));
+        }
+        {
+            TrainedModelAllocation.Builder builder = TrainedModelAllocation.Builder.empty(randomParams());
+            for (DiscoveryNode node : nodes) {
+                builder.addRoutingEntry(node.getId(), RoutingState.STARTED);
+            }
+            assertThat(builder.calculateAllocationState(), equalTo(AllocationState.STARTED));
+        }
+    }
+
+
+    private static DiscoveryNode buildNode() {
+        return new DiscoveryNode(
+            randomAlphaOfLength(10),
+            randomAlphaOfLength(10),
+            buildNewFakeTransportAddress(),
+            Map.of(),
+            DiscoveryNodeRole.roles(),
+            Version.CURRENT
+        );
+    }
+
+    private static StartTrainedModelDeploymentAction.TaskParams randomParams() {
+        return new StartTrainedModelDeploymentAction.TaskParams(randomAlphaOfLength(10), randomNonNegativeLong());
+    }
+
     private static void assertUnchanged(
         TrainedModelAllocation.Builder builder,
         Function<TrainedModelAllocation.Builder, TrainedModelAllocation.Builder> function

+ 49 - 6
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java

@@ -12,10 +12,13 @@ import org.elasticsearch.client.Request;
 import org.elasticsearch.client.RequestOptions;
 import org.elasticsearch.client.Response;
 import org.elasticsearch.client.ResponseException;
+import org.elasticsearch.common.CheckedBiConsumer;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
+import org.elasticsearch.common.xcontent.support.XContentMapValues;
 import org.elasticsearch.test.SecuritySettingsSourceField;
 import org.elasticsearch.test.rest.ESRestTestCase;
+import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationStatus;
 import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
 import org.elasticsearch.xpack.core.ml.integration.MlRestTestStateCleaner;
 import org.elasticsearch.xpack.core.ml.utils.MapHelper;
@@ -40,7 +43,11 @@ import java.util.stream.Collectors;
 import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.empty;
 import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.greaterThanOrEqualTo;
 import static org.hamcrest.Matchers.hasSize;
+import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.not;
+import static org.hamcrest.Matchers.nullValue;
 
 /**
  * This test uses a tiny hardcoded base64 encoded PyTorch TorchScript model.
@@ -168,6 +175,38 @@ public class PyTorchModelIT extends ESRestTestCase {
         }
     }
 
+    @SuppressWarnings("unchecked")
+    public void testDeploymentStats() throws IOException {
+        String model = "model_starting_test";
+        String modelPartial = "model_partially_started";
+        String modelStarted = "model_started";
+        createTrainedModel(model);
+        putVocabulary(List.of("once", "twice"), model);
+        putModelDefinition(model);
+        createTrainedModel(modelPartial);
+        putVocabulary(List.of("once", "twice"), modelPartial);
+        putModelDefinition(modelPartial);
+        createTrainedModel(modelStarted);
+        putVocabulary(List.of("once", "twice"), modelStarted);
+        putModelDefinition(modelStarted);
+
+        CheckedBiConsumer<String, AllocationStatus.State, IOException> assertAtLeast = (modelId, state) -> {
+            startDeployment(modelId, state.toString());
+            Response response = getDeploymentStats(modelId);
+            List<Map<String, Object>> stats = (List<Map<String, Object>>)entityAsMap(response).get("deployment_stats");
+            assertThat(stats, hasSize(1));
+            String statusState = (String)XContentMapValues.extractValue("allocation_status.state", stats.get(0));
+            assertThat(stats.toString(), statusState, is(not(nullValue())));
+            assertThat(AllocationStatus.State.fromString(statusState), greaterThanOrEqualTo(state));
+            stopDeployment(model);
+        };
+
+        assertAtLeast.accept(model, AllocationStatus.State.STARTING);
+        assertAtLeast.accept(modelPartial, AllocationStatus.State.STARTED);
+        assertAtLeast.accept(modelStarted, AllocationStatus.State.FULLY_ALLOCATED);
+    }
+
+    @AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/pull/1961")
     @SuppressWarnings("unchecked")
     public void testLiveDeploymentStats() throws IOException {
         String modelA = "model_a";
@@ -175,7 +214,7 @@ public class PyTorchModelIT extends ESRestTestCase {
         createTrainedModel(modelA);
         putVocabulary(List.of("once", "twice"), modelA);
         putModelDefinition(modelA);
-        startDeployment(modelA);
+        startDeployment(modelA, AllocationStatus.State.FULLY_ALLOCATED.toString());
         infer("once", modelA);
         infer("twice", modelA);
         Response response = getDeploymentStats(modelA);
@@ -209,8 +248,8 @@ public class PyTorchModelIT extends ESRestTestCase {
         putVocabulary(List.of("once", "twice"), modelBar);
         putModelDefinition(modelBar);
 
-        startDeployment(modelFoo);
-        startDeployment(modelBar);
+        startDeployment(modelFoo, AllocationStatus.State.FULLY_ALLOCATED.toString());
+        startDeployment(modelBar, AllocationStatus.State.FULLY_ALLOCATED.toString());
         infer("once", modelFoo);
         infer("once", modelBar);
         {
@@ -266,8 +305,8 @@ public class PyTorchModelIT extends ESRestTestCase {
         putVocabulary(List.of("once", "twice"), modelBar);
         putModelDefinition(modelBar);
 
-        startDeployment(modelFoo);
-        startDeployment(modelBar);
+        startDeployment(modelFoo, AllocationStatus.State.FULLY_ALLOCATED.toString());
+        startDeployment(modelBar, AllocationStatus.State.FULLY_ALLOCATED.toString());
         infer("once", modelFoo);
         infer("once", modelBar);
 
@@ -380,7 +419,11 @@ public class PyTorchModelIT extends ESRestTestCase {
     }
 
     private Response startDeployment(String modelId) throws IOException {
-        Request request = new Request("POST", "/_ml/trained_models/" + modelId + "/deployment/_start?timeout=40s");
+        return startDeployment(modelId, AllocationStatus.State.STARTED.toString());
+    }
+
+    private Response startDeployment(String modelId, String waitForState) throws IOException {
+        Request request = new Request("POST", "/_ml/trained_models/" + modelId + "/deployment/_start?timeout=40s&wait_for=" + waitForState);
         return client().performRequest(request);
     }
 

+ 31 - 4
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsAction.java

@@ -12,6 +12,8 @@ import org.elasticsearch.action.FailedNodeException;
 import org.elasticsearch.action.TaskOperationFailure;
 import org.elasticsearch.action.support.ActionFilters;
 import org.elasticsearch.action.support.tasks.TransportTasksAction;
+import org.elasticsearch.cluster.ClusterState;
+import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.inject.Inject;
@@ -20,8 +22,11 @@ import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.TransportService;
 import org.elasticsearch.xpack.core.action.util.ExpandedIdsMatcher;
 import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
+import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
+import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationState;
 import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingState;
 import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingStateAndReason;
+import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocation;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationMetadata;
 import org.elasticsearch.xpack.ml.inference.deployment.ModelStats;
@@ -120,10 +125,32 @@ public class TransportGetDeploymentStatsAction extends TransportTasksAction<Trai
 
         ActionListener<GetDeploymentStatsAction.Response> addFailedListener = listener.delegateFailure(
             (l, response) -> {
-                var updatedResponse =
-                    GetDeploymentStatsAction.Response.addFailedRoutes(response,
-                        nonStartedAllocationsForModel,
-                        clusterService.state().nodes());
+                var updatedResponse= GetDeploymentStatsAction.Response.addFailedRoutes(response,
+                    nonStartedAllocationsForModel,
+                    clusterService.state().nodes()
+                );
+                ClusterState latestState = clusterService.state();
+                Set<String> nodesShuttingDown = TransportStartTrainedModelDeploymentAction.nodesShuttingDown(latestState);
+                List<DiscoveryNode> nodes = latestState.getNodes()
+                    .getAllNodes()
+                    .stream()
+                    .filter(d -> nodesShuttingDown.contains(d.getId()) == false)
+                    .filter(StartTrainedModelDeploymentAction.TaskParams::mayAllocateToNode)
+                    .collect(Collectors.toList());
+                // Set the allocation state and reason if we have it
+                for (GetDeploymentStatsAction.Response.AllocationStats stats : updatedResponse.getStats().results()) {
+                    Optional<TrainedModelAllocation> modelAllocation = Optional.ofNullable(
+                        allocation.getModelAllocation(stats.getModelId())
+                    );
+                    TrainedModelAllocation trainedModelAllocation = modelAllocation.orElse(null);
+                    if (trainedModelAllocation != null) {
+                        stats.setState(trainedModelAllocation.getAllocationState())
+                            .setReason(trainedModelAllocation.getReason().orElse(null));
+                        if (trainedModelAllocation.getAllocationState().isAnyOf(AllocationState.STARTED, AllocationState.STARTING)) {
+                            stats.setAllocationStatus(trainedModelAllocation.calculateAllocationStatus(nodes).orElse(null));
+                        }
+                    }
+                }
                 l.onResponse(updatedResponse);
             }
         );

+ 66 - 9
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java

@@ -21,6 +21,8 @@ import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.cluster.block.ClusterBlockException;
 import org.elasticsearch.cluster.block.ClusterBlockLevel;
 import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
+import org.elasticsearch.cluster.metadata.NodesShutdownMetadata;
+import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.inject.Inject;
 import org.elasticsearch.common.xcontent.NamedXContentRegistry;
@@ -39,21 +41,26 @@ import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
 import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.TaskParams;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
+import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationStatus;
 import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingState;
 import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingStateAndReason;
 import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocation;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.ml.MachineLearning;
+import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationMetadata;
 import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationService;
 import org.elasticsearch.xpack.ml.inference.persistence.ChunkedTrainedModelRestorer;
 import org.elasticsearch.xpack.ml.process.MlMemoryTracker;
 
+import java.util.Collections;
 import java.util.HashMap;
-import java.util.HashSet;
+import java.util.LinkedHashSet;
+import java.util.List;
 import java.util.Map;
 import java.util.Objects;
 import java.util.Set;
 import java.util.function.Predicate;
+import java.util.stream.Collectors;
 
 import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
 
@@ -95,7 +102,12 @@ public class TransportStartTrainedModelDeploymentAction
 
         ActionListener<CreateTrainedModelAllocationAction.Response> waitForDeploymentToStart =
             ActionListener.wrap(
-                modelAllocation -> waitForDeploymentStarted(request.getModelId(), request.getTimeout(), listener),
+                modelAllocation -> waitForDeploymentState(
+                    request.getModelId(),
+                    request.getTimeout(),
+                    request.getWaitForState(),
+                    listener
+                ),
                 e -> {
                     logger.warn(() -> new ParameterizedMessage("[{}] creating new allocation failed", request.getModelId()), e);
                     if (ExceptionsHelper.unwrapCause(e) instanceof ResourceAlreadyExistsException) {
@@ -180,12 +192,13 @@ public class TransportStartTrainedModelDeploymentAction
         );
     }
 
-    private void waitForDeploymentStarted(
+    private void waitForDeploymentState(
         String modelId,
         TimeValue timeout,
+        AllocationStatus.State state,
         ActionListener<CreateTrainedModelAllocationAction.Response> listener
     ) {
-        DeploymentStartedPredicate predicate = new DeploymentStartedPredicate(modelId);
+        DeploymentStartedPredicate predicate = new DeploymentStartedPredicate(modelId, state);
         trainedModelAllocationService.waitForAllocationCondition(modelId, predicate, timeout,
             new TrainedModelAllocationService.WaitForAllocationListener() {
                 @Override
@@ -234,19 +247,23 @@ public class TransportStartTrainedModelDeploymentAction
         return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE);
     }
 
-    private static class DeploymentStartedPredicate implements Predicate<TrainedModelAllocation> {
+    private static class DeploymentStartedPredicate implements Predicate<ClusterState> {
 
         private volatile Exception exception;
 
         // for logging
         private final String modelId;
+        private final AllocationStatus.State waitForState;
 
-        DeploymentStartedPredicate(String modelId) {
+        DeploymentStartedPredicate(String modelId, AllocationStatus.State waitForState) {
             this.modelId = ExceptionsHelper.requireNonNull(modelId, "model_id");
+            this.waitForState = waitForState;
         }
 
         @Override
-        public boolean test(TrainedModelAllocation trainedModelAllocation) {
+        public boolean test(ClusterState clusterState) {
+            TrainedModelAllocation trainedModelAllocation = TrainedModelAllocationMetadata.allocationForModelId(clusterState, modelId)
+                .orElse(null);
             if (trainedModelAllocation == null) {
                 // Something weird happened, it should NEVER be null...
                 return true;
@@ -257,7 +274,7 @@ public class TransportStartTrainedModelDeploymentAction
                 .entrySet();
 
             Map<String, String> nodeFailuresAndReasons = new HashMap<>();
-            Set<String> nodesStillInitializing = new HashSet<>();
+            Set<String> nodesStillInitializing = new LinkedHashSet<>();
             for (Map.Entry<String, RoutingStateAndReason> nodeIdAndState : nodesAndState) {
                 if (RoutingState.FAILED.equals(nodeIdAndState.getValue().getState())) {
                     nodeFailuresAndReasons.put(nodeIdAndState.getKey(), nodeIdAndState.getValue().getReason());
@@ -276,14 +293,54 @@ public class TransportStartTrainedModelDeploymentAction
                 return true;
             }
 
+            // No nodes allocated at all!
+            // TODO when we support autoscaling for this, check for `maxLazyNodes` setting
+            if (nodesAndState.isEmpty()) {
+                String msg = "Could not start deployment because no suitable nodes were found, allocation explanation ["
+                    + trainedModelAllocation.getReason()
+                    + "]";
+                logger.warn("[{}] {}", modelId, msg);
+                Exception detail = new IllegalStateException(msg);
+                exception = new ElasticsearchStatusException(
+                    "Could not start deployment because no ML nodes with sufficient capacity were found",
+                    RestStatus.TOO_MANY_REQUESTS,
+                    detail
+                );
+                return true;
+            }
+
+            Set<String> nodesShuttingDown = nodesShuttingDown(clusterState);
+            List<DiscoveryNode> nodes = clusterState.nodes()
+                .getAllNodes()
+                .stream()
+                .filter(d -> nodesShuttingDown.contains(d.getId()) == false)
+                .filter(TaskParams::mayAllocateToNode)
+                .collect(Collectors.toList());
+            AllocationStatus allocationStatus = trainedModelAllocation.calculateAllocationStatus(nodes).orElse(null);
+            if (allocationStatus == null || allocationStatus.calculateState().compareTo(waitForState) >= 0) {
+                return true;
+            }
+
             if (nodesStillInitializing.isEmpty()) {
                 return true;
             }
             logger.trace(
-                () -> new ParameterizedMessage("[{}] tested and nodes {} still initializing", modelId, nodesStillInitializing)
+                () -> new ParameterizedMessage(
+                    "[{}] tested with state [{}] and nodes {} still initializing",
+                    modelId,
+                    trainedModelAllocation.getAllocationState(),
+                    nodesStillInitializing
+                )
             );
             return false;
         }
     }
 
+    static Set<String> nodesShuttingDown(final ClusterState state) {
+        return NodesShutdownMetadata.getShutdowns(state)
+            .map(NodesShutdownMetadata::getAllNodeMetadataMap)
+            .map(Map::keySet)
+            .orElse(Collections.emptySet());
+    }
+
 }

+ 69 - 50
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationClusterService.java

@@ -39,13 +39,14 @@ import org.elasticsearch.xpack.ml.MachineLearning;
 import org.elasticsearch.xpack.ml.job.NodeLoad;
 import org.elasticsearch.xpack.ml.job.NodeLoadDetector;
 
-import java.util.ArrayList;
 import java.util.Collections;
-import java.util.HashMap;
 import java.util.List;
+import java.util.Locale;
 import java.util.Map;
 import java.util.Optional;
 import java.util.Set;
+import java.util.TreeMap;
+import java.util.stream.Collectors;
 
 public class TrainedModelAllocationClusterService implements ClusterStateListener {
 
@@ -161,7 +162,7 @@ public class TrainedModelAllocationClusterService implements ClusterStateListene
         clusterService.submitStateUpdateTask("set model allocation stopping", new ClusterStateUpdateTask() {
             @Override
             public ClusterState execute(ClusterState currentState) {
-                return setToStopping(currentState, modelId);
+                return setToStopping(currentState, modelId, "client API call");
             }
 
             @Override
@@ -215,12 +216,8 @@ public class TrainedModelAllocationClusterService implements ClusterStateListene
         });
     }
 
-    private static ClusterState update(
-        ClusterState currentState,
-        TrainedModelAllocationMetadata.Builder modelAllocations,
-        boolean force
-    ) {
-        if (force || modelAllocations.isChanged()) {
+    private static ClusterState update(ClusterState currentState, TrainedModelAllocationMetadata.Builder modelAllocations) {
+        if (modelAllocations.isChanged()) {
             return ClusterState.builder(currentState)
                 .metadata(
                     Metadata.builder(currentState.metadata()).putCustom(TrainedModelAllocationMetadata.NAME, modelAllocations.build())
@@ -231,10 +228,6 @@ public class TrainedModelAllocationClusterService implements ClusterStateListene
         }
     }
 
-    private static ClusterState update(ClusterState currentState, TrainedModelAllocationMetadata.Builder modelAllocations) {
-        return update(currentState, modelAllocations, false);
-    }
-
     ClusterState createModelAllocation(ClusterState currentState, StartTrainedModelDeploymentAction.TaskParams params) {
         if (MlMetadata.getMlMetadata(currentState).isResetMode()) {
             throw new ElasticsearchStatusException(
@@ -250,22 +243,31 @@ public class TrainedModelAllocationClusterService implements ClusterStateListene
         TrainedModelAllocation.Builder allocationBuilder = TrainedModelAllocation.Builder.empty(params);
 
         Set<String> shuttingDownNodes = nodesShuttingDown(currentState);
+        Map<String, String> nodeToReason = new TreeMap<>();
         for (DiscoveryNode node : currentState.getNodes().getAllNodes()) {
             if (StartTrainedModelDeploymentAction.TaskParams.mayAllocateToNode(node)
                 && shuttingDownNodes.contains(node.getId()) == false) {
                 Optional<String> maybeError = nodeHasCapacity(currentState, params, node);
                 if (maybeError.isPresent()) {
-                    allocationBuilder.addNewFailedRoutingEntry(node.getId(), maybeError.get());
+                    nodeToReason.put(node.getName(), maybeError.get());
                 } else {
                     allocationBuilder.addNewRoutingEntry(node.getId());
                 }
             }
         }
+        if (nodeToReason.isEmpty() == false) {
+            allocationBuilder.setReason(
+                nodeToReason.entrySet()
+                    .stream()
+                    .map(entry -> String.format(Locale.ROOT, "Not allocating on node [%s]. Reason: %s", entry.getKey(), entry.getValue()))
+                    .collect(Collectors.joining("|"))
+            );
+        }
         builder.addNewAllocation(params.getModelId(), allocationBuilder);
         return update(currentState, builder);
     }
 
-    static ClusterState setToStopping(ClusterState clusterState,  String modelId) {
+    static ClusterState setToStopping(ClusterState clusterState, String modelId, String reason) {
         TrainedModelAllocationMetadata metadata = TrainedModelAllocationMetadata.fromState(clusterState);
         final TrainedModelAllocation existingAllocation = metadata.getModelAllocation(modelId);
         if (existingAllocation == null) {
@@ -276,8 +278,8 @@ public class TrainedModelAllocationClusterService implements ClusterStateListene
             return clusterState;
         }
         TrainedModelAllocationMetadata.Builder builder = TrainedModelAllocationMetadata.builder(clusterState);
-        final boolean isChanged = builder.getAllocation(modelId).stopAllocation().isChanged();
-        return update(clusterState, builder, isChanged);
+        builder.getAllocation(modelId).stopAllocation(reason);
+        return update(clusterState, builder);
     }
 
     static ClusterState updateModelRoutingTable(ClusterState currentState, UpdateTrainedModelAllocationStateAction.Request request) {
@@ -287,6 +289,14 @@ public class TrainedModelAllocationClusterService implements ClusterStateListene
         logger.trace(
             () -> new ParameterizedMessage("[{}] [{}] current metadata before update {}", modelId, nodeId, Strings.toString(metadata))
         );
+        Set<String> shuttingDownNodes = nodesShuttingDown(currentState);
+        List<DiscoveryNode> allocatableNodes = currentState.nodes()
+            .getAllNodes()
+            .stream()
+            .filter(
+                d -> StartTrainedModelDeploymentAction.TaskParams.mayAllocateToNode(d) && shuttingDownNodes.contains(d.getId()) == false
+            )
+            .collect(Collectors.toList());
         final TrainedModelAllocation existingAllocation = metadata.getModelAllocation(modelId);
         final TrainedModelAllocationMetadata.Builder builder =  TrainedModelAllocationMetadata.builder(currentState);
         // If state is stopped, this indicates the node process is closed, remove the node from the allocation
@@ -294,8 +304,8 @@ public class TrainedModelAllocationClusterService implements ClusterStateListene
             if (existingAllocation == null || existingAllocation.isRoutedToNode(nodeId) == false) {
                 return currentState;
             }
-            final boolean isChanged = builder.getAllocation(modelId).removeRoutingEntry(nodeId).isChanged();
-            return update(currentState, builder, isChanged);
+            builder.getAllocation(modelId).removeRoutingEntry(nodeId).calculateAndSetAllocationState();
+            return update(currentState, builder);
         }
 
         if (existingAllocation == null) {
@@ -314,8 +324,11 @@ public class TrainedModelAllocationClusterService implements ClusterStateListene
         if (existingAllocation.isRoutedToNode(nodeId) == false) {
             throw new ResourceNotFoundException("allocation for model with id [{}]] is not routed to node [{}]", modelId, nodeId);
         }
-        final boolean isChanged = builder.getAllocation(modelId).updateExistingRoutingEntry(nodeId, request.getRoutingState()).isChanged();
-        return update(currentState, builder, isChanged);
+        builder.getAllocation(modelId)
+            .updateExistingRoutingEntry(nodeId, request.getRoutingState())
+            .calculateAndSetAllocationState();
+
+        return update(currentState, builder);
     }
 
     static ClusterState removeAllocation(ClusterState currentState, String modelId) {
@@ -342,18 +355,24 @@ public class TrainedModelAllocationClusterService implements ClusterStateListene
     ClusterState addRemoveAllocationNodes(ClusterState currentState) {
         final TrainedModelAllocationMetadata previousState = TrainedModelAllocationMetadata.fromState(currentState);
         final TrainedModelAllocationMetadata.Builder builder = TrainedModelAllocationMetadata.builder(currentState);
-        Map<String, List<String>> removedNodeModelLookUp = new HashMap<>();
         Set<String> shuttingDownNodes = nodesShuttingDown(currentState);
+        Set<String> currentNotShuttingDownNodes = currentState.getNodes()
+            .getAllNodes()
+            .stream()
+            .map(DiscoveryNode::getId)
+            .filter(id -> shuttingDownNodes.contains(id) == false)
+            .collect(Collectors.toSet());
         // TODO: make more efficient, right now this is O(nm) where n = sizeof(models) and m = sizeof(nodes)
         // It could probably be O(max(n, m))
         // Add nodes and keep track of currently routed nodes
         // Should we indicate a partial allocation somehow if some nodes don't have space?
-        boolean isChanged = false;
         for (Map.Entry<String, TrainedModelAllocation> modelAllocationEntry : previousState.modelAllocations().entrySet()) {
             // Don't bother adding/removing nodes if this allocation is stopping
             if (modelAllocationEntry.getValue().getAllocationState().equals(AllocationState.STOPPING)) {
                 continue;
             }
+            final String modelId = modelAllocationEntry.getKey();
+            Map<String, String> nodeToReason = new TreeMap<>();
             for (DiscoveryNode node : currentState.getNodes()) {
                 // Only add the route if the node is NOT shutting down, this would be a weird case of the node
                 // just being added to the cluster and immediately shutting down...
@@ -362,39 +381,40 @@ public class TrainedModelAllocationClusterService implements ClusterStateListene
                     && modelAllocationEntry.getValue().isRoutedToNode(node.getId()) == false) {
                     Optional<String> failure = nodeHasCapacity(currentState, modelAllocationEntry.getValue().getTaskParams(), node);
                     if (failure.isPresent()) {
-                        isChanged |= builder.getAllocation(modelAllocationEntry.getKey())
-                            .addNewFailedRoutingEntry(node.getId(), failure.get())
-                            .isChanged();
+                        nodeToReason.put(node.getName(), failure.get());
                     } else {
-                        isChanged |= builder.getAllocation(modelAllocationEntry.getKey()).addNewRoutingEntry(node.getId()).isChanged();
+                        builder.getAllocation(modelId).addNewRoutingEntry(node.getId());
                     }
                 }
             }
-            for (String nodeId : modelAllocationEntry.getValue().getNodeRoutingTable().keySet()) {
-                removedNodeModelLookUp.computeIfAbsent(nodeId, k -> new ArrayList<>()).add(modelAllocationEntry.getKey());
+            if (nodeToReason.isEmpty() == false) {
+                builder.getAllocation(modelId)
+                    .setReason(
+                        nodeToReason.entrySet()
+                            .stream()
+                            .map(
+                                entry -> String.format(
+                                    Locale.ROOT,
+                                    "Not allocating on node [%s]. Reason: %s",
+                                    entry.getKey(),
+                                    entry.getValue()
+                                )
+                            )
+                            .collect(Collectors.joining("|"))
+                    );
+            } else {
+                builder.getAllocation(modelId).clearReason();
             }
-        }
-
-        // Remove nodes
-        currentState.getNodes()
-            .forEach(
-                d -> {
-                    // If a node is referenced in the current state, we shouldn't remove the node
-                    // But, if that node that is referenced is shutting down, we should remove the node
-                    if (shuttingDownNodes.contains(d.getId()) == false) {
-                        removedNodeModelLookUp.remove(d.getId());
-                    }
+            for (String nodeId : modelAllocationEntry.getValue().getNodeRoutingTable().keySet()) {
+                if (currentNotShuttingDownNodes.contains(nodeId) == false) {
+                    builder.getAllocation(modelId).removeRoutingEntry(nodeId);
                 }
-            );
-        for (Map.Entry<String, List<String>> nodeToModels : removedNodeModelLookUp.entrySet()) {
-            final String nodeId = nodeToModels.getKey();
-            for (String modelId : nodeToModels.getValue()) {
-                isChanged |= Optional.ofNullable(builder.getAllocation(modelId))
-                    .map(allocation -> allocation.removeRoutingEntry(nodeId).isChanged())
-                    .orElse(false);
             }
+            // It may be we moved from STARTED to PARTIALLY_STARTED with the addition of new nodes
+            // Or moved from PARTIALLY_STARTED to STARTED if a node was removed
+            builder.getAllocation(modelId).calculateAndSetAllocationState();
         }
-        return update(currentState, builder, isChanged);
+        return update(currentState, builder);
     }
 
     static boolean shouldAllocateModels(final ClusterChangedEvent event) {
@@ -453,8 +473,7 @@ public class TrainedModelAllocationClusterService implements ClusterStateListene
     }
 
     /**
-     * Returns true if the given node is marked as shutting down with any
-     * shutdown type.
+     * Returns the set of nodes that are currently shutting down
      */
     static Set<String> nodesShuttingDown(final ClusterState state) {
         return NodesShutdownMetadata.getShutdowns(state)

+ 1 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationMetadata.java

@@ -174,7 +174,7 @@ public class TrainedModelAllocationMetadata implements Metadata.Custom {
         }
 
         public boolean isChanged() {
-            return isChanged;
+            return isChanged || modelRoutingEntries.values().stream().anyMatch(TrainedModelAllocation.Builder::isChanged);
         }
 
         public TrainedModelAllocationMetadata build() {

+ 3 - 6
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationService.java

@@ -96,17 +96,14 @@ public class TrainedModelAllocationService {
 
     public void waitForAllocationCondition(
         final String modelId,
-        final Predicate<TrainedModelAllocation> predicate,
+        final Predicate<ClusterState> predicate,
         final @Nullable TimeValue timeout,
         final WaitForAllocationListener listener
     ) {
-        final Predicate<ClusterState> clusterStatePredicate = clusterState -> predicate.test(
-            TrainedModelAllocationMetadata.allocationForModelId(clusterState, modelId).orElse(null)
-        );
 
         final ClusterStateObserver observer = new ClusterStateObserver(clusterService, timeout, logger, threadPool.getThreadContext());
         final ClusterState clusterState = observer.setAndGetObservedState();
-        if (clusterStatePredicate.test(clusterState)) {
+        if (predicate.test(clusterState)) {
             listener.onResponse(TrainedModelAllocationMetadata.allocationForModelId(clusterState, modelId).orElse(null));
         } else {
             observer.waitForNextChange(new ClusterStateObserver.Listener() {
@@ -124,7 +121,7 @@ public class TrainedModelAllocationService {
                 public void onTimeout(TimeValue timeout) {
                     listener.onTimeout(timeout);
                 }
-            }, clusterStatePredicate);
+            }, predicate);
         }
     }
 

+ 7 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestStartTrainedModelDeploymentAction.java

@@ -13,6 +13,7 @@ import org.elasticsearch.rest.BaseRestHandler;
 import org.elasticsearch.rest.RestRequest;
 import org.elasticsearch.rest.action.RestToXContentListener;
 import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
+import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationStatus;
 import org.elasticsearch.xpack.ml.MachineLearning;
 
 import java.io.IOException;
@@ -45,6 +46,12 @@ public class RestStartTrainedModelDeploymentAction extends BaseRestHandler {
                 StartTrainedModelDeploymentAction.DEFAULT_TIMEOUT);
             request.setTimeout(openTimeout);
         }
+        request.setWaitForState(AllocationStatus.State.fromString(
+            restRequest.param(
+                StartTrainedModelDeploymentAction.Request.WAIT_FOR.getPreferredName(),
+                AllocationStatus.State.STARTED.toString()
+            )
+        ));
 
         return channel -> client.execute(StartTrainedModelDeploymentAction.INSTANCE, request, new RestToXContentListener<>(channel));
     }

+ 56 - 26
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationClusterServiceTests.java

@@ -76,8 +76,14 @@ public class TrainedModelAllocationClusterServiceTests extends ESTestCase {
     public void testUpdateModelRoutingTable() {
         String modelId = "existing-model";
         String nodeId = "ml-node-with-room";
+        String startedNode = "started-ml-node-with-room";
         ClusterState currentState = ClusterState.builder(new ClusterName("testUpdateModelRoutingTable"))
-            .nodes(DiscoveryNodes.builder().add(buildNode("ml-node-with-room", true, ByteSizeValue.ofGb(4).getBytes())).build())
+            .nodes(
+                DiscoveryNodes.builder()
+                    .add(buildNode(nodeId, true, ByteSizeValue.ofGb(4).getBytes()))
+                    .add(buildNode(startedNode, true, ByteSizeValue.ofGb(4).getBytes()))
+                    .build()
+            )
             .metadata(
                 Metadata.builder()
                     .putCustom(
@@ -86,6 +92,7 @@ public class TrainedModelAllocationClusterServiceTests extends ESTestCase {
                             .addNewAllocation(
                                 modelId,
                                 TrainedModelAllocation.Builder.empty(newParams(modelId, 10_000L)).addNewRoutingEntry(nodeId)
+                                    .addNewRoutingEntry(startedNode)
                             )
                             .build()
                     )
@@ -96,24 +103,37 @@ public class TrainedModelAllocationClusterServiceTests extends ESTestCase {
         assertThatStoppingAllocationPreventsMutation(
             state -> TrainedModelAllocationClusterService.updateModelRoutingTable(
                 state,
-                new UpdateTrainedModelAllocationStateAction.Request(nodeId, modelId, new RoutingStateAndReason(RoutingState.STARTED, ""))
+                new UpdateTrainedModelAllocationStateAction.Request(nodeId, modelId, started())
             ),
             currentState
         );
 
+        assertThat(
+            TrainedModelAllocationMetadata.fromState(currentState).getModelAllocation(modelId).getAllocationState(),
+            equalTo(AllocationState.STARTING)
+        );
+
         ClusterState newState = TrainedModelAllocationClusterService.updateModelRoutingTable(
             currentState,
-            new UpdateTrainedModelAllocationStateAction.Request(nodeId, modelId, new RoutingStateAndReason(RoutingState.STARTED, ""))
+            new UpdateTrainedModelAllocationStateAction.Request(startedNode, modelId, started())
         );
         assertThat(
-            TrainedModelAllocationMetadata.fromState(newState).getModelAllocation(modelId).getNodeRoutingTable().get(nodeId).getState(),
+            TrainedModelAllocationMetadata.fromState(newState)
+                .getModelAllocation(modelId)
+                .getNodeRoutingTable()
+                .get(startedNode)
+                .getState(),
             equalTo(RoutingState.STARTED)
         );
+        assertThat(
+            TrainedModelAllocationMetadata.fromState(newState).getModelAllocation(modelId).getAllocationState(),
+            equalTo(AllocationState.STARTED)
+        );
 
         expectThrows(
             ResourceNotFoundException.class,
             () -> TrainedModelAllocationClusterService.updateModelRoutingTable(
-                currentState,
+                newState,
                 new UpdateTrainedModelAllocationStateAction.Request(
                     "missingNode",
                     modelId,
@@ -124,7 +144,7 @@ public class TrainedModelAllocationClusterServiceTests extends ESTestCase {
         expectThrows(
             ResourceNotFoundException.class,
             () -> TrainedModelAllocationClusterService.updateModelRoutingTable(
-                currentState,
+                newState,
                 new UpdateTrainedModelAllocationStateAction.Request(
                     nodeId,
                     "missingModel",
@@ -137,22 +157,26 @@ public class TrainedModelAllocationClusterServiceTests extends ESTestCase {
 
         // We should allow a "stopped" update on missing models and nodes as entries may have already been deleted
         TrainedModelAllocationClusterService.updateModelRoutingTable(
-            currentState,
+            newState,
             new UpdateTrainedModelAllocationStateAction.Request("missingNode", modelId, new RoutingStateAndReason(RoutingState.STOPPED, ""))
         );
         TrainedModelAllocationClusterService.updateModelRoutingTable(
-            currentState,
+            newState,
             new UpdateTrainedModelAllocationStateAction.Request(nodeId, "missingModel", new RoutingStateAndReason(RoutingState.STOPPED, ""))
         );
 
         ClusterState updateState = TrainedModelAllocationClusterService.updateModelRoutingTable(
-            currentState,
+            newState,
             new UpdateTrainedModelAllocationStateAction.Request(nodeId, modelId, new RoutingStateAndReason(RoutingState.STOPPED, ""))
         );
         assertThat(
             TrainedModelAllocationMetadata.fromState(updateState).getModelAllocation(modelId).getNodeRoutingTable(),
             not(hasKey(nodeId))
         );
+        assertThat(
+            TrainedModelAllocationMetadata.fromState(updateState).getModelAllocation(modelId).getAllocationState(),
+            equalTo(AllocationState.STARTED)
+        );
     }
 
     public void testRemoveAllocation() {
@@ -226,15 +250,15 @@ public class TrainedModelAllocationClusterServiceTests extends ESTestCase {
         TrainedModelAllocation createdAllocation = TrainedModelAllocationMetadata.fromState(newState).getModelAllocation("new-model");
 
         assertThat(createdAllocation, is(not(nullValue())));
-        assertThat(createdAllocation.getNodeRoutingTable().keySet(), hasSize(2));
+        assertThat(createdAllocation.getNodeRoutingTable().keySet(), hasSize(1));
         assertThat(createdAllocation.getNodeRoutingTable(), hasKey("ml-node-with-room"));
         assertThat(createdAllocation.getNodeRoutingTable().get("ml-node-with-room").getState(), equalTo(RoutingState.STARTING));
-        assertThat(createdAllocation.getNodeRoutingTable(), hasKey("ml-node-without-room"));
-        assertThat(createdAllocation.getNodeRoutingTable().get("ml-node-without-room").getState(), equalTo(RoutingState.FAILED));
+        assertThat(createdAllocation.getReason().isPresent(), is(true));
         assertThat(
-            createdAllocation.getNodeRoutingTable().get("ml-node-without-room").getReason(),
-            containsString("This node has insufficient available memory.")
+            createdAllocation.getReason().get(),
+            containsString("Not allocating on node [ml-node-without-room]")
         );
+        assertThat(createdAllocation.getAllocationState(), equalTo(AllocationState.STARTING));
 
         expectThrows(
             ResourceAlreadyExistsException.class,
@@ -319,23 +343,29 @@ public class TrainedModelAllocationClusterServiceTests extends ESTestCase {
         assertThat(trainedModelAllocationMetadata.modelAllocations().keySet(), hasSize(2));
         assertThat(trainedModelAllocationMetadata.modelAllocations(), allOf(hasKey("model-1"), hasKey("model-2")));
 
-        assertThat(trainedModelAllocationMetadata.getModelAllocation("model-1").getNodeRoutingTable().keySet(), hasSize(3));
+        assertThat(trainedModelAllocationMetadata.getModelAllocation("model-1").getNodeRoutingTable().keySet(), hasSize(2));
         assertThat(
             trainedModelAllocationMetadata.getModelAllocation("model-1").getNodeRoutingTable(),
-            allOf(hasKey("ml-node-with-room"), hasKey("new-ml-node-with-room"), hasKey("ml-node-without-room"))
+            allOf(hasKey("ml-node-with-room"), hasKey("new-ml-node-with-room"))
         );
         assertNodeState(trainedModelAllocationMetadata, "model-1", "ml-node-with-room", RoutingState.STARTED);
         assertNodeState(trainedModelAllocationMetadata, "model-1", "new-ml-node-with-room", RoutingState.STARTING);
-        assertNodeState(trainedModelAllocationMetadata, "model-1", "ml-node-without-room", RoutingState.FAILED);
+        assertThat(
+            trainedModelAllocationMetadata.modelAllocations().get("model-1").getAllocationState(),
+            equalTo(AllocationState.STARTED)
+        );
 
-        assertThat(trainedModelAllocationMetadata.getModelAllocation("model-2").getNodeRoutingTable().keySet(), hasSize(3));
+        assertThat(trainedModelAllocationMetadata.getModelAllocation("model-2").getNodeRoutingTable().keySet(), hasSize(2));
         assertThat(
             trainedModelAllocationMetadata.getModelAllocation("model-2").getNodeRoutingTable(),
-            allOf(hasKey("ml-node-with-room"), hasKey("new-ml-node-with-room"), hasKey("ml-node-without-room"))
+            allOf(hasKey("ml-node-with-room"), hasKey("new-ml-node-with-room"))
         );
         assertNodeState(trainedModelAllocationMetadata, "model-2", "ml-node-with-room", RoutingState.STARTING);
         assertNodeState(trainedModelAllocationMetadata, "model-2", "new-ml-node-with-room", RoutingState.STARTING);
-        assertNodeState(trainedModelAllocationMetadata, "model-2", "ml-node-without-room", RoutingState.FAILED);
+        assertThat(
+            trainedModelAllocationMetadata.modelAllocations().get("model-2").getAllocationState(),
+            equalTo(AllocationState.STARTING)
+        );
     }
 
     public void testShouldAllocateModels() {
@@ -518,7 +548,7 @@ public class TrainedModelAllocationClusterServiceTests extends ESTestCase {
                                     TrainedModelAllocationMetadata.Builder.empty()
                                         .addNewAllocation(
                                             model1,
-                                            TrainedModelAllocation.Builder.empty(newParams(model1, 100)).stopAllocation()
+                                            TrainedModelAllocation.Builder.empty(newParams(model1, 100)).stopAllocation("test")
                                         )
                                         .build()
                                 )
@@ -649,7 +679,7 @@ public class TrainedModelAllocationClusterServiceTests extends ESTestCase {
                                             TrainedModelAllocation.Builder.empty(newParams("model-2", 100))
                                                 .addNewRoutingEntry(mlNode1)
                                                 .addNewRoutingEntry(mlNode2)
-                                                .stopAllocation()
+                                                .stopAllocation("test")
                                         )
                                         .build()
                                 )
@@ -691,7 +721,7 @@ public class TrainedModelAllocationClusterServiceTests extends ESTestCase {
 
         expectThrows(
             ResourceNotFoundException.class,
-            () -> TrainedModelAllocationClusterService.setToStopping(clusterStateWithoutAllocation, modelId)
+            () -> TrainedModelAllocationClusterService.setToStopping(clusterStateWithoutAllocation, modelId, "test")
         );
 
         ClusterState clusterStateWithAllocation = ClusterState.builder(new ClusterName("testSetAllocationToStopping"))
@@ -708,9 +738,9 @@ public class TrainedModelAllocationClusterServiceTests extends ESTestCase {
             .build();
         TrainedModelAllocationMetadata before = TrainedModelAllocationMetadata.fromState(clusterStateWithAllocation);
         assertThat(before.getModelAllocation(modelId), is(not(nullValue())));
-        assertThat(before.getModelAllocation(modelId).getAllocationState(), equalTo(AllocationState.STARTED));
+        assertThat(before.getModelAllocation(modelId).getAllocationState(), equalTo(AllocationState.STARTING));
 
-        ClusterState modified = TrainedModelAllocationClusterService.setToStopping(clusterStateWithAllocation, modelId);
+        ClusterState modified = TrainedModelAllocationClusterService.setToStopping(clusterStateWithAllocation, modelId, "test");
         assertThat(
             TrainedModelAllocationMetadata.fromState(modified).getModelAllocation(modelId).getAllocationState(),
             equalTo(AllocationState.STOPPING)
@@ -727,7 +757,7 @@ public class TrainedModelAllocationClusterServiceTests extends ESTestCase {
         }
         TrainedModelAllocationMetadata.Builder builder = TrainedModelAllocationMetadata.builder(original);
         for (String modelId : tempMetadata.modelAllocations().keySet()) {
-            builder.getAllocation(modelId).stopAllocation();
+            builder.getAllocation(modelId).stopAllocation("test");
         }
         TrainedModelAllocationMetadata metadataWithStopping = builder.build();
         ClusterState originalWithStoppingAllocations = ClusterState.builder(original)

+ 14 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationMetadataTests.java

@@ -62,6 +62,20 @@ public class TrainedModelAllocationMetadataTests extends AbstractSerializingTest
         assertThat(builder.isChanged(), is(true));
     }
 
+    public void testBuilderChangedWhenAllocationChanged() {
+        String allocatedModelId = "test_model_id";
+        TrainedModelAllocationMetadata.Builder builder = TrainedModelAllocationMetadata.Builder.fromMetadata(
+            TrainedModelAllocationMetadata.Builder
+                .empty()
+                .addNewAllocation(allocatedModelId, TrainedModelAllocation.Builder.empty(randomParams(allocatedModelId)))
+                .build()
+        );
+        assertThat(builder.isChanged(), is(false));
+
+        builder.getAllocation(allocatedModelId).addNewRoutingEntry("new-node");
+        assertThat(builder.isChanged(), is(true));
+    }
+
     public void testIsAllocated() {
         String allocatedModelId = "test_model_id";
         TrainedModelAllocationMetadata metadata = TrainedModelAllocationMetadata.Builder.empty()