1
0
Эх сурвалжийг харах

[ML] Fix missing deployment stats peak throughput field #85436

In an edge case peak_throughput_per_minute was not being returned 
even if the stat could be calculated for the last bucket
David Kyle 3 жил өмнө
parent
commit
7a22f39aae

+ 7 - 1
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java

@@ -328,14 +328,20 @@ public class PyTorchModelIT extends ESRestTestCase {
             assertThat(nodes, hasSize(2));
             for (var node : nodes) {
                 assertThat(node.get("number_of_pending_requests"), notNullValue());
-                // last_access and average_inference_time_ms may be null if inference wasn't performed on this node
             }
+            // last_access and average_inference_time_ms may be null if inference wasn't performed on this node
+            assertAtLeastOneOfTheseIsNotNull("last_access", nodes);
+            assertAtLeastOneOfTheseIsNotNull("average_inference_time_ms", nodes);
 
             int inferenceCount = sumInferenceCountOnNodes(nodes);
             assertThat(inferenceCount, equalTo(2));
         }
     }
 
+    private void assertAtLeastOneOfTheseIsNotNull(String name, List<Map<String, Object>> nodes) {
+        assertTrue("all nodes have null value for [" + name + "]", nodes.stream().anyMatch(n -> n.get(name) != null));
+    }
+
     @SuppressWarnings("unchecked")
     public void testGetDeploymentStats_WithWildcard() throws IOException {
         String modelFoo = "foo";

+ 1 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java

@@ -157,6 +157,7 @@ public class PyTorchResultProcessor {
             // in this period to close off the last period stats.
             // The stats are valid return them here
             rs = new RecentStats(lastPeriodSummaryStats.getCount(), lastPeriodSummaryStats.getAverage());
+            peakThroughput = Math.max(peakThroughput, lastPeriodSummaryStats.getCount());
         }
 
         if (rs == null) {

+ 7 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessorTests.java

@@ -26,6 +26,7 @@ import static org.hamcrest.Matchers.closeTo;
 import static org.hamcrest.Matchers.comparesEqualTo;
 import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.nullValue;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
 
@@ -150,7 +151,7 @@ public class PyTorchResultProcessorTests extends ESTestCase {
         assertThat(stats.timingStats().getSum(), comparesEqualTo(2100L));
     }
 
-    public void testsRecentStats() {
+    public void testsTimeDependentStats() {
 
         long start = System.currentTimeMillis();
         // the first value is used in the ctor to set the start time.
@@ -211,12 +212,14 @@ public class PyTorchResultProcessorTests extends ESTestCase {
         // first call has no results as is in the same period
         var stats = processor.getResultStats();
         assertThat(stats.recentStats().requestsProcessed(), equalTo(0L));
+        assertThat(stats.recentStats().avgInferenceTime(), nullValue());
         // 2nd time in the next period
         stats = processor.getResultStats();
         assertNotNull(stats.recentStats());
         assertThat(stats.recentStats().requestsProcessed(), equalTo(3L));
         assertThat(stats.recentStats().avgInferenceTime(), closeTo(200.0, 0.00001));
         assertThat(stats.lastUsed(), equalTo(Instant.ofEpochMilli(resultTimestamps[3])));
+        assertThat(stats.peakThroughput(), equalTo(3L));
 
         // 2nd period
         processor.processInferenceResult(new PyTorchInferenceResult("foo", null, 100L, null));
@@ -225,6 +228,7 @@ public class PyTorchResultProcessorTests extends ESTestCase {
         assertThat(stats.recentStats().requestsProcessed(), equalTo(1L));
         assertThat(stats.recentStats().avgInferenceTime(), closeTo(100.0, 0.00001));
         assertThat(stats.lastUsed(), equalTo(Instant.ofEpochMilli(resultTimestamps[6])));
+        assertThat(stats.peakThroughput(), equalTo(3L));
 
         stats = processor.getResultStats();
         assertThat(stats.recentStats().requestsProcessed(), equalTo(0L));
@@ -242,6 +246,7 @@ public class PyTorchResultProcessorTests extends ESTestCase {
         processor.processInferenceResult(new PyTorchInferenceResult("foo", null, 390L, null));
         stats = processor.getResultStats();
         assertThat(stats.recentStats().requestsProcessed(), equalTo(0L));
+        assertThat(stats.recentStats().avgInferenceTime(), nullValue());
         stats = processor.getResultStats(); // called in the next period
         assertNotNull(stats.recentStats());
         assertThat(stats.recentStats().requestsProcessed(), equalTo(2L));
@@ -257,6 +262,7 @@ public class PyTorchResultProcessorTests extends ESTestCase {
         assertThat(stats.recentStats().requestsProcessed(), equalTo(3L));
         assertThat(stats.recentStats().avgInferenceTime(), closeTo(500.0, 0.00001));
         assertThat(stats.lastUsed(), equalTo(Instant.ofEpochMilli(resultTimestamps[17])));
+        assertThat(stats.peakThroughput(), equalTo(3L));
     }
 
     private static class TimeSupplier implements LongSupplier {