|
@@ -18,14 +18,14 @@ import org.elasticsearch.common.io.stream.StreamInput;
|
|
|
import org.elasticsearch.common.io.stream.StreamOutput;
|
|
|
import org.elasticsearch.common.io.stream.Writeable;
|
|
|
import org.elasticsearch.common.unit.ByteSizeValue;
|
|
|
+import org.elasticsearch.core.Nullable;
|
|
|
+import org.elasticsearch.tasks.Task;
|
|
|
import org.elasticsearch.xcontent.ParseField;
|
|
|
import org.elasticsearch.xcontent.ToXContentObject;
|
|
|
import org.elasticsearch.xcontent.XContentBuilder;
|
|
|
-import org.elasticsearch.core.Nullable;
|
|
|
-import org.elasticsearch.tasks.Task;
|
|
|
import org.elasticsearch.xpack.core.action.util.QueryPage;
|
|
|
-import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationStatus;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationState;
|
|
|
+import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationStatus;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingState;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingStateAndReason;
|
|
|
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
|
@@ -233,6 +233,7 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
|
|
|
@Nullable private final ByteSizeValue modelSize;
|
|
|
@Nullable private final Integer inferenceThreads;
|
|
|
@Nullable private final Integer modelThreads;
|
|
|
+ @Nullable private final Integer queueCapacity;
|
|
|
private final List<NodeStats> nodeStats;
|
|
|
|
|
|
public AllocationStats(
|
|
@@ -240,12 +241,14 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
|
|
|
@Nullable ByteSizeValue modelSize,
|
|
|
@Nullable Integer inferenceThreads,
|
|
|
@Nullable Integer modelThreads,
|
|
|
+ @Nullable Integer queueCapacity,
|
|
|
List<NodeStats> nodeStats
|
|
|
) {
|
|
|
this.modelId = modelId;
|
|
|
this.modelSize = modelSize;
|
|
|
this.inferenceThreads = inferenceThreads;
|
|
|
this.modelThreads = modelThreads;
|
|
|
+ this.queueCapacity = queueCapacity;
|
|
|
this.nodeStats = nodeStats;
|
|
|
this.state = null;
|
|
|
this.reason = null;
|
|
@@ -256,6 +259,7 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
|
|
|
modelSize = in.readOptionalWriteable(ByteSizeValue::new);
|
|
|
inferenceThreads = in.readOptionalVInt();
|
|
|
modelThreads = in.readOptionalVInt();
|
|
|
+ queueCapacity = in.readOptionalVInt();
|
|
|
nodeStats = in.readList(NodeStats::new);
|
|
|
state = in.readOptionalEnum(AllocationState.class);
|
|
|
reason = in.readOptionalString();
|
|
@@ -280,6 +284,11 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
|
|
|
return modelThreads;
|
|
|
}
|
|
|
|
|
|
+ @Nullable
|
|
|
+ public Integer getQueueCapacity() {
|
|
|
+ return queueCapacity;
|
|
|
+ }
|
|
|
+
|
|
|
public List<NodeStats> getNodeStats() {
|
|
|
return nodeStats;
|
|
|
}
|
|
@@ -320,6 +329,9 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
|
|
|
if (modelThreads != null) {
|
|
|
builder.field(StartTrainedModelDeploymentAction.TaskParams.MODEL_THREADS.getPreferredName(), modelThreads);
|
|
|
}
|
|
|
+ if (queueCapacity != null) {
|
|
|
+ builder.field(StartTrainedModelDeploymentAction.TaskParams.QUEUE_CAPACITY.getPreferredName(), queueCapacity);
|
|
|
+ }
|
|
|
if (state != null) {
|
|
|
builder.field("state", state);
|
|
|
}
|
|
@@ -344,6 +356,7 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
|
|
|
out.writeOptionalWriteable(modelSize);
|
|
|
out.writeOptionalVInt(inferenceThreads);
|
|
|
out.writeOptionalVInt(modelThreads);
|
|
|
+ out.writeOptionalVInt(queueCapacity);
|
|
|
out.writeList(nodeStats);
|
|
|
out.writeOptionalEnum(state);
|
|
|
out.writeOptionalString(reason);
|
|
@@ -359,6 +372,7 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
|
|
|
Objects.equals(modelSize, that.modelSize) &&
|
|
|
Objects.equals(inferenceThreads, that.inferenceThreads) &&
|
|
|
Objects.equals(modelThreads, that.modelThreads) &&
|
|
|
+ Objects.equals(queueCapacity, that.queueCapacity) &&
|
|
|
Objects.equals(state, that.state) &&
|
|
|
Objects.equals(reason, that.reason) &&
|
|
|
Objects.equals(allocationStatus, that.allocationStatus) &&
|
|
@@ -367,7 +381,8 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
|
|
|
|
|
|
@Override
|
|
|
public int hashCode() {
|
|
|
- return Objects.hash(modelId, modelSize, inferenceThreads, modelThreads, nodeStats, state, reason, allocationStatus);
|
|
|
+ return Objects.hash(modelId, modelSize, inferenceThreads, modelThreads, queueCapacity, nodeStats, state, reason,
|
|
|
+ allocationStatus);
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -482,6 +497,7 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
|
|
|
stat.getModelSize(),
|
|
|
stat.getInferenceThreads(),
|
|
|
stat.getModelThreads(),
|
|
|
+ stat.getQueueCapacity(),
|
|
|
updatedNodeStats
|
|
|
)
|
|
|
);
|
|
@@ -510,7 +526,7 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
|
|
|
nodeStats.sort(Comparator.comparing(n -> n.getNode().getId()));
|
|
|
|
|
|
updatedAllocationStats.add(new GetDeploymentStatsAction.Response.AllocationStats(
|
|
|
- modelId, null, null, null, nodeStats)
|
|
|
+ modelId, null, null, null, null, nodeStats)
|
|
|
);
|
|
|
}
|
|
|
}
|