|  | @@ -69,8 +69,8 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
 | 
											
												
													
														|  |          public static final ParseField MODEL_ID = new ParseField("model_id");
 |  |          public static final ParseField MODEL_ID = new ParseField("model_id");
 | 
											
												
													
														|  |          public static final ParseField TIMEOUT = new ParseField("timeout");
 |  |          public static final ParseField TIMEOUT = new ParseField("timeout");
 | 
											
												
													
														|  |          public static final ParseField WAIT_FOR = new ParseField("wait_for");
 |  |          public static final ParseField WAIT_FOR = new ParseField("wait_for");
 | 
											
												
													
														|  | -        public static final ParseField INFERENCE_THREADS = TaskParams.INFERENCE_THREADS;
 |  | 
 | 
											
												
													
														|  | -        public static final ParseField MODEL_THREADS = TaskParams.MODEL_THREADS;
 |  | 
 | 
											
												
													
														|  | 
 |  | +        public static final ParseField THREADS_PER_ALLOCATION = new ParseField("threads_per_allocation", "inference_threads");
 | 
											
												
													
														|  | 
 |  | +        public static final ParseField NUMBER_OF_ALLOCATIONS = new ParseField("number_of_allocations", "model_threads");
 | 
											
												
													
														|  |          public static final ParseField QUEUE_CAPACITY = TaskParams.QUEUE_CAPACITY;
 |  |          public static final ParseField QUEUE_CAPACITY = TaskParams.QUEUE_CAPACITY;
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |          public static final ObjectParser<Request, Void> PARSER = new ObjectParser<>(NAME, Request::new);
 |  |          public static final ObjectParser<Request, Void> PARSER = new ObjectParser<>(NAME, Request::new);
 | 
											
										
											
												
													
														|  | @@ -79,8 +79,8 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
 | 
											
												
													
														|  |              PARSER.declareString(Request::setModelId, MODEL_ID);
 |  |              PARSER.declareString(Request::setModelId, MODEL_ID);
 | 
											
												
													
														|  |              PARSER.declareString((request, val) -> request.setTimeout(TimeValue.parseTimeValue(val, TIMEOUT.getPreferredName())), TIMEOUT);
 |  |              PARSER.declareString((request, val) -> request.setTimeout(TimeValue.parseTimeValue(val, TIMEOUT.getPreferredName())), TIMEOUT);
 | 
											
												
													
														|  |              PARSER.declareString((request, waitFor) -> request.setWaitForState(AllocationStatus.State.fromString(waitFor)), WAIT_FOR);
 |  |              PARSER.declareString((request, waitFor) -> request.setWaitForState(AllocationStatus.State.fromString(waitFor)), WAIT_FOR);
 | 
											
												
													
														|  | -            PARSER.declareInt(Request::setInferenceThreads, INFERENCE_THREADS);
 |  | 
 | 
											
												
													
														|  | -            PARSER.declareInt(Request::setModelThreads, MODEL_THREADS);
 |  | 
 | 
											
												
													
														|  | 
 |  | +            PARSER.declareInt(Request::setThreadsPerAllocation, THREADS_PER_ALLOCATION);
 | 
											
												
													
														|  | 
 |  | +            PARSER.declareInt(Request::setNumberOfAllocations, NUMBER_OF_ALLOCATIONS);
 | 
											
												
													
														|  |              PARSER.declareInt(Request::setQueueCapacity, QUEUE_CAPACITY);
 |  |              PARSER.declareInt(Request::setQueueCapacity, QUEUE_CAPACITY);
 | 
											
												
													
														|  |          }
 |  |          }
 | 
											
												
													
														|  |  
 |  |  
 | 
											
										
											
												
													
														|  | @@ -99,8 +99,8 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
 | 
											
												
													
														|  |          private String modelId;
 |  |          private String modelId;
 | 
											
												
													
														|  |          private TimeValue timeout = DEFAULT_TIMEOUT;
 |  |          private TimeValue timeout = DEFAULT_TIMEOUT;
 | 
											
												
													
														|  |          private AllocationStatus.State waitForState = AllocationStatus.State.STARTED;
 |  |          private AllocationStatus.State waitForState = AllocationStatus.State.STARTED;
 | 
											
												
													
														|  | -        private int modelThreads = 1;
 |  | 
 | 
											
												
													
														|  | -        private int inferenceThreads = 1;
 |  | 
 | 
											
												
													
														|  | 
 |  | +        private int numberOfAllocations = 1;
 | 
											
												
													
														|  | 
 |  | +        private int threadsPerAllocation = 1;
 | 
											
												
													
														|  |          private int queueCapacity = 1024;
 |  |          private int queueCapacity = 1024;
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |          private Request() {}
 |  |          private Request() {}
 | 
											
										
											
												
													
														|  | @@ -114,8 +114,8 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
 | 
											
												
													
														|  |              modelId = in.readString();
 |  |              modelId = in.readString();
 | 
											
												
													
														|  |              timeout = in.readTimeValue();
 |  |              timeout = in.readTimeValue();
 | 
											
												
													
														|  |              waitForState = in.readEnum(AllocationStatus.State.class);
 |  |              waitForState = in.readEnum(AllocationStatus.State.class);
 | 
											
												
													
														|  | -            modelThreads = in.readVInt();
 |  | 
 | 
											
												
													
														|  | -            inferenceThreads = in.readVInt();
 |  | 
 | 
											
												
													
														|  | 
 |  | +            numberOfAllocations = in.readVInt();
 | 
											
												
													
														|  | 
 |  | +            threadsPerAllocation = in.readVInt();
 | 
											
												
													
														|  |              queueCapacity = in.readVInt();
 |  |              queueCapacity = in.readVInt();
 | 
											
												
													
														|  |          }
 |  |          }
 | 
											
												
													
														|  |  
 |  |  
 | 
											
										
											
												
													
														|  | @@ -144,20 +144,20 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
 | 
											
												
													
														|  |              return this;
 |  |              return this;
 | 
											
												
													
														|  |          }
 |  |          }
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | -        public int getModelThreads() {
 |  | 
 | 
											
												
													
														|  | -            return modelThreads;
 |  | 
 | 
											
												
													
														|  | 
 |  | +        public int getNumberOfAllocations() {
 | 
											
												
													
														|  | 
 |  | +            return numberOfAllocations;
 | 
											
												
													
														|  |          }
 |  |          }
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | -        public void setModelThreads(int modelThreads) {
 |  | 
 | 
											
												
													
														|  | -            this.modelThreads = modelThreads;
 |  | 
 | 
											
												
													
														|  | 
 |  | +        public void setNumberOfAllocations(int numberOfAllocations) {
 | 
											
												
													
														|  | 
 |  | +            this.numberOfAllocations = numberOfAllocations;
 | 
											
												
													
														|  |          }
 |  |          }
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | -        public int getInferenceThreads() {
 |  | 
 | 
											
												
													
														|  | -            return inferenceThreads;
 |  | 
 | 
											
												
													
														|  | 
 |  | +        public int getThreadsPerAllocation() {
 | 
											
												
													
														|  | 
 |  | +            return threadsPerAllocation;
 | 
											
												
													
														|  |          }
 |  |          }
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | -        public void setInferenceThreads(int inferenceThreads) {
 |  | 
 | 
											
												
													
														|  | -            this.inferenceThreads = inferenceThreads;
 |  | 
 | 
											
												
													
														|  | 
 |  | +        public void setThreadsPerAllocation(int threadsPerAllocation) {
 | 
											
												
													
														|  | 
 |  | +            this.threadsPerAllocation = threadsPerAllocation;
 | 
											
												
													
														|  |          }
 |  |          }
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |          public int getQueueCapacity() {
 |  |          public int getQueueCapacity() {
 | 
											
										
											
												
													
														|  | @@ -174,8 +174,8 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
 | 
											
												
													
														|  |              out.writeString(modelId);
 |  |              out.writeString(modelId);
 | 
											
												
													
														|  |              out.writeTimeValue(timeout);
 |  |              out.writeTimeValue(timeout);
 | 
											
												
													
														|  |              out.writeEnum(waitForState);
 |  |              out.writeEnum(waitForState);
 | 
											
												
													
														|  | -            out.writeVInt(modelThreads);
 |  | 
 | 
											
												
													
														|  | -            out.writeVInt(inferenceThreads);
 |  | 
 | 
											
												
													
														|  | 
 |  | +            out.writeVInt(numberOfAllocations);
 | 
											
												
													
														|  | 
 |  | +            out.writeVInt(threadsPerAllocation);
 | 
											
												
													
														|  |              out.writeVInt(queueCapacity);
 |  |              out.writeVInt(queueCapacity);
 | 
											
												
													
														|  |          }
 |  |          }
 | 
											
												
													
														|  |  
 |  |  
 | 
											
										
											
												
													
														|  | @@ -185,8 +185,8 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
 | 
											
												
													
														|  |              builder.field(MODEL_ID.getPreferredName(), modelId);
 |  |              builder.field(MODEL_ID.getPreferredName(), modelId);
 | 
											
												
													
														|  |              builder.field(TIMEOUT.getPreferredName(), timeout.getStringRep());
 |  |              builder.field(TIMEOUT.getPreferredName(), timeout.getStringRep());
 | 
											
												
													
														|  |              builder.field(WAIT_FOR.getPreferredName(), waitForState);
 |  |              builder.field(WAIT_FOR.getPreferredName(), waitForState);
 | 
											
												
													
														|  | -            builder.field(MODEL_THREADS.getPreferredName(), modelThreads);
 |  | 
 | 
											
												
													
														|  | -            builder.field(INFERENCE_THREADS.getPreferredName(), inferenceThreads);
 |  | 
 | 
											
												
													
														|  | 
 |  | +            builder.field(NUMBER_OF_ALLOCATIONS.getPreferredName(), numberOfAllocations);
 | 
											
												
													
														|  | 
 |  | +            builder.field(THREADS_PER_ALLOCATION.getPreferredName(), threadsPerAllocation);
 | 
											
												
													
														|  |              builder.field(QUEUE_CAPACITY.getPreferredName(), queueCapacity);
 |  |              builder.field(QUEUE_CAPACITY.getPreferredName(), queueCapacity);
 | 
											
												
													
														|  |              builder.endObject();
 |  |              builder.endObject();
 | 
											
												
													
														|  |              return builder;
 |  |              return builder;
 | 
											
										
											
												
													
														|  | @@ -203,11 +203,11 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
 | 
											
												
													
														|  |                          + Strings.arrayToCommaDelimitedString(VALID_WAIT_STATES)
 |  |                          + Strings.arrayToCommaDelimitedString(VALID_WAIT_STATES)
 | 
											
												
													
														|  |                  );
 |  |                  );
 | 
											
												
													
														|  |              }
 |  |              }
 | 
											
												
													
														|  | -            if (modelThreads < 1) {
 |  | 
 | 
											
												
													
														|  | -                validationException.addValidationError("[" + MODEL_THREADS + "] must be a positive integer");
 |  | 
 | 
											
												
													
														|  | 
 |  | +            if (numberOfAllocations < 1) {
 | 
											
												
													
														|  | 
 |  | +                validationException.addValidationError("[" + NUMBER_OF_ALLOCATIONS + "] must be a positive integer");
 | 
											
												
													
														|  |              }
 |  |              }
 | 
											
												
													
														|  | -            if (inferenceThreads < 1) {
 |  | 
 | 
											
												
													
														|  | -                validationException.addValidationError("[" + INFERENCE_THREADS + "] must be a positive integer");
 |  | 
 | 
											
												
													
														|  | 
 |  | +            if (threadsPerAllocation < 1) {
 | 
											
												
													
														|  | 
 |  | +                validationException.addValidationError("[" + THREADS_PER_ALLOCATION + "] must be a positive integer");
 | 
											
												
													
														|  |              }
 |  |              }
 | 
											
												
													
														|  |              if (queueCapacity < 1) {
 |  |              if (queueCapacity < 1) {
 | 
											
												
													
														|  |                  validationException.addValidationError("[" + QUEUE_CAPACITY + "] must be a positive integer");
 |  |                  validationException.addValidationError("[" + QUEUE_CAPACITY + "] must be a positive integer");
 | 
											
										
											
												
													
														|  | @@ -217,7 +217,7 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |          @Override
 |  |          @Override
 | 
											
												
													
														|  |          public int hashCode() {
 |  |          public int hashCode() {
 | 
											
												
													
														|  | -            return Objects.hash(modelId, timeout, waitForState, modelThreads, inferenceThreads, queueCapacity);
 |  | 
 | 
											
												
													
														|  | 
 |  | +            return Objects.hash(modelId, timeout, waitForState, numberOfAllocations, threadsPerAllocation, queueCapacity);
 | 
											
												
													
														|  |          }
 |  |          }
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |          @Override
 |  |          @Override
 | 
											
										
											
												
													
														|  | @@ -232,8 +232,8 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
 | 
											
												
													
														|  |              return Objects.equals(modelId, other.modelId)
 |  |              return Objects.equals(modelId, other.modelId)
 | 
											
												
													
														|  |                  && Objects.equals(timeout, other.timeout)
 |  |                  && Objects.equals(timeout, other.timeout)
 | 
											
												
													
														|  |                  && Objects.equals(waitForState, other.waitForState)
 |  |                  && Objects.equals(waitForState, other.waitForState)
 | 
											
												
													
														|  | -                && modelThreads == other.modelThreads
 |  | 
 | 
											
												
													
														|  | -                && inferenceThreads == other.inferenceThreads
 |  | 
 | 
											
												
													
														|  | 
 |  | +                && numberOfAllocations == other.numberOfAllocations
 | 
											
												
													
														|  | 
 |  | +                && threadsPerAllocation == other.threadsPerAllocation
 | 
											
												
													
														|  |                  && queueCapacity == other.queueCapacity;
 |  |                  && queueCapacity == other.queueCapacity;
 | 
											
												
													
														|  |          }
 |  |          }
 | 
											
												
													
														|  |  
 |  |  
 | 
											
										
											
												
													
														|  | @@ -254,22 +254,28 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |          public static final Version VERSION_INTRODUCED = Version.V_8_0_0;
 |  |          public static final Version VERSION_INTRODUCED = Version.V_8_0_0;
 | 
											
												
													
														|  |          private static final ParseField MODEL_BYTES = new ParseField("model_bytes");
 |  |          private static final ParseField MODEL_BYTES = new ParseField("model_bytes");
 | 
											
												
													
														|  | -        public static final ParseField MODEL_THREADS = new ParseField("model_threads");
 |  | 
 | 
											
												
													
														|  | -        public static final ParseField INFERENCE_THREADS = new ParseField("inference_threads");
 |  | 
 | 
											
												
													
														|  | 
 |  | +        public static final ParseField NUMBER_OF_ALLOCATIONS = new ParseField("number_of_allocations");
 | 
											
												
													
														|  | 
 |  | +        public static final ParseField THREADS_PER_ALLOCATION = new ParseField("threads_per_allocation");
 | 
											
												
													
														|  | 
 |  | +        // number_of_allocations was previously named model_threads
 | 
											
												
													
														|  | 
 |  | +        private static final ParseField LEGACY_MODEL_THREADS = new ParseField("model_threads");
 | 
											
												
													
														|  | 
 |  | +        // threads_per_allocation was previously named inference_threads
 | 
											
												
													
														|  | 
 |  | +        public static final ParseField LEGACY_INFERENCE_THREADS = new ParseField("inference_threads");
 | 
											
												
													
														|  |          public static final ParseField QUEUE_CAPACITY = new ParseField("queue_capacity");
 |  |          public static final ParseField QUEUE_CAPACITY = new ParseField("queue_capacity");
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |          private static final ConstructingObjectParser<TaskParams, Void> PARSER = new ConstructingObjectParser<>(
 |  |          private static final ConstructingObjectParser<TaskParams, Void> PARSER = new ConstructingObjectParser<>(
 | 
											
												
													
														|  |              "trained_model_deployment_params",
 |  |              "trained_model_deployment_params",
 | 
											
												
													
														|  |              true,
 |  |              true,
 | 
											
												
													
														|  | -            a -> new TaskParams((String) a[0], (Long) a[1], (int) a[2], (int) a[3], (int) a[4])
 |  | 
 | 
											
												
													
														|  | 
 |  | +            a -> new TaskParams((String) a[0], (Long) a[1], (Integer) a[2], (Integer) a[3], (int) a[4], (Integer) a[5], (Integer) a[6])
 | 
											
												
													
														|  |          );
 |  |          );
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |          static {
 |  |          static {
 | 
											
												
													
														|  |              PARSER.declareString(ConstructingObjectParser.constructorArg(), TrainedModelConfig.MODEL_ID);
 |  |              PARSER.declareString(ConstructingObjectParser.constructorArg(), TrainedModelConfig.MODEL_ID);
 | 
											
												
													
														|  |              PARSER.declareLong(ConstructingObjectParser.constructorArg(), MODEL_BYTES);
 |  |              PARSER.declareLong(ConstructingObjectParser.constructorArg(), MODEL_BYTES);
 | 
											
												
													
														|  | -            PARSER.declareInt(ConstructingObjectParser.constructorArg(), INFERENCE_THREADS);
 |  | 
 | 
											
												
													
														|  | -            PARSER.declareInt(ConstructingObjectParser.constructorArg(), MODEL_THREADS);
 |  | 
 | 
											
												
													
														|  | 
 |  | +            PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUMBER_OF_ALLOCATIONS);
 | 
											
												
													
														|  | 
 |  | +            PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), THREADS_PER_ALLOCATION);
 | 
											
												
													
														|  |              PARSER.declareInt(ConstructingObjectParser.constructorArg(), QUEUE_CAPACITY);
 |  |              PARSER.declareInt(ConstructingObjectParser.constructorArg(), QUEUE_CAPACITY);
 | 
											
												
													
														|  | 
 |  | +            PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), LEGACY_MODEL_THREADS);
 | 
											
												
													
														|  | 
 |  | +            PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), LEGACY_INFERENCE_THREADS);
 | 
											
												
													
														|  |          }
 |  |          }
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |          public static TaskParams fromXContent(XContentParser parser) {
 |  |          public static TaskParams fromXContent(XContentParser parser) {
 | 
											
										
											
												
													
														|  | @@ -279,24 +285,42 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
 | 
											
												
													
														|  |          private final String modelId;
 |  |          private final String modelId;
 | 
											
												
													
														|  |          private final long modelBytes;
 |  |          private final long modelBytes;
 | 
											
												
													
														|  |          // How many threads are used by the model during inference. Used to increase inference speed.
 |  |          // How many threads are used by the model during inference. Used to increase inference speed.
 | 
											
												
													
														|  | -        private final int inferenceThreads;
 |  | 
 | 
											
												
													
														|  | 
 |  | +        private final int threadsPerAllocation;
 | 
											
												
													
														|  |          // How many threads are used when forwarding the request to the model. Used to increase throughput.
 |  |          // How many threads are used when forwarding the request to the model. Used to increase throughput.
 | 
											
												
													
														|  | -        private final int modelThreads;
 |  | 
 | 
											
												
													
														|  | 
 |  | +        private final int numberOfAllocations;
 | 
											
												
													
														|  |          private final int queueCapacity;
 |  |          private final int queueCapacity;
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | -        public TaskParams(String modelId, long modelBytes, int inferenceThreads, int modelThreads, int queueCapacity) {
 |  | 
 | 
											
												
													
														|  | 
 |  | +        private TaskParams(
 | 
											
												
													
														|  | 
 |  | +            String modelId,
 | 
											
												
													
														|  | 
 |  | +            long modelBytes,
 | 
											
												
													
														|  | 
 |  | +            Integer numberOfAllocations,
 | 
											
												
													
														|  | 
 |  | +            Integer threadsPerAllocation,
 | 
											
												
													
														|  | 
 |  | +            int queueCapacity,
 | 
											
												
													
														|  | 
 |  | +            Integer legacyModelThreads,
 | 
											
												
													
														|  | 
 |  | +            Integer legacyInferenceThreads
 | 
											
												
													
														|  | 
 |  | +        ) {
 | 
											
												
													
														|  | 
 |  | +            this(
 | 
											
												
													
														|  | 
 |  | +                modelId,
 | 
											
												
													
														|  | 
 |  | +                modelBytes,
 | 
											
												
													
														|  | 
 |  | +                threadsPerAllocation == null ? legacyInferenceThreads : threadsPerAllocation,
 | 
											
												
													
														|  | 
 |  | +                numberOfAllocations == null ? legacyModelThreads : numberOfAllocations,
 | 
											
												
													
														|  | 
 |  | +                queueCapacity
 | 
											
												
													
														|  | 
 |  | +            );
 | 
											
												
													
														|  | 
 |  | +        }
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        public TaskParams(String modelId, long modelBytes, int threadsPerAllocation, int numberOfAllocations, int queueCapacity) {
 | 
											
												
													
														|  |              this.modelId = Objects.requireNonNull(modelId);
 |  |              this.modelId = Objects.requireNonNull(modelId);
 | 
											
												
													
														|  |              this.modelBytes = modelBytes;
 |  |              this.modelBytes = modelBytes;
 | 
											
												
													
														|  | -            this.inferenceThreads = inferenceThreads;
 |  | 
 | 
											
												
													
														|  | -            this.modelThreads = modelThreads;
 |  | 
 | 
											
												
													
														|  | 
 |  | +            this.threadsPerAllocation = threadsPerAllocation;
 | 
											
												
													
														|  | 
 |  | +            this.numberOfAllocations = numberOfAllocations;
 | 
											
												
													
														|  |              this.queueCapacity = queueCapacity;
 |  |              this.queueCapacity = queueCapacity;
 | 
											
												
													
														|  |          }
 |  |          }
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |          public TaskParams(StreamInput in) throws IOException {
 |  |          public TaskParams(StreamInput in) throws IOException {
 | 
											
												
													
														|  |              this.modelId = in.readString();
 |  |              this.modelId = in.readString();
 | 
											
												
													
														|  |              this.modelBytes = in.readLong();
 |  |              this.modelBytes = in.readLong();
 | 
											
												
													
														|  | -            this.inferenceThreads = in.readVInt();
 |  | 
 | 
											
												
													
														|  | -            this.modelThreads = in.readVInt();
 |  | 
 | 
											
												
													
														|  | 
 |  | +            this.threadsPerAllocation = in.readVInt();
 | 
											
												
													
														|  | 
 |  | +            this.numberOfAllocations = in.readVInt();
 | 
											
												
													
														|  |              this.queueCapacity = in.readVInt();
 |  |              this.queueCapacity = in.readVInt();
 | 
											
												
													
														|  |          }
 |  |          }
 | 
											
												
													
														|  |  
 |  |  
 | 
											
										
											
												
													
														|  | @@ -316,8 +340,8 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
 | 
											
												
													
														|  |          public void writeTo(StreamOutput out) throws IOException {
 |  |          public void writeTo(StreamOutput out) throws IOException {
 | 
											
												
													
														|  |              out.writeString(modelId);
 |  |              out.writeString(modelId);
 | 
											
												
													
														|  |              out.writeLong(modelBytes);
 |  |              out.writeLong(modelBytes);
 | 
											
												
													
														|  | -            out.writeVInt(inferenceThreads);
 |  | 
 | 
											
												
													
														|  | -            out.writeVInt(modelThreads);
 |  | 
 | 
											
												
													
														|  | 
 |  | +            out.writeVInt(threadsPerAllocation);
 | 
											
												
													
														|  | 
 |  | +            out.writeVInt(numberOfAllocations);
 | 
											
												
													
														|  |              out.writeVInt(queueCapacity);
 |  |              out.writeVInt(queueCapacity);
 | 
											
												
													
														|  |          }
 |  |          }
 | 
											
												
													
														|  |  
 |  |  
 | 
											
										
											
												
													
														|  | @@ -326,8 +350,8 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
 | 
											
												
													
														|  |              builder.startObject();
 |  |              builder.startObject();
 | 
											
												
													
														|  |              builder.field(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId);
 |  |              builder.field(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId);
 | 
											
												
													
														|  |              builder.field(MODEL_BYTES.getPreferredName(), modelBytes);
 |  |              builder.field(MODEL_BYTES.getPreferredName(), modelBytes);
 | 
											
												
													
														|  | -            builder.field(INFERENCE_THREADS.getPreferredName(), inferenceThreads);
 |  | 
 | 
											
												
													
														|  | -            builder.field(MODEL_THREADS.getPreferredName(), modelThreads);
 |  | 
 | 
											
												
													
														|  | 
 |  | +            builder.field(THREADS_PER_ALLOCATION.getPreferredName(), threadsPerAllocation);
 | 
											
												
													
														|  | 
 |  | +            builder.field(NUMBER_OF_ALLOCATIONS.getPreferredName(), numberOfAllocations);
 | 
											
												
													
														|  |              builder.field(QUEUE_CAPACITY.getPreferredName(), queueCapacity);
 |  |              builder.field(QUEUE_CAPACITY.getPreferredName(), queueCapacity);
 | 
											
												
													
														|  |              builder.endObject();
 |  |              builder.endObject();
 | 
											
												
													
														|  |              return builder;
 |  |              return builder;
 | 
											
										
											
												
													
														|  | @@ -335,7 +359,7 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |          @Override
 |  |          @Override
 | 
											
												
													
														|  |          public int hashCode() {
 |  |          public int hashCode() {
 | 
											
												
													
														|  | -            return Objects.hash(modelId, modelBytes, inferenceThreads, modelThreads, queueCapacity);
 |  | 
 | 
											
												
													
														|  | 
 |  | +            return Objects.hash(modelId, modelBytes, threadsPerAllocation, numberOfAllocations, queueCapacity);
 | 
											
												
													
														|  |          }
 |  |          }
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |          @Override
 |  |          @Override
 | 
											
										
											
												
													
														|  | @@ -346,8 +370,8 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
 | 
											
												
													
														|  |              TaskParams other = (TaskParams) o;
 |  |              TaskParams other = (TaskParams) o;
 | 
											
												
													
														|  |              return Objects.equals(modelId, other.modelId)
 |  |              return Objects.equals(modelId, other.modelId)
 | 
											
												
													
														|  |                  && modelBytes == other.modelBytes
 |  |                  && modelBytes == other.modelBytes
 | 
											
												
													
														|  | -                && inferenceThreads == other.inferenceThreads
 |  | 
 | 
											
												
													
														|  | -                && modelThreads == other.modelThreads
 |  | 
 | 
											
												
													
														|  | 
 |  | +                && threadsPerAllocation == other.threadsPerAllocation
 | 
											
												
													
														|  | 
 |  | +                && numberOfAllocations == other.numberOfAllocations
 | 
											
												
													
														|  |                  && queueCapacity == other.queueCapacity;
 |  |                  && queueCapacity == other.queueCapacity;
 | 
											
												
													
														|  |          }
 |  |          }
 | 
											
												
													
														|  |  
 |  |  
 | 
											
										
											
												
													
														|  | @@ -360,12 +384,12 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
 | 
											
												
													
														|  |              return modelBytes;
 |  |              return modelBytes;
 | 
											
												
													
														|  |          }
 |  |          }
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | -        public int getInferenceThreads() {
 |  | 
 | 
											
												
													
														|  | -            return inferenceThreads;
 |  | 
 | 
											
												
													
														|  | 
 |  | +        public int getThreadsPerAllocation() {
 | 
											
												
													
														|  | 
 |  | +            return threadsPerAllocation;
 | 
											
												
													
														|  |          }
 |  |          }
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | -        public int getModelThreads() {
 |  | 
 | 
											
												
													
														|  | -            return modelThreads;
 |  | 
 | 
											
												
													
														|  | 
 |  | +        public int getNumberOfAllocations() {
 | 
											
												
													
														|  | 
 |  | +            return numberOfAllocations;
 | 
											
												
													
														|  |          }
 |  |          }
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |          public int getQueueCapacity() {
 |  |          public int getQueueCapacity() {
 |