|
@@ -69,8 +69,8 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|
|
public static final ParseField MODEL_ID = new ParseField("model_id");
|
|
|
public static final ParseField TIMEOUT = new ParseField("timeout");
|
|
|
public static final ParseField WAIT_FOR = new ParseField("wait_for");
|
|
|
- public static final ParseField INFERENCE_THREADS = TaskParams.INFERENCE_THREADS;
|
|
|
- public static final ParseField MODEL_THREADS = TaskParams.MODEL_THREADS;
|
|
|
+ public static final ParseField THREADS_PER_ALLOCATION = new ParseField("threads_per_allocation", "inference_threads");
|
|
|
+ public static final ParseField NUMBER_OF_ALLOCATIONS = new ParseField("number_of_allocations", "model_threads");
|
|
|
public static final ParseField QUEUE_CAPACITY = TaskParams.QUEUE_CAPACITY;
|
|
|
|
|
|
public static final ObjectParser<Request, Void> PARSER = new ObjectParser<>(NAME, Request::new);
|
|
@@ -79,8 +79,8 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|
|
PARSER.declareString(Request::setModelId, MODEL_ID);
|
|
|
PARSER.declareString((request, val) -> request.setTimeout(TimeValue.parseTimeValue(val, TIMEOUT.getPreferredName())), TIMEOUT);
|
|
|
PARSER.declareString((request, waitFor) -> request.setWaitForState(AllocationStatus.State.fromString(waitFor)), WAIT_FOR);
|
|
|
- PARSER.declareInt(Request::setInferenceThreads, INFERENCE_THREADS);
|
|
|
- PARSER.declareInt(Request::setModelThreads, MODEL_THREADS);
|
|
|
+ PARSER.declareInt(Request::setThreadsPerAllocation, THREADS_PER_ALLOCATION);
|
|
|
+ PARSER.declareInt(Request::setNumberOfAllocations, NUMBER_OF_ALLOCATIONS);
|
|
|
PARSER.declareInt(Request::setQueueCapacity, QUEUE_CAPACITY);
|
|
|
}
|
|
|
|
|
@@ -99,8 +99,8 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|
|
private String modelId;
|
|
|
private TimeValue timeout = DEFAULT_TIMEOUT;
|
|
|
private AllocationStatus.State waitForState = AllocationStatus.State.STARTED;
|
|
|
- private int modelThreads = 1;
|
|
|
- private int inferenceThreads = 1;
|
|
|
+ private int numberOfAllocations = 1;
|
|
|
+ private int threadsPerAllocation = 1;
|
|
|
private int queueCapacity = 1024;
|
|
|
|
|
|
private Request() {}
|
|
@@ -114,8 +114,8 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|
|
modelId = in.readString();
|
|
|
timeout = in.readTimeValue();
|
|
|
waitForState = in.readEnum(AllocationStatus.State.class);
|
|
|
- modelThreads = in.readVInt();
|
|
|
- inferenceThreads = in.readVInt();
|
|
|
+ numberOfAllocations = in.readVInt();
|
|
|
+ threadsPerAllocation = in.readVInt();
|
|
|
queueCapacity = in.readVInt();
|
|
|
}
|
|
|
|
|
@@ -144,20 +144,20 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|
|
return this;
|
|
|
}
|
|
|
|
|
|
- public int getModelThreads() {
|
|
|
- return modelThreads;
|
|
|
+ public int getNumberOfAllocations() {
|
|
|
+ return numberOfAllocations;
|
|
|
}
|
|
|
|
|
|
- public void setModelThreads(int modelThreads) {
|
|
|
- this.modelThreads = modelThreads;
|
|
|
+ public void setNumberOfAllocations(int numberOfAllocations) {
|
|
|
+ this.numberOfAllocations = numberOfAllocations;
|
|
|
}
|
|
|
|
|
|
- public int getInferenceThreads() {
|
|
|
- return inferenceThreads;
|
|
|
+ public int getThreadsPerAllocation() {
|
|
|
+ return threadsPerAllocation;
|
|
|
}
|
|
|
|
|
|
- public void setInferenceThreads(int inferenceThreads) {
|
|
|
- this.inferenceThreads = inferenceThreads;
|
|
|
+ public void setThreadsPerAllocation(int threadsPerAllocation) {
|
|
|
+ this.threadsPerAllocation = threadsPerAllocation;
|
|
|
}
|
|
|
|
|
|
public int getQueueCapacity() {
|
|
@@ -174,8 +174,8 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|
|
out.writeString(modelId);
|
|
|
out.writeTimeValue(timeout);
|
|
|
out.writeEnum(waitForState);
|
|
|
- out.writeVInt(modelThreads);
|
|
|
- out.writeVInt(inferenceThreads);
|
|
|
+ out.writeVInt(numberOfAllocations);
|
|
|
+ out.writeVInt(threadsPerAllocation);
|
|
|
out.writeVInt(queueCapacity);
|
|
|
}
|
|
|
|
|
@@ -185,8 +185,8 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|
|
builder.field(MODEL_ID.getPreferredName(), modelId);
|
|
|
builder.field(TIMEOUT.getPreferredName(), timeout.getStringRep());
|
|
|
builder.field(WAIT_FOR.getPreferredName(), waitForState);
|
|
|
- builder.field(MODEL_THREADS.getPreferredName(), modelThreads);
|
|
|
- builder.field(INFERENCE_THREADS.getPreferredName(), inferenceThreads);
|
|
|
+ builder.field(NUMBER_OF_ALLOCATIONS.getPreferredName(), numberOfAllocations);
|
|
|
+ builder.field(THREADS_PER_ALLOCATION.getPreferredName(), threadsPerAllocation);
|
|
|
builder.field(QUEUE_CAPACITY.getPreferredName(), queueCapacity);
|
|
|
builder.endObject();
|
|
|
return builder;
|
|
@@ -203,11 +203,11 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|
|
+ Strings.arrayToCommaDelimitedString(VALID_WAIT_STATES)
|
|
|
);
|
|
|
}
|
|
|
- if (modelThreads < 1) {
|
|
|
- validationException.addValidationError("[" + MODEL_THREADS + "] must be a positive integer");
|
|
|
+ if (numberOfAllocations < 1) {
|
|
|
+ validationException.addValidationError("[" + NUMBER_OF_ALLOCATIONS + "] must be a positive integer");
|
|
|
}
|
|
|
- if (inferenceThreads < 1) {
|
|
|
- validationException.addValidationError("[" + INFERENCE_THREADS + "] must be a positive integer");
|
|
|
+ if (threadsPerAllocation < 1) {
|
|
|
+ validationException.addValidationError("[" + THREADS_PER_ALLOCATION + "] must be a positive integer");
|
|
|
}
|
|
|
if (queueCapacity < 1) {
|
|
|
validationException.addValidationError("[" + QUEUE_CAPACITY + "] must be a positive integer");
|
|
@@ -217,7 +217,7 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|
|
|
|
|
@Override
|
|
|
public int hashCode() {
|
|
|
- return Objects.hash(modelId, timeout, waitForState, modelThreads, inferenceThreads, queueCapacity);
|
|
|
+ return Objects.hash(modelId, timeout, waitForState, numberOfAllocations, threadsPerAllocation, queueCapacity);
|
|
|
}
|
|
|
|
|
|
@Override
|
|
@@ -232,8 +232,8 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|
|
return Objects.equals(modelId, other.modelId)
|
|
|
&& Objects.equals(timeout, other.timeout)
|
|
|
&& Objects.equals(waitForState, other.waitForState)
|
|
|
- && modelThreads == other.modelThreads
|
|
|
- && inferenceThreads == other.inferenceThreads
|
|
|
+ && numberOfAllocations == other.numberOfAllocations
|
|
|
+ && threadsPerAllocation == other.threadsPerAllocation
|
|
|
&& queueCapacity == other.queueCapacity;
|
|
|
}
|
|
|
|
|
@@ -254,22 +254,28 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|
|
|
|
|
public static final Version VERSION_INTRODUCED = Version.V_8_0_0;
|
|
|
private static final ParseField MODEL_BYTES = new ParseField("model_bytes");
|
|
|
- public static final ParseField MODEL_THREADS = new ParseField("model_threads");
|
|
|
- public static final ParseField INFERENCE_THREADS = new ParseField("inference_threads");
|
|
|
+ public static final ParseField NUMBER_OF_ALLOCATIONS = new ParseField("number_of_allocations");
|
|
|
+ public static final ParseField THREADS_PER_ALLOCATION = new ParseField("threads_per_allocation");
|
|
|
+ // number_of_allocations was previously named model_threads
|
|
|
+ private static final ParseField LEGACY_MODEL_THREADS = new ParseField("model_threads");
|
|
|
+ // threads_per_allocation was previously named inference_threads
|
|
|
+ public static final ParseField LEGACY_INFERENCE_THREADS = new ParseField("inference_threads");
|
|
|
public static final ParseField QUEUE_CAPACITY = new ParseField("queue_capacity");
|
|
|
|
|
|
private static final ConstructingObjectParser<TaskParams, Void> PARSER = new ConstructingObjectParser<>(
|
|
|
"trained_model_deployment_params",
|
|
|
true,
|
|
|
- a -> new TaskParams((String) a[0], (Long) a[1], (int) a[2], (int) a[3], (int) a[4])
|
|
|
+ a -> new TaskParams((String) a[0], (Long) a[1], (Integer) a[2], (Integer) a[3], (int) a[4], (Integer) a[5], (Integer) a[6])
|
|
|
);
|
|
|
|
|
|
static {
|
|
|
PARSER.declareString(ConstructingObjectParser.constructorArg(), TrainedModelConfig.MODEL_ID);
|
|
|
PARSER.declareLong(ConstructingObjectParser.constructorArg(), MODEL_BYTES);
|
|
|
- PARSER.declareInt(ConstructingObjectParser.constructorArg(), INFERENCE_THREADS);
|
|
|
- PARSER.declareInt(ConstructingObjectParser.constructorArg(), MODEL_THREADS);
|
|
|
+ PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUMBER_OF_ALLOCATIONS);
|
|
|
+ PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), THREADS_PER_ALLOCATION);
|
|
|
PARSER.declareInt(ConstructingObjectParser.constructorArg(), QUEUE_CAPACITY);
|
|
|
+ PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), LEGACY_MODEL_THREADS);
|
|
|
+ PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), LEGACY_INFERENCE_THREADS);
|
|
|
}
|
|
|
|
|
|
public static TaskParams fromXContent(XContentParser parser) {
|
|
@@ -279,24 +285,42 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|
|
private final String modelId;
|
|
|
private final long modelBytes;
|
|
|
// How many threads are used by the model during inference. Used to increase inference speed.
|
|
|
- private final int inferenceThreads;
|
|
|
+ private final int threadsPerAllocation;
|
|
|
// How many threads are used when forwarding the request to the model. Used to increase throughput.
|
|
|
- private final int modelThreads;
|
|
|
+ private final int numberOfAllocations;
|
|
|
private final int queueCapacity;
|
|
|
|
|
|
- public TaskParams(String modelId, long modelBytes, int inferenceThreads, int modelThreads, int queueCapacity) {
|
|
|
+ private TaskParams(
|
|
|
+ String modelId,
|
|
|
+ long modelBytes,
|
|
|
+ Integer numberOfAllocations,
|
|
|
+ Integer threadsPerAllocation,
|
|
|
+ int queueCapacity,
|
|
|
+ Integer legacyModelThreads,
|
|
|
+ Integer legacyInferenceThreads
|
|
|
+ ) {
|
|
|
+ this(
|
|
|
+ modelId,
|
|
|
+ modelBytes,
|
|
|
+ threadsPerAllocation == null ? legacyInferenceThreads : threadsPerAllocation,
|
|
|
+ numberOfAllocations == null ? legacyModelThreads : numberOfAllocations,
|
|
|
+ queueCapacity
|
|
|
+ );
|
|
|
+ }
|
|
|
+
|
|
|
+ public TaskParams(String modelId, long modelBytes, int threadsPerAllocation, int numberOfAllocations, int queueCapacity) {
|
|
|
this.modelId = Objects.requireNonNull(modelId);
|
|
|
this.modelBytes = modelBytes;
|
|
|
- this.inferenceThreads = inferenceThreads;
|
|
|
- this.modelThreads = modelThreads;
|
|
|
+ this.threadsPerAllocation = threadsPerAllocation;
|
|
|
+ this.numberOfAllocations = numberOfAllocations;
|
|
|
this.queueCapacity = queueCapacity;
|
|
|
}
|
|
|
|
|
|
public TaskParams(StreamInput in) throws IOException {
|
|
|
this.modelId = in.readString();
|
|
|
this.modelBytes = in.readLong();
|
|
|
- this.inferenceThreads = in.readVInt();
|
|
|
- this.modelThreads = in.readVInt();
|
|
|
+ this.threadsPerAllocation = in.readVInt();
|
|
|
+ this.numberOfAllocations = in.readVInt();
|
|
|
this.queueCapacity = in.readVInt();
|
|
|
}
|
|
|
|
|
@@ -316,8 +340,8 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|
|
public void writeTo(StreamOutput out) throws IOException {
|
|
|
out.writeString(modelId);
|
|
|
out.writeLong(modelBytes);
|
|
|
- out.writeVInt(inferenceThreads);
|
|
|
- out.writeVInt(modelThreads);
|
|
|
+ out.writeVInt(threadsPerAllocation);
|
|
|
+ out.writeVInt(numberOfAllocations);
|
|
|
out.writeVInt(queueCapacity);
|
|
|
}
|
|
|
|
|
@@ -326,8 +350,8 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|
|
builder.startObject();
|
|
|
builder.field(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId);
|
|
|
builder.field(MODEL_BYTES.getPreferredName(), modelBytes);
|
|
|
- builder.field(INFERENCE_THREADS.getPreferredName(), inferenceThreads);
|
|
|
- builder.field(MODEL_THREADS.getPreferredName(), modelThreads);
|
|
|
+ builder.field(THREADS_PER_ALLOCATION.getPreferredName(), threadsPerAllocation);
|
|
|
+ builder.field(NUMBER_OF_ALLOCATIONS.getPreferredName(), numberOfAllocations);
|
|
|
builder.field(QUEUE_CAPACITY.getPreferredName(), queueCapacity);
|
|
|
builder.endObject();
|
|
|
return builder;
|
|
@@ -335,7 +359,7 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|
|
|
|
|
@Override
|
|
|
public int hashCode() {
|
|
|
- return Objects.hash(modelId, modelBytes, inferenceThreads, modelThreads, queueCapacity);
|
|
|
+ return Objects.hash(modelId, modelBytes, threadsPerAllocation, numberOfAllocations, queueCapacity);
|
|
|
}
|
|
|
|
|
|
@Override
|
|
@@ -346,8 +370,8 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|
|
TaskParams other = (TaskParams) o;
|
|
|
return Objects.equals(modelId, other.modelId)
|
|
|
&& modelBytes == other.modelBytes
|
|
|
- && inferenceThreads == other.inferenceThreads
|
|
|
- && modelThreads == other.modelThreads
|
|
|
+ && threadsPerAllocation == other.threadsPerAllocation
|
|
|
+ && numberOfAllocations == other.numberOfAllocations
|
|
|
&& queueCapacity == other.queueCapacity;
|
|
|
}
|
|
|
|
|
@@ -360,12 +384,12 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
|
|
|
return modelBytes;
|
|
|
}
|
|
|
|
|
|
- public int getInferenceThreads() {
|
|
|
- return inferenceThreads;
|
|
|
+ public int getThreadsPerAllocation() {
|
|
|
+ return threadsPerAllocation;
|
|
|
}
|
|
|
|
|
|
- public int getModelThreads() {
|
|
|
- return modelThreads;
|
|
|
+ public int getNumberOfAllocations() {
|
|
|
+ return numberOfAllocations;
|
|
|
}
|
|
|
|
|
|
public int getQueueCapacity() {
|