1
0
Эх сурвалжийг харах

[ML] Add inference cache hit count to inference node stats (#88807)

The inference node stats for deployed PyTorch inference
models now contain two new fields: `inference_cache_hit_count`
and `inference_cache_hit_count_last_minute`.

These indicate how many inferences on that node were served
from the C++-side response cache that was added in
https://github.com/elastic/ml-cpp/pull/2305. Cache hits
occur when exactly the same inference request is sent to the
same node more than once.

The `average_inference_time_ms` and
`average_inference_time_ms_last_minute` fields now refer to
the time taken to do the cache lookup, plus, if necessary,
the time to do the inference. We would expect average inference
time to be vastly reduced in situations where the cache hit
rate is high.
David Roberts 3 жил өмнө
parent
commit
15e7b06b79
19 өөрчлөгдсөн 223 нэмэгдсэн , 71 устгасан
  1. 5 0
      docs/changelog/88807.yaml
  2. 10 0
      docs/reference/ml/trained-models/apis/get-trained-models-stats.asciidoc
  3. 52 8
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/AssignmentStats.java
  4. 11 6
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsActionResponseTests.java
  5. 9 3
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/AssignmentStatsTests.java
  6. 24 5
      x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java
  7. 3 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsAction.java
  8. 5 2
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java
  9. 3 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/ModelStats.java
  10. 15 7
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java
  11. 8 10
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchInferenceResult.java
  12. 7 3
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MachineLearningInfoTransportActionTests.java
  13. 2 2
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java
  14. 5 5
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java
  15. 1 1
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/QuestionAnsweringProcessorTests.java
  16. 2 2
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessorTests.java
  17. 14 14
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessorTests.java
  18. 1 1
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchInferenceResultTests.java
  19. 46 0
      x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/3rd_party_deployment.yml

+ 5 - 0
docs/changelog/88807.yaml

@@ -0,0 +1,5 @@
+pr: 88807
+summary: Add inference cache hit count to inference node stats
+area: Machine Learning
+type: enhancement
+issues: []

+ 10 - 0
docs/reference/ml/trained-models/apis/get-trained-models-stats.asciidoc

@@ -149,6 +149,16 @@ in the last minute.
 (integer)
 The number of errors when evaluating the trained model.
 
+`inference_cache_hit_count`:::
+(integer)
+The total number of inference calls made against this node for this
+model that were served from the inference cache.
+
+`inference_cache_hit_count_last_minute`:::
+(integer)
+The number of inference calls made against this node for this model
+in the last minute that were served from the inference cache.
+
 `inference_count`:::
 (integer)
 The total number of inference calls made against this node for this model.

+ 52 - 8
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/AssignmentStats.java

@@ -35,6 +35,7 @@ public class AssignmentStats implements ToXContentObject, Writeable {
         private final Instant lastAccess;
         private final Integer pendingCount;
         private final int errorCount;
+        private final Long cacheHitCount;
         private final int rejectedExecutionCount;
         private final int timeoutCount;
         private final RoutingStateAndReason routingState;
@@ -44,6 +45,7 @@ public class AssignmentStats implements ToXContentObject, Writeable {
         private final long peakThroughput;
         private final long throughputLastPeriod;
         private final Double avgInferenceTimeLastPeriod;
+        private final Long cacheHitCountLastPeriod;
 
         public static AssignmentStats.NodeStats forStartedState(
             DiscoveryNode node,
@@ -51,6 +53,7 @@ public class AssignmentStats implements ToXContentObject, Writeable {
             Double avgInferenceTime,
             int pendingCount,
             int errorCount,
+            long cacheHitCount,
             int rejectedExecutionCount,
             int timeoutCount,
             Instant lastAccess,
@@ -59,7 +62,8 @@ public class AssignmentStats implements ToXContentObject, Writeable {
             Integer numberOfAllocations,
             long peakThroughput,
             long throughputLastPeriod,
-            Double avgInferenceTimeLastPeriod
+            Double avgInferenceTimeLastPeriod,
+            long cacheHitCountLastPeriod
         ) {
             return new AssignmentStats.NodeStats(
                 node,
@@ -68,6 +72,7 @@ public class AssignmentStats implements ToXContentObject, Writeable {
                 lastAccess,
                 pendingCount,
                 errorCount,
+                cacheHitCount,
                 rejectedExecutionCount,
                 timeoutCount,
                 new RoutingStateAndReason(RoutingState.STARTED, null),
@@ -76,7 +81,8 @@ public class AssignmentStats implements ToXContentObject, Writeable {
                 numberOfAllocations,
                 peakThroughput,
                 throughputLastPeriod,
-                avgInferenceTimeLastPeriod
+                avgInferenceTimeLastPeriod,
+                cacheHitCountLastPeriod
             );
         }
 
@@ -88,14 +94,16 @@ public class AssignmentStats implements ToXContentObject, Writeable {
                 null,
                 null,
                 0,
+                null,
                 0,
                 0,
                 new RoutingStateAndReason(state, reason),
                 null,
                 null,
                 null,
-                0,
-                0,
+                0L,
+                0L,
+                null,
                 null
             );
         }
@@ -107,6 +115,7 @@ public class AssignmentStats implements ToXContentObject, Writeable {
             @Nullable Instant lastAccess,
             Integer pendingCount,
             int errorCount,
+            Long cacheHitCount,
             int rejectedExecutionCount,
             int timeoutCount,
             RoutingStateAndReason routingState,
@@ -115,7 +124,8 @@ public class AssignmentStats implements ToXContentObject, Writeable {
             @Nullable Integer numberOfAllocations,
             long peakThroughput,
             long throughputLastPeriod,
-            Double avgInferenceTimeLastPeriod
+            Double avgInferenceTimeLastPeriod,
+            Long cacheHitCountLastPeriod
         ) {
             this.node = node;
             this.inferenceCount = inferenceCount;
@@ -123,6 +133,7 @@ public class AssignmentStats implements ToXContentObject, Writeable {
             this.lastAccess = lastAccess;
             this.pendingCount = pendingCount;
             this.errorCount = errorCount;
+            this.cacheHitCount = cacheHitCount;
             this.rejectedExecutionCount = rejectedExecutionCount;
             this.timeoutCount = timeoutCount;
             this.routingState = routingState;
@@ -132,6 +143,7 @@ public class AssignmentStats implements ToXContentObject, Writeable {
             this.peakThroughput = peakThroughput;
             this.throughputLastPeriod = throughputLastPeriod;
             this.avgInferenceTimeLastPeriod = avgInferenceTimeLastPeriod;
+            this.cacheHitCountLastPeriod = cacheHitCountLastPeriod;
 
             // if lastAccess time is null there have been no inferences
             assert this.lastAccess != null || (inferenceCount == null || inferenceCount == 0);
@@ -167,6 +179,13 @@ public class AssignmentStats implements ToXContentObject, Writeable {
                 this.throughputLastPeriod = 0;
                 this.avgInferenceTimeLastPeriod = null;
             }
+            if (in.getVersion().onOrAfter(Version.V_8_4_0)) {
+                this.cacheHitCount = in.readOptionalVLong();
+                this.cacheHitCountLastPeriod = in.readOptionalVLong();
+            } else {
+                this.cacheHitCount = null;
+                this.cacheHitCountLastPeriod = null;
+            }
         }
 
         public DiscoveryNode getNode() {
@@ -197,6 +216,10 @@ public class AssignmentStats implements ToXContentObject, Writeable {
             return errorCount;
         }
 
+        public Optional<Long> getCacheHitCount() {
+            return Optional.ofNullable(cacheHitCount);
+        }
+
         public int getRejectedExecutionCount() {
             return rejectedExecutionCount;
         }
@@ -229,6 +252,10 @@ public class AssignmentStats implements ToXContentObject, Writeable {
             return avgInferenceTimeLastPeriod;
         }
 
+        public Optional<Long> getCacheHitCountLastPeriod() {
+            return Optional.ofNullable(cacheHitCountLastPeriod);
+        }
+
         @Override
         public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
             builder.startObject();
@@ -245,6 +272,9 @@ public class AssignmentStats implements ToXContentObject, Writeable {
             if (avgInferenceTime != null && (inferenceCount != null && inferenceCount > 0)) {
                 builder.field("average_inference_time_ms", avgInferenceTime);
             }
+            if (cacheHitCount != null) {
+                builder.field("inference_cache_hit_count", cacheHitCount);
+            }
             if (lastAccess != null) {
                 builder.timeField("last_access", "last_access_string", lastAccess.toEpochMilli());
             }
@@ -274,6 +304,9 @@ public class AssignmentStats implements ToXContentObject, Writeable {
             if (avgInferenceTimeLastPeriod != null) {
                 builder.field("average_inference_time_ms_last_minute", avgInferenceTimeLastPeriod);
             }
+            if (cacheHitCountLastPeriod != null) {
+                builder.field("inference_cache_hit_count_last_minute", cacheHitCountLastPeriod);
+            }
 
             builder.endObject();
             return builder;
@@ -300,6 +333,10 @@ public class AssignmentStats implements ToXContentObject, Writeable {
                 out.writeVLong(throughputLastPeriod);
                 out.writeOptionalDouble(avgInferenceTimeLastPeriod);
             }
+            if (out.getVersion().onOrAfter(Version.V_8_4_0)) {
+                out.writeOptionalVLong(cacheHitCount);
+                out.writeOptionalVLong(cacheHitCountLastPeriod);
+            }
         }
 
         @Override
@@ -313,6 +350,7 @@ public class AssignmentStats implements ToXContentObject, Writeable {
                 && Objects.equals(lastAccess, that.lastAccess)
                 && Objects.equals(pendingCount, that.pendingCount)
                 && Objects.equals(errorCount, that.errorCount)
+                && Objects.equals(cacheHitCount, that.cacheHitCount)
                 && Objects.equals(rejectedExecutionCount, that.rejectedExecutionCount)
                 && Objects.equals(timeoutCount, that.timeoutCount)
                 && Objects.equals(routingState, that.routingState)
@@ -321,7 +359,8 @@ public class AssignmentStats implements ToXContentObject, Writeable {
                 && Objects.equals(numberOfAllocations, that.numberOfAllocations)
                 && Objects.equals(peakThroughput, that.peakThroughput)
                 && Objects.equals(throughputLastPeriod, that.throughputLastPeriod)
-                && Objects.equals(avgInferenceTimeLastPeriod, that.avgInferenceTimeLastPeriod);
+                && Objects.equals(avgInferenceTimeLastPeriod, that.avgInferenceTimeLastPeriod)
+                && Objects.equals(cacheHitCountLastPeriod, that.cacheHitCountLastPeriod);
         }
 
         @Override
@@ -333,6 +372,7 @@ public class AssignmentStats implements ToXContentObject, Writeable {
                 lastAccess,
                 pendingCount,
                 errorCount,
+                cacheHitCount,
                 rejectedExecutionCount,
                 timeoutCount,
                 routingState,
@@ -341,7 +381,8 @@ public class AssignmentStats implements ToXContentObject, Writeable {
                 numberOfAllocations,
                 peakThroughput,
                 throughputLastPeriod,
-                avgInferenceTimeLastPeriod
+                avgInferenceTimeLastPeriod,
+                cacheHitCountLastPeriod
             );
         }
     }
@@ -462,7 +503,10 @@ public class AssignmentStats implements ToXContentObject, Writeable {
             nodeStats.stream().filter(n -> n.getInferenceCount().isPresent()).mapToLong(n -> n.getInferenceCount().get()).sum(),
             // This is for ALL failures, so sum the error counts, timeouts, and rejections
             nodeStats.stream().mapToLong(n -> n.getErrorCount() + n.getTimeoutCount() + n.getRejectedExecutionCount()).sum(),
-            // TODO Update when we actually have cache miss/hit values
+            // The number below is a cache miss count for the JVM model cache. We know the cache hit count for
+            // the inference cache in the native process, but that's completely different, so it doesn't make
+            // sense to reuse the same field here.
+            // TODO: consider adding another field here for inference cache hits, but mindful of the naming collision
             0L,
             modelId,
             null,

+ 11 - 6
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsActionResponseTests.java

@@ -88,7 +88,7 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractBWCWireSer
                                 null
                             )
                         )
-                        .collect(Collectors.toList()),
+                        .toList(),
                     instance.getResources().count(),
                     RESULTS_FIELD
                 )
@@ -126,6 +126,7 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractBWCWireSer
                                                     nodeStats.getLastAccess(),
                                                     nodeStats.getPendingCount(),
                                                     0,
+                                                    null,
                                                     0,
                                                     0,
                                                     nodeStats.getRoutingState(),
@@ -134,6 +135,7 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractBWCWireSer
                                                     null,
                                                     0L,
                                                     0L,
+                                                    null,
                                                     null
                                                 )
                                             )
@@ -141,7 +143,7 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractBWCWireSer
                                     )
                             )
                         )
-                        .collect(Collectors.toList()),
+                        .toList(),
                     instance.getResources().count(),
                     RESULTS_FIELD
                 )
@@ -179,6 +181,7 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractBWCWireSer
                                                     nodeStats.getLastAccess(),
                                                     nodeStats.getPendingCount(),
                                                     nodeStats.getErrorCount(),
+                                                    null,
                                                     nodeStats.getRejectedExecutionCount(),
                                                     nodeStats.getTimeoutCount(),
                                                     nodeStats.getRoutingState(),
@@ -187,6 +190,7 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractBWCWireSer
                                                     nodeStats.getNumberOfAllocations(),
                                                     0L,
                                                     0L,
+                                                    null,
                                                     null
                                                 )
                                             )
@@ -194,7 +198,7 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractBWCWireSer
                                     )
                             )
                         )
-                        .collect(Collectors.toList()),
+                        .toList(),
                     instance.getResources().count(),
                     RESULTS_FIELD
                 )
@@ -232,6 +236,7 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractBWCWireSer
                                                     nodeStats.getLastAccess(),
                                                     nodeStats.getPendingCount(),
                                                     nodeStats.getErrorCount(),
+                                                    null,
                                                     nodeStats.getRejectedExecutionCount(),
                                                     nodeStats.getTimeoutCount(),
                                                     nodeStats.getRoutingState(),
@@ -240,14 +245,15 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractBWCWireSer
                                                     nodeStats.getNumberOfAllocations(),
                                                     nodeStats.getPeakThroughput(),
                                                     nodeStats.getThroughputLastPeriod(),
-                                                    nodeStats.getAvgInferenceTimeLastPeriod()
+                                                    nodeStats.getAvgInferenceTimeLastPeriod(),
+                                                    null
                                                 )
                                             )
                                             .toList()
                                     )
                             )
                         )
-                        .collect(Collectors.toList()),
+                        .toList(),
                     instance.getResources().count(),
                     RESULTS_FIELD
                 )
@@ -255,5 +261,4 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractBWCWireSer
         }
         return instance;
     }
-
 }

+ 9 - 3
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/AssignmentStatsTests.java

@@ -73,6 +73,7 @@ public class AssignmentStatsTests extends AbstractWireSerializingTestCase<Assign
             avgInferenceTime,
             randomIntBetween(0, 100),
             randomIntBetween(0, 100),
+            randomLongBetween(0, 100),
             randomIntBetween(0, 100),
             randomIntBetween(0, 100),
             lastAccess,
@@ -81,7 +82,8 @@ public class AssignmentStatsTests extends AbstractWireSerializingTestCase<Assign
             randomIntBetween(1, 16),
             randomIntBetween(0, 100),
             randomIntBetween(0, 100),
-            avgInferenceTimeLastPeriod
+            avgInferenceTimeLastPeriod,
+            randomLongBetween(0, 100)
         );
     }
 
@@ -102,6 +104,7 @@ public class AssignmentStatsTests extends AbstractWireSerializingTestCase<Assign
                     randomDoubleBetween(0.0, 100.0, true),
                     randomIntBetween(1, 10),
                     5,
+                    4L,
                     12,
                     3,
                     Instant.now(),
@@ -110,7 +113,8 @@ public class AssignmentStatsTests extends AbstractWireSerializingTestCase<Assign
                     randomIntBetween(1, 2),
                     randomNonNegativeLong(),
                     randomNonNegativeLong(),
-                    null
+                    null,
+                    1L
                 ),
                 AssignmentStats.NodeStats.forStartedState(
                     new DiscoveryNode("node_started_2", buildNewFakeTransportAddress(), Version.CURRENT),
@@ -118,6 +122,7 @@ public class AssignmentStatsTests extends AbstractWireSerializingTestCase<Assign
                     randomDoubleBetween(0.0, 100.0, true),
                     randomIntBetween(1, 10),
                     15,
+                    3L,
                     4,
                     2,
                     Instant.now(),
@@ -126,7 +131,8 @@ public class AssignmentStatsTests extends AbstractWireSerializingTestCase<Assign
                     randomIntBetween(1, 2),
                     randomNonNegativeLong(),
                     randomNonNegativeLong(),
-                    null
+                    null,
+                    1L
                 ),
                 AssignmentStats.NodeStats.forNotStartedState(
                     new DiscoveryNode("node_not_started_3", buildNewFakeTransportAddress(), Version.CURRENT),

+ 24 - 5
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java

@@ -268,7 +268,7 @@ public class PyTorchModelIT extends ESRestTestCase {
             assertThat(stats, hasSize(1));
             String stringModelSizeBytes = (String) XContentMapValues.extractValue("model_size_stats.model_size", stats.get(0));
             assertThat(
-                "stats response: " + responseMap + " human stats response" + humanResponseMap,
+                "stats response: " + responseMap + " human stats response " + humanResponseMap,
                 stringModelSizeBytes,
                 is(not(nullValue()))
             );
@@ -323,6 +323,10 @@ public class PyTorchModelIT extends ESRestTestCase {
 
         infer("once", modelA);
         infer("twice", modelA);
+        // By making this request 3 times at least one of the responses must come from the cache because the cluster has 2 ML nodes
+        infer("three times", modelA);
+        infer("three times", modelA);
+        infer("three times", modelA);
         {
             Response postInferStatsResponse = getTrainedModelStats(modelA);
             List<Map<String, Object>> stats = (List<Map<String, Object>>) entityAsMap(postInferStatsResponse).get("trained_model_stats");
@@ -338,13 +342,17 @@ public class PyTorchModelIT extends ESRestTestCase {
             for (var node : nodes) {
                 assertThat(node.get("number_of_pending_requests"), notNullValue());
             }
-            // last_access and average_inference_time_ms may be null if inference wasn't performed on this node
+            assertAtLeastOneOfTheseIsNonZero("inference_cache_hit_count", nodes);
+            // last_access, average_inference_time_ms and inference_cache_hit_count_last_minute
+            // may be null if inference wasn't performed on a node. Also, in this test they'll
+            // be zero even when they are present because we don't have a full minute of history.
             assertAtLeastOneOfTheseIsNotNull("last_access", nodes);
             assertAtLeastOneOfTheseIsNotNull("average_inference_time_ms", nodes);
+            assertAtLeastOneOfTheseIsNotNull("inference_cache_hit_count_last_minute", nodes);
 
-            assertThat((Integer) XContentMapValues.extractValue("inference_stats.inference_count", stats.get(0)), equalTo(2));
+            assertThat((Integer) XContentMapValues.extractValue("inference_stats.inference_count", stats.get(0)), equalTo(5));
             int inferenceCount = sumInferenceCountOnNodes(nodes);
-            assertThat(inferenceCount, equalTo(2));
+            assertThat(inferenceCount, equalTo(5));
         }
     }
 
@@ -373,7 +381,18 @@ public class PyTorchModelIT extends ESRestTestCase {
     }
 
     private void assertAtLeastOneOfTheseIsNotNull(String name, List<Map<String, Object>> nodes) {
-        assertTrue("all nodes have null value for [" + name + "]", nodes.stream().anyMatch(n -> n.get(name) != null));
+        assertTrue("all nodes have null value for [" + name + "] in " + nodes, nodes.stream().anyMatch(n -> n.get(name) != null));
+    }
+
+    private void assertAtLeastOneOfTheseIsNonZero(String name, List<Map<String, Object>> nodes) {
+        assertTrue("all nodes have null or zero value for [" + name + "] in " + nodes, nodes.stream().anyMatch(n -> {
+            Object o = n.get(name);
+            if (o instanceof Number) {
+                return ((Number) o).longValue() != 0;
+            } else {
+                return false;
+            }
+        }));
     }
 
     @SuppressWarnings("unchecked")

+ 3 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsAction.java

@@ -303,6 +303,7 @@ public class TransportGetDeploymentStatsAction extends TransportTasksAction<
                     presentValue.timingStats().getAverage(),
                     presentValue.pendingCount(),
                     presentValue.errorCount(),
+                    presentValue.cacheHitCount(),
                     presentValue.rejectedExecutionCount(),
                     presentValue.timeoutCount(),
                     presentValue.lastUsed(),
@@ -311,7 +312,8 @@ public class TransportGetDeploymentStatsAction extends TransportTasksAction<
                     presentValue.numberOfAllocations(),
                     presentValue.peakThroughput(),
                     presentValue.throughputLastPeriod(),
-                    presentValue.avgInferenceTimeLastPeriod()
+                    presentValue.avgInferenceTimeLastPeriod(),
+                    presentValue.cacheHitCountLastPeriod()
                 )
             );
         } else {

+ 5 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java

@@ -98,19 +98,22 @@ public class DeploymentManager {
     public Optional<ModelStats> getStats(TrainedModelDeploymentTask task) {
         return Optional.ofNullable(processContextByAllocation.get(task.getId())).map(processContext -> {
             var stats = processContext.getResultProcessor().getResultStats();
+            var recentStats = stats.recentStats();
             return new ModelStats(
                 processContext.startTime,
                 stats.timingStats(),
                 stats.lastUsed(),
                 processContext.executorService.queueSize() + stats.numberOfPendingResults(),
                 stats.errorCount(),
+                stats.cacheHitCount(),
                 processContext.rejectedExecutionCount.intValue(),
                 processContext.timeoutCount.intValue(),
                 processContext.numThreadsPerAllocation,
                 processContext.numAllocations,
                 stats.peakThroughput(),
-                stats.recentStats().requestsProcessed(),
-                stats.recentStats().avgInferenceTime()
+                recentStats.requestsProcessed(),
+                recentStats.avgInferenceTime(),
+                recentStats.cacheHitCount()
             );
         });
     }

+ 3 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/ModelStats.java

@@ -16,11 +16,13 @@ public record ModelStats(
     Instant lastUsed,
     int pendingCount,
     int errorCount,
+    long cacheHitCount,
     int rejectedExecutionCount,
     int timeoutCount,
     Integer threadsPerAllocation,
     Integer numberOfAllocations,
     long peakThroughput,
     long throughputLastPeriod,
-    Double avgInferenceTimeLastPeriod
+    Double avgInferenceTimeLastPeriod,
+    long cacheHitCountLastPeriod
 ) {}

+ 15 - 7
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java

@@ -30,7 +30,7 @@ import static org.elasticsearch.core.Strings.format;
 
 public class PyTorchResultProcessor {
 
-    public record RecentStats(long requestsProcessed, Double avgInferenceTime) {}
+    public record RecentStats(long requestsProcessed, Double avgInferenceTime, long cacheHitCount) {}
 
     public record ResultStats(
         LongSummaryStatistics timingStats,
@@ -55,6 +55,7 @@ public class PyTorchResultProcessor {
     private long peakThroughput;
 
     private LongSummaryStatistics lastPeriodSummaryStats;
+    private long lastPeriodCacheHitCount;
     private RecentStats lastPeriodStats;
     private long currentPeriodEndTimeMs;
     private long lastResultTimeMs;
@@ -197,13 +198,13 @@ public class PyTorchResultProcessor {
             // there was a result in the last period but not one
             // in this period to close off the last period stats.
             // The stats are valid return them here
-            rs = new RecentStats(lastPeriodSummaryStats.getCount(), lastPeriodSummaryStats.getAverage());
+            rs = new RecentStats(lastPeriodSummaryStats.getCount(), lastPeriodSummaryStats.getAverage(), lastPeriodCacheHitCount);
             peakThroughput = Math.max(peakThroughput, lastPeriodSummaryStats.getCount());
         }
 
         if (rs == null) {
             // no results processed in the previous period
-            rs = new RecentStats(0L, null);
+            rs = new RecentStats(0L, null, 0L);
         }
 
         return new ResultStats(
@@ -219,9 +220,6 @@ public class PyTorchResultProcessor {
 
     private synchronized void processResult(PyTorchInferenceResult result) {
         timingStats.accept(result.getTimeMs());
-        if (result.isCacheHit()) {
-            cacheHitCount++;
-        }
 
         lastResultTimeMs = currentTimeMsSupplier.getAsLong();
         if (lastResultTimeMs > currentPeriodEndTimeMs) {
@@ -233,9 +231,14 @@ public class PyTorchResultProcessor {
                 // there is no data for the last period
                 lastPeriodStats = null;
             } else {
-                lastPeriodStats = new RecentStats(lastPeriodSummaryStats.getCount(), lastPeriodSummaryStats.getAverage());
+                lastPeriodStats = new RecentStats(
+                    lastPeriodSummaryStats.getCount(),
+                    lastPeriodSummaryStats.getAverage(),
+                    lastPeriodCacheHitCount
+                );
             }
 
+            lastPeriodCacheHitCount = 0;
             lastPeriodSummaryStats = new LongSummaryStatistics();
             lastPeriodSummaryStats.accept(result.getTimeMs());
 
@@ -244,6 +247,11 @@ public class PyTorchResultProcessor {
         } else {
             lastPeriodSummaryStats.accept(result.getTimeMs());
         }
+
+        if (result.isCacheHit()) {
+            cacheHitCount++;
+            lastPeriodCacheHitCount++;
+        }
     }
 
     public void stop() {

+ 8 - 10
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchInferenceResult.java

@@ -44,8 +44,8 @@ public class PyTorchInferenceResult implements ToXContentObject {
             INFERENCE,
             ObjectParser.ValueType.VALUE_ARRAY
         );
-        PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), TIME_MS);
-        PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), CACHE_HIT);
+        PARSER.declareLong(ConstructingObjectParser.constructorArg(), TIME_MS);
+        PARSER.declareBoolean(ConstructingObjectParser.constructorArg(), CACHE_HIT);
     }
 
     public static PyTorchInferenceResult fromXContent(XContentParser parser) throws IOException {
@@ -54,14 +54,14 @@ public class PyTorchInferenceResult implements ToXContentObject {
 
     private final String requestId;
     private final double[][][] inference;
-    private final Long timeMs;
+    private final long timeMs;
     private final boolean cacheHit;
 
-    public PyTorchInferenceResult(String requestId, @Nullable double[][][] inference, @Nullable Long timeMs, @Nullable Boolean cacheHit) {
+    public PyTorchInferenceResult(String requestId, @Nullable double[][][] inference, long timeMs, boolean cacheHit) {
         this.requestId = Objects.requireNonNull(requestId);
         this.inference = inference;
         this.timeMs = timeMs;
-        this.cacheHit = cacheHit != null && cacheHit;
+        this.cacheHit = cacheHit;
     }
 
     public String getRequestId() {
@@ -72,7 +72,7 @@ public class PyTorchInferenceResult implements ToXContentObject {
         return inference;
     }
 
-    public Long getTimeMs() {
+    public long getTimeMs() {
         return timeMs;
     }
 
@@ -95,9 +95,7 @@ public class PyTorchInferenceResult implements ToXContentObject {
             }
             builder.endArray();
         }
-        if (timeMs != null) {
-            builder.field(TIME_MS.getPreferredName(), timeMs);
-        }
+        builder.field(TIME_MS.getPreferredName(), timeMs);
         builder.field(CACHE_HIT.getPreferredName(), cacheHit);
         builder.endObject();
         return builder;
@@ -116,7 +114,7 @@ public class PyTorchInferenceResult implements ToXContentObject {
         PyTorchInferenceResult that = (PyTorchInferenceResult) other;
         return Objects.equals(requestId, that.requestId)
             && Arrays.deepEquals(inference, that.inference)
-            && Objects.equals(timeMs, that.timeMs)
+            && timeMs == that.timeMs
             && cacheHit == that.cacheHit;
     }
 }

+ 7 - 3
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MachineLearningInfoTransportActionTests.java

@@ -182,7 +182,7 @@ public class MachineLearningInfoTransportActionTests extends ESTestCase {
             enabled = randomBoolean();
             settings.put("xpack.ml.enabled", enabled);
         }
-        boolean expected = enabled || useDefault;
+        boolean expected = enabled;
         MachineLearningInfoTransportAction featureSet = new MachineLearningInfoTransportAction(
             mock(TransportService.class),
             mock(ActionFilters.class),
@@ -383,6 +383,7 @@ public class MachineLearningInfoTransportActionTests extends ESTestCase {
                                         42.0,
                                         0,
                                         1,
+                                        3L,
                                         2,
                                         3,
                                         Instant.now(),
@@ -391,7 +392,8 @@ public class MachineLearningInfoTransportActionTests extends ESTestCase {
                                         randomIntBetween(1, 16),
                                         1L,
                                         2L,
-                                        33.0
+                                        33.0,
+                                        1L
                                     ),
                                     AssignmentStats.NodeStats.forStartedState(
                                         new DiscoveryNode("bar", new TransportAddress(TransportAddress.META_ADDRESS, 3), Version.CURRENT),
@@ -399,6 +401,7 @@ public class MachineLearningInfoTransportActionTests extends ESTestCase {
                                         50.0,
                                         0,
                                         1,
+                                        1L,
                                         2,
                                         3,
                                         Instant.now(),
@@ -407,7 +410,8 @@ public class MachineLearningInfoTransportActionTests extends ESTestCase {
                                         randomIntBetween(1, 16),
                                         2L,
                                         4L,
-                                        34.0
+                                        34.0,
+                                        1L
                                     )
                                 )
                             ).setState(AssignmentState.STARTED).setAllocationStatus(new AllocationStatus(2, 2))

+ 2 - 2
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java

@@ -66,7 +66,7 @@ public class FillMaskProcessorTests extends ESTestCase {
         String resultsField = randomAlphaOfLength(10);
         FillMaskResults result = (FillMaskResults) FillMaskProcessor.processResult(
             tokenization,
-            new PyTorchInferenceResult("1", scores, 0L, null),
+            new PyTorchInferenceResult("1", scores, 0L, false),
             tokenizer,
             4,
             resultsField
@@ -93,7 +93,7 @@ public class FillMaskProcessorTests extends ESTestCase {
             0
         );
 
-        PyTorchInferenceResult pyTorchResult = new PyTorchInferenceResult("1", new double[][][] { { {} } }, 0L, null);
+        PyTorchInferenceResult pyTorchResult = new PyTorchInferenceResult("1", new double[][][] { { {} } }, 0L, false);
         expectThrows(
             ElasticsearchStatusException.class,
             () -> FillMaskProcessor.processResult(tokenization, pyTorchResult, tokenizer, 5, randomAlphaOfLength(10))

+ 5 - 5
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java

@@ -72,7 +72,7 @@ public class NerProcessorTests extends ESTestCase {
 
         var e = expectThrows(
             ElasticsearchStatusException.class,
-            () -> processor.processResult(tokenization, new PyTorchInferenceResult("test", null, 0L, null))
+            () -> processor.processResult(tokenization, new PyTorchInferenceResult("test", null, 0L, false))
         );
         assertThat(e, instanceOf(ElasticsearchStatusException.class));
     }
@@ -113,7 +113,7 @@ public class NerProcessorTests extends ESTestCase {
                     { 0, 0, 0, 0, 0, 0, 0, 6, 0 }, // london
                     { 7, 0, 0, 0, 0, 0, 0, 0, 0 } // sep
                 } };
-            NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult("1", scores, 1L, null));
+            NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult("1", scores, 1L, false));
 
             assertThat(result.getAnnotatedResult(), equalTo("Many use [Elasticsearch](ORG&Elasticsearch) in [London](LOC&London)"));
             assertThat(result.getEntityGroups().size(), equalTo(2));
@@ -141,7 +141,7 @@ public class NerProcessorTests extends ESTestCase {
                 { 0, 0, 0, 0, 0, 0, 0, 0, 0 }, // in
                 { 0, 0, 0, 0, 0, 0, 0, 6, 0 } // london
             } };
-        NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult("1", scores, 1L, null));
+        NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult("1", scores, 1L, false));
 
         assertThat(result.getAnnotatedResult(), equalTo("Many use [Elasticsearch](ORG&Elasticsearch) in [London](LOC&London)"));
         assertThat(result.getEntityGroups().size(), equalTo(2));
@@ -178,7 +178,7 @@ public class NerProcessorTests extends ESTestCase {
                 { 0, 0, 0, 0, 0, 0, 0, 0, 5 }, // in
                 { 6, 0, 0, 0, 0, 0, 0, 0, 0 } // london
             } };
-        NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult("1", scores, 1L, null));
+        NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult("1", scores, 1L, false));
 
         assertThat(result.getAnnotatedResult(), equalTo("[Elasticsearch](ORG&Elasticsearch) in [London](LOC&London)"));
         assertThat(result.getEntityGroups().size(), equalTo(2));
@@ -211,7 +211,7 @@ public class NerProcessorTests extends ESTestCase {
                 { 0, 0, 0, 0, 5 }, // in
                 { 6, 0, 0, 0, 0 } // london
             } };
-        NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult("1", scores, 1L, null));
+        NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult("1", scores, 1L, false));
 
         assertThat(result.getAnnotatedResult(), equalTo("[Elasticsearch](SOFTWARE&Elasticsearch) in [London](LOC&London)"));
         assertThat(result.getEntityGroups().size(), equalTo(2));

+ 1 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/QuestionAnsweringProcessorTests.java

@@ -87,7 +87,7 @@ public class QuestionAnsweringProcessorTests extends ESTestCase {
         assertThat(tokenizationResult.getTokenization(0).seqPairOffset(), equalTo(7));
         double[][][] scores = { { START_TOKEN_SCORES }, { END_TOKEN_SCORES } };
         NlpTask.ResultProcessor resultProcessor = processor.getResultProcessor(config);
-        PyTorchInferenceResult pyTorchResult = new PyTorchInferenceResult("1", scores, 1L, null);
+        PyTorchInferenceResult pyTorchResult = new PyTorchInferenceResult("1", scores, 1L, false);
         QuestionAnsweringInferenceResults result = (QuestionAnsweringInferenceResults) resultProcessor.processResult(
             tokenizationResult,
             pyTorchResult

+ 2 - 2
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessorTests.java

@@ -32,7 +32,7 @@ public class TextClassificationProcessorTests extends ESTestCase {
 
     public void testInvalidResult() {
         {
-            PyTorchInferenceResult torchResult = new PyTorchInferenceResult("foo", new double[][][] {}, 0L, null);
+            PyTorchInferenceResult torchResult = new PyTorchInferenceResult("foo", new double[][][] {}, 0L, false);
             var e = expectThrows(
                 ElasticsearchStatusException.class,
                 () -> TextClassificationProcessor.processResult(null, torchResult, randomInt(), List.of("a", "b"), randomAlphaOfLength(10))
@@ -41,7 +41,7 @@ public class TextClassificationProcessorTests extends ESTestCase {
             assertThat(e.getMessage(), containsString("Text classification result has no data"));
         }
         {
-            PyTorchInferenceResult torchResult = new PyTorchInferenceResult("foo", new double[][][] { { { 1.0 } } }, 0L, null);
+            PyTorchInferenceResult torchResult = new PyTorchInferenceResult("foo", new double[][][] { { { 1.0 } } }, 0L, false);
             var e = expectThrows(
                 ElasticsearchStatusException.class,
                 () -> TextClassificationProcessor.processResult(null, torchResult, randomInt(), List.of("a", "b"), randomAlphaOfLength(10))

+ 14 - 14
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessorTests.java

@@ -46,7 +46,7 @@ public class PyTorchResultProcessorTests extends ESTestCase {
     }
 
     public void testResultsProcessing() {
-        var inferenceResult = new PyTorchInferenceResult("a", null, 1000L, null);
+        var inferenceResult = new PyTorchInferenceResult("a", null, 1000L, false);
         var threadSettings = new ThreadSettings(1, 1, "b");
         var errorResult = new ErrorResult("c", "a bad thing has happened");
 
@@ -86,7 +86,7 @@ public class PyTorchResultProcessorTests extends ESTestCase {
         );
         processor.registerRequest("b", calledOnShutdown);
 
-        var inferenceResult = new PyTorchInferenceResult("a", null, 1000L, null);
+        var inferenceResult = new PyTorchInferenceResult("a", null, 1000L, false);
 
         processor.process(mockNativeProcess(List.of(new PyTorchResult(inferenceResult, null, null)).iterator()));
         assertSame(inferenceResult, resultHolder.get());
@@ -100,7 +100,7 @@ public class PyTorchResultProcessorTests extends ESTestCase {
 
         processor.ignoreResponseWithoutNotifying("a");
 
-        var inferenceResult = new PyTorchInferenceResult("a", null, 1000L, null);
+        var inferenceResult = new PyTorchInferenceResult("a", null, 1000L, false);
         processor.process(mockNativeProcess(List.of(new PyTorchResult(inferenceResult, null, null)).iterator()));
     }
 
@@ -161,7 +161,7 @@ public class PyTorchResultProcessorTests extends ESTestCase {
         processor.registerRequest("b", pendingB);
         processor.registerRequest("c", pendingC);
 
-        var a = wrapInferenceResult(new PyTorchInferenceResult("a", null, 1000L, null));
+        var a = wrapInferenceResult(new PyTorchInferenceResult("a", null, 1000L, false));
         var b = wrapInferenceResult(new PyTorchInferenceResult("b", null, 900L, false));
         var c = wrapInferenceResult(new PyTorchInferenceResult("c", null, 200L, true));
 
@@ -227,9 +227,9 @@ public class PyTorchResultProcessorTests extends ESTestCase {
         var processor = new PyTorchResultProcessor("foo", s -> {}, timeSupplier);
 
         // 1st period
-        processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 200L, null)));
-        processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 200L, null)));
-        processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 200L, null)));
+        processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 200L, false)));
+        processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 200L, false)));
+        processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 200L, false)));
         // first call has no results as is in the same period
         var stats = processor.getResultStats();
         assertThat(stats.recentStats().requestsProcessed(), equalTo(0L));
@@ -243,7 +243,7 @@ public class PyTorchResultProcessorTests extends ESTestCase {
         assertThat(stats.peakThroughput(), equalTo(3L));
 
         // 2nd period
-        processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 100L, null)));
+        processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 100L, false)));
         stats = processor.getResultStats();
         assertNotNull(stats.recentStats());
         assertThat(stats.recentStats().requestsProcessed(), equalTo(1L));
@@ -255,7 +255,7 @@ public class PyTorchResultProcessorTests extends ESTestCase {
         assertThat(stats.recentStats().requestsProcessed(), equalTo(0L));
 
         // 4th period
-        processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 300L, null)));
+        processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 300L, false)));
         stats = processor.getResultStats();
         assertNotNull(stats.recentStats());
         assertThat(stats.recentStats().requestsProcessed(), equalTo(1L));
@@ -263,8 +263,8 @@ public class PyTorchResultProcessorTests extends ESTestCase {
         assertThat(stats.lastUsed(), equalTo(Instant.ofEpochMilli(resultTimestamps[9])));
 
         // 7th period
-        processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 410L, null)));
-        processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 390L, null)));
+        processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 410L, false)));
+        processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 390L, false)));
         stats = processor.getResultStats();
         assertThat(stats.recentStats().requestsProcessed(), equalTo(0L));
         assertThat(stats.recentStats().avgInferenceTime(), nullValue());
@@ -275,9 +275,9 @@ public class PyTorchResultProcessorTests extends ESTestCase {
         assertThat(stats.lastUsed(), equalTo(Instant.ofEpochMilli(resultTimestamps[12])));
 
         // 8th period
-        processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 510L, null)));
-        processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 500L, null)));
-        processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 490L, null)));
+        processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 510L, false)));
+        processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 500L, false)));
+        processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 490L, false)));
         stats = processor.getResultStats();
         assertNotNull(stats.recentStats());
         assertThat(stats.recentStats().requestsProcessed(), equalTo(3L));

+ 1 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchInferenceResultTests.java

@@ -43,6 +43,6 @@ public class PyTorchInferenceResultTests extends AbstractXContentTestCase<PyTorc
                 }
             }
         }
-        return new PyTorchInferenceResult(id, arr, randomLong(), randomBoolean() ? null : randomBoolean());
+        return new PyTorchInferenceResult(id, arr, randomLong(), randomBoolean());
     }
 }

+ 46 - 0
x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/3rd_party_deployment.yml

@@ -75,6 +75,9 @@ setup:
   - match: { stopped: true }
 ---
 "Test start and stop deployment with cache":
+  - skip:
+      features: allowed_warnings
+
   - do:
       ml.start_trained_model_deployment:
         model_id: test_model
@@ -84,6 +87,49 @@ setup:
   - match: {assignment.task_parameters.model_id: test_model}
   - match: {assignment.task_parameters.cache_size: 10kb}
 
+  - do:
+      allowed_warnings:
+        - '[POST /_ml/trained_models/{model_id}/deployment/_infer] is deprecated! Use [POST /_ml/trained_models/{model_id}/_infer] instead.'
+      ml.infer_trained_model:
+        model_id: "test_model"
+        body: >
+          {
+            "docs": [
+              { "input": "words" }
+            ]
+          }
+
+  - do:
+      allowed_warnings:
+        - '[POST /_ml/trained_models/{model_id}/deployment/_infer] is deprecated! Use [POST /_ml/trained_models/{model_id}/_infer] instead.'
+      ml.infer_trained_model:
+        model_id: "test_model"
+        body: >
+          {
+            "docs": [
+              { "input": "are" }
+            ]
+          }
+
+  - do:
+      allowed_warnings:
+        - '[POST /_ml/trained_models/{model_id}/deployment/_infer] is deprecated! Use [POST /_ml/trained_models/{model_id}/_infer] instead.'
+      ml.infer_trained_model:
+        model_id: "test_model"
+        body: >
+          {
+            "docs": [
+              { "input": "words" }
+            ]
+          }
+
+  - do:
+      ml.get_trained_models_stats:
+        model_id: "test_model"
+  - match: { count: 1 }
+  - match: { trained_model_stats.0.deployment_stats.nodes.0.inference_count: 3 }
+  - match: { trained_model_stats.0.deployment_stats.nodes.0.inference_cache_hit_count: 1 }
+
   - do:
       ml.stop_trained_model_deployment:
         model_id: test_model