Browse Source

[ML] Add stat for non cache hit inference time (#90464)

David Kyle 3 years ago
parent
commit
17579ae1af

+ 5 - 0
docs/changelog/90464.yaml

@@ -0,0 +1,5 @@
+pr: 90464
+summary: Add measure of non cache hit inference count
+area: Machine Learning
+type: enhancement
+issues: []

+ 8 - 0
docs/reference/ml/trained-models/apis/get-trained-models-stats.asciidoc

@@ -140,6 +140,14 @@ The deployment stats for each node that currently has the model allocated.
 The average time for each inference call to complete on this node.
 The average is calculated over the lifetime of the deployment.
 
+`average_inference_time_ms_excluding_cache_hits`:::
+(double)
+The average time to perform inference on the trained model excluding
+occasions where the response comes from the cache. Cached inference
+calls return very quickly as the model is not evaluated, by excluding
+cache hits this value is an accurate measure of the average time taken
+to evaluate the model.
+
 `average_inference_time_ms_last_minute`:::
 (double)
 The average time for each inference call to complete on this node

+ 28 - 2
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/AssignmentStats.java

@@ -32,6 +32,7 @@ public class AssignmentStats implements ToXContentObject, Writeable {
         private final DiscoveryNode node;
         private final Long inferenceCount;
         private final Double avgInferenceTime;
+        private final Double avgInferenceTimeExcludingCacheHit;
         private final Instant lastAccess;
         private final Integer pendingCount;
         private final int errorCount;
@@ -51,6 +52,7 @@ public class AssignmentStats implements ToXContentObject, Writeable {
             DiscoveryNode node,
             long inferenceCount,
             Double avgInferenceTime,
+            Double avgInferenceTimeExcludingCacheHit,
             int pendingCount,
             int errorCount,
             long cacheHitCount,
@@ -69,6 +71,7 @@ public class AssignmentStats implements ToXContentObject, Writeable {
                 node,
                 inferenceCount,
                 avgInferenceTime,
+                avgInferenceTimeExcludingCacheHit,
                 lastAccess,
                 pendingCount,
                 errorCount,
@@ -93,6 +96,7 @@ public class AssignmentStats implements ToXContentObject, Writeable {
                 null,
                 null,
                 null,
+                null,
                 0,
                 null,
                 0,
@@ -112,6 +116,7 @@ public class AssignmentStats implements ToXContentObject, Writeable {
             DiscoveryNode node,
             Long inferenceCount,
             Double avgInferenceTime,
+            Double avgInferenceTimeExcludingCacheHit,
             @Nullable Instant lastAccess,
             Integer pendingCount,
             int errorCount,
@@ -130,6 +135,7 @@ public class AssignmentStats implements ToXContentObject, Writeable {
             this.node = node;
             this.inferenceCount = inferenceCount;
             this.avgInferenceTime = avgInferenceTime;
+            this.avgInferenceTimeExcludingCacheHit = avgInferenceTimeExcludingCacheHit;
             this.lastAccess = lastAccess;
             this.pendingCount = pendingCount;
             this.errorCount = errorCount;
@@ -186,6 +192,12 @@ public class AssignmentStats implements ToXContentObject, Writeable {
                 this.cacheHitCount = null;
                 this.cacheHitCountLastPeriod = null;
             }
+            if (in.getVersion().onOrAfter(Version.V_8_5_0)) {
+                this.avgInferenceTimeExcludingCacheHit = in.readOptionalDouble();
+            } else {
+                this.avgInferenceTimeExcludingCacheHit = null;
+            }
+
         }
 
         public DiscoveryNode getNode() {
@@ -204,6 +216,10 @@ public class AssignmentStats implements ToXContentObject, Writeable {
             return Optional.ofNullable(avgInferenceTime);
         }
 
+        public Optional<Double> getAvgInferenceTimeExcludingCacheHit() {
+            return Optional.ofNullable(avgInferenceTimeExcludingCacheHit);
+        }
+
         public Instant getLastAccess() {
             return lastAccess;
         }
@@ -269,8 +285,13 @@ public class AssignmentStats implements ToXContentObject, Writeable {
                 builder.field("inference_count", inferenceCount);
             }
             // avoid reporting the average time as 0 if count < 1
-            if (avgInferenceTime != null && (inferenceCount != null && inferenceCount > 0)) {
-                builder.field("average_inference_time_ms", avgInferenceTime);
+            if (inferenceCount != null && inferenceCount > 0) {
+                if (avgInferenceTime != null) {
+                    builder.field("average_inference_time_ms", avgInferenceTime);
+                }
+                if (avgInferenceTimeExcludingCacheHit != null) {
+                    builder.field("average_inference_time_ms_excluding_cache_hits", avgInferenceTimeExcludingCacheHit);
+                }
             }
             if (cacheHitCount != null) {
                 builder.field("inference_cache_hit_count", cacheHitCount);
@@ -337,6 +358,9 @@ public class AssignmentStats implements ToXContentObject, Writeable {
                 out.writeOptionalVLong(cacheHitCount);
                 out.writeOptionalVLong(cacheHitCountLastPeriod);
             }
+            if (out.getVersion().onOrAfter(Version.V_8_5_0)) {
+                out.writeOptionalDouble(avgInferenceTimeExcludingCacheHit);
+            }
         }
 
         @Override
@@ -346,6 +370,7 @@ public class AssignmentStats implements ToXContentObject, Writeable {
             AssignmentStats.NodeStats that = (AssignmentStats.NodeStats) o;
             return Objects.equals(inferenceCount, that.inferenceCount)
                 && Objects.equals(that.avgInferenceTime, avgInferenceTime)
+                && Objects.equals(that.avgInferenceTimeExcludingCacheHit, avgInferenceTimeExcludingCacheHit)
                 && Objects.equals(node, that.node)
                 && Objects.equals(lastAccess, that.lastAccess)
                 && Objects.equals(pendingCount, that.pendingCount)
@@ -369,6 +394,7 @@ public class AssignmentStats implements ToXContentObject, Writeable {
                 node,
                 inferenceCount,
                 avgInferenceTime,
+                avgInferenceTimeExcludingCacheHit,
                 lastAccess,
                 pendingCount,
                 errorCount,

+ 61 - 1
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsActionResponseTests.java

@@ -28,7 +28,8 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractBWCWireSer
 
     @Override
     protected Response createTestInstance() {
-        int listSize = randomInt(10);
+        // int listSize = randomInt(10);
+        int listSize = 1;
         List<Response.TrainedModelStats> trainedModelStats = Stream.generate(() -> randomAlphaOfLength(10))
             .limit(listSize)
             .map(
@@ -123,6 +124,7 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractBWCWireSer
                                                     nodeStats.getNode(),
                                                     nodeStats.getInferenceCount().orElse(null),
                                                     nodeStats.getAvgInferenceTime().orElse(null),
+                                                    null,
                                                     nodeStats.getLastAccess(),
                                                     nodeStats.getPendingCount(),
                                                     0,
@@ -178,6 +180,7 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractBWCWireSer
                                                     nodeStats.getNode(),
                                                     nodeStats.getInferenceCount().orElse(null),
                                                     nodeStats.getAvgInferenceTime().orElse(null),
+                                                    null,
                                                     nodeStats.getLastAccess(),
                                                     nodeStats.getPendingCount(),
                                                     nodeStats.getErrorCount(),
@@ -233,6 +236,7 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractBWCWireSer
                                                     nodeStats.getNode(),
                                                     nodeStats.getInferenceCount().orElse(null),
                                                     nodeStats.getAvgInferenceTime().orElse(null),
+                                                    null,
                                                     nodeStats.getLastAccess(),
                                                     nodeStats.getPendingCount(),
                                                     nodeStats.getErrorCount(),
@@ -258,6 +262,62 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractBWCWireSer
                     RESULTS_FIELD
                 )
             );
+        } else if (version.before(Version.V_8_5_0)) {
+            return new Response(
+                new QueryPage<>(
+                    instance.getResources()
+                        .results()
+                        .stream()
+                        .map(
+                            stats -> new Response.TrainedModelStats(
+                                stats.getModelId(),
+                                stats.getModelSizeStats(),
+                                stats.getIngestStats(),
+                                stats.getPipelineCount(),
+                                stats.getInferenceStats(),
+                                stats.getDeploymentStats() == null
+                                    ? null
+                                    : new AssignmentStats(
+                                        stats.getDeploymentStats().getModelId(),
+                                        stats.getDeploymentStats().getThreadsPerAllocation(),
+                                        stats.getDeploymentStats().getNumberOfAllocations(),
+                                        stats.getDeploymentStats().getQueueCapacity(),
+                                        stats.getDeploymentStats().getCacheSize(),
+                                        stats.getDeploymentStats().getStartTime(),
+                                        stats.getDeploymentStats()
+                                            .getNodeStats()
+                                            .stream()
+                                            .map(
+                                                nodeStats -> new AssignmentStats.NodeStats(
+                                                    nodeStats.getNode(),
+                                                    nodeStats.getInferenceCount().orElse(null),
+                                                    nodeStats.getAvgInferenceTime().orElse(null),
+                                                    null,
+                                                    nodeStats.getLastAccess(),
+                                                    nodeStats.getPendingCount(),
+                                                    nodeStats.getErrorCount(),
+                                                    nodeStats.getCacheHitCount().orElse(null),
+                                                    nodeStats.getRejectedExecutionCount(),
+                                                    nodeStats.getTimeoutCount(),
+                                                    nodeStats.getRoutingState(),
+                                                    nodeStats.getStartTime(),
+                                                    nodeStats.getThreadsPerAllocation(),
+                                                    nodeStats.getNumberOfAllocations(),
+                                                    nodeStats.getPeakThroughput(),
+                                                    nodeStats.getThroughputLastPeriod(),
+                                                    nodeStats.getAvgInferenceTimeLastPeriod(),
+                                                    nodeStats.getCacheHitCountLastPeriod().orElse(null)
+                                                )
+                                            )
+                                            .toList()
+                                    )
+                            )
+                        )
+                        .toList(),
+                    instance.getResources().count(),
+                    RESULTS_FIELD
+                )
+            );
         }
         return instance;
     }

+ 5 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/AssignmentStatsTests.java

@@ -58,6 +58,7 @@ public class AssignmentStatsTests extends AbstractWireSerializingTestCase<Assign
         var lastAccess = Instant.now();
         var inferenceCount = randomNonNegativeLong();
         Double avgInferenceTime = randomDoubleBetween(0.0, 100.0, true);
+        Double avgInferenceTimeExcludingCacheHit = randomDoubleBetween(0.0, 100.0, true);
         Double avgInferenceTimeLastPeriod = randomDoubleBetween(0.0, 100.0, true);
 
         var noInferenceCallsOnNodeYet = randomBoolean();
@@ -65,12 +66,14 @@ public class AssignmentStatsTests extends AbstractWireSerializingTestCase<Assign
             lastAccess = null;
             inferenceCount = 0;
             avgInferenceTime = null;
+            avgInferenceTimeExcludingCacheHit = null;
             avgInferenceTimeLastPeriod = null;
         }
         return AssignmentStats.NodeStats.forStartedState(
             node,
             inferenceCount,
             avgInferenceTime,
+            avgInferenceTimeExcludingCacheHit,
             randomIntBetween(0, 100),
             randomIntBetween(0, 100),
             randomLongBetween(0, 100),
@@ -102,6 +105,7 @@ public class AssignmentStatsTests extends AbstractWireSerializingTestCase<Assign
                     new DiscoveryNode("node_started_1", buildNewFakeTransportAddress(), Version.CURRENT),
                     10L,
                     randomDoubleBetween(0.0, 100.0, true),
+                    randomDoubleBetween(0.0, 100.0, true),
                     randomIntBetween(1, 10),
                     5,
                     4L,
@@ -120,6 +124,7 @@ public class AssignmentStatsTests extends AbstractWireSerializingTestCase<Assign
                     new DiscoveryNode("node_started_2", buildNewFakeTransportAddress(), Version.CURRENT),
                     12L,
                     randomDoubleBetween(0.0, 100.0, true),
+                    randomDoubleBetween(0.0, 100.0, true),
                     randomIntBetween(1, 10),
                     15,
                     3L,

+ 3 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsAction.java

@@ -309,8 +309,9 @@ public class TransportGetDeploymentStatsAction extends TransportTasksAction<
             nodeStats.add(
                 AssignmentStats.NodeStats.forStartedState(
                     clusterService.localNode(),
-                    presentValue.timingStats().getCount(),
-                    presentValue.timingStats().getAverage(),
+                    presentValue.inferenceCount(),
+                    presentValue.averageInferenceTime(),
+                    presentValue.averageInferenceTimeNoCacheHits(),
                     presentValue.pendingCount(),
                     presentValue.errorCount(),
                     presentValue.cacheHitCount(),

+ 3 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java

@@ -103,7 +103,9 @@ public class DeploymentManager {
             var recentStats = stats.recentStats();
             return new ModelStats(
                 processContext.startTime,
-                stats.timingStats(),
+                stats.timingStats().getCount(),
+                stats.timingStats().getAverage(),
+                stats.timingStatsExcludingCacheHits().getAverage(),
                 stats.lastUsed(),
                 processContext.executorService.queueSize() + stats.numberOfPendingResults(),
                 stats.errorCount(),

+ 3 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/ModelStats.java

@@ -8,11 +8,12 @@
 package org.elasticsearch.xpack.ml.inference.deployment;
 
 import java.time.Instant;
-import java.util.LongSummaryStatistics;
 
 public record ModelStats(
     Instant startTime,
-    LongSummaryStatistics timingStats,
+    long inferenceCount,
+    Double averageInferenceTime,
+    Double averageInferenceTimeNoCacheHits,
     Instant lastUsed,
     int pendingCount,
     int errorCount,

+ 14 - 3
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java

@@ -35,6 +35,7 @@ public class PyTorchResultProcessor {
 
     public record ResultStats(
         LongSummaryStatistics timingStats,
+        LongSummaryStatistics timingStatsExcludingCacheHits,
         int errorCount,
         long cacheHitCount,
         int numberOfPendingResults,
@@ -51,6 +52,7 @@ public class PyTorchResultProcessor {
     private final Consumer<ThreadSettings> threadSettingsConsumer;
     private volatile boolean isStopping;
     private final LongSummaryStatistics timingStats;
+    private final LongSummaryStatistics timingStatsExcludingCacheHits;
     private int errorCount;
     private long cacheHitCount;
     private long peakThroughput;
@@ -71,6 +73,7 @@ public class PyTorchResultProcessor {
     PyTorchResultProcessor(String deploymentId, Consumer<ThreadSettings> threadSettingsConsumer, LongSupplier currentTimeSupplier) {
         this.deploymentId = Objects.requireNonNull(deploymentId);
         this.timingStats = new LongSummaryStatistics();
+        this.timingStatsExcludingCacheHits = new LongSummaryStatistics();
         this.lastPeriodSummaryStats = new LongSummaryStatistics();
         this.threadSettingsConsumer = Objects.requireNonNull(threadSettingsConsumer);
         this.currentTimeMsSupplier = currentTimeSupplier;
@@ -157,7 +160,7 @@ public class PyTorchResultProcessor {
         }
 
         logger.trace(() -> format("[%s] Parsed inference result with id [%s]", deploymentId, result.requestId()));
-        processResult(inferenceResult, timeMs, Boolean.TRUE.equals(result.isCacheHit()));
+        updateStats(timeMs, Boolean.TRUE.equals(result.isCacheHit()));
         PendingResult pendingResult = pendingResults.remove(result.requestId());
         if (pendingResult == null) {
             logger.debug(() -> format("[%s] no pending result for inference [%s]", deploymentId, result.requestId()));
@@ -235,7 +238,8 @@ public class PyTorchResultProcessor {
         }
 
         return new ResultStats(
-            new LongSummaryStatistics(timingStats.getCount(), timingStats.getMin(), timingStats.getMax(), timingStats.getSum()),
+            cloneSummaryStats(timingStats),
+            cloneSummaryStats(timingStatsExcludingCacheHits),
             errorCount,
             cacheHitCount,
             pendingResults.size(),
@@ -245,7 +249,11 @@ public class PyTorchResultProcessor {
         );
     }
 
-    private synchronized void processResult(PyTorchInferenceResult result, long timeMs, boolean isCacheHit) {
+    private LongSummaryStatistics cloneSummaryStats(LongSummaryStatistics stats) {
+        return new LongSummaryStatistics(stats.getCount(), stats.getMin(), stats.getMax(), stats.getSum());
+    }
+
+    private synchronized void updateStats(long timeMs, boolean isCacheHit) {
         timingStats.accept(timeMs);
 
         lastResultTimeMs = currentTimeMsSupplier.getAsLong();
@@ -278,6 +286,9 @@ public class PyTorchResultProcessor {
         if (isCacheHit) {
             cacheHitCount++;
             lastPeriodCacheHitCount++;
+        } else {
+            // don't include cache hits when recording inference time
+            timingStatsExcludingCacheHits.accept(timeMs);
         }
     }
 

+ 2 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MachineLearningInfoTransportActionTests.java

@@ -381,6 +381,7 @@ public class MachineLearningInfoTransportActionTests extends ESTestCase {
                                         new DiscoveryNode("foo", new TransportAddress(TransportAddress.META_ADDRESS, 2), Version.CURRENT),
                                         5,
                                         42.0,
+                                        42.0,
                                         0,
                                         1,
                                         3L,
@@ -399,6 +400,7 @@ public class MachineLearningInfoTransportActionTests extends ESTestCase {
                                         new DiscoveryNode("bar", new TransportAddress(TransportAddress.META_ADDRESS, 3), Version.CURRENT),
                                         4,
                                         50.0,
+                                        50.0,
                                         0,
                                         1,
                                         1L,

+ 21 - 15
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessorTests.java

@@ -153,8 +153,8 @@ public class PyTorchResultProcessorTests extends ESTestCase {
         }
     }
 
-    private PyTorchResult wrapInferenceResult(String requestId, boolean isCacheHit, long timeMs, PyTorchInferenceResult result) {
-        return new PyTorchResult(requestId, isCacheHit, timeMs, result, null, null, null);
+    private PyTorchResult wrapInferenceResult(String requestId, boolean isCacheHit, long timeMs) {
+        return new PyTorchResult(requestId, isCacheHit, timeMs, new PyTorchInferenceResult(null), null, null, null);
     }
 
     public void testsStats() {
@@ -168,9 +168,9 @@ public class PyTorchResultProcessorTests extends ESTestCase {
         processor.registerRequest("b", pendingB);
         processor.registerRequest("c", pendingC);
 
-        var a = wrapInferenceResult("a", false, 1000L, new PyTorchInferenceResult(null));
-        var b = wrapInferenceResult("b", false, 900L, new PyTorchInferenceResult(null));
-        var c = wrapInferenceResult("c", true, 200L, new PyTorchInferenceResult(null));
+        var a = wrapInferenceResult("a", false, 1000L);
+        var b = wrapInferenceResult("b", false, 900L);
+        var c = wrapInferenceResult("c", true, 200L); // cache hit
 
         processor.processInferenceResult(a);
         var stats = processor.getResultStats();
@@ -179,6 +179,8 @@ public class PyTorchResultProcessorTests extends ESTestCase {
         assertThat(stats.numberOfPendingResults(), equalTo(2));
         assertThat(stats.timingStats().getCount(), equalTo(1L));
         assertThat(stats.timingStats().getSum(), equalTo(1000L));
+        assertThat(stats.timingStatsExcludingCacheHits().getCount(), equalTo(1L));
+        assertThat(stats.timingStatsExcludingCacheHits().getSum(), equalTo(1000L));
 
         processor.processInferenceResult(b);
         stats = processor.getResultStats();
@@ -187,6 +189,8 @@ public class PyTorchResultProcessorTests extends ESTestCase {
         assertThat(stats.numberOfPendingResults(), equalTo(1));
         assertThat(stats.timingStats().getCount(), equalTo(2L));
         assertThat(stats.timingStats().getSum(), equalTo(1900L));
+        assertThat(stats.timingStatsExcludingCacheHits().getCount(), equalTo(2L));
+        assertThat(stats.timingStatsExcludingCacheHits().getSum(), equalTo(1900L));
 
         processor.processInferenceResult(c);
         stats = processor.getResultStats();
@@ -195,6 +199,8 @@ public class PyTorchResultProcessorTests extends ESTestCase {
         assertThat(stats.numberOfPendingResults(), equalTo(0));
         assertThat(stats.timingStats().getCount(), equalTo(3L));
         assertThat(stats.timingStats().getSum(), equalTo(2100L));
+        assertThat(stats.timingStatsExcludingCacheHits().getCount(), equalTo(2L));
+        assertThat(stats.timingStatsExcludingCacheHits().getSum(), equalTo(1900L));
     }
 
     public void testsTimeDependentStats() {
@@ -234,9 +240,9 @@ public class PyTorchResultProcessorTests extends ESTestCase {
         var processor = new PyTorchResultProcessor("foo", s -> {}, timeSupplier);
 
         // 1st period
-        processor.processInferenceResult(wrapInferenceResult("foo", false, 200L, new PyTorchInferenceResult(null)));
-        processor.processInferenceResult(wrapInferenceResult("foo", false, 200L, new PyTorchInferenceResult(null)));
-        processor.processInferenceResult(wrapInferenceResult("foo", false, 200L, new PyTorchInferenceResult(null)));
+        processor.processInferenceResult(wrapInferenceResult("foo", false, 200L));
+        processor.processInferenceResult(wrapInferenceResult("foo", false, 200L));
+        processor.processInferenceResult(wrapInferenceResult("foo", false, 200L));
         // first call has no results as is in the same period
         var stats = processor.getResultStats();
         assertThat(stats.recentStats().requestsProcessed(), equalTo(0L));
@@ -250,7 +256,7 @@ public class PyTorchResultProcessorTests extends ESTestCase {
         assertThat(stats.peakThroughput(), equalTo(3L));
 
         // 2nd period
-        processor.processInferenceResult(wrapInferenceResult("foo", false, 100L, new PyTorchInferenceResult(null)));
+        processor.processInferenceResult(wrapInferenceResult("foo", false, 100L));
         stats = processor.getResultStats();
         assertNotNull(stats.recentStats());
         assertThat(stats.recentStats().requestsProcessed(), equalTo(1L));
@@ -262,7 +268,7 @@ public class PyTorchResultProcessorTests extends ESTestCase {
         assertThat(stats.recentStats().requestsProcessed(), equalTo(0L));
 
         // 4th period
-        processor.processInferenceResult(wrapInferenceResult("foo", false, 300L, new PyTorchInferenceResult(null)));
+        processor.processInferenceResult(wrapInferenceResult("foo", false, 300L));
         stats = processor.getResultStats();
         assertNotNull(stats.recentStats());
         assertThat(stats.recentStats().requestsProcessed(), equalTo(1L));
@@ -270,8 +276,8 @@ public class PyTorchResultProcessorTests extends ESTestCase {
         assertThat(stats.lastUsed(), equalTo(Instant.ofEpochMilli(resultTimestamps[9])));
 
         // 7th period
-        processor.processInferenceResult(wrapInferenceResult("foo", false, 410L, new PyTorchInferenceResult(null)));
-        processor.processInferenceResult(wrapInferenceResult("foo", false, 390L, new PyTorchInferenceResult(null)));
+        processor.processInferenceResult(wrapInferenceResult("foo", false, 410L));
+        processor.processInferenceResult(wrapInferenceResult("foo", false, 390L));
         stats = processor.getResultStats();
         assertThat(stats.recentStats().requestsProcessed(), equalTo(0L));
         assertThat(stats.recentStats().avgInferenceTime(), nullValue());
@@ -282,9 +288,9 @@ public class PyTorchResultProcessorTests extends ESTestCase {
         assertThat(stats.lastUsed(), equalTo(Instant.ofEpochMilli(resultTimestamps[12])));
 
         // 8th period
-        processor.processInferenceResult(wrapInferenceResult("foo", false, 510L, new PyTorchInferenceResult(null)));
-        processor.processInferenceResult(wrapInferenceResult("foo", false, 500L, new PyTorchInferenceResult(null)));
-        processor.processInferenceResult(wrapInferenceResult("foo", false, 490L, new PyTorchInferenceResult(null)));
+        processor.processInferenceResult(wrapInferenceResult("foo", false, 510L));
+        processor.processInferenceResult(wrapInferenceResult("foo", false, 500L));
+        processor.processInferenceResult(wrapInferenceResult("foo", false, 490L));
         stats = processor.getResultStats();
         assertNotNull(stats.recentStats());
         assertThat(stats.recentStats().requestsProcessed(), equalTo(3L));