Parcourir la source

[ML] Add the number of queue requests per node to deployment stats (#80098)

Adds the number of pending inference requests to the
node deployment stats as `number_of_pending_requests`
David Kyle il y a 4 ans
Parent
commit
901eb5b28f

+ 14 - 3
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDeploymentStatsAction.java

@@ -127,12 +127,14 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
                 private final Long inferenceCount;
                 private final Double avgInferenceTime;
                 private final Instant lastAccess;
+                private final Integer pendingCount;
                 private final RoutingStateAndReason routingState;
 
                 public static NodeStats forStartedState(
                     DiscoveryNode node,
                     long inferenceCount,
-                    double avgInferenceTime,
+                    Double avgInferenceTime,
+                    int pendingCount,
                     Instant lastAccess
                 ) {
                     return new NodeStats(
@@ -140,12 +142,13 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
                         inferenceCount,
                         avgInferenceTime,
                         lastAccess,
+                        pendingCount,
                         new RoutingStateAndReason(RoutingState.STARTED, null)
                     );
                 }
 
                 public static NodeStats forNotStartedState(DiscoveryNode node, RoutingState state, String reason) {
-                    return new NodeStats(node, null, null, null, new RoutingStateAndReason(state, reason));
+                    return new NodeStats(node, null, null, null, null, new RoutingStateAndReason(state, reason));
                 }
 
                 private NodeStats(
@@ -153,12 +156,14 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
                     Long inferenceCount,
                     Double avgInferenceTime,
                     Instant lastAccess,
+                    Integer pendingCount,
                     RoutingStateAndReason routingState
                 ) {
                     this.node = node;
                     this.inferenceCount = inferenceCount;
                     this.avgInferenceTime = avgInferenceTime;
                     this.lastAccess = lastAccess;
+                    this.pendingCount = pendingCount;
                     this.routingState = routingState;
 
                     // if lastAccess time is null there have been no inferences
@@ -170,6 +175,7 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
                     this.inferenceCount = in.readOptionalLong();
                     this.avgInferenceTime = in.readOptionalDouble();
                     this.lastAccess = in.readOptionalInstant();
+                    this.pendingCount = in.readOptionalVInt();
                     this.routingState = in.readOptionalWriteable(RoutingStateAndReason::new);
                 }
 
@@ -199,6 +205,9 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
                     if (lastAccess != null) {
                         builder.timeField("last_access", "last_access_string", lastAccess.toEpochMilli());
                     }
+                    if (pendingCount != null) {
+                        builder.field("number_of_pending_requests", pendingCount);
+                    }
                     builder.endObject();
                     return builder;
                 }
@@ -209,6 +218,7 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
                     out.writeOptionalLong(inferenceCount);
                     out.writeOptionalDouble(avgInferenceTime);
                     out.writeOptionalInstant(lastAccess);
+                    out.writeOptionalVInt(pendingCount);
                     out.writeOptionalWriteable(routingState);
                 }
 
@@ -221,12 +231,13 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
                         && Objects.equals(that.avgInferenceTime, avgInferenceTime)
                         && Objects.equals(node, that.node)
                         && Objects.equals(lastAccess, that.lastAccess)
+                        && Objects.equals(pendingCount, that.pendingCount)
                         && Objects.equals(routingState, that.routingState);
                 }
 
                 @Override
                 public int hashCode() {
-                    return Objects.hash(node, inferenceCount, avgInferenceTime, lastAccess, routingState);
+                    return Objects.hash(node, inferenceCount, avgInferenceTime, lastAccess, pendingCount, routingState);
                 }
             }
 

+ 7 - 2
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDeploymentStatsActionResponseTests.java

@@ -42,7 +42,7 @@ public class GetDeploymentStatsActionResponseTests extends AbstractWireSerializi
         for (var i = 0; i < numStats; i++) {
             stats.add(randomDeploymentStats());
         }
-        stats.sort(Comparator.comparing(s -> s.getModelId()));
+        stats.sort(Comparator.comparing(GetDeploymentStatsAction.Response.AllocationStats::getModelId));
         return new GetDeploymentStatsAction.Response(Collections.emptyList(), Collections.emptyList(), stats, stats.size());
     }
 
@@ -90,6 +90,7 @@ public class GetDeploymentStatsActionResponseTests extends AbstractWireSerializi
                 nodes.get("node1"),
                 randomNonNegativeLong(),
                 randomDoubleBetween(0.0, 100.0, true),
+                randomIntBetween(0, 100),
                 Instant.now()
             )
         );
@@ -98,6 +99,7 @@ public class GetDeploymentStatsActionResponseTests extends AbstractWireSerializi
                 nodes.get("node2"),
                 randomNonNegativeLong(),
                 randomDoubleBetween(0.0, 100.0, true),
+                randomIntBetween(0, 100),
                 Instant.now()
             )
         );
@@ -139,6 +141,7 @@ public class GetDeploymentStatsActionResponseTests extends AbstractWireSerializi
                 nodes.get("node1"),
                 randomNonNegativeLong(),
                 randomDoubleBetween(0.0, 100.0, true),
+                randomIntBetween(0, 100),
                 Instant.now()
             )
         );
@@ -147,6 +150,7 @@ public class GetDeploymentStatsActionResponseTests extends AbstractWireSerializi
                 nodes.get("node2"),
                 randomNonNegativeLong(),
                 randomDoubleBetween(0.0, 100.0, true),
+                randomIntBetween(0, 100),
                 Instant.now()
             )
         );
@@ -198,7 +202,8 @@ public class GetDeploymentStatsActionResponseTests extends AbstractWireSerializi
                     GetDeploymentStatsAction.Response.AllocationStats.NodeStats.forStartedState(
                         node,
                         randomNonNegativeLong(),
-                        randomDoubleBetween(0.0, 100.0, true),
+                        randomBoolean() ? randomDoubleBetween(0.0, 100.0, true) : null,
+                        randomIntBetween(0, 100),
                         Instant.now()
                     )
                 );

+ 5 - 0
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java

@@ -48,6 +48,7 @@ import static org.hamcrest.Matchers.greaterThanOrEqualTo;
 import static org.hamcrest.Matchers.hasSize;
 import static org.hamcrest.Matchers.is;
 import static org.hamcrest.Matchers.not;
+import static org.hamcrest.Matchers.notNullValue;
 import static org.hamcrest.Matchers.nullValue;
 
 /**
@@ -278,6 +279,10 @@ public class PyTorchModelIT extends ESRestTestCase {
         // 2 of the 3 nodes in the cluster are ML nodes
         assertThat(nodes, hasSize(2));
         int inferenceCount = sumInferenceCountOnNodes(nodes);
+        for (var node : nodes) {
+            assertThat(node.get("number_of_pending_requests"), notNullValue());
+            // last_access and average_inference_time_ms may be null if inference wasn't performed on this node
+        }
         assertThat(inferenceCount, equalTo(2));
     }
 

+ 3 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsAction.java

@@ -189,7 +189,9 @@ public class TransportGetDeploymentStatsAction extends TransportTasksAction<
                 GetDeploymentStatsAction.Response.AllocationStats.NodeStats.forStartedState(
                     clusterService.localNode(),
                     stats.get().getTimingStats().getCount(),
-                    stats.get().getTimingStats().getAverage(),
+                    // avoid reporting the average time as 0 if count < 1
+                    (stats.get().getTimingStats().getCount() > 0) ? stats.get().getTimingStats().getAverage() : null,
+                    stats.get().getPendingCount(),
                     stats.get().getLastUsed()
                 )
             );

+ 2 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java

@@ -103,7 +103,8 @@ public class DeploymentManager {
             .map(
                 processContext -> new ModelStats(
                     processContext.getResultProcessor().getTimingStats(),
-                    processContext.getResultProcessor().getLastUsed()
+                    processContext.getResultProcessor().getLastUsed(),
+                    processContext.executorService.queueSize() + processContext.getResultProcessor().numberOfPendingResults()
                 )
             );
     }

+ 7 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/ModelStats.java

@@ -14,10 +14,12 @@ public class ModelStats {
 
     private final LongSummaryStatistics timingStats;
     private final Instant lastUsed;
+    private final int pendingCount;
 
-    ModelStats(LongSummaryStatistics timingStats, Instant lastUsed) {
+    ModelStats(LongSummaryStatistics timingStats, Instant lastUsed, int pendingCount) {
         this.timingStats = timingStats;
         this.lastUsed = lastUsed;
+        this.pendingCount = pendingCount;
     }
 
     public LongSummaryStatistics getTimingStats() {
@@ -27,4 +29,8 @@ public class ModelStats {
     public Instant getLastUsed() {
         return lastUsed;
     }
+
+    public int getPendingCount() {
+        return pendingCount;
+    }
 }

+ 4 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java

@@ -108,6 +108,10 @@ public class PyTorchResultProcessor {
         return lastUsed;
     }
 
+    public int numberOfPendingResults() {
+        return pendingResults.size();
+    }
+
     public void stop() {
         isStopping = true;
     }

+ 4 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/ProcessWorkerExecutorService.java

@@ -55,6 +55,10 @@ public class ProcessWorkerExecutorService extends AbstractExecutorService {
         this.queue = new LinkedBlockingQueue<>(queueCapacity);
     }
 
+    public int queueSize() {
+        return queue.size();
+    }
+
     @Override
     public void shutdown() {
         running = false;