|  | @@ -18,6 +18,7 @@ import org.elasticsearch.common.io.stream.StreamInput;
 | 
	
		
			
				|  |  |  import org.elasticsearch.common.io.stream.StreamOutput;
 | 
	
		
			
				|  |  |  import org.elasticsearch.common.io.stream.Writeable;
 | 
	
		
			
				|  |  |  import org.elasticsearch.common.unit.ByteSizeValue;
 | 
	
		
			
				|  |  | +import org.elasticsearch.core.Nullable;
 | 
	
		
			
				|  |  |  import org.elasticsearch.core.TimeValue;
 | 
	
		
			
				|  |  |  import org.elasticsearch.tasks.Task;
 | 
	
		
			
				|  |  |  import org.elasticsearch.xcontent.ConstructingObjectParser;
 | 
	
	
		
			
				|  | @@ -34,8 +35,10 @@ import org.elasticsearch.xpack.core.ml.utils.MlTaskParams;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  import java.io.IOException;
 | 
	
		
			
				|  |  |  import java.util.Objects;
 | 
	
		
			
				|  |  | +import java.util.Optional;
 | 
	
		
			
				|  |  |  import java.util.concurrent.TimeUnit;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
 | 
	
		
			
				|  |  |  import static org.elasticsearch.xpack.core.ml.MlTasks.trainedModelAssignmentTaskDescription;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedModelAssignmentAction.Response> {
 | 
	
	
		
			
				|  | @@ -75,6 +78,7 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
 | 
	
		
			
				|  |  |          public static final ParseField THREADS_PER_ALLOCATION = new ParseField("threads_per_allocation", "inference_threads");
 | 
	
		
			
				|  |  |          public static final ParseField NUMBER_OF_ALLOCATIONS = new ParseField("number_of_allocations", "model_threads");
 | 
	
		
			
				|  |  |          public static final ParseField QUEUE_CAPACITY = TaskParams.QUEUE_CAPACITY;
 | 
	
		
			
				|  |  | +        public static final ParseField CACHE_SIZE = TaskParams.CACHE_SIZE;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          public static final ObjectParser<Request, Void> PARSER = new ObjectParser<>(NAME, Request::new);
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -85,6 +89,12 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
 | 
	
		
			
				|  |  |              PARSER.declareInt(Request::setThreadsPerAllocation, THREADS_PER_ALLOCATION);
 | 
	
		
			
				|  |  |              PARSER.declareInt(Request::setNumberOfAllocations, NUMBER_OF_ALLOCATIONS);
 | 
	
		
			
				|  |  |              PARSER.declareInt(Request::setQueueCapacity, QUEUE_CAPACITY);
 | 
	
		
			
				|  |  | +            PARSER.declareField(
 | 
	
		
			
				|  |  | +                Request::setCacheSize,
 | 
	
		
			
				|  |  | +                (p, c) -> ByteSizeValue.parseBytesSizeValue(p.text(), CACHE_SIZE.getPreferredName()),
 | 
	
		
			
				|  |  | +                CACHE_SIZE,
 | 
	
		
			
				|  |  | +                ObjectParser.ValueType.VALUE
 | 
	
		
			
				|  |  | +            );
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          public static Request parseRequest(String modelId, XContentParser parser) {
 | 
	
	
		
			
				|  | @@ -102,6 +112,7 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
 | 
	
		
			
				|  |  |          private String modelId;
 | 
	
		
			
				|  |  |          private TimeValue timeout = DEFAULT_TIMEOUT;
 | 
	
		
			
				|  |  |          private AllocationStatus.State waitForState = AllocationStatus.State.STARTED;
 | 
	
		
			
				|  |  | +        private ByteSizeValue cacheSize;
 | 
	
		
			
				|  |  |          private int numberOfAllocations = 1;
 | 
	
		
			
				|  |  |          private int threadsPerAllocation = 1;
 | 
	
		
			
				|  |  |          private int queueCapacity = 1024;
 | 
	
	
		
			
				|  | @@ -120,6 +131,9 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
 | 
	
		
			
				|  |  |              numberOfAllocations = in.readVInt();
 | 
	
		
			
				|  |  |              threadsPerAllocation = in.readVInt();
 | 
	
		
			
				|  |  |              queueCapacity = in.readVInt();
 | 
	
		
			
				|  |  | +            if (in.getVersion().onOrAfter(Version.V_8_4_0)) {
 | 
	
		
			
				|  |  | +                this.cacheSize = in.readOptionalWriteable(ByteSizeValue::new);
 | 
	
		
			
				|  |  | +            }
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          public final void setModelId(String modelId) {
 | 
	
	
		
			
				|  | @@ -171,6 +185,14 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
 | 
	
		
			
				|  |  |              this.queueCapacity = queueCapacity;
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +        public ByteSizeValue getCacheSize() {
 | 
	
		
			
				|  |  | +            return cacheSize;
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        public void setCacheSize(ByteSizeValue cacheSize) {
 | 
	
		
			
				|  |  | +            this.cacheSize = cacheSize;
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |          @Override
 | 
	
		
			
				|  |  |          public void writeTo(StreamOutput out) throws IOException {
 | 
	
		
			
				|  |  |              super.writeTo(out);
 | 
	
	
		
			
				|  | @@ -180,6 +202,9 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
 | 
	
		
			
				|  |  |              out.writeVInt(numberOfAllocations);
 | 
	
		
			
				|  |  |              out.writeVInt(threadsPerAllocation);
 | 
	
		
			
				|  |  |              out.writeVInt(queueCapacity);
 | 
	
		
			
				|  |  | +            if (out.getVersion().onOrAfter(Version.V_8_4_0)) {
 | 
	
		
			
				|  |  | +                out.writeOptionalWriteable(cacheSize);
 | 
	
		
			
				|  |  | +            }
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          @Override
 | 
	
	
		
			
				|  | @@ -191,6 +216,9 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
 | 
	
		
			
				|  |  |              builder.field(NUMBER_OF_ALLOCATIONS.getPreferredName(), numberOfAllocations);
 | 
	
		
			
				|  |  |              builder.field(THREADS_PER_ALLOCATION.getPreferredName(), threadsPerAllocation);
 | 
	
		
			
				|  |  |              builder.field(QUEUE_CAPACITY.getPreferredName(), queueCapacity);
 | 
	
		
			
				|  |  | +            if (cacheSize != null) {
 | 
	
		
			
				|  |  | +                builder.field(CACHE_SIZE.getPreferredName(), cacheSize);
 | 
	
		
			
				|  |  | +            }
 | 
	
		
			
				|  |  |              builder.endObject();
 | 
	
		
			
				|  |  |              return builder;
 | 
	
		
			
				|  |  |          }
 | 
	
	
		
			
				|  | @@ -229,7 +257,7 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          @Override
 | 
	
		
			
				|  |  |          public int hashCode() {
 | 
	
		
			
				|  |  | -            return Objects.hash(modelId, timeout, waitForState, numberOfAllocations, threadsPerAllocation, queueCapacity);
 | 
	
		
			
				|  |  | +            return Objects.hash(modelId, timeout, waitForState, numberOfAllocations, threadsPerAllocation, queueCapacity, cacheSize);
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          @Override
 | 
	
	
		
			
				|  | @@ -244,6 +272,7 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
 | 
	
		
			
				|  |  |              return Objects.equals(modelId, other.modelId)
 | 
	
		
			
				|  |  |                  && Objects.equals(timeout, other.timeout)
 | 
	
		
			
				|  |  |                  && Objects.equals(waitForState, other.waitForState)
 | 
	
		
			
				|  |  | +                && Objects.equals(cacheSize, other.cacheSize)
 | 
	
		
			
				|  |  |                  && numberOfAllocations == other.numberOfAllocations
 | 
	
		
			
				|  |  |                  && threadsPerAllocation == other.threadsPerAllocation
 | 
	
		
			
				|  |  |                  && queueCapacity == other.queueCapacity;
 | 
	
	
		
			
				|  | @@ -273,11 +302,21 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
 | 
	
		
			
				|  |  |          // threads_per_allocation was previously named inference_threads
 | 
	
		
			
				|  |  |          public static final ParseField LEGACY_INFERENCE_THREADS = new ParseField("inference_threads");
 | 
	
		
			
				|  |  |          public static final ParseField QUEUE_CAPACITY = new ParseField("queue_capacity");
 | 
	
		
			
				|  |  | +        public static final ParseField CACHE_SIZE = new ParseField("cache_size");
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          private static final ConstructingObjectParser<TaskParams, Void> PARSER = new ConstructingObjectParser<>(
 | 
	
		
			
				|  |  |              "trained_model_deployment_params",
 | 
	
		
			
				|  |  |              true,
 | 
	
		
			
				|  |  | -            a -> new TaskParams((String) a[0], (Long) a[1], (Integer) a[2], (Integer) a[3], (int) a[4], (Integer) a[5], (Integer) a[6])
 | 
	
		
			
				|  |  | +            a -> new TaskParams(
 | 
	
		
			
				|  |  | +                (String) a[0],
 | 
	
		
			
				|  |  | +                (Long) a[1],
 | 
	
		
			
				|  |  | +                (Integer) a[2],
 | 
	
		
			
				|  |  | +                (Integer) a[3],
 | 
	
		
			
				|  |  | +                (int) a[4],
 | 
	
		
			
				|  |  | +                (ByteSizeValue) a[5],
 | 
	
		
			
				|  |  | +                (Integer) a[6],
 | 
	
		
			
				|  |  | +                (Integer) a[7]
 | 
	
		
			
				|  |  | +            )
 | 
	
		
			
				|  |  |          );
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          static {
 | 
	
	
		
			
				|  | @@ -286,6 +325,12 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
 | 
	
		
			
				|  |  |              PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUMBER_OF_ALLOCATIONS);
 | 
	
		
			
				|  |  |              PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), THREADS_PER_ALLOCATION);
 | 
	
		
			
				|  |  |              PARSER.declareInt(ConstructingObjectParser.constructorArg(), QUEUE_CAPACITY);
 | 
	
		
			
				|  |  | +            PARSER.declareField(
 | 
	
		
			
				|  |  | +                optionalConstructorArg(),
 | 
	
		
			
				|  |  | +                (p, c) -> ByteSizeValue.parseBytesSizeValue(p.text(), CACHE_SIZE.getPreferredName()),
 | 
	
		
			
				|  |  | +                CACHE_SIZE,
 | 
	
		
			
				|  |  | +                ObjectParser.ValueType.VALUE
 | 
	
		
			
				|  |  | +            );
 | 
	
		
			
				|  |  |              PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), LEGACY_MODEL_THREADS);
 | 
	
		
			
				|  |  |              PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), LEGACY_INFERENCE_THREADS);
 | 
	
		
			
				|  |  |          }
 | 
	
	
		
			
				|  | @@ -295,6 +340,7 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          private final String modelId;
 | 
	
		
			
				|  |  | +        private final ByteSizeValue cacheSize;
 | 
	
		
			
				|  |  |          private final long modelBytes;
 | 
	
		
			
				|  |  |          // How many threads are used by the model during inference. Used to increase inference speed.
 | 
	
		
			
				|  |  |          private final int threadsPerAllocation;
 | 
	
	
		
			
				|  | @@ -308,6 +354,7 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
 | 
	
		
			
				|  |  |              Integer numberOfAllocations,
 | 
	
		
			
				|  |  |              Integer threadsPerAllocation,
 | 
	
		
			
				|  |  |              int queueCapacity,
 | 
	
		
			
				|  |  | +            ByteSizeValue cacheSizeValue,
 | 
	
		
			
				|  |  |              Integer legacyModelThreads,
 | 
	
		
			
				|  |  |              Integer legacyInferenceThreads
 | 
	
		
			
				|  |  |          ) {
 | 
	
	
		
			
				|  | @@ -316,16 +363,25 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
 | 
	
		
			
				|  |  |                  modelBytes,
 | 
	
		
			
				|  |  |                  threadsPerAllocation == null ? legacyInferenceThreads : threadsPerAllocation,
 | 
	
		
			
				|  |  |                  numberOfAllocations == null ? legacyModelThreads : numberOfAllocations,
 | 
	
		
			
				|  |  | -                queueCapacity
 | 
	
		
			
				|  |  | +                queueCapacity,
 | 
	
		
			
				|  |  | +                cacheSizeValue
 | 
	
		
			
				|  |  |              );
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -        public TaskParams(String modelId, long modelBytes, int threadsPerAllocation, int numberOfAllocations, int queueCapacity) {
 | 
	
		
			
				|  |  | +        public TaskParams(
 | 
	
		
			
				|  |  | +            String modelId,
 | 
	
		
			
				|  |  | +            long modelBytes,
 | 
	
		
			
				|  |  | +            int threadsPerAllocation,
 | 
	
		
			
				|  |  | +            int numberOfAllocations,
 | 
	
		
			
				|  |  | +            int queueCapacity,
 | 
	
		
			
				|  |  | +            @Nullable ByteSizeValue cacheSize
 | 
	
		
			
				|  |  | +        ) {
 | 
	
		
			
				|  |  |              this.modelId = Objects.requireNonNull(modelId);
 | 
	
		
			
				|  |  |              this.modelBytes = modelBytes;
 | 
	
		
			
				|  |  |              this.threadsPerAllocation = threadsPerAllocation;
 | 
	
		
			
				|  |  |              this.numberOfAllocations = numberOfAllocations;
 | 
	
		
			
				|  |  |              this.queueCapacity = queueCapacity;
 | 
	
		
			
				|  |  | +            this.cacheSize = cacheSize;
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          public TaskParams(StreamInput in) throws IOException {
 | 
	
	
		
			
				|  | @@ -334,6 +390,11 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
 | 
	
		
			
				|  |  |              this.threadsPerAllocation = in.readVInt();
 | 
	
		
			
				|  |  |              this.numberOfAllocations = in.readVInt();
 | 
	
		
			
				|  |  |              this.queueCapacity = in.readVInt();
 | 
	
		
			
				|  |  | +            if (in.getVersion().onOrAfter(Version.V_8_4_0)) {
 | 
	
		
			
				|  |  | +                this.cacheSize = in.readOptionalWriteable(ByteSizeValue::new);
 | 
	
		
			
				|  |  | +            } else {
 | 
	
		
			
				|  |  | +                this.cacheSize = null;
 | 
	
		
			
				|  |  | +            }
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          public String getModelId() {
 | 
	
	
		
			
				|  | @@ -341,6 +402,11 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          public long estimateMemoryUsageBytes() {
 | 
	
		
			
				|  |  | +            // We already take into account 2x the model bytes. If the cache size is larger than the model bytes, then
 | 
	
		
			
				|  |  | +            // we need to take it into account when returning the estimate.
 | 
	
		
			
				|  |  | +            if (cacheSize != null && cacheSize.getBytes() > modelBytes) {
 | 
	
		
			
				|  |  | +                return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(modelBytes) + (cacheSize.getBytes() - modelBytes);
 | 
	
		
			
				|  |  | +            }
 | 
	
		
			
				|  |  |              return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(modelBytes);
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -355,6 +421,9 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
 | 
	
		
			
				|  |  |              out.writeVInt(threadsPerAllocation);
 | 
	
		
			
				|  |  |              out.writeVInt(numberOfAllocations);
 | 
	
		
			
				|  |  |              out.writeVInt(queueCapacity);
 | 
	
		
			
				|  |  | +            if (out.getVersion().onOrAfter(Version.V_8_4_0)) {
 | 
	
		
			
				|  |  | +                out.writeOptionalWriteable(cacheSize);
 | 
	
		
			
				|  |  | +            }
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          @Override
 | 
	
	
		
			
				|  | @@ -365,13 +434,16 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
 | 
	
		
			
				|  |  |              builder.field(THREADS_PER_ALLOCATION.getPreferredName(), threadsPerAllocation);
 | 
	
		
			
				|  |  |              builder.field(NUMBER_OF_ALLOCATIONS.getPreferredName(), numberOfAllocations);
 | 
	
		
			
				|  |  |              builder.field(QUEUE_CAPACITY.getPreferredName(), queueCapacity);
 | 
	
		
			
				|  |  | +            if (cacheSize != null) {
 | 
	
		
			
				|  |  | +                builder.field(CACHE_SIZE.getPreferredName(), cacheSize.getStringRep());
 | 
	
		
			
				|  |  | +            }
 | 
	
		
			
				|  |  |              builder.endObject();
 | 
	
		
			
				|  |  |              return builder;
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          @Override
 | 
	
		
			
				|  |  |          public int hashCode() {
 | 
	
		
			
				|  |  | -            return Objects.hash(modelId, modelBytes, threadsPerAllocation, numberOfAllocations, queueCapacity);
 | 
	
		
			
				|  |  | +            return Objects.hash(modelId, modelBytes, threadsPerAllocation, numberOfAllocations, queueCapacity, cacheSize);
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          @Override
 | 
	
	
		
			
				|  | @@ -384,6 +456,7 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
 | 
	
		
			
				|  |  |                  && modelBytes == other.modelBytes
 | 
	
		
			
				|  |  |                  && threadsPerAllocation == other.threadsPerAllocation
 | 
	
		
			
				|  |  |                  && numberOfAllocations == other.numberOfAllocations
 | 
	
		
			
				|  |  | +                && Objects.equals(cacheSize, other.cacheSize)
 | 
	
		
			
				|  |  |                  && queueCapacity == other.queueCapacity;
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -408,6 +481,14 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
 | 
	
		
			
				|  |  |              return queueCapacity;
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +        public Optional<ByteSizeValue> getCacheSize() {
 | 
	
		
			
				|  |  | +            return Optional.ofNullable(cacheSize);
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        public long getCacheSizeBytes() {
 | 
	
		
			
				|  |  | +            return Optional.ofNullable(cacheSize).map(ByteSizeValue::getBytes).orElse(modelBytes);
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |          @Override
 | 
	
		
			
				|  |  |          public String toString() {
 | 
	
		
			
				|  |  |              return Strings.toString(this);
 |