Kaynağa Gözat

[ML] Add queue_capacity to deployment stats (#79905)

Adds the `queue_capacity` to the response of the
get trained model deployment stats API.
Dimitris Athanasiou 4 yıl önce
ebeveyn
işleme
ee97aa3016

+ 21 - 5
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDeploymentStatsAction.java

@@ -18,14 +18,14 @@ import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.unit.ByteSizeValue;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.tasks.Task;
 import org.elasticsearch.xcontent.ParseField;
 import org.elasticsearch.xcontent.ToXContentObject;
 import org.elasticsearch.xcontent.XContentBuilder;
-import org.elasticsearch.core.Nullable;
-import org.elasticsearch.tasks.Task;
 import org.elasticsearch.xpack.core.action.util.QueryPage;
-import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationStatus;
 import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationState;
+import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationStatus;
 import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingState;
 import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingStateAndReason;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
@@ -233,6 +233,7 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
             @Nullable private final ByteSizeValue modelSize;
             @Nullable private final Integer inferenceThreads;
             @Nullable private final Integer modelThreads;
+            @Nullable private final Integer queueCapacity;
             private final List<NodeStats> nodeStats;
 
             public AllocationStats(
@@ -240,12 +241,14 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
                 @Nullable ByteSizeValue modelSize,
                 @Nullable Integer inferenceThreads,
                 @Nullable Integer modelThreads,
+                @Nullable Integer queueCapacity,
                 List<NodeStats> nodeStats
             ) {
                 this.modelId = modelId;
                 this.modelSize = modelSize;
                 this.inferenceThreads = inferenceThreads;
                 this.modelThreads = modelThreads;
+                this.queueCapacity = queueCapacity;
                 this.nodeStats = nodeStats;
                 this.state = null;
                 this.reason = null;
@@ -256,6 +259,7 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
                 modelSize = in.readOptionalWriteable(ByteSizeValue::new);
                 inferenceThreads = in.readOptionalVInt();
                 modelThreads = in.readOptionalVInt();
+                queueCapacity = in.readOptionalVInt();
                 nodeStats = in.readList(NodeStats::new);
                 state = in.readOptionalEnum(AllocationState.class);
                 reason = in.readOptionalString();
@@ -280,6 +284,11 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
                 return modelThreads;
             }
 
+            @Nullable
+            public Integer getQueueCapacity() {
+                return queueCapacity;
+            }
+
             public List<NodeStats> getNodeStats() {
                 return nodeStats;
             }
@@ -320,6 +329,9 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
                 if (modelThreads != null) {
                     builder.field(StartTrainedModelDeploymentAction.TaskParams.MODEL_THREADS.getPreferredName(), modelThreads);
                 }
+                if (queueCapacity != null) {
+                    builder.field(StartTrainedModelDeploymentAction.TaskParams.QUEUE_CAPACITY.getPreferredName(), queueCapacity);
+                }
                 if (state != null) {
                     builder.field("state", state);
                 }
@@ -344,6 +356,7 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
                 out.writeOptionalWriteable(modelSize);
                 out.writeOptionalVInt(inferenceThreads);
                 out.writeOptionalVInt(modelThreads);
+                out.writeOptionalVInt(queueCapacity);
                 out.writeList(nodeStats);
                 out.writeOptionalEnum(state);
                 out.writeOptionalString(reason);
@@ -359,6 +372,7 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
                     Objects.equals(modelSize, that.modelSize) &&
                     Objects.equals(inferenceThreads, that.inferenceThreads) &&
                     Objects.equals(modelThreads, that.modelThreads) &&
+                    Objects.equals(queueCapacity, that.queueCapacity) &&
                     Objects.equals(state, that.state) &&
                     Objects.equals(reason, that.reason) &&
                     Objects.equals(allocationStatus, that.allocationStatus) &&
@@ -367,7 +381,8 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
 
             @Override
             public int hashCode() {
-                return Objects.hash(modelId, modelSize, inferenceThreads, modelThreads, nodeStats, state, reason, allocationStatus);
+                return Objects.hash(modelId, modelSize, inferenceThreads, modelThreads, queueCapacity, nodeStats, state, reason,
+                    allocationStatus);
             }
         }
 
@@ -482,6 +497,7 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
                             stat.getModelSize(),
                             stat.getInferenceThreads(),
                             stat.getModelThreads(),
+                            stat.getQueueCapacity(),
                             updatedNodeStats
                         )
                     );
@@ -510,7 +526,7 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
                     nodeStats.sort(Comparator.comparing(n -> n.getNode().getId()));
 
                     updatedAllocationStats.add(new GetDeploymentStatsAction.Response.AllocationStats(
-                        modelId, null, null, null, nodeStats)
+                        modelId, null, null, null, null, nodeStats)
                     );
                 }
             }

+ 3 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDeploymentStatsActionResponseTests.java

@@ -101,6 +101,7 @@ public class GetDeploymentStatsActionResponseTests extends AbstractWireSerializi
             ByteSizeValue.ofBytes(randomNonNegativeLong()),
             randomBoolean() ? null : randomIntBetween(1, 8),
             randomBoolean() ? null : randomIntBetween(1, 8),
+            randomBoolean() ? null : randomIntBetween(1, 10000),
             nodeStatsList);
 
         Map<String, Map<String, RoutingStateAndReason>> badRoutes = new HashMap<>();
@@ -145,6 +146,7 @@ public class GetDeploymentStatsActionResponseTests extends AbstractWireSerializi
             ByteSizeValue.ofBytes(randomNonNegativeLong()),
             randomBoolean() ? null : randomIntBetween(1, 8),
             randomBoolean() ? null : randomIntBetween(1, 8),
+            randomBoolean() ? null : randomIntBetween(1, 10000),
             nodeStatsList);
         var response = new GetDeploymentStatsAction.Response(Collections.emptyList(), Collections.emptyList(),
             List.of(model1), 1);
@@ -202,6 +204,7 @@ public class GetDeploymentStatsActionResponseTests extends AbstractWireSerializi
             ByteSizeValue.ofBytes(randomNonNegativeLong()),
             randomBoolean() ? null : randomIntBetween(1, 8),
             randomBoolean() ? null : randomIntBetween(1, 8),
+            randomBoolean() ? null : randomIntBetween(1, 10000),
             nodeStatsList);
     }
 }

+ 1 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsAction.java

@@ -185,6 +185,7 @@ public class TransportGetDeploymentStatsAction extends TransportTasksAction<Trai
             ByteSizeValue.ofBytes(task.getParams().getModelBytes()),
             task.getParams().getInferenceThreads(),
             task.getParams().getModelThreads(),
+            task.getParams().getQueueCapacity(),
             nodeStats)
         );
     }