Browse Source

[ML] Add api to update trained model deployment number_of_allocations (#90728)

This commit adds a new API that users can use calling:

```
POST _ml/trained_models/{model_id}/deployment/_update
{
  "number_of_allocations": 4
}
```

This allows a user to update the number of allocations for a deployment
that is `started`.

If the allocations are increased we rebalance and let the assignment
planner find how to allocate the additional allocations.

If the allocations are decreased we cannot use the assignment planner.
Instead, we implement the reduction in a new class `AllocationReducer`
that tries to reduce the allocations so that:

  1. availability zone balance is maintained
  2. assignments that can be completely stopped are preferred to release memory
Dimitris Athanasiou 3 years ago
parent
commit
16bfc550ea
20 changed files with 1148 additions and 5 deletions
  1. 5 0
      docs/changelog/90728.yaml
  2. 2 0
      docs/reference/ml/trained-models/apis/index.asciidoc
  3. 1 0
      docs/reference/ml/trained-models/apis/ml-trained-models-apis.asciidoc
  4. 88 0
      docs/reference/ml/trained-models/apis/update-trained-model-deployment.asciidoc
  5. 34 0
      rest-api-spec/src/main/resources/rest-api-spec/api/ml.update_trained_model_deployment.json
  6. 142 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/UpdateTrainedModelDeploymentAction.java
  7. 17 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java
  8. 1 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java
  9. 67 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/UpdateTrainedModelDeploymentRequestTests.java
  10. 2 1
      x-pack/plugin/ml/qa/ml-with-security/build.gradle
  11. 89 3
      x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java
  12. 5 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java
  13. 100 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportUpdateTrainedModelDeploymentAction.java
  14. 145 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterService.java
  15. 164 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AllocationReducer.java
  16. 55 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestUpdateTrainedModelDeploymentAction.java
  17. 195 0
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AllocationReducerTests.java
  18. 1 0
      x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java
  19. 25 0
      x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/3rd_party_deployment.yml
  20. 10 0
      x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/update_trained_model_deployment.yml

+ 5 - 0
docs/changelog/90728.yaml

@@ -0,0 +1,5 @@
+pr: 90728
+summary: Add api to update trained model deployment `number_of_allocations`
+area: Machine Learning
+type: enhancement
+issues: []

+ 2 - 0
docs/reference/ml/trained-models/apis/index.asciidoc

@@ -17,3 +17,5 @@ include::infer-trained-model.asciidoc[leveloffset=+2]
 //START/STOP
 include::start-trained-model-deployment.asciidoc[leveloffset=+2]
 include::stop-trained-model-deployment.asciidoc[leveloffset=+2]
+//UPDATE
+include::update-trained-model-deployment.asciidoc[leveloffset=+2]

+ 1 - 0
docs/reference/ml/trained-models/apis/ml-trained-models-apis.asciidoc

@@ -16,6 +16,7 @@ You can use the following APIs to perform model management operations:
 * <<infer-trained-model>>
 * <<start-trained-model-deployment>>
 * <<stop-trained-model-deployment>>
+* <<update-trained-model-deployment>>
 
 You can deploy a trained model to make predictions in an ingest pipeline or in
 an aggregation. Refer to the following documentation to learn more:

+ 88 - 0
docs/reference/ml/trained-models/apis/update-trained-model-deployment.asciidoc

@@ -0,0 +1,88 @@
+[role="xpack"]
+[[update-trained-model-deployment]]
+= Update trained model deployment API
+
+[subs="attributes"]
+++++
+<titleabbrev>Update trained model deployment</titleabbrev>
+++++
+
+Updates certain properties of a trained model deployment.
+
+beta::[]
+
+[[update-trained-model-deployment-request]]
+== {api-request-title}
+
+`POST _ml/trained_models/<model_id>/deployment/_update`
+
+
+[[update-trained-model-deployments-prereqs]]
+== {api-prereq-title}
+
+Requires the `manage_ml` cluster privilege. This privilege is included in the
+`machine_learning_admin` built-in role.
+
+[[update-trained-model-deployment-desc]]
+== {api-description-title}
+
+You can update a trained model deployment whose `assignment_state` is `started`.
+You can either increase or decrease the number of allocations of such a deployment.
+
+[[update-trained-model-deployments-path-parms]]
+== {api-path-parms-title}
+
+`<model_id>`::
+(Required, string)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]
+
+[[update-trained-model-deployment-request-body]]
+== {api-request-body-title}
+
+`number_of_allocations`::
+(Optional, integer)
+The total number of allocations this model is assigned across {ml} nodes.
+Increasing this value generally increases the throughput.
+
+
+[[update-trained-model-deployment-example]]
+== {api-examples-title}
+
+The following example updates the deployment for a
+ `elastic__distilbert-base-uncased-finetuned-conll03-english` trained model to have 4 allocations:
+
+[source,console]
+--------------------------------------------------
+POST _ml/trained_models/elastic__distilbert-base-uncased-finetuned-conll03-english/deployment/_update
+{
+  "number_of_allocations": 4
+}
+--------------------------------------------------
+// TEST[skip:TBD]
+
+The API returns the following results:
+
+[source,console-result]
+----
+{
+    "assignment": {
+        "task_parameters": {
+            "model_id": "elastic__distilbert-base-uncased-finetuned-conll03-english",
+            "model_bytes": 265632637,
+            "threads_per_allocation" : 1,
+            "number_of_allocations" : 4,
+            "queue_capacity" : 1024
+        },
+        "routing_table": {
+            "uckeG3R8TLe2MMNBQ6AGrw": {
+                "current_allocations": 1,
+                "target_allocations": 4,
+                "routing_state": "started",
+                "reason": ""
+            }
+        },
+        "assignment_state": "started",
+        "start_time": "2022-11-02T11:50:34.766591Z"
+    }
+}
+----

+ 34 - 0
rest-api-spec/src/main/resources/rest-api-spec/api/ml.update_trained_model_deployment.json

@@ -0,0 +1,34 @@
+{
+  "ml.update_trained_model_deployment":{
+    "documentation":{
+      "url":"https://www.elastic.co/guide/en/elasticsearch/reference/current/ml-update-trained-model-deployment.html",
+      "description":"Updates certain properties of trained model deployment."
+    },
+    "stability":"beta",
+    "visibility":"public",
+    "headers":{
+      "accept": [ "application/json"],
+      "content_type": ["application/json"]
+    },
+    "url":{
+      "paths":[
+        {
+          "path":"/_ml/trained_models/{model_id}/deployment/_update",
+          "methods":[
+            "POST"
+          ],
+          "parts":{
+            "model_id":{
+              "type":"string",
+              "description":"The unique identifier of the trained model."
+            }
+          }
+        }
+      ]
+    },
+    "body":{
+      "description":"The updated trained model deployment settings",
+      "required":true
+    }
+  }
+}

+ 142 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/UpdateTrainedModelDeploymentAction.java

@@ -0,0 +1,142 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.action;
+
+import org.elasticsearch.action.ActionRequestValidationException;
+import org.elasticsearch.action.ActionType;
+import org.elasticsearch.action.support.master.AcknowledgedRequest;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.xcontent.ObjectParser;
+import org.elasticsearch.xcontent.ParseField;
+import org.elasticsearch.xcontent.ToXContentObject;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentParser;
+import org.elasticsearch.xpack.core.ml.job.messages.Messages;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+
+import java.io.IOException;
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.MODEL_ID;
+import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.NUMBER_OF_ALLOCATIONS;
+
+public class UpdateTrainedModelDeploymentAction extends ActionType<CreateTrainedModelAssignmentAction.Response> {
+
+    public static final UpdateTrainedModelDeploymentAction INSTANCE = new UpdateTrainedModelDeploymentAction();
+    public static final String NAME = "cluster:admin/xpack/ml/trained_models/deployment/update";
+
+    public UpdateTrainedModelDeploymentAction() {
+        super(NAME, CreateTrainedModelAssignmentAction.Response::new);
+    }
+
+    public static class Request extends AcknowledgedRequest<Request> implements ToXContentObject {
+
+        public static final ObjectParser<Request, Void> PARSER = new ObjectParser<>(NAME, Request::new);
+
+        public static final ParseField TIMEOUT = new ParseField("timeout");
+
+        static {
+            PARSER.declareString(Request::setModelId, MODEL_ID);
+            PARSER.declareInt(Request::setNumberOfAllocations, NUMBER_OF_ALLOCATIONS);
+            PARSER.declareString((r, val) -> r.timeout(TimeValue.parseTimeValue(val, TIMEOUT.getPreferredName())), TIMEOUT);
+        }
+
+        public static Request parseRequest(String modelId, XContentParser parser) {
+            Request request = PARSER.apply(parser, null);
+            if (request.getModelId() == null) {
+                request.setModelId(modelId);
+            } else if (Strings.isNullOrEmpty(modelId) == false && modelId.equals(request.getModelId()) == false) {
+                throw ExceptionsHelper.badRequestException(
+                    Messages.getMessage(Messages.INCONSISTENT_ID, MODEL_ID, request.getModelId(), modelId)
+                );
+            }
+            return request;
+        }
+
+        private String modelId;
+        private int numberOfAllocations;
+
+        private Request() {}
+
+        public Request(String modelId) {
+            setModelId(modelId);
+        }
+
+        public Request(StreamInput in) throws IOException {
+            super(in);
+            modelId = in.readString();
+            numberOfAllocations = in.readVInt();
+        }
+
+        public final void setModelId(String modelId) {
+            this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
+        }
+
+        public String getModelId() {
+            return modelId;
+        }
+
+        public void setNumberOfAllocations(int numberOfAllocations) {
+            this.numberOfAllocations = numberOfAllocations;
+        }
+
+        public int getNumberOfAllocations() {
+            return numberOfAllocations;
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            super.writeTo(out);
+            out.writeString(modelId);
+            out.writeVInt(numberOfAllocations);
+        }
+
+        @Override
+        public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+            builder.startObject();
+            builder.field(MODEL_ID.getPreferredName(), modelId);
+            builder.field(NUMBER_OF_ALLOCATIONS.getPreferredName(), numberOfAllocations);
+            builder.endObject();
+            return builder;
+        }
+
+        @Override
+        public ActionRequestValidationException validate() {
+            ActionRequestValidationException validationException = new ActionRequestValidationException();
+            if (numberOfAllocations < 1) {
+                validationException.addValidationError("[" + NUMBER_OF_ALLOCATIONS + "] must be a positive integer");
+            }
+            return validationException.validationErrors().isEmpty() ? null : validationException;
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(modelId, numberOfAllocations);
+        }
+
+        @Override
+        public boolean equals(Object obj) {
+            if (this == obj) {
+                return true;
+            }
+            if (obj == null || obj.getClass() != getClass()) {
+                return false;
+            }
+            Request other = (Request) obj;
+            return Objects.equals(modelId, other.modelId) && numberOfAllocations == other.numberOfAllocations;
+        }
+
+        @Override
+        public String toString() {
+            return Strings.toString(this);
+        }
+    }
+}

+ 17 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java

@@ -233,6 +233,10 @@ public class TrainedModelAssignment implements SimpleDiffable<TrainedModelAssign
         return nodeRoutingTable.values().stream().mapToInt(RoutingInfo::getCurrentAllocations).sum();
     }
 
+    public int totalTargetAllocations() {
+        return nodeRoutingTable.values().stream().mapToInt(RoutingInfo::getTargetAllocations).sum();
+    }
+
     @Override
     public boolean equals(Object o) {
         if (this == o) return true;
@@ -292,7 +296,7 @@ public class TrainedModelAssignment implements SimpleDiffable<TrainedModelAssign
 
     public static class Builder {
         private final Map<String, RoutingInfo> nodeRoutingTable;
-        private final StartTrainedModelDeploymentAction.TaskParams taskParams;
+        private StartTrainedModelDeploymentAction.TaskParams taskParams;
         private AssignmentState assignmentState;
         private String reason;
         private Instant startTime;
@@ -426,6 +430,18 @@ public class TrainedModelAssignment implements SimpleDiffable<TrainedModelAssign
             return this;
         }
 
+        public Builder setNumberOfAllocations(int numberOfAllocations) {
+            this.taskParams = new StartTrainedModelDeploymentAction.TaskParams(
+                taskParams.getModelId(),
+                taskParams.getModelBytes(),
+                taskParams.getThreadsPerAllocation(),
+                numberOfAllocations,
+                taskParams.getQueueCapacity(),
+                taskParams.getCacheSize().orElse(null)
+            );
+            return this;
+        }
+
         public TrainedModelAssignment build() {
             return new TrainedModelAssignment(taskParams, nodeRoutingTable, assignmentState, reason, startTime, maxAssignedAllocations);
         }

+ 1 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java

@@ -132,6 +132,7 @@ public final class Messages {
     public static final String INFERENCE_DEPLOYMENT_STARTED = "Started deployment";
     public static final String INFERENCE_DEPLOYMENT_STOPPED = "Stopped deployment";
     public static final String INFERENCE_DEPLOYMENT_REBALANCED = "Rebalanced trained model allocations because [{0}]";
+    public static final String INFERENCE_DEPLOYMENT_UPDATED_NUMBER_OF_ALLOCATIONS = "Updated number_of_allocations to [{0}]";
 
     public static final String INVALID_MODEL_ALIAS = "Invalid model_alias; ''{0}'' can contain lowercase alphanumeric (a-z and 0-9), "
         + "hyphens or underscores; must start with alphanumeric and cannot end with numbers";

+ 67 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/UpdateTrainedModelDeploymentRequestTests.java

@@ -0,0 +1,67 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.action;
+
+import org.elasticsearch.action.ActionRequestValidationException;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.test.AbstractXContentSerializingTestCase;
+import org.elasticsearch.xcontent.XContentParser;
+import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelDeploymentAction.Request;
+
+import java.io.IOException;
+
+import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.not;
+import static org.hamcrest.Matchers.nullValue;
+
+public class UpdateTrainedModelDeploymentRequestTests extends AbstractXContentSerializingTestCase<Request> {
+
+    @Override
+    protected Request doParseInstance(XContentParser parser) throws IOException {
+        return Request.parseRequest(null, parser);
+    }
+
+    @Override
+    protected Writeable.Reader<Request> instanceReader() {
+        return Request::new;
+    }
+
+    @Override
+    protected Request createTestInstance() {
+        return createRandom();
+    }
+
+    public static Request createRandom() {
+        Request request = new Request(randomAlphaOfLength(10));
+        if (randomBoolean()) {
+            request.setNumberOfAllocations(randomIntBetween(1, 512));
+        }
+        return request;
+    }
+
+    public void testValidate_GivenNumberOfAllocationsIsZero() {
+        Request request = createRandom();
+        request.setNumberOfAllocations(0);
+
+        ActionRequestValidationException e = request.validate();
+
+        assertThat(e, is(not(nullValue())));
+        assertThat(e.getMessage(), containsString("[number_of_allocations] must be a positive integer"));
+    }
+
+    public void testValidate_GivenNumberOfAllocationsIsNegative() {
+        Request request = createRandom();
+        request.setNumberOfAllocations(randomIntBetween(-100, -1));
+
+        ActionRequestValidationException e = request.validate();
+
+        assertThat(e, is(not(nullValue())));
+        assertThat(e.getMessage(), containsString("[number_of_allocations] must be a positive integer"));
+    }
+}

+ 2 - 1
x-pack/plugin/ml/qa/ml-with-security/build.gradle

@@ -234,7 +234,8 @@ tasks.named("yamlRestTest").configure {
     'ml/set_upgrade_mode/Attempt to open job when upgrade_mode is enabled',
     'ml/set_upgrade_mode/Setting upgrade_mode to enabled',
     'ml/set_upgrade_mode/Setting upgrade mode to disabled from enabled',
-    'ml/set_upgrade_mode/Test setting upgrade_mode to false when it is already false'
+    'ml/set_upgrade_mode/Test setting upgrade_mode to false when it is already false',
+    'ml/update_trained_model_deployment/Test with unknown model id'
   ].join(',')
 }
 

+ 89 - 3
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java

@@ -877,14 +877,94 @@ public class PyTorchModelIT extends ESRestTestCase {
         stopDeployment(modelId2);
     }
 
-    @SuppressWarnings("unchecked")
+    public void testUpdateDeployment_GivenMissingModel() throws IOException {
+        ResponseException ex = expectThrows(ResponseException.class, () -> updateDeployment("missing", 4));
+        assertThat(ex.getResponse().getStatusLine().getStatusCode(), equalTo(404));
+        assertThat(EntityUtils.toString(ex.getResponse().getEntity()), containsString("deployment for model with id [missing] not found"));
+    }
+
+    public void testUpdateDeployment_GivenAllocationsAreIncreased() throws Exception {
+        String modelId = "update_deployment_allocations_increased";
+        createTrainedModel(modelId);
+        putModelDefinition(modelId);
+        putVocabulary(List.of("these", "are", "my", "words"), modelId);
+        startDeployment(modelId);
+
+        assertAllocationCount(modelId, 1);
+
+        updateDeployment(modelId, 2);
+
+        assertBusy(() -> assertAllocationCount(modelId, 2));
+    }
+
+    public void testUpdateDeployment_GivenAllocationsAreIncreasedOverResources_AndScalingIsPossible() throws Exception {
+        Request maxLazyNodeSetting = new Request("PUT", "_cluster/settings");
+        maxLazyNodeSetting.setJsonEntity("""
+            {"persistent" : {
+                    "xpack.ml.max_lazy_ml_nodes": 5
+                }}""");
+        client().performRequest(maxLazyNodeSetting);
+
+        String modelId = "update_deployment_allocations_increased_scaling_possible";
+        createTrainedModel(modelId);
+        putModelDefinition(modelId);
+        putVocabulary(List.of("these", "are", "my", "words"), modelId);
+        startDeployment(modelId);
+
+        assertAllocationCount(modelId, 1);
+
+        updateDeployment(modelId, 42);
+
+        assertBusy(() -> {
+            int allocationCount = getAllocationCount(modelId);
+            assertThat(allocationCount, greaterThanOrEqualTo(2));
+        });
+    }
+
+    public void testUpdateDeployment_GivenAllocationsAreIncreasedOverResources_AndScalingIsNotPossible() throws Exception {
+        String modelId = "update_deployment_allocations_increased_scaling_not_possible";
+        createTrainedModel(modelId);
+        putModelDefinition(modelId);
+        putVocabulary(List.of("these", "are", "my", "words"), modelId);
+        startDeployment(modelId);
+
+        assertAllocationCount(modelId, 1);
+
+        ResponseException ex = expectThrows(ResponseException.class, () -> updateDeployment(modelId, 257));
+        assertThat(ex.getResponse().getStatusLine().getStatusCode(), equalTo(429));
+        assertThat(
+            EntityUtils.toString(ex.getResponse().getEntity()),
+            containsString("Could not update deployment because there are not enough resources to provide all requested allocations")
+        );
+        assertAllocationCount(modelId, 1);
+    }
+
+    public void testUpdateDeployment_GivenAllocationsAreDecreased() throws Exception {
+        String modelId = "update_deployment_allocations_decreased";
+        createTrainedModel(modelId);
+        putModelDefinition(modelId);
+        putVocabulary(List.of("these", "are", "my", "words"), modelId);
+        startDeployment(modelId, "started", 2, 1);
+
+        assertAllocationCount(modelId, 2);
+
+        updateDeployment(modelId, 1);
+
+        assertBusy(() -> assertAllocationCount(modelId, 1));
+    }
+
     private void assertAllocationCount(String modelId, int expectedAllocationCount) throws IOException {
+        int allocations = getAllocationCount(modelId);
+        assertThat(allocations, equalTo(expectedAllocationCount));
+    }
+
+    @SuppressWarnings("unchecked")
+    private int getAllocationCount(String modelId) throws IOException {
         Response response = getTrainedModelStats(modelId);
         var responseMap = entityAsMap(response);
         List<Map<String, Object>> stats = (List<Map<String, Object>>) responseMap.get("trained_model_stats");
         assertThat(stats, hasSize(1));
-        int allocations = (int) XContentMapValues.extractValue("deployment_stats.allocation_status.allocation_count", stats.get(0));
-        assertThat(allocations, equalTo(expectedAllocationCount));
+        return (int) XContentMapValues.extractValue("deployment_stats.allocation_status.allocation_count", stats.get(0));
     }
 
     private int sumInferenceCountOnNodes(List<Map<String, Object>> nodes) {
@@ -973,6 +1053,12 @@ public class PyTorchModelIT extends ESRestTestCase {
         client().performRequest(request);
     }
 
+    private Response updateDeployment(String modelId, int numberOfAllocations) throws IOException {
+        Request request = new Request("POST", "/_ml/trained_models/" + modelId + "/deployment/_update");
+        request.setJsonEntity("{\"number_of_allocations\":" + numberOfAllocations + "}");
+        return client().performRequest(request);
+    }
+
     private Response getTrainedModelStats(String modelId) throws IOException {
         Request request = new Request("GET", "/_ml/trained_models/" + modelId + "/_stats");
         return client().performRequest(request);

+ 5 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java

@@ -177,6 +177,7 @@ import org.elasticsearch.xpack.core.ml.action.UpdateJobAction;
 import org.elasticsearch.xpack.core.ml.action.UpdateModelSnapshotAction;
 import org.elasticsearch.xpack.core.ml.action.UpdateProcessAction;
 import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAssignmentRoutingInfoAction;
+import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelDeploymentAction;
 import org.elasticsearch.xpack.core.ml.action.UpgradeJobModelSnapshotAction;
 import org.elasticsearch.xpack.core.ml.action.ValidateDetectorAction;
 import org.elasticsearch.xpack.core.ml.action.ValidateJobConfigAction;
@@ -276,6 +277,7 @@ import org.elasticsearch.xpack.ml.action.TransportUpdateJobAction;
 import org.elasticsearch.xpack.ml.action.TransportUpdateModelSnapshotAction;
 import org.elasticsearch.xpack.ml.action.TransportUpdateProcessAction;
 import org.elasticsearch.xpack.ml.action.TransportUpdateTrainedModelAssignmentStateAction;
+import org.elasticsearch.xpack.ml.action.TransportUpdateTrainedModelDeploymentAction;
 import org.elasticsearch.xpack.ml.action.TransportUpgradeJobModelSnapshotAction;
 import org.elasticsearch.xpack.ml.action.TransportValidateDetectorAction;
 import org.elasticsearch.xpack.ml.action.TransportValidateJobConfigAction;
@@ -412,6 +414,7 @@ import org.elasticsearch.xpack.ml.rest.inference.RestPutTrainedModelDefinitionPa
 import org.elasticsearch.xpack.ml.rest.inference.RestPutTrainedModelVocabularyAction;
 import org.elasticsearch.xpack.ml.rest.inference.RestStartTrainedModelDeploymentAction;
 import org.elasticsearch.xpack.ml.rest.inference.RestStopTrainedModelDeploymentAction;
+import org.elasticsearch.xpack.ml.rest.inference.RestUpdateTrainedModelDeploymentAction;
 import org.elasticsearch.xpack.ml.rest.job.RestCloseJobAction;
 import org.elasticsearch.xpack.ml.rest.job.RestDeleteForecastAction;
 import org.elasticsearch.xpack.ml.rest.job.RestDeleteJobAction;
@@ -1306,6 +1309,7 @@ public class MachineLearning extends Plugin
             new RestStartTrainedModelDeploymentAction(),
             new RestStopTrainedModelDeploymentAction(),
             new RestInferTrainedModelDeploymentAction(),
+            new RestUpdateTrainedModelDeploymentAction(),
             new RestPutTrainedModelDefinitionPartAction(),
             new RestPutTrainedModelVocabularyAction(),
             new RestInferTrainedModelAction(),
@@ -1404,6 +1408,7 @@ public class MachineLearning extends Plugin
             new ActionHandler<>(StartTrainedModelDeploymentAction.INSTANCE, TransportStartTrainedModelDeploymentAction.class),
             new ActionHandler<>(StopTrainedModelDeploymentAction.INSTANCE, TransportStopTrainedModelDeploymentAction.class),
             new ActionHandler<>(InferTrainedModelDeploymentAction.INSTANCE, TransportInferTrainedModelDeploymentAction.class),
+            new ActionHandler<>(UpdateTrainedModelDeploymentAction.INSTANCE, TransportUpdateTrainedModelDeploymentAction.class),
             new ActionHandler<>(GetDeploymentStatsAction.INSTANCE, TransportGetDeploymentStatsAction.class),
             new ActionHandler<>(GetDatafeedRunningStateAction.INSTANCE, TransportGetDatafeedRunningStateAction.class),
             new ActionHandler<>(CreateTrainedModelAssignmentAction.INSTANCE, TransportCreateTrainedModelAssignmentAction.class),

+ 100 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportUpdateTrainedModelDeploymentAction.java

@@ -0,0 +1,100 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.action;
+
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.support.ActionFilters;
+import org.elasticsearch.action.support.master.TransportMasterNodeAction;
+import org.elasticsearch.cluster.ClusterState;
+import org.elasticsearch.cluster.block.ClusterBlockException;
+import org.elasticsearch.cluster.block.ClusterBlockLevel;
+import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
+import org.elasticsearch.cluster.service.ClusterService;
+import org.elasticsearch.common.inject.Inject;
+import org.elasticsearch.tasks.Task;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.transport.TransportService;
+import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
+import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelDeploymentAction;
+import org.elasticsearch.xpack.core.ml.job.messages.Messages;
+import org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentClusterService;
+import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
+
+import java.util.Objects;
+
+import static org.elasticsearch.core.Strings.format;
+
+public class TransportUpdateTrainedModelDeploymentAction extends TransportMasterNodeAction<
+    UpdateTrainedModelDeploymentAction.Request,
+    CreateTrainedModelAssignmentAction.Response> {
+
+    private static final Logger logger = LogManager.getLogger(TransportUpdateTrainedModelDeploymentAction.class);
+
+    private final TrainedModelAssignmentClusterService trainedModelAssignmentClusterService;
+    private final InferenceAuditor auditor;
+
+    @Inject
+    public TransportUpdateTrainedModelDeploymentAction(
+        TransportService transportService,
+        ClusterService clusterService,
+        ThreadPool threadPool,
+        ActionFilters actionFilters,
+        IndexNameExpressionResolver indexNameExpressionResolver,
+        TrainedModelAssignmentClusterService trainedModelAssignmentClusterService,
+        InferenceAuditor auditor
+    ) {
+        super(
+            UpdateTrainedModelDeploymentAction.NAME,
+            transportService,
+            clusterService,
+            threadPool,
+            actionFilters,
+            UpdateTrainedModelDeploymentAction.Request::new,
+            indexNameExpressionResolver,
+            CreateTrainedModelAssignmentAction.Response::new,
+            ThreadPool.Names.SAME
+        );
+        this.trainedModelAssignmentClusterService = Objects.requireNonNull(trainedModelAssignmentClusterService);
+        this.auditor = Objects.requireNonNull(auditor);
+    }
+
+    @Override
+    protected void masterOperation(
+        Task task,
+        UpdateTrainedModelDeploymentAction.Request request,
+        ClusterState state,
+        ActionListener<CreateTrainedModelAssignmentAction.Response> listener
+    ) throws Exception {
+        logger.debug(
+            () -> format(
+                "[%s] received request to update number of allocations to [%s]",
+                request.getModelId(),
+                request.getNumberOfAllocations()
+            )
+        );
+
+        trainedModelAssignmentClusterService.updateNumberOfAllocations(
+            request.getModelId(),
+            request.getNumberOfAllocations(),
+            ActionListener.wrap(updatedAssignment -> {
+                auditor.info(
+                    request.getModelId(),
+                    Messages.getMessage(Messages.INFERENCE_DEPLOYMENT_UPDATED_NUMBER_OF_ALLOCATIONS, request.getNumberOfAllocations())
+                );
+                listener.onResponse(new CreateTrainedModelAssignmentAction.Response(updatedAssignment));
+            }, listener::onFailure)
+        );
+    }
+
+    @Override
+    protected ClusterBlockException checkBlock(UpdateTrainedModelDeploymentAction.Request request, ClusterState state) {
+        return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE);
+    }
+}

+ 145 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterService.java

@@ -44,6 +44,7 @@ import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignme
 import org.elasticsearch.xpack.core.ml.job.messages.Messages;
 import org.elasticsearch.xpack.ml.MachineLearning;
 import org.elasticsearch.xpack.ml.autoscaling.NodeAvailabilityZoneMapper;
+import org.elasticsearch.xpack.ml.inference.assignment.planning.AllocationReducer;
 import org.elasticsearch.xpack.ml.job.NodeLoad;
 import org.elasticsearch.xpack.ml.job.NodeLoadDetector;
 import org.elasticsearch.xpack.ml.notifications.SystemAuditor;
@@ -555,6 +556,150 @@ public class TrainedModelAssignmentClusterService implements ClusterStateListene
             || (smallestMLNode.isPresent() && smallestMLNode.getAsLong() < maxMLNodeSize);
     }
 
+    public void updateNumberOfAllocations(String modelId, int numberOfAllocations, ActionListener<TrainedModelAssignment> listener) {
+        updateNumberOfAllocations(clusterService.state(), modelId, numberOfAllocations, listener);
+    }
+
+    private void updateNumberOfAllocations(
+        ClusterState clusterState,
+        String modelId,
+        int numberOfAllocations,
+        ActionListener<TrainedModelAssignment> listener
+    ) {
+        TrainedModelAssignmentMetadata metadata = TrainedModelAssignmentMetadata.fromState(clusterState);
+        final TrainedModelAssignment existingAssignment = metadata.getModelAssignment(modelId);
+        if (existingAssignment == null) {
+            throw new ResourceNotFoundException("deployment for model with id [{}] not found", modelId);
+        }
+        if (existingAssignment.getTaskParams().getNumberOfAllocations() == numberOfAllocations) {
+            listener.onResponse(existingAssignment);
+            return;
+        }
+        if (existingAssignment.getAssignmentState() != AssignmentState.STARTED) {
+            listener.onFailure(
+                new ElasticsearchStatusException(
+                    "cannot update deployment that is not in [{}] state",
+                    RestStatus.CONFLICT,
+                    AssignmentState.STARTED
+                )
+            );
+            return;
+        }
+        if (clusterState.nodes().getMinNodeVersion().before(DISTRIBUTED_MODEL_ALLOCATION_VERSION)) {
+            listener.onFailure(
+                new ElasticsearchStatusException(
+                    "cannot update number_of_allocations for deployment with model id [{}] while there are nodes older than version [{}]",
+                    RestStatus.CONFLICT,
+                    modelId,
+                    DISTRIBUTED_MODEL_ALLOCATION_VERSION
+                )
+            );
+            return;
+        }
+
+        ActionListener<ClusterState> updatedStateListener = ActionListener.wrap(
+            updatedState -> submitUnbatchedTask("update model deployment number_of_allocations", new ClusterStateUpdateTask() {
+
+                private volatile boolean isUpdated;
+
+                @Override
+                public ClusterState execute(ClusterState currentState) {
+                    if (areClusterStatesCompatibleForRebalance(clusterState, currentState)) {
+                        isUpdated = true;
+                        return updatedState;
+                    }
+                    logger.debug(() -> format("[%s] Retrying update as cluster state has been modified", modelId));
+                    updateNumberOfAllocations(currentState, modelId, numberOfAllocations, listener);
+                    return currentState;
+                }
+
+                @Override
+                public void onFailure(Exception e) {
+                    listener.onFailure(e);
+                }
+
+                @Override
+                public void clusterStateProcessed(ClusterState oldState, ClusterState newState) {
+                    if (isUpdated) {
+                        listener.onResponse(TrainedModelAssignmentMetadata.fromState(newState).getModelAssignment(modelId));
+                    }
+                }
+            }),
+            listener::onFailure
+        );
+
+        adjustNumberOfAllocations(clusterState, existingAssignment, numberOfAllocations, updatedStateListener);
+    }
+
+    private void adjustNumberOfAllocations(
+        ClusterState clusterState,
+        TrainedModelAssignment assignment,
+        int numberOfAllocations,
+        ActionListener<ClusterState> listener
+    ) {
+        threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(() -> {
+            if (numberOfAllocations > assignment.getTaskParams().getNumberOfAllocations()) {
+                increaseNumberOfAllocations(clusterState, assignment, numberOfAllocations, listener);
+            } else {
+                decreaseNumberOfAllocations(clusterState, assignment, numberOfAllocations, listener);
+            }
+        });
+    }
+
+    private void increaseNumberOfAllocations(
+        ClusterState clusterState,
+        TrainedModelAssignment assignment,
+        int numberOfAllocations,
+        ActionListener<ClusterState> listener
+    ) {
+        try {
+            final ClusterState updatedClusterState = update(
+                clusterState,
+                TrainedModelAssignmentMetadata.builder(clusterState)
+                    .updateAssignment(
+                        assignment.getModelId(),
+                        TrainedModelAssignment.Builder.fromAssignment(assignment).setNumberOfAllocations(numberOfAllocations)
+                    )
+            );
+            TrainedModelAssignmentMetadata.Builder rebalancedMetadata = rebalanceAssignments(updatedClusterState, Optional.empty());
+            if (isScalingPossible(getAssignableNodes(clusterState)) == false
+                && rebalancedMetadata.getAssignment(assignment.getModelId()).build().totalTargetAllocations() < numberOfAllocations) {
+                listener.onFailure(
+                    new ElasticsearchStatusException(
+                        "Could not update deployment because there are not enough resources to provide all requested allocations",
+                        RestStatus.TOO_MANY_REQUESTS
+                    )
+                );
+            } else {
+                listener.onResponse(update(clusterState, rebalancedMetadata));
+            }
+        } catch (Exception e) {
+            listener.onFailure(e);
+        }
+    }
+
+    private void decreaseNumberOfAllocations(
+        ClusterState clusterState,
+        TrainedModelAssignment assignment,
+        int numberOfAllocations,
+        ActionListener<ClusterState> listener
+    ) {
+        TrainedModelAssignment.Builder updatedAssignment = numberOfAllocations < assignment.totalTargetAllocations()
+            ? new AllocationReducer(assignment, nodeAvailabilityZoneMapper.buildMlNodesByAvailabilityZone(clusterState)).reduceTo(
+                numberOfAllocations
+            )
+            : TrainedModelAssignment.Builder.fromAssignment(assignment).setNumberOfAllocations(numberOfAllocations);
+
+        // We have now reduced allocations to a number we can be sure it is satisfied
+        // and thus we should clear the assignment reason.
+        if (numberOfAllocations <= assignment.totalTargetAllocations()) {
+            updatedAssignment.setReason(null);
+        }
+        TrainedModelAssignmentMetadata.Builder builder = TrainedModelAssignmentMetadata.builder(clusterState);
+        builder.updateAssignment(assignment.getModelId(), updatedAssignment);
+        listener.onResponse(update(clusterState, builder));
+    }
+
     static ClusterState setToStopping(ClusterState clusterState, String modelId, String reason) {
         TrainedModelAssignmentMetadata metadata = TrainedModelAssignmentMetadata.fromState(clusterState);
         final TrainedModelAssignment existingAssignment = metadata.getModelAssignment(modelId);

+ 164 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AllocationReducer.java

@@ -0,0 +1,164 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.inference.assignment.planning;
+
+import org.elasticsearch.cluster.node.DiscoveryNode;
+import org.elasticsearch.logging.LogManager;
+import org.elasticsearch.logging.Logger;
+import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingInfo;
+import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
+
+import java.util.Collection;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+/**
+ * Reduces the number of allocations of a {@link TrainedModelAssignment}.
+ */
+public class AllocationReducer {
+
+    private static final Logger logger = LogManager.getLogger(AllocationReducer.class);
+
+    private final TrainedModelAssignment assignment;
+    private final Map<List<String>, Set<String>> nodeIdsByZone;
+
+    public AllocationReducer(TrainedModelAssignment assignment, Map<List<String>, Collection<DiscoveryNode>> nodesByZone) {
+        this.assignment = Objects.requireNonNull(assignment);
+        this.nodeIdsByZone = nodesByZone.entrySet()
+            .stream()
+            .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().stream().map(DiscoveryNode::getId).collect(Collectors.toSet())));
+    }
+
+    public TrainedModelAssignment.Builder reduceTo(int numberOfAllocations) {
+        final Map<String, Integer> allocationsByNode = assignment.getNodeRoutingTable()
+            .entrySet()
+            .stream()
+            .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().getTargetAllocations()));
+
+        final Map<List<String>, Integer> allocationsByZone = nodeIdsByZone.entrySet()
+            .stream()
+            .collect(
+                Collectors.toMap(
+                    Map.Entry::getKey,
+                    e -> e.getValue().stream().mapToInt(nodeId -> allocationsByNode.getOrDefault(nodeId, 0)).sum()
+                )
+            );
+
+        int totalRemainingAllocations = allocationsByZone.values().stream().mapToInt(Integer::intValue).sum();
+
+        if (totalRemainingAllocations <= numberOfAllocations) {
+            String msg = "request to reduce allocations is greater than or equal to the existing target number of allocations";
+            throw new IllegalArgumentException(msg);
+        }
+
+        while (totalRemainingAllocations > numberOfAllocations) {
+            // While we reduce allocations there are 2 concerns:
+            // 1. Remove entire assignments when possible to free the used memory
+            // 2. Preserve balance across zones
+            // The way we achieve this here is simple. Each iteration, we pick the zone with the most allocations
+            // and find its smallest assignment. We then check if we can remove that assignment entirely, or we
+            // reduce its allocations by 1.
+
+            final int allocationsToRemove = totalRemainingAllocations - numberOfAllocations;
+            List<Map.Entry<List<String>, Integer>> allocationsPerZoneInAscendingOrder = allocationsByZone.entrySet()
+                .stream()
+                .sorted(Map.Entry.comparingByValue())
+                .toList();
+            if (allocationsPerZoneInAscendingOrder.isEmpty()) {
+                logger.warn("no allocations remain in any zone");
+                throw new IllegalStateException("no allocations remain in any zone");
+            }
+            final List<String> largestZone = allocationsPerZoneInAscendingOrder.get(allocationsPerZoneInAscendingOrder.size() - 1).getKey();
+            final int largestZoneAllocations = allocationsPerZoneInAscendingOrder.get(allocationsPerZoneInAscendingOrder.size() - 1)
+                .getValue();
+            final int minAllocationsInOtherZones = allocationsPerZoneInAscendingOrder.size() <= 1
+                ? 0
+                : allocationsPerZoneInAscendingOrder.get(0).getValue();
+
+            List<Map.Entry<String, Integer>> largestZoneAssignmentsInAscendingOrder = allocationsByNode.entrySet()
+                .stream()
+                .filter(e -> nodeIdsByZone.get(largestZone).contains(e.getKey()))
+                .sorted(Map.Entry.comparingByValue())
+                .toList();
+
+            if (largestZoneAssignmentsInAscendingOrder.isEmpty()) {
+                logger.warn("no assignments remain in the largest zone");
+                throw new IllegalStateException("no assignments remain in the largest zone");
+            }
+
+            Map.Entry<String, Integer> smallestAssignmentInLargestZone = largestZoneAssignmentsInAscendingOrder.get(0);
+            if (canAssignmentBeRemovedEntirely(
+                smallestAssignmentInLargestZone,
+                minAllocationsInOtherZones,
+                largestZoneAllocations,
+                allocationsToRemove
+            )) {
+                allocationsByNode.remove(smallestAssignmentInLargestZone.getKey());
+                allocationsByZone.computeIfPresent(largestZone, (k, v) -> v - smallestAssignmentInLargestZone.getValue());
+                totalRemainingAllocations -= smallestAssignmentInLargestZone.getValue();
+            } else {
+                allocationsByNode.computeIfPresent(smallestAssignmentInLargestZone.getKey(), (k, v) -> v - 1);
+                allocationsByZone.computeIfPresent(largestZone, (k, v) -> v - 1);
+                totalRemainingAllocations -= 1;
+            }
+        }
+
+        return buildUpdatedAssignment(numberOfAllocations, allocationsByNode);
+    }
+
+    private boolean canAssignmentBeRemovedEntirely(
+        Map.Entry<String, Integer> assignment,
+        int minAllocationsInOtherZones,
+        int zoneAllocations,
+        int allocationsToRemove
+    ) {
+        // Assignment has a single allocations so we should be able to remove it entirely
+        if (assignment.getValue() == 1) {
+            return true;
+        }
+
+        // Assignment has more allocations so we cannot remove it entirely.
+        if (assignment.getValue() > allocationsToRemove) {
+            return false;
+        }
+
+        // No allocations in other zones means we do not have to consider preserving balance of allocations across zones
+        if (minAllocationsInOtherZones == 0) {
+            return true;
+        }
+        // If we remove the allocations of the assignment from the zone and we still have as many allocations
+        // as the smallest of the other zones we're still fairly balanced.
+        return zoneAllocations - assignment.getValue() >= minAllocationsInOtherZones;
+    }
+
+    private TrainedModelAssignment.Builder buildUpdatedAssignment(int numberOfAllocations, Map<String, Integer> allocationsByNode) {
+        TrainedModelAssignment.Builder reducedAssignmentBuilder = TrainedModelAssignment.Builder.fromAssignment(assignment);
+        reducedAssignmentBuilder.setNumberOfAllocations(numberOfAllocations);
+        for (Map.Entry<String, RoutingInfo> routingEntries : assignment.getNodeRoutingTable().entrySet()) {
+            final String nodeId = routingEntries.getKey();
+            if (allocationsByNode.containsKey(nodeId)) {
+                final RoutingInfo existingRoutingInfo = routingEntries.getValue();
+                reducedAssignmentBuilder.updateExistingRoutingEntry(
+                    nodeId,
+                    new RoutingInfo(
+                        existingRoutingInfo.getCurrentAllocations(),
+                        allocationsByNode.get(nodeId),
+                        existingRoutingInfo.getState(),
+                        existingRoutingInfo.getReason()
+                    )
+                );
+            } else {
+                reducedAssignmentBuilder.removeRoutingEntry(nodeId);
+            }
+        }
+        return reducedAssignmentBuilder;
+    }
+}

+ 55 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestUpdateTrainedModelDeploymentAction.java

@@ -0,0 +1,55 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.rest.inference;
+
+import org.elasticsearch.client.internal.node.NodeClient;
+import org.elasticsearch.rest.BaseRestHandler;
+import org.elasticsearch.rest.RestRequest;
+import org.elasticsearch.rest.action.RestToXContentListener;
+import org.elasticsearch.xcontent.XContentParser;
+import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
+import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelDeploymentAction;
+import org.elasticsearch.xpack.ml.MachineLearning;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
+
+import static org.elasticsearch.rest.RestRequest.Method.POST;
+
+public class RestUpdateTrainedModelDeploymentAction extends BaseRestHandler {
+
+    @Override
+    public String getName() {
+        return "xpack_ml_update_trained_model_deployment_action";
+    }
+
+    @Override
+    public List<Route> routes() {
+        return Collections.singletonList(
+            new Route(
+                POST,
+                MachineLearning.BASE_PATH
+                    + "trained_models/{"
+                    + StartTrainedModelDeploymentAction.Request.MODEL_ID.getPreferredName()
+                    + "}/deployment/_update"
+            )
+        );
+    }
+
+    @Override
+    protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException {
+        String modelId = restRequest.param(StartTrainedModelDeploymentAction.Request.MODEL_ID.getPreferredName());
+        XContentParser parser = restRequest.contentParser();
+        UpdateTrainedModelDeploymentAction.Request request = UpdateTrainedModelDeploymentAction.Request.parseRequest(modelId, parser);
+        request.timeout(restRequest.paramAsTime("timeout", request.timeout()));
+        request.masterNodeTimeout(restRequest.paramAsTime("master_timeout", request.masterNodeTimeout()));
+
+        return channel -> client.execute(UpdateTrainedModelDeploymentAction.INSTANCE, request, new RestToXContentListener<>(channel));
+    }
+}

+ 195 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AllocationReducerTests.java

@@ -0,0 +1,195 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.inference.assignment.planning;
+
+import org.elasticsearch.Version;
+import org.elasticsearch.cluster.node.DiscoveryNode;
+import org.elasticsearch.cluster.node.DiscoveryNodeRole;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
+import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingInfo;
+import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingState;
+import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
+
+import java.util.Collection;
+import java.util.List;
+import java.util.Map;
+
+import static org.hamcrest.Matchers.aMapWithSize;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.hasKey;
+
+public class AllocationReducerTests extends ESTestCase {
+
+    public void testReduceTo_ValueEqualToCurrentAllocations() {
+        Map<List<String>, Collection<DiscoveryNode>> nodesByZone = Map.of(List.of("z"), List.of(buildNode("n")));
+        TrainedModelAssignment assignment = createAssignment("m", 2, Map.of("n", 2));
+        expectThrows(IllegalArgumentException.class, () -> new AllocationReducer(assignment, nodesByZone).reduceTo(2));
+    }
+
+    public void testReduceTo_ValueLargerThanCurrentAllocations() {
+        Map<List<String>, Collection<DiscoveryNode>> nodesByZone = Map.of(List.of("z"), List.of(buildNode("n")));
+        TrainedModelAssignment assignment = createAssignment("m", 2, Map.of("n", 2));
+        expectThrows(IllegalArgumentException.class, () -> new AllocationReducer(assignment, nodesByZone).reduceTo(3));
+    }
+
+    public void testReduceTo_GivenOneZone_OneAssignment_ReductionByOne() {
+        Map<List<String>, Collection<DiscoveryNode>> nodesByZone = Map.of(List.of("z"), List.of(buildNode("n")));
+        TrainedModelAssignment assignment = createAssignment("m", 2, Map.of("n", 2));
+
+        TrainedModelAssignment updatedAssignment = new AllocationReducer(assignment, nodesByZone).reduceTo(1).build();
+
+        assertThat(updatedAssignment.getTaskParams().getNumberOfAllocations(), equalTo(1));
+
+        Map<String, RoutingInfo> routingTable = updatedAssignment.getNodeRoutingTable();
+        assertThat(routingTable, aMapWithSize(1));
+        assertThat(routingTable, hasKey("n"));
+        assertThat(routingTable.get("n").getTargetAllocations(), equalTo(1));
+    }
+
+    public void testReduceTo_GivenOneZone_OneAssignment_ReductionByMany() {
+        Map<List<String>, Collection<DiscoveryNode>> nodesByZone = Map.of(List.of("z"), List.of(buildNode("n")));
+        TrainedModelAssignment assignment = createAssignment("m", 5, Map.of("n", 5));
+
+        TrainedModelAssignment updatedAssignment = new AllocationReducer(assignment, nodesByZone).reduceTo(2).build();
+
+        assertThat(updatedAssignment.getTaskParams().getNumberOfAllocations(), equalTo(2));
+
+        Map<String, RoutingInfo> routingTable = updatedAssignment.getNodeRoutingTable();
+        assertThat(routingTable, aMapWithSize(1));
+        assertThat(routingTable, hasKey("n"));
+        assertThat(routingTable.get("n").getTargetAllocations(), equalTo(2));
+    }
+
+    public void testReduceTo_GivenOneZone_MultipleAssignments_RemovableAssignments() {
+        Map<List<String>, Collection<DiscoveryNode>> nodesByZone = Map.of(
+            List.of("z"),
+            List.of(buildNode("n_1"), buildNode("n_2"), buildNode("n_3"))
+        );
+        TrainedModelAssignment assignment = createAssignment("m", 6, Map.of("n_1", 3, "n_2", 2, "n_3", 1));
+
+        TrainedModelAssignment updatedAssignment = new AllocationReducer(assignment, nodesByZone).reduceTo(3).build();
+
+        assertThat(updatedAssignment.getTaskParams().getNumberOfAllocations(), equalTo(3));
+
+        Map<String, RoutingInfo> routingTable = updatedAssignment.getNodeRoutingTable();
+        assertThat(routingTable, aMapWithSize(1));
+        assertThat(routingTable, hasKey("n_1"));
+        assertThat(routingTable.get("n_1").getTargetAllocations(), equalTo(3));
+    }
+
+    public void testReduceTo_GivenOneZone_MultipleAssignments_NonRemovableAssignments() {
+        Map<List<String>, Collection<DiscoveryNode>> nodesByZone = Map.of(
+            List.of("z"),
+            List.of(buildNode("n_1"), buildNode("n_2"), buildNode("n_3"))
+        );
+        TrainedModelAssignment assignment = createAssignment("m", 6, Map.of("n_1", 2, "n_2", 2, "n_3", 2));
+
+        TrainedModelAssignment updatedAssignment = new AllocationReducer(assignment, nodesByZone).reduceTo(5).build();
+
+        assertThat(updatedAssignment.getTaskParams().getNumberOfAllocations(), equalTo(5));
+        assertThat(updatedAssignment.totalTargetAllocations(), equalTo(5));
+
+        Map<String, RoutingInfo> routingTable = updatedAssignment.getNodeRoutingTable();
+        assertThat(routingTable, aMapWithSize(3));
+        assertThat(routingTable, hasKey("n_1"));
+        assertThat(routingTable, hasKey("n_2"));
+        assertThat(routingTable, hasKey("n_3"));
+    }
+
+    public void testReduceTo_GivenTwoZones_RemovableAssignments() {
+        Map<List<String>, Collection<DiscoveryNode>> nodesByZone = Map.of(
+            List.of("z_1"),
+            List.of(buildNode("n_1"), buildNode("n_2")),
+            List.of("z_2"),
+            List.of(buildNode("n_3"))
+        );
+        TrainedModelAssignment assignment = createAssignment("m", 5, Map.of("n_1", 3, "n_2", 1, "n_3", 1));
+
+        TrainedModelAssignment updatedAssignment = new AllocationReducer(assignment, nodesByZone).reduceTo(4).build();
+
+        assertThat(updatedAssignment.getTaskParams().getNumberOfAllocations(), equalTo(4));
+
+        Map<String, RoutingInfo> routingTable = updatedAssignment.getNodeRoutingTable();
+        assertThat(routingTable, aMapWithSize(2));
+        assertThat(routingTable, hasKey("n_1"));
+        assertThat(routingTable.get("n_1").getTargetAllocations(), equalTo(3));
+        assertThat(routingTable, hasKey("n_3"));
+        assertThat(routingTable.get("n_3").getTargetAllocations(), equalTo(1));
+    }
+
+    public void testReduceTo_GivenTwoZones_NonRemovableAssignments() {
+        Map<List<String>, Collection<DiscoveryNode>> nodesByZone = Map.of(
+            List.of("z_1"),
+            List.of(buildNode("n_1")),
+            List.of("z_2"),
+            List.of(buildNode("n_2"))
+        );
+        TrainedModelAssignment assignment = createAssignment("m", 6, Map.of("n_1", 3, "n_2", 3));
+
+        TrainedModelAssignment updatedAssignment = new AllocationReducer(assignment, nodesByZone).reduceTo(4).build();
+
+        assertThat(updatedAssignment.getTaskParams().getNumberOfAllocations(), equalTo(4));
+
+        Map<String, RoutingInfo> routingTable = updatedAssignment.getNodeRoutingTable();
+        assertThat(routingTable, aMapWithSize(2));
+        assertThat(routingTable, hasKey("n_1"));
+        assertThat(routingTable.get("n_1").getTargetAllocations(), equalTo(2));
+        assertThat(routingTable, hasKey("n_2"));
+        assertThat(routingTable.get("n_2").getTargetAllocations(), equalTo(2));
+    }
+
+    public void testReduceTo_GivenTwoZones_WithSameAssignmentsOfOneAllocationEach() {
+        Map<List<String>, Collection<DiscoveryNode>> nodesByZone = Map.of(
+            List.of("z_1"),
+            List.of(buildNode("n_1")),
+            List.of("z_2"),
+            List.of(buildNode("n_2"))
+        );
+        TrainedModelAssignment assignment = createAssignment("m", 2, Map.of("n_1", 1, "n_2", 1));
+
+        TrainedModelAssignment updatedAssignment = new AllocationReducer(assignment, nodesByZone).reduceTo(1).build();
+
+        assertThat(updatedAssignment.getTaskParams().getNumberOfAllocations(), equalTo(1));
+
+        Map<String, RoutingInfo> routingTable = updatedAssignment.getNodeRoutingTable();
+        assertThat(routingTable, aMapWithSize(1));
+        assertThat(routingTable, hasKey("n_1"));
+        assertThat(routingTable.get(routingTable.keySet().iterator().next()).getTargetAllocations(), equalTo(1));
+    }
+
+    private static TrainedModelAssignment createAssignment(
+        String modelId,
+        int numberOfAllocations,
+        Map<String, Integer> allocationsByNode
+    ) {
+        TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(
+            new StartTrainedModelDeploymentAction.TaskParams(
+                modelId,
+                randomNonNegativeLong(),
+                randomIntBetween(1, 16),
+                numberOfAllocations,
+                1024,
+                null
+            )
+        );
+        allocationsByNode.entrySet()
+            .stream()
+            .forEach(
+                e -> builder.addRoutingEntry(
+                    e.getKey(),
+                    new RoutingInfo(randomIntBetween(1, e.getValue()), e.getValue(), RoutingState.STARTED, "")
+                )
+            );
+        return builder.build();
+    }
+
+    private static DiscoveryNode buildNode(String nodeId) {
+        return new DiscoveryNode(nodeId, nodeId, buildNewFakeTransportAddress(), Map.of(), DiscoveryNodeRole.roles(), Version.CURRENT);
+    }
+}

+ 1 - 0
x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java

@@ -163,6 +163,7 @@ public class Constants {
         "cluster:admin/xpack/ml/trained_models/deployment/clear_cache",
         "cluster:admin/xpack/ml/trained_models/deployment/start",
         "cluster:admin/xpack/ml/trained_models/deployment/stop",
+        "cluster:admin/xpack/ml/trained_models/deployment/update",
         "cluster:admin/xpack/ml/trained_models/part/put",
         "cluster:admin/xpack/ml/trained_models/vocabulary/put",
         "cluster:admin/xpack/ml/upgrade_mode",

+ 25 - 0
x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/3rd_party_deployment.yml

@@ -135,6 +135,31 @@ setup:
         model_id: test_model
   - match: { stopped: true }
 ---
+"Test update deployment":
+  - do:
+      ml.start_trained_model_deployment:
+        model_id: test_model
+        wait_for: started
+  - match: { assignment.assignment_state: started }
+  - match: { assignment.task_parameters.model_id: test_model }
+  - match: { assignment.task_parameters.number_of_allocations: 1 }
+
+  - do:
+      # We update to the same value of 1 as if the test runs on a node with just 1 processor it would fail otherwise
+      ml.update_trained_model_deployment:
+        model_id: test_model
+        body: >
+          {
+            "number_of_allocations": 1
+          }
+  - match: { assignment.task_parameters.model_id: test_model }
+  - match: { assignment.task_parameters.number_of_allocations: 1 }
+
+  - do:
+      ml.stop_trained_model_deployment:
+        model_id: test_model
+  - match: { stopped: true }
+---
 "Test clear deployment cache":
   - skip:
       features: allowed_warnings

+ 10 - 0
x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/update_trained_model_deployment.yml

@@ -0,0 +1,10 @@
+---
+"Test with unknown model id":
+  - do:
+      catch: missing
+      ml.update_trained_model_deployment:
+        model_id: "missing-model"
+        body: >
+          {
+            "number_of_allocations": 4
+          }