|
@@ -18,6 +18,7 @@ import org.elasticsearch.common.io.stream.StreamInput;
|
|
|
import org.elasticsearch.common.io.stream.StreamOutput;
|
|
|
import org.elasticsearch.common.io.stream.Writeable;
|
|
|
import org.elasticsearch.common.unit.ByteSizeValue;
|
|
|
+import org.elasticsearch.core.Nullable;
|
|
|
import org.elasticsearch.core.TimeValue;
|
|
|
import org.elasticsearch.tasks.Task;
|
|
|
import org.elasticsearch.xcontent.ConstructingObjectParser;
|
|
@@ -34,8 +35,10 @@ import org.elasticsearch.xpack.core.ml.utils.MlTaskParams;
|
|
|
|
|
|
import java.io.IOException;
|
|
|
import java.util.Objects;
|
|
|
+import java.util.Optional;
|
|
|
import java.util.concurrent.TimeUnit;
|
|
|
|
|
|
+import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
|
|
import static org.elasticsearch.xpack.core.ml.MlTasks.trainedModelAssignmentTaskDescription;
|
|
|
|
|
|
public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedModelAssignmentAction.Response> {
|
|
@@ -75,6 +78,7 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|
|
public static final ParseField THREADS_PER_ALLOCATION = new ParseField("threads_per_allocation", "inference_threads");
|
|
|
public static final ParseField NUMBER_OF_ALLOCATIONS = new ParseField("number_of_allocations", "model_threads");
|
|
|
public static final ParseField QUEUE_CAPACITY = TaskParams.QUEUE_CAPACITY;
|
|
|
+ public static final ParseField CACHE_SIZE = TaskParams.CACHE_SIZE;
|
|
|
|
|
|
public static final ObjectParser<Request, Void> PARSER = new ObjectParser<>(NAME, Request::new);
|
|
|
|
|
@@ -85,6 +89,12 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|
|
PARSER.declareInt(Request::setThreadsPerAllocation, THREADS_PER_ALLOCATION);
|
|
|
PARSER.declareInt(Request::setNumberOfAllocations, NUMBER_OF_ALLOCATIONS);
|
|
|
PARSER.declareInt(Request::setQueueCapacity, QUEUE_CAPACITY);
|
|
|
+ PARSER.declareField(
|
|
|
+ Request::setCacheSize,
|
|
|
+ (p, c) -> ByteSizeValue.parseBytesSizeValue(p.text(), CACHE_SIZE.getPreferredName()),
|
|
|
+ CACHE_SIZE,
|
|
|
+ ObjectParser.ValueType.VALUE
|
|
|
+ );
|
|
|
}
|
|
|
|
|
|
public static Request parseRequest(String modelId, XContentParser parser) {
|
|
@@ -102,6 +112,7 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|
|
private String modelId;
|
|
|
private TimeValue timeout = DEFAULT_TIMEOUT;
|
|
|
private AllocationStatus.State waitForState = AllocationStatus.State.STARTED;
|
|
|
+ private ByteSizeValue cacheSize;
|
|
|
private int numberOfAllocations = 1;
|
|
|
private int threadsPerAllocation = 1;
|
|
|
private int queueCapacity = 1024;
|
|
@@ -120,6 +131,9 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|
|
numberOfAllocations = in.readVInt();
|
|
|
threadsPerAllocation = in.readVInt();
|
|
|
queueCapacity = in.readVInt();
|
|
|
+ if (in.getVersion().onOrAfter(Version.V_8_4_0)) {
|
|
|
+ this.cacheSize = in.readOptionalWriteable(ByteSizeValue::new);
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
public final void setModelId(String modelId) {
|
|
@@ -171,6 +185,14 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|
|
this.queueCapacity = queueCapacity;
|
|
|
}
|
|
|
|
|
|
+ public ByteSizeValue getCacheSize() {
|
|
|
+ return cacheSize;
|
|
|
+ }
|
|
|
+
|
|
|
+ public void setCacheSize(ByteSizeValue cacheSize) {
|
|
|
+ this.cacheSize = cacheSize;
|
|
|
+ }
|
|
|
+
|
|
|
@Override
|
|
|
public void writeTo(StreamOutput out) throws IOException {
|
|
|
super.writeTo(out);
|
|
@@ -180,6 +202,9 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|
|
out.writeVInt(numberOfAllocations);
|
|
|
out.writeVInt(threadsPerAllocation);
|
|
|
out.writeVInt(queueCapacity);
|
|
|
+ if (out.getVersion().onOrAfter(Version.V_8_4_0)) {
|
|
|
+ out.writeOptionalWriteable(cacheSize);
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
@Override
|
|
@@ -191,6 +216,9 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|
|
builder.field(NUMBER_OF_ALLOCATIONS.getPreferredName(), numberOfAllocations);
|
|
|
builder.field(THREADS_PER_ALLOCATION.getPreferredName(), threadsPerAllocation);
|
|
|
builder.field(QUEUE_CAPACITY.getPreferredName(), queueCapacity);
|
|
|
+ if (cacheSize != null) {
|
|
|
+ builder.field(CACHE_SIZE.getPreferredName(), cacheSize);
|
|
|
+ }
|
|
|
builder.endObject();
|
|
|
return builder;
|
|
|
}
|
|
@@ -229,7 +257,7 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|
|
|
|
|
@Override
|
|
|
public int hashCode() {
|
|
|
- return Objects.hash(modelId, timeout, waitForState, numberOfAllocations, threadsPerAllocation, queueCapacity);
|
|
|
+ return Objects.hash(modelId, timeout, waitForState, numberOfAllocations, threadsPerAllocation, queueCapacity, cacheSize);
|
|
|
}
|
|
|
|
|
|
@Override
|
|
@@ -244,6 +272,7 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|
|
return Objects.equals(modelId, other.modelId)
|
|
|
&& Objects.equals(timeout, other.timeout)
|
|
|
&& Objects.equals(waitForState, other.waitForState)
|
|
|
+ && Objects.equals(cacheSize, other.cacheSize)
|
|
|
&& numberOfAllocations == other.numberOfAllocations
|
|
|
&& threadsPerAllocation == other.threadsPerAllocation
|
|
|
&& queueCapacity == other.queueCapacity;
|
|
@@ -273,11 +302,21 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|
|
// threads_per_allocation was previously named inference_threads
|
|
|
public static final ParseField LEGACY_INFERENCE_THREADS = new ParseField("inference_threads");
|
|
|
public static final ParseField QUEUE_CAPACITY = new ParseField("queue_capacity");
|
|
|
+ public static final ParseField CACHE_SIZE = new ParseField("cache_size");
|
|
|
|
|
|
private static final ConstructingObjectParser<TaskParams, Void> PARSER = new ConstructingObjectParser<>(
|
|
|
"trained_model_deployment_params",
|
|
|
true,
|
|
|
- a -> new TaskParams((String) a[0], (Long) a[1], (Integer) a[2], (Integer) a[3], (int) a[4], (Integer) a[5], (Integer) a[6])
|
|
|
+ a -> new TaskParams(
|
|
|
+ (String) a[0],
|
|
|
+ (Long) a[1],
|
|
|
+ (Integer) a[2],
|
|
|
+ (Integer) a[3],
|
|
|
+ (int) a[4],
|
|
|
+ (ByteSizeValue) a[5],
|
|
|
+ (Integer) a[6],
|
|
|
+ (Integer) a[7]
|
|
|
+ )
|
|
|
);
|
|
|
|
|
|
static {
|
|
@@ -286,6 +325,12 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|
|
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUMBER_OF_ALLOCATIONS);
|
|
|
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), THREADS_PER_ALLOCATION);
|
|
|
PARSER.declareInt(ConstructingObjectParser.constructorArg(), QUEUE_CAPACITY);
|
|
|
+ PARSER.declareField(
|
|
|
+ optionalConstructorArg(),
|
|
|
+ (p, c) -> ByteSizeValue.parseBytesSizeValue(p.text(), CACHE_SIZE.getPreferredName()),
|
|
|
+ CACHE_SIZE,
|
|
|
+ ObjectParser.ValueType.VALUE
|
|
|
+ );
|
|
|
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), LEGACY_MODEL_THREADS);
|
|
|
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), LEGACY_INFERENCE_THREADS);
|
|
|
}
|
|
@@ -295,6 +340,7 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|
|
}
|
|
|
|
|
|
private final String modelId;
|
|
|
+ private final ByteSizeValue cacheSize;
|
|
|
private final long modelBytes;
|
|
|
// How many threads are used by the model during inference. Used to increase inference speed.
|
|
|
private final int threadsPerAllocation;
|
|
@@ -308,6 +354,7 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|
|
Integer numberOfAllocations,
|
|
|
Integer threadsPerAllocation,
|
|
|
int queueCapacity,
|
|
|
+ ByteSizeValue cacheSizeValue,
|
|
|
Integer legacyModelThreads,
|
|
|
Integer legacyInferenceThreads
|
|
|
) {
|
|
@@ -316,16 +363,25 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|
|
modelBytes,
|
|
|
threadsPerAllocation == null ? legacyInferenceThreads : threadsPerAllocation,
|
|
|
numberOfAllocations == null ? legacyModelThreads : numberOfAllocations,
|
|
|
- queueCapacity
|
|
|
+ queueCapacity,
|
|
|
+ cacheSizeValue
|
|
|
);
|
|
|
}
|
|
|
|
|
|
- public TaskParams(String modelId, long modelBytes, int threadsPerAllocation, int numberOfAllocations, int queueCapacity) {
|
|
|
+ public TaskParams(
|
|
|
+ String modelId,
|
|
|
+ long modelBytes,
|
|
|
+ int threadsPerAllocation,
|
|
|
+ int numberOfAllocations,
|
|
|
+ int queueCapacity,
|
|
|
+ @Nullable ByteSizeValue cacheSize
|
|
|
+ ) {
|
|
|
this.modelId = Objects.requireNonNull(modelId);
|
|
|
this.modelBytes = modelBytes;
|
|
|
this.threadsPerAllocation = threadsPerAllocation;
|
|
|
this.numberOfAllocations = numberOfAllocations;
|
|
|
this.queueCapacity = queueCapacity;
|
|
|
+ this.cacheSize = cacheSize;
|
|
|
}
|
|
|
|
|
|
public TaskParams(StreamInput in) throws IOException {
|
|
@@ -334,6 +390,11 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|
|
this.threadsPerAllocation = in.readVInt();
|
|
|
this.numberOfAllocations = in.readVInt();
|
|
|
this.queueCapacity = in.readVInt();
|
|
|
+ if (in.getVersion().onOrAfter(Version.V_8_4_0)) {
|
|
|
+ this.cacheSize = in.readOptionalWriteable(ByteSizeValue::new);
|
|
|
+ } else {
|
|
|
+ this.cacheSize = null;
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
public String getModelId() {
|
|
@@ -341,6 +402,11 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|
|
}
|
|
|
|
|
|
public long estimateMemoryUsageBytes() {
|
|
|
+ // We already take into account 2x the model bytes. If the cache size is larger than the model bytes, then
|
|
|
+ // we need to take it into account when returning the estimate.
|
|
|
+ if (cacheSize != null && cacheSize.getBytes() > modelBytes) {
|
|
|
+ return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(modelBytes) + (cacheSize.getBytes() - modelBytes);
|
|
|
+ }
|
|
|
return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(modelBytes);
|
|
|
}
|
|
|
|
|
@@ -355,6 +421,9 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|
|
out.writeVInt(threadsPerAllocation);
|
|
|
out.writeVInt(numberOfAllocations);
|
|
|
out.writeVInt(queueCapacity);
|
|
|
+ if (out.getVersion().onOrAfter(Version.V_8_4_0)) {
|
|
|
+ out.writeOptionalWriteable(cacheSize);
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
@Override
|
|
@@ -365,13 +434,16 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|
|
builder.field(THREADS_PER_ALLOCATION.getPreferredName(), threadsPerAllocation);
|
|
|
builder.field(NUMBER_OF_ALLOCATIONS.getPreferredName(), numberOfAllocations);
|
|
|
builder.field(QUEUE_CAPACITY.getPreferredName(), queueCapacity);
|
|
|
+ if (cacheSize != null) {
|
|
|
+ builder.field(CACHE_SIZE.getPreferredName(), cacheSize.getStringRep());
|
|
|
+ }
|
|
|
builder.endObject();
|
|
|
return builder;
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
public int hashCode() {
|
|
|
- return Objects.hash(modelId, modelBytes, threadsPerAllocation, numberOfAllocations, queueCapacity);
|
|
|
+ return Objects.hash(modelId, modelBytes, threadsPerAllocation, numberOfAllocations, queueCapacity, cacheSize);
|
|
|
}
|
|
|
|
|
|
@Override
|
|
@@ -384,6 +456,7 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|
|
&& modelBytes == other.modelBytes
|
|
|
&& threadsPerAllocation == other.threadsPerAllocation
|
|
|
&& numberOfAllocations == other.numberOfAllocations
|
|
|
+ && Objects.equals(cacheSize, other.cacheSize)
|
|
|
&& queueCapacity == other.queueCapacity;
|
|
|
}
|
|
|
|
|
@@ -408,6 +481,14 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|
|
return queueCapacity;
|
|
|
}
|
|
|
|
|
|
+ public Optional<ByteSizeValue> getCacheSize() {
|
|
|
+ return Optional.ofNullable(cacheSize);
|
|
|
+ }
|
|
|
+
|
|
|
+ public long getCacheSizeBytes() {
|
|
|
+ return Optional.ofNullable(cacheSize).map(ByteSizeValue::getBytes).orElse(modelBytes);
|
|
|
+ }
|
|
|
+
|
|
|
@Override
|
|
|
public String toString() {
|
|
|
return Strings.toString(this);
|