Browse Source

[ML] Rename threading params in _start trained model deployment API (#86597)

When starting a trained model deployment the user can tweak performance
by setting the `model_threads` and `inference_threads` parameters.
These parameters are hard to understand and cause confusion.

This commit renames these as well as the fields where their values are
reported in the stats API.

- `model_threads` => `number_of_allocations`
- `inference_threads` => `threads_per_allocation`

Now the terminology is as follows.

A model deployment starts with a requested `number_of_allocations`.
Each allocation means the model gets another thread for executing
parallel inference requests. Thus, more allocations should increase
throughput. In its turn, each allocation is may be using a number
of threads to parallelize each individual inference request.
This is the `threads_per_allocation` setting and increases inference
speed (which might also result in improved throughput).
Dimitris Athanasiou 3 years ago
parent
commit
68c51f3ada
15 changed files with 232 additions and 177 deletions
  1. 21 20
      docs/reference/ml/trained-models/apis/get-trained-models-stats.asciidoc
  2. 14 12
      docs/reference/ml/trained-models/apis/start-trained-model-deployment.asciidoc
  3. 74 50
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java
  4. 59 49
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/AssignmentStats.java
  5. 6 6
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsActionResponseTests.java
  6. 16 16
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentRequestTests.java
  7. 1 1
      x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java
  8. 6 6
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsAction.java
  9. 2 2
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java
  10. 2 2
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/ModelStats.java
  11. 2 2
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/NativePyTorchProcessFactory.java
  12. 7 7
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchBuilder.java
  13. 20 4
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestStartTrainedModelDeploymentAction.java
  14. 1 0
      x-pack/qa/full-cluster-restart/src/test/java/org/elasticsearch/xpack/restart/MLModelDeploymentFullClusterRestartIT.java
  15. 1 0
      x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/MLModelDeploymentsUpgradeIT.java

+ 21 - 20
docs/reference/ml/trained-models/apis/get-trained-models-stats.asciidoc

@@ -120,18 +120,10 @@ The sum of `error_count` for all nodes in the deployment.
 (integer)
 The sum of `inference_count` for all nodes in the deployment.
 
-`inference_threads`:::
-(integer)
-The number of threads used by the inference process.
-
 `model_id`:::
 (string)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]
 
-`model_threads`:::
-(integer)
-The number of threads used when sending inference requests to the model.
-
 `nodes`:::
 (array of objects)
 The deployment stats for each node that currently has the model allocated.
@@ -157,22 +149,10 @@ The number of errors when evaluating the trained model.
 (integer)
 The total number of inference calls made against this node for this model.
 
-`inference_threads`:::
-(integer)
-The number of threads used by the inference process.
-This value is limited by the number of hardware threads on the node;
-it might therefore differ from the `inference_threads` value in the <<start-trained-model-deployment>> API.
-
 `last_access`:::
 (long)
 The epoch time stamp of the last inference call for the model on this node.
 
-`model_threads`:::
-(integer)
-The number of threads used when sending inference requests to the model.
-This value is limited by the number of hardware threads on the node;
-it might therefore differ from the `model_threads` value in the <<start-trained-model-deployment>> API.
-
 `node`:::
 (object)
 Information pertaining to the node.
@@ -200,6 +180,12 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=node-id]
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=node-transport-address]
 ========
 
+`number_of_allocations`:::
+(integer)
+The number of allocations assigned to this node.
+This value is limited by the number of hardware threads on the node;
+it might therefore differ from the `number_of_allocations` value in the <<start-trained-model-deployment>> API.
+
 `number_of_pending_requests`:::
 (integer)
 The number of inference requests queued to be processed.
@@ -240,6 +226,12 @@ queue was full.
 (long)
 The epoch timestamp when the allocation started.
 
+`threads_per_allocation`:::
+(integer)
+The number of threads for each allocation during inference.
+This value is limited by the number of hardware threads on the node;
+it might therefore differ from the `threads_per_allocation` value in the <<start-trained-model-deployment>> API.
+
 `timeout_count`:::
 (integer)
 The number of inference requests that timed out before being processed.
@@ -248,6 +240,11 @@ The number of inference requests that timed out before being processed.
 (integer)
 The number of requests processed in the last 1 minute.
 ======
+
+`number_of_allocations`:::
+(integer)
+The requested number of allocations for the trained model deployment.
+
 `peak_throughput_per_minute`:::
 (integer)
 The peak number of requests processed in a 1 minute period for
@@ -280,6 +277,10 @@ The overall state of the deployment. The values may be:
 * `stopping`: The deployment is preparing to stop and deallocate the model from the relevant nodes.
 --
 
+`threads_per_allocation`:::
+(integer)
+The number of threads per allocation used by the inference process.
+
 `timeout_count`:::
 (integer)
 The sum of `timeout_count` for all nodes in the deployment.

+ 14 - 12
docs/reference/ml/trained-models/apis/start-trained-model-deployment.asciidoc

@@ -38,18 +38,11 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]
 [[start-trained-model-deployment-query-params]]
 == {api-query-parms-title}
 
-`inference_threads`::
+`number_of_allocations`::
 (Optional, integer)
-Sets the number of threads used by the inference process. This generally increases
-the inference speed. The inference process is a compute-bound process; any number
-greater than the number of available hardware threads on the machine does not increase the
-inference speed. If this setting is greater than the number of hardware threads
-it will automatically be changed to a value less than the number of hardware threads.
-Defaults to 1.
-
-`model_threads`::
-(Optional, integer)
-The number of threads used when sending inference requests to the model.
+The number of model allocations on each node where the model is deployed.
+All allocations on a node share the same copy of the model in memory but use
+a separate set of threads to evaluate the model. 
 Increasing this value generally increases the throughput.
 If this setting is greater than the number of hardware threads
 it will automatically be changed to a value less than the number of hardware threads.
@@ -57,7 +50,7 @@ Defaults to 1.
 
 [NOTE]
 =============================================
-If the sum of `inference_threads` and `model_threads` is greater than the number of
+If the sum of `threads_per_allocation` and `number_of_allocations` is greater than the number of
 hardware threads then the number of `inference_threads` will be reduced.
 =============================================
 
@@ -68,6 +61,15 @@ Every machine learning node in the cluster where the model can be allocated
 has a queue of this size; when the number of requests exceeds the total value,
 new requests are rejected with a 429 error. Defaults to 1024.
 
+`threads_per_allocation`::
+(Optional, integer)
+Sets the number of threads used by each model allocation during inference. This generally increases
+the inference speed. The inference process is a compute-bound process; any number
+greater than the number of available hardware threads on the machine does not increase the
+inference speed. If this setting is greater than the number of hardware threads
+it will automatically be changed to a value less than the number of hardware threads.
+Defaults to 1.
+
 `timeout`::
 (Optional, time)
 Controls the amount of time to wait for the model to deploy. Defaults

+ 74 - 50
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java

@@ -69,8 +69,8 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
         public static final ParseField MODEL_ID = new ParseField("model_id");
         public static final ParseField TIMEOUT = new ParseField("timeout");
         public static final ParseField WAIT_FOR = new ParseField("wait_for");
-        public static final ParseField INFERENCE_THREADS = TaskParams.INFERENCE_THREADS;
-        public static final ParseField MODEL_THREADS = TaskParams.MODEL_THREADS;
+        public static final ParseField THREADS_PER_ALLOCATION = new ParseField("threads_per_allocation", "inference_threads");
+        public static final ParseField NUMBER_OF_ALLOCATIONS = new ParseField("number_of_allocations", "model_threads");
         public static final ParseField QUEUE_CAPACITY = TaskParams.QUEUE_CAPACITY;
 
         public static final ObjectParser<Request, Void> PARSER = new ObjectParser<>(NAME, Request::new);
@@ -79,8 +79,8 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
             PARSER.declareString(Request::setModelId, MODEL_ID);
             PARSER.declareString((request, val) -> request.setTimeout(TimeValue.parseTimeValue(val, TIMEOUT.getPreferredName())), TIMEOUT);
             PARSER.declareString((request, waitFor) -> request.setWaitForState(AllocationStatus.State.fromString(waitFor)), WAIT_FOR);
-            PARSER.declareInt(Request::setInferenceThreads, INFERENCE_THREADS);
-            PARSER.declareInt(Request::setModelThreads, MODEL_THREADS);
+            PARSER.declareInt(Request::setThreadsPerAllocation, THREADS_PER_ALLOCATION);
+            PARSER.declareInt(Request::setNumberOfAllocations, NUMBER_OF_ALLOCATIONS);
             PARSER.declareInt(Request::setQueueCapacity, QUEUE_CAPACITY);
         }
 
@@ -99,8 +99,8 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
         private String modelId;
         private TimeValue timeout = DEFAULT_TIMEOUT;
         private AllocationStatus.State waitForState = AllocationStatus.State.STARTED;
-        private int modelThreads = 1;
-        private int inferenceThreads = 1;
+        private int numberOfAllocations = 1;
+        private int threadsPerAllocation = 1;
         private int queueCapacity = 1024;
 
         private Request() {}
@@ -114,8 +114,8 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
             modelId = in.readString();
             timeout = in.readTimeValue();
             waitForState = in.readEnum(AllocationStatus.State.class);
-            modelThreads = in.readVInt();
-            inferenceThreads = in.readVInt();
+            numberOfAllocations = in.readVInt();
+            threadsPerAllocation = in.readVInt();
             queueCapacity = in.readVInt();
         }
 
@@ -144,20 +144,20 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
             return this;
         }
 
-        public int getModelThreads() {
-            return modelThreads;
+        public int getNumberOfAllocations() {
+            return numberOfAllocations;
         }
 
-        public void setModelThreads(int modelThreads) {
-            this.modelThreads = modelThreads;
+        public void setNumberOfAllocations(int numberOfAllocations) {
+            this.numberOfAllocations = numberOfAllocations;
         }
 
-        public int getInferenceThreads() {
-            return inferenceThreads;
+        public int getThreadsPerAllocation() {
+            return threadsPerAllocation;
         }
 
-        public void setInferenceThreads(int inferenceThreads) {
-            this.inferenceThreads = inferenceThreads;
+        public void setThreadsPerAllocation(int threadsPerAllocation) {
+            this.threadsPerAllocation = threadsPerAllocation;
         }
 
         public int getQueueCapacity() {
@@ -174,8 +174,8 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
             out.writeString(modelId);
             out.writeTimeValue(timeout);
             out.writeEnum(waitForState);
-            out.writeVInt(modelThreads);
-            out.writeVInt(inferenceThreads);
+            out.writeVInt(numberOfAllocations);
+            out.writeVInt(threadsPerAllocation);
             out.writeVInt(queueCapacity);
         }
 
@@ -185,8 +185,8 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
             builder.field(MODEL_ID.getPreferredName(), modelId);
             builder.field(TIMEOUT.getPreferredName(), timeout.getStringRep());
             builder.field(WAIT_FOR.getPreferredName(), waitForState);
-            builder.field(MODEL_THREADS.getPreferredName(), modelThreads);
-            builder.field(INFERENCE_THREADS.getPreferredName(), inferenceThreads);
+            builder.field(NUMBER_OF_ALLOCATIONS.getPreferredName(), numberOfAllocations);
+            builder.field(THREADS_PER_ALLOCATION.getPreferredName(), threadsPerAllocation);
             builder.field(QUEUE_CAPACITY.getPreferredName(), queueCapacity);
             builder.endObject();
             return builder;
@@ -203,11 +203,11 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
                         + Strings.arrayToCommaDelimitedString(VALID_WAIT_STATES)
                 );
             }
-            if (modelThreads < 1) {
-                validationException.addValidationError("[" + MODEL_THREADS + "] must be a positive integer");
+            if (numberOfAllocations < 1) {
+                validationException.addValidationError("[" + NUMBER_OF_ALLOCATIONS + "] must be a positive integer");
             }
-            if (inferenceThreads < 1) {
-                validationException.addValidationError("[" + INFERENCE_THREADS + "] must be a positive integer");
+            if (threadsPerAllocation < 1) {
+                validationException.addValidationError("[" + THREADS_PER_ALLOCATION + "] must be a positive integer");
             }
             if (queueCapacity < 1) {
                 validationException.addValidationError("[" + QUEUE_CAPACITY + "] must be a positive integer");
@@ -217,7 +217,7 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
 
         @Override
         public int hashCode() {
-            return Objects.hash(modelId, timeout, waitForState, modelThreads, inferenceThreads, queueCapacity);
+            return Objects.hash(modelId, timeout, waitForState, numberOfAllocations, threadsPerAllocation, queueCapacity);
         }
 
         @Override
@@ -232,8 +232,8 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
             return Objects.equals(modelId, other.modelId)
                 && Objects.equals(timeout, other.timeout)
                 && Objects.equals(waitForState, other.waitForState)
-                && modelThreads == other.modelThreads
-                && inferenceThreads == other.inferenceThreads
+                && numberOfAllocations == other.numberOfAllocations
+                && threadsPerAllocation == other.threadsPerAllocation
                 && queueCapacity == other.queueCapacity;
         }
 
@@ -254,22 +254,28 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
 
         public static final Version VERSION_INTRODUCED = Version.V_8_0_0;
         private static final ParseField MODEL_BYTES = new ParseField("model_bytes");
-        public static final ParseField MODEL_THREADS = new ParseField("model_threads");
-        public static final ParseField INFERENCE_THREADS = new ParseField("inference_threads");
+        public static final ParseField NUMBER_OF_ALLOCATIONS = new ParseField("number_of_allocations");
+        public static final ParseField THREADS_PER_ALLOCATION = new ParseField("threads_per_allocation");
+        // number_of_allocations was previously named model_threads
+        private static final ParseField LEGACY_MODEL_THREADS = new ParseField("model_threads");
+        // threads_per_allocation was previously named inference_threads
+        public static final ParseField LEGACY_INFERENCE_THREADS = new ParseField("inference_threads");
         public static final ParseField QUEUE_CAPACITY = new ParseField("queue_capacity");
 
         private static final ConstructingObjectParser<TaskParams, Void> PARSER = new ConstructingObjectParser<>(
             "trained_model_deployment_params",
             true,
-            a -> new TaskParams((String) a[0], (Long) a[1], (int) a[2], (int) a[3], (int) a[4])
+            a -> new TaskParams((String) a[0], (Long) a[1], (Integer) a[2], (Integer) a[3], (int) a[4], (Integer) a[5], (Integer) a[6])
         );
 
         static {
             PARSER.declareString(ConstructingObjectParser.constructorArg(), TrainedModelConfig.MODEL_ID);
             PARSER.declareLong(ConstructingObjectParser.constructorArg(), MODEL_BYTES);
-            PARSER.declareInt(ConstructingObjectParser.constructorArg(), INFERENCE_THREADS);
-            PARSER.declareInt(ConstructingObjectParser.constructorArg(), MODEL_THREADS);
+            PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUMBER_OF_ALLOCATIONS);
+            PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), THREADS_PER_ALLOCATION);
             PARSER.declareInt(ConstructingObjectParser.constructorArg(), QUEUE_CAPACITY);
+            PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), LEGACY_MODEL_THREADS);
+            PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), LEGACY_INFERENCE_THREADS);
         }
 
         public static TaskParams fromXContent(XContentParser parser) {
@@ -279,24 +285,42 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
         private final String modelId;
         private final long modelBytes;
         // How many threads are used by the model during inference. Used to increase inference speed.
-        private final int inferenceThreads;
+        private final int threadsPerAllocation;
         // How many threads are used when forwarding the request to the model. Used to increase throughput.
-        private final int modelThreads;
+        private final int numberOfAllocations;
         private final int queueCapacity;
 
-        public TaskParams(String modelId, long modelBytes, int inferenceThreads, int modelThreads, int queueCapacity) {
+        private TaskParams(
+            String modelId,
+            long modelBytes,
+            Integer numberOfAllocations,
+            Integer threadsPerAllocation,
+            int queueCapacity,
+            Integer legacyModelThreads,
+            Integer legacyInferenceThreads
+        ) {
+            this(
+                modelId,
+                modelBytes,
+                threadsPerAllocation == null ? legacyInferenceThreads : threadsPerAllocation,
+                numberOfAllocations == null ? legacyModelThreads : numberOfAllocations,
+                queueCapacity
+            );
+        }
+
+        public TaskParams(String modelId, long modelBytes, int threadsPerAllocation, int numberOfAllocations, int queueCapacity) {
             this.modelId = Objects.requireNonNull(modelId);
             this.modelBytes = modelBytes;
-            this.inferenceThreads = inferenceThreads;
-            this.modelThreads = modelThreads;
+            this.threadsPerAllocation = threadsPerAllocation;
+            this.numberOfAllocations = numberOfAllocations;
             this.queueCapacity = queueCapacity;
         }
 
         public TaskParams(StreamInput in) throws IOException {
             this.modelId = in.readString();
             this.modelBytes = in.readLong();
-            this.inferenceThreads = in.readVInt();
-            this.modelThreads = in.readVInt();
+            this.threadsPerAllocation = in.readVInt();
+            this.numberOfAllocations = in.readVInt();
             this.queueCapacity = in.readVInt();
         }
 
@@ -316,8 +340,8 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
         public void writeTo(StreamOutput out) throws IOException {
             out.writeString(modelId);
             out.writeLong(modelBytes);
-            out.writeVInt(inferenceThreads);
-            out.writeVInt(modelThreads);
+            out.writeVInt(threadsPerAllocation);
+            out.writeVInt(numberOfAllocations);
             out.writeVInt(queueCapacity);
         }
 
@@ -326,8 +350,8 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
             builder.startObject();
             builder.field(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId);
             builder.field(MODEL_BYTES.getPreferredName(), modelBytes);
-            builder.field(INFERENCE_THREADS.getPreferredName(), inferenceThreads);
-            builder.field(MODEL_THREADS.getPreferredName(), modelThreads);
+            builder.field(THREADS_PER_ALLOCATION.getPreferredName(), threadsPerAllocation);
+            builder.field(NUMBER_OF_ALLOCATIONS.getPreferredName(), numberOfAllocations);
             builder.field(QUEUE_CAPACITY.getPreferredName(), queueCapacity);
             builder.endObject();
             return builder;
@@ -335,7 +359,7 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
 
         @Override
         public int hashCode() {
-            return Objects.hash(modelId, modelBytes, inferenceThreads, modelThreads, queueCapacity);
+            return Objects.hash(modelId, modelBytes, threadsPerAllocation, numberOfAllocations, queueCapacity);
         }
 
         @Override
@@ -346,8 +370,8 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
             TaskParams other = (TaskParams) o;
             return Objects.equals(modelId, other.modelId)
                 && modelBytes == other.modelBytes
-                && inferenceThreads == other.inferenceThreads
-                && modelThreads == other.modelThreads
+                && threadsPerAllocation == other.threadsPerAllocation
+                && numberOfAllocations == other.numberOfAllocations
                 && queueCapacity == other.queueCapacity;
         }
 
@@ -360,12 +384,12 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
             return modelBytes;
         }
 
-        public int getInferenceThreads() {
-            return inferenceThreads;
+        public int getThreadsPerAllocation() {
+            return threadsPerAllocation;
         }
 
-        public int getModelThreads() {
-            return modelThreads;
+        public int getNumberOfAllocations() {
+            return numberOfAllocations;
         }
 
         public int getQueueCapacity() {

+ 59 - 49
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/AssignmentStats.java

@@ -37,8 +37,8 @@ public class AssignmentStats implements ToXContentObject, Writeable {
         private final int timeoutCount;
         private final RoutingStateAndReason routingState;
         private final Instant startTime;
-        private final Integer inferenceThreads;
-        private final Integer modelThreads;
+        private final Integer threadsPerAllocation;
+        private final Integer numberOfAllocations;
         private final long peakThroughput;
         private final long throughputLastPeriod;
         private final Double avgInferenceTimeLastPeriod;
@@ -53,8 +53,8 @@ public class AssignmentStats implements ToXContentObject, Writeable {
             int timeoutCount,
             Instant lastAccess,
             Instant startTime,
-            Integer inferenceThreads,
-            Integer modelThreads,
+            Integer threadsPerAllocation,
+            Integer numberOfAllocations,
             long peakThroughput,
             long throughputLastPeriod,
             Double avgInferenceTimeLastPeriod
@@ -70,8 +70,8 @@ public class AssignmentStats implements ToXContentObject, Writeable {
                 timeoutCount,
                 new RoutingStateAndReason(RoutingState.STARTED, null),
                 Objects.requireNonNull(startTime),
-                inferenceThreads,
-                modelThreads,
+                threadsPerAllocation,
+                numberOfAllocations,
                 peakThroughput,
                 throughputLastPeriod,
                 avgInferenceTimeLastPeriod
@@ -109,8 +109,8 @@ public class AssignmentStats implements ToXContentObject, Writeable {
             int timeoutCount,
             RoutingStateAndReason routingState,
             @Nullable Instant startTime,
-            @Nullable Integer inferenceThreads,
-            @Nullable Integer modelThreads,
+            @Nullable Integer threadsPerAllocation,
+            @Nullable Integer numberOfAllocations,
             long peakThroughput,
             long throughputLastPeriod,
             Double avgInferenceTimeLastPeriod
@@ -125,8 +125,8 @@ public class AssignmentStats implements ToXContentObject, Writeable {
             this.timeoutCount = timeoutCount;
             this.routingState = routingState;
             this.startTime = startTime;
-            this.inferenceThreads = inferenceThreads;
-            this.modelThreads = modelThreads;
+            this.threadsPerAllocation = threadsPerAllocation;
+            this.numberOfAllocations = numberOfAllocations;
             this.peakThroughput = peakThroughput;
             this.throughputLastPeriod = throughputLastPeriod;
             this.avgInferenceTimeLastPeriod = avgInferenceTimeLastPeriod;
@@ -144,14 +144,14 @@ public class AssignmentStats implements ToXContentObject, Writeable {
             this.routingState = in.readOptionalWriteable(RoutingStateAndReason::new);
             this.startTime = in.readOptionalInstant();
             if (in.getVersion().onOrAfter(Version.V_8_1_0)) {
-                this.inferenceThreads = in.readOptionalVInt();
-                this.modelThreads = in.readOptionalVInt();
+                this.threadsPerAllocation = in.readOptionalVInt();
+                this.numberOfAllocations = in.readOptionalVInt();
                 this.errorCount = in.readVInt();
                 this.rejectedExecutionCount = in.readVInt();
                 this.timeoutCount = in.readVInt();
             } else {
-                this.inferenceThreads = null;
-                this.modelThreads = null;
+                this.threadsPerAllocation = null;
+                this.numberOfAllocations = null;
                 this.errorCount = 0;
                 this.rejectedExecutionCount = 0;
                 this.timeoutCount = 0;
@@ -207,12 +207,12 @@ public class AssignmentStats implements ToXContentObject, Writeable {
             return startTime;
         }
 
-        public Integer getInferenceThreads() {
-            return inferenceThreads;
+        public Integer getThreadsPerAllocation() {
+            return threadsPerAllocation;
         }
 
-        public Integer getModelThreads() {
-            return modelThreads;
+        public Integer getNumberOfAllocations() {
+            return numberOfAllocations;
         }
 
         public long getPeakThroughput() {
@@ -261,11 +261,11 @@ public class AssignmentStats implements ToXContentObject, Writeable {
             if (startTime != null) {
                 builder.timeField("start_time", "start_time_string", startTime.toEpochMilli());
             }
-            if (inferenceThreads != null) {
-                builder.field("inference_threads", inferenceThreads);
+            if (threadsPerAllocation != null) {
+                builder.field("threads_per_allocation", threadsPerAllocation);
             }
-            if (modelThreads != null) {
-                builder.field("model_threads", modelThreads);
+            if (numberOfAllocations != null) {
+                builder.field("number_of_allocations", numberOfAllocations);
             }
             builder.field("peak_throughput_per_minute", peakThroughput);
             builder.field("throughput_last_minute", throughputLastPeriod);
@@ -287,8 +287,8 @@ public class AssignmentStats implements ToXContentObject, Writeable {
             out.writeOptionalWriteable(routingState);
             out.writeOptionalInstant(startTime);
             if (out.getVersion().onOrAfter(Version.V_8_1_0)) {
-                out.writeOptionalVInt(inferenceThreads);
-                out.writeOptionalVInt(modelThreads);
+                out.writeOptionalVInt(threadsPerAllocation);
+                out.writeOptionalVInt(numberOfAllocations);
                 out.writeVInt(errorCount);
                 out.writeVInt(rejectedExecutionCount);
                 out.writeVInt(timeoutCount);
@@ -315,8 +315,8 @@ public class AssignmentStats implements ToXContentObject, Writeable {
                 && Objects.equals(timeoutCount, that.timeoutCount)
                 && Objects.equals(routingState, that.routingState)
                 && Objects.equals(startTime, that.startTime)
-                && Objects.equals(inferenceThreads, that.inferenceThreads)
-                && Objects.equals(modelThreads, that.modelThreads)
+                && Objects.equals(threadsPerAllocation, that.threadsPerAllocation)
+                && Objects.equals(numberOfAllocations, that.numberOfAllocations)
                 && Objects.equals(peakThroughput, that.peakThroughput)
                 && Objects.equals(throughputLastPeriod, that.throughputLastPeriod)
                 && Objects.equals(avgInferenceTimeLastPeriod, that.avgInferenceTimeLastPeriod);
@@ -335,8 +335,8 @@ public class AssignmentStats implements ToXContentObject, Writeable {
                 timeoutCount,
                 routingState,
                 startTime,
-                inferenceThreads,
-                modelThreads,
+                threadsPerAllocation,
+                numberOfAllocations,
                 peakThroughput,
                 throughputLastPeriod,
                 avgInferenceTimeLastPeriod
@@ -349,9 +349,9 @@ public class AssignmentStats implements ToXContentObject, Writeable {
     private AllocationStatus allocationStatus;
     private String reason;
     @Nullable
-    private final Integer inferenceThreads;
+    private final Integer threadsPerAllocation;
     @Nullable
-    private final Integer modelThreads;
+    private final Integer numberOfAllocations;
     @Nullable
     private final Integer queueCapacity;
     private final Instant startTime;
@@ -359,15 +359,15 @@ public class AssignmentStats implements ToXContentObject, Writeable {
 
     public AssignmentStats(
         String modelId,
-        @Nullable Integer inferenceThreads,
-        @Nullable Integer modelThreads,
+        @Nullable Integer threadsPerAllocation,
+        @Nullable Integer numberOfAllocations,
         @Nullable Integer queueCapacity,
         Instant startTime,
         List<AssignmentStats.NodeStats> nodeStats
     ) {
         this.modelId = modelId;
-        this.inferenceThreads = inferenceThreads;
-        this.modelThreads = modelThreads;
+        this.threadsPerAllocation = threadsPerAllocation;
+        this.numberOfAllocations = numberOfAllocations;
         this.queueCapacity = queueCapacity;
         this.startTime = Objects.requireNonNull(startTime);
         this.nodeStats = nodeStats;
@@ -377,8 +377,8 @@ public class AssignmentStats implements ToXContentObject, Writeable {
 
     public AssignmentStats(StreamInput in) throws IOException {
         modelId = in.readString();
-        inferenceThreads = in.readOptionalVInt();
-        modelThreads = in.readOptionalVInt();
+        threadsPerAllocation = in.readOptionalVInt();
+        numberOfAllocations = in.readOptionalVInt();
         queueCapacity = in.readOptionalVInt();
         startTime = in.readInstant();
         nodeStats = in.readList(AssignmentStats.NodeStats::new);
@@ -392,13 +392,13 @@ public class AssignmentStats implements ToXContentObject, Writeable {
     }
 
     @Nullable
-    public Integer getInferenceThreads() {
-        return inferenceThreads;
+    public Integer getThreadsPerAllocation() {
+        return threadsPerAllocation;
     }
 
     @Nullable
-    public Integer getModelThreads() {
-        return modelThreads;
+    public Integer getNumberOfAllocations() {
+        return numberOfAllocations;
     }
 
     @Nullable
@@ -441,11 +441,11 @@ public class AssignmentStats implements ToXContentObject, Writeable {
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
         builder.startObject();
         builder.field("model_id", modelId);
-        if (inferenceThreads != null) {
-            builder.field(StartTrainedModelDeploymentAction.TaskParams.INFERENCE_THREADS.getPreferredName(), inferenceThreads);
+        if (threadsPerAllocation != null) {
+            builder.field(StartTrainedModelDeploymentAction.TaskParams.THREADS_PER_ALLOCATION.getPreferredName(), threadsPerAllocation);
         }
-        if (modelThreads != null) {
-            builder.field(StartTrainedModelDeploymentAction.TaskParams.MODEL_THREADS.getPreferredName(), modelThreads);
+        if (numberOfAllocations != null) {
+            builder.field(StartTrainedModelDeploymentAction.TaskParams.NUMBER_OF_ALLOCATIONS.getPreferredName(), numberOfAllocations);
         }
         if (queueCapacity != null) {
             builder.field(StartTrainedModelDeploymentAction.TaskParams.QUEUE_CAPACITY.getPreferredName(), queueCapacity);
@@ -496,8 +496,8 @@ public class AssignmentStats implements ToXContentObject, Writeable {
     @Override
     public void writeTo(StreamOutput out) throws IOException {
         out.writeString(modelId);
-        out.writeOptionalVInt(inferenceThreads);
-        out.writeOptionalVInt(modelThreads);
+        out.writeOptionalVInt(threadsPerAllocation);
+        out.writeOptionalVInt(numberOfAllocations);
         out.writeOptionalVInt(queueCapacity);
         out.writeInstant(startTime);
         out.writeList(nodeStats);
@@ -512,8 +512,8 @@ public class AssignmentStats implements ToXContentObject, Writeable {
         if (o == null || getClass() != o.getClass()) return false;
         AssignmentStats that = (AssignmentStats) o;
         return Objects.equals(modelId, that.modelId)
-            && Objects.equals(inferenceThreads, that.inferenceThreads)
-            && Objects.equals(modelThreads, that.modelThreads)
+            && Objects.equals(threadsPerAllocation, that.threadsPerAllocation)
+            && Objects.equals(numberOfAllocations, that.numberOfAllocations)
             && Objects.equals(queueCapacity, that.queueCapacity)
             && Objects.equals(startTime, that.startTime)
             && Objects.equals(state, that.state)
@@ -524,7 +524,17 @@ public class AssignmentStats implements ToXContentObject, Writeable {
 
     @Override
     public int hashCode() {
-        return Objects.hash(modelId, inferenceThreads, modelThreads, queueCapacity, startTime, nodeStats, state, reason, allocationStatus);
+        return Objects.hash(
+            modelId,
+            threadsPerAllocation,
+            numberOfAllocations,
+            queueCapacity,
+            startTime,
+            nodeStats,
+            state,
+            reason,
+            allocationStatus
+        );
     }
 
     @Override

+ 6 - 6
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsActionResponseTests.java

@@ -112,8 +112,8 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractBWCWireSer
                                     ? null
                                     : new AssignmentStats(
                                         stats.getDeploymentStats().getModelId(),
-                                        stats.getDeploymentStats().getInferenceThreads(),
-                                        stats.getDeploymentStats().getModelThreads(),
+                                        stats.getDeploymentStats().getThreadsPerAllocation(),
+                                        stats.getDeploymentStats().getNumberOfAllocations(),
                                         stats.getDeploymentStats().getQueueCapacity(),
                                         stats.getDeploymentStats().getStartTime(),
                                         stats.getDeploymentStats()
@@ -164,8 +164,8 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractBWCWireSer
                                     ? null
                                     : new AssignmentStats(
                                         stats.getDeploymentStats().getModelId(),
-                                        stats.getDeploymentStats().getInferenceThreads(),
-                                        stats.getDeploymentStats().getModelThreads(),
+                                        stats.getDeploymentStats().getThreadsPerAllocation(),
+                                        stats.getDeploymentStats().getNumberOfAllocations(),
                                         stats.getDeploymentStats().getQueueCapacity(),
                                         stats.getDeploymentStats().getStartTime(),
                                         stats.getDeploymentStats()
@@ -183,8 +183,8 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractBWCWireSer
                                                     nodeStats.getTimeoutCount(),
                                                     nodeStats.getRoutingState(),
                                                     nodeStats.getStartTime(),
-                                                    nodeStats.getInferenceThreads(),
-                                                    nodeStats.getModelThreads(),
+                                                    nodeStats.getThreadsPerAllocation(),
+                                                    nodeStats.getNumberOfAllocations(),
                                                     0L,
                                                     0L,
                                                     null

+ 16 - 16
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentRequestTests.java

@@ -49,10 +49,10 @@ public class StartTrainedModelDeploymentRequestTests extends AbstractSerializing
             request.setWaitForState(randomFrom(AllocationStatus.State.values()));
         }
         if (randomBoolean()) {
-            request.setInferenceThreads(randomIntBetween(1, 8));
+            request.setThreadsPerAllocation(randomIntBetween(1, 8));
         }
         if (randomBoolean()) {
-            request.setModelThreads(randomIntBetween(1, 8));
+            request.setNumberOfAllocations(randomIntBetween(1, 8));
         }
         if (randomBoolean()) {
             request.setQueueCapacity(randomIntBetween(1, 10000));
@@ -60,44 +60,44 @@ public class StartTrainedModelDeploymentRequestTests extends AbstractSerializing
         return request;
     }
 
-    public void testValidate_GivenInferenceThreadsIsZero() {
+    public void testValidate_GivenThreadsPerAllocationIsZero() {
         Request request = createRandom();
-        request.setInferenceThreads(0);
+        request.setThreadsPerAllocation(0);
 
         ActionRequestValidationException e = request.validate();
 
         assertThat(e, is(not(nullValue())));
-        assertThat(e.getMessage(), containsString("[inference_threads] must be a positive integer"));
+        assertThat(e.getMessage(), containsString("[threads_per_allocation] must be a positive integer"));
     }
 
-    public void testValidate_GivenInferenceThreadsIsNegative() {
+    public void testValidate_GivenThreadsPerAllocationIsNegative() {
         Request request = createRandom();
-        request.setInferenceThreads(randomIntBetween(-100, -1));
+        request.setThreadsPerAllocation(randomIntBetween(-100, -1));
 
         ActionRequestValidationException e = request.validate();
 
         assertThat(e, is(not(nullValue())));
-        assertThat(e.getMessage(), containsString("[inference_threads] must be a positive integer"));
+        assertThat(e.getMessage(), containsString("[threads_per_allocation] must be a positive integer"));
     }
 
-    public void testValidate_GivenModelThreadsIsZero() {
+    public void testValidate_GivenNumberOfAllocationsIsZero() {
         Request request = createRandom();
-        request.setModelThreads(0);
+        request.setNumberOfAllocations(0);
 
         ActionRequestValidationException e = request.validate();
 
         assertThat(e, is(not(nullValue())));
-        assertThat(e.getMessage(), containsString("[model_threads] must be a positive integer"));
+        assertThat(e.getMessage(), containsString("[number_of_allocations] must be a positive integer"));
     }
 
-    public void testValidate_GivenModelThreadsIsNegative() {
+    public void testValidate_GivenNumberOfAllocationsIsNegative() {
         Request request = createRandom();
-        request.setModelThreads(randomIntBetween(-100, -1));
+        request.setNumberOfAllocations(randomIntBetween(-100, -1));
 
         ActionRequestValidationException e = request.validate();
 
         assertThat(e, is(not(nullValue())));
-        assertThat(e.getMessage(), containsString("[model_threads] must be a positive integer"));
+        assertThat(e.getMessage(), containsString("[number_of_allocations] must be a positive integer"));
     }
 
     public void testValidate_GivenQueueCapacityIsZero() {
@@ -124,8 +124,8 @@ public class StartTrainedModelDeploymentRequestTests extends AbstractSerializing
         Request request = new Request(randomAlphaOfLength(10));
         assertThat(request.getTimeout(), equalTo(TimeValue.timeValueSeconds(20)));
         assertThat(request.getWaitForState(), equalTo(AllocationStatus.State.STARTED));
-        assertThat(request.getInferenceThreads(), equalTo(1));
-        assertThat(request.getModelThreads(), equalTo(1));
+        assertThat(request.getNumberOfAllocations(), equalTo(1));
+        assertThat(request.getThreadsPerAllocation(), equalTo(1));
         assertThat(request.getQueueCapacity(), equalTo(1024));
     }
 }

+ 1 - 1
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java

@@ -750,7 +750,7 @@ public class PyTorchModelIT extends ESRestTestCase {
                 + modelId
                 + "/deployment/_start?timeout=40s&wait_for="
                 + waitForState
-                + "&inference_threads=1&model_threads=1"
+                + "&threads_per_allocation=1&number_of_allocations=1"
         );
         return client().performRequest(request);
     }

+ 6 - 6
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsAction.java

@@ -234,8 +234,8 @@ public class TransportGetDeploymentStatsAction extends TransportTasksAction<
                 updatedAssignmentStats.add(
                     new AssignmentStats(
                         stat.getModelId(),
-                        stat.getInferenceThreads(),
-                        stat.getModelThreads(),
+                        stat.getThreadsPerAllocation(),
+                        stat.getNumberOfAllocations(),
                         stat.getQueueCapacity(),
                         stat.getStartTime(),
                         updatedNodeStats
@@ -304,8 +304,8 @@ public class TransportGetDeploymentStatsAction extends TransportTasksAction<
                     presentValue.timeoutCount(),
                     presentValue.lastUsed(),
                     presentValue.startTime(),
-                    presentValue.inferenceThreads(),
-                    presentValue.modelThreads(),
+                    presentValue.threadsPerAllocation(),
+                    presentValue.numberOfAllocations(),
                     presentValue.peakThroughput(),
                     presentValue.throughputLastPeriod(),
                     presentValue.avgInferenceTimeLastPeriod()
@@ -320,8 +320,8 @@ public class TransportGetDeploymentStatsAction extends TransportTasksAction<
         listener.onResponse(
             new AssignmentStats(
                 task.getModelId(),
-                task.getParams().getInferenceThreads(),
-                task.getParams().getModelThreads(),
+                task.getParams().getThreadsPerAllocation(),
+                task.getParams().getNumberOfAllocations(),
                 task.getParams().getQueueCapacity(),
                 TrainedModelAssignmentMetadata.fromState(clusterService.state()).getModelAssignment(task.getModelId()).getStartTime(),
                 nodeStats

+ 2 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java

@@ -223,8 +223,8 @@ public class TransportStartTrainedModelDeploymentAction extends TransportMasterN
                     TaskParams taskParams = new TaskParams(
                         trainedModelConfig.getModelId(),
                         modelBytes,
-                        request.getInferenceThreads(),
-                        request.getModelThreads(),
+                        request.getThreadsPerAllocation(),
+                        request.getNumberOfAllocations(),
                         request.getQueueCapacity()
                     );
                     PersistentTasksCustomMetadata persistentTasks = clusterService.state()

+ 2 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/ModelStats.java

@@ -18,8 +18,8 @@ public record ModelStats(
     int errorCount,
     int rejectedExecutionCount,
     int timeoutCount,
-    Integer inferenceThreads,
-    Integer modelThreads,
+    Integer threadsPerAllocation,
+    Integer numberOfAllocations,
     long peakThroughput,
     long throughputLastPeriod,
     Double avgInferenceTimeLastPeriod

+ 2 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/NativePyTorchProcessFactory.java

@@ -102,8 +102,8 @@ public class NativePyTorchProcessFactory implements PyTorchProcessFactory {
         PyTorchBuilder pyTorchBuilder = new PyTorchBuilder(
             nativeController,
             processPipes,
-            task.getParams().getInferenceThreads(),
-            task.getParams().getModelThreads()
+            task.getParams().getThreadsPerAllocation(),
+            task.getParams().getNumberOfAllocations()
         );
         try {
             pyTorchBuilder.build();

+ 7 - 7
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchBuilder.java

@@ -26,14 +26,14 @@ public class PyTorchBuilder {
 
     private final NativeController nativeController;
     private final ProcessPipes processPipes;
-    private final int inferenceThreads;
-    private final int modelThreads;
+    private final int threadsPerAllocation;
+    private final int numberOfAllocations;
 
-    public PyTorchBuilder(NativeController nativeController, ProcessPipes processPipes, int inferenceThreads, int modelThreads) {
+    public PyTorchBuilder(NativeController nativeController, ProcessPipes processPipes, int threadPerAllocation, int numberOfAllocations) {
         this.nativeController = Objects.requireNonNull(nativeController);
         this.processPipes = Objects.requireNonNull(processPipes);
-        this.inferenceThreads = inferenceThreads;
-        this.modelThreads = modelThreads;
+        this.threadsPerAllocation = threadPerAllocation;
+        this.numberOfAllocations = numberOfAllocations;
     }
 
     public void build() throws IOException, InterruptedException {
@@ -49,8 +49,8 @@ public class PyTorchBuilder {
         // License was validated when the trained model was started
         command.add(LICENSE_KEY_VALIDATED_ARG + true);
 
-        command.add(NUM_THREADS_PER_ALLOCATION_ARG + inferenceThreads);
-        command.add(NUM_ALLOCATIONS_ARG + modelThreads);
+        command.add(NUM_THREADS_PER_ALLOCATION_ARG + threadsPerAllocation);
+        command.add(NUM_ALLOCATIONS_ARG + numberOfAllocations);
 
         return command;
     }

+ 20 - 4
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestStartTrainedModelDeploymentAction.java

@@ -8,6 +8,7 @@
 package org.elasticsearch.xpack.ml.rest.inference;
 
 import org.elasticsearch.client.internal.node.NodeClient;
+import org.elasticsearch.core.RestApiVersion;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.rest.BaseRestHandler;
 import org.elasticsearch.rest.RestRequest;
@@ -15,15 +16,16 @@ import org.elasticsearch.rest.action.RestToXContentListener;
 import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
 import org.elasticsearch.xpack.core.ml.inference.assignment.AllocationStatus;
 import org.elasticsearch.xpack.ml.MachineLearning;
+import org.elasticsearch.xpack.ml.rest.RestCompatibilityChecker;
 
 import java.io.IOException;
 import java.util.Collections;
 import java.util.List;
 
 import static org.elasticsearch.rest.RestRequest.Method.POST;
-import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.INFERENCE_THREADS;
-import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.MODEL_THREADS;
+import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.NUMBER_OF_ALLOCATIONS;
 import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.QUEUE_CAPACITY;
+import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.THREADS_PER_ALLOCATION;
 import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.TIMEOUT;
 import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.WAIT_FOR;
 
@@ -65,8 +67,22 @@ public class RestStartTrainedModelDeploymentAction extends BaseRestHandler {
             request.setWaitForState(
                 AllocationStatus.State.fromString(restRequest.param(WAIT_FOR.getPreferredName(), AllocationStatus.State.STARTED.toString()))
             );
-            request.setInferenceThreads(restRequest.paramAsInt(INFERENCE_THREADS.getPreferredName(), request.getInferenceThreads()));
-            request.setModelThreads(restRequest.paramAsInt(MODEL_THREADS.getPreferredName(), request.getModelThreads()));
+            RestCompatibilityChecker.checkAndSetDeprecatedParam(
+                NUMBER_OF_ALLOCATIONS.getDeprecatedNames()[0],
+                NUMBER_OF_ALLOCATIONS.getPreferredName(),
+                RestApiVersion.V_8,
+                restRequest,
+                (r, s) -> r.paramAsInt(s, request.getNumberOfAllocations()),
+                request::setNumberOfAllocations
+            );
+            RestCompatibilityChecker.checkAndSetDeprecatedParam(
+                THREADS_PER_ALLOCATION.getDeprecatedNames()[0],
+                THREADS_PER_ALLOCATION.getPreferredName(),
+                RestApiVersion.V_8,
+                restRequest,
+                (r, s) -> r.paramAsInt(s, request.getThreadsPerAllocation()),
+                request::setThreadsPerAllocation
+            );
             request.setQueueCapacity(restRequest.paramAsInt(QUEUE_CAPACITY.getPreferredName(), request.getQueueCapacity()));
         }
 

+ 1 - 0
x-pack/qa/full-cluster-restart/src/test/java/org/elasticsearch/xpack/restart/MLModelDeploymentFullClusterRestartIT.java

@@ -183,6 +183,7 @@ public class MLModelDeploymentFullClusterRestartIT extends AbstractFullClusterRe
                 + waitForState
                 + "&inference_threads=1&model_threads=1"
         );
+        request.setOptions(request.getOptions().toBuilder().setWarningsHandler(PERMISSIVE).build());
         var response = client().performRequest(request);
         assertOK(response);
         return response;

+ 1 - 0
x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/MLModelDeploymentsUpgradeIT.java

@@ -207,6 +207,7 @@ public class MLModelDeploymentsUpgradeIT extends AbstractUpgradeTestCase {
                 + waitForState
                 + "&inference_threads=1&model_threads=1"
         );
+        request.setOptions(request.getOptions().toBuilder().setWarningsHandler(PERMISSIVE).build());
         var response = client().performRequest(request);
         assertOK(response);
         return response;