Browse Source

[ML] adding new model types and deployments to xpack usage (#80282)

This adds new model types + deployment information to xpack/usage under ml.inference

closes: #80200
Benjamin Trent 3 years ago
parent
commit
4557d5f797

+ 19 - 3
docs/reference/rest-api/usage.asciidoc

@@ -155,12 +155,10 @@ GET /_xpack/usage
       },
       "trained_models" : {
         "_all" : {
-          "count" : 0
+          "count": 1
         },
         "count": {
           "total": 1,
-          "classification": 0,
-          "regression": 0,
           "prepackaged": 1,
           "other": 0
         },
@@ -176,6 +174,24 @@ GET /_xpack/usage
           "avg": 0.0,
           "total": 0.0
         }
+      },
+      "deployments": {
+        "count": 0,
+        "inference_counts": {
+          "total": 0.0,
+          "min": 0.0,
+          "avg": 0.0,
+          "max": 0.0
+        },
+        "model_sizes_bytes": {
+          "total": 0.0,
+          "min": 0.0,
+          "avg": 0.0,
+          "max": 0.0
+        },
+        "time_ms": {
+          "avg": 0.0
+        }
       }
     },
     "node_count" : 1

+ 9 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDeploymentStatsAction.java

@@ -34,6 +34,7 @@ import java.time.Instant;
 import java.util.Collections;
 import java.util.List;
 import java.util.Objects;
+import java.util.Optional;
 
 public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsAction.Response> {
 
@@ -187,6 +188,14 @@ public class GetDeploymentStatsAction extends ActionType<GetDeploymentStatsActio
                     return routingState;
                 }
 
+                public Optional<Long> getInferenceCount() {
+                    return Optional.ofNullable(inferenceCount);
+                }
+
+                public Optional<Double> getAvgInferenceTime() {
+                    return Optional.ofNullable(avgInferenceTime);
+                }
+
                 @Override
                 public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
                     builder.startObject();

+ 69 - 29
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearningUsageTransportAction.java

@@ -14,6 +14,7 @@ import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsRequest;
 import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsResponse;
 import org.elasticsearch.action.support.ActionFilters;
 import org.elasticsearch.client.Client;
+import org.elasticsearch.client.OriginSettingClient;
 import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
 import org.elasticsearch.cluster.metadata.Metadata;
@@ -37,6 +38,7 @@ import org.elasticsearch.xpack.core.ml.MachineLearningField;
 import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsAction;
 import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction;
 import org.elasticsearch.xpack.core.ml.action.GetDatafeedsStatsAction;
+import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
 import org.elasticsearch.xpack.core.ml.action.GetJobsStatsAction;
 import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
 import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState;
@@ -44,9 +46,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
 import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
 import org.elasticsearch.xpack.core.ml.dataframe.stats.common.MemoryUsage;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
-import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
-import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
 import org.elasticsearch.xpack.core.ml.job.config.Job;
 import org.elasticsearch.xpack.core.ml.job.config.JobState;
 import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.ModelSizeStats;
@@ -65,6 +65,8 @@ import java.util.Map;
 import java.util.Set;
 import java.util.stream.Collectors;
 
+import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
+
 public class MachineLearningUsageTransportAction extends XPackUsageFeatureTransportAction {
 
     private final Client client;
@@ -92,7 +94,7 @@ public class MachineLearningUsageTransportAction extends XPackUsageFeatureTransp
             actionFilters,
             indexNameExpressionResolver
         );
-        this.client = client;
+        this.client = new OriginSettingClient(client, ML_ORIGIN);
         this.licenseState = licenseState;
         this.jobManagerHolder = jobManagerHolder;
         this.enabled = XPackSettings.MACHINE_LEARNING_ENABLED.get(environment.settings());
@@ -125,19 +127,32 @@ public class MachineLearningUsageTransportAction extends XPackUsageFeatureTransp
         Map<String, Object> inferenceUsage = new LinkedHashMap<>();
         int nodeCount = mlNodeCount(state);
 
-        // Step 6. extract trained model config count and then return results
+        // Step 7. extract deployment stats and then return results
+        ActionListener<GetDeploymentStatsAction.Response> trainedModelDeploymentsListener = ActionListener.wrap(response -> {
+            addDeploymentStats(response, inferenceUsage);
+            listener.onResponse(
+                new XPackUsageFeatureResponse(
+                    new MachineLearningFeatureSetUsage(
+                        MachineLearningField.ML_API_FEATURE.checkWithoutTracking(licenseState),
+                        enabled,
+                        jobsUsage,
+                        datafeedsUsage,
+                        analyticsUsage,
+                        inferenceUsage,
+                        nodeCount
+                    )
+                )
+            );
+        }, listener::onFailure);
+
+        // Step 6. extract trained model config count and gather deployment stats then return results
         ActionListener<GetTrainedModelsAction.Response> trainedModelsListener = ActionListener.wrap(response -> {
             addTrainedModelStats(response, inferenceUsage);
-            MachineLearningFeatureSetUsage usage = new MachineLearningFeatureSetUsage(
-                MachineLearningField.ML_API_FEATURE.checkWithoutTracking(licenseState),
-                enabled,
-                jobsUsage,
-                datafeedsUsage,
-                analyticsUsage,
-                inferenceUsage,
-                nodeCount
+            client.execute(
+                GetDeploymentStatsAction.INSTANCE,
+                new GetDeploymentStatsAction.Request("_all"),
+                trainedModelDeploymentsListener
             );
-            listener.onResponse(new XPackUsageFeatureResponse(usage));
         }, listener::onFailure);
 
         // Step 5. Extract usage from ingest statistics and gather trained model config count
@@ -181,13 +196,14 @@ public class MachineLearningUsageTransportAction extends XPackUsageFeatureTransp
 
         // Step 1. Extract usage from jobs stats and then request stats for all datafeeds
         GetJobsStatsAction.Request jobStatsRequest = new GetJobsStatsAction.Request(Metadata.ALL);
-        ActionListener<GetJobsStatsAction.Response> jobStatsListener = ActionListener.wrap(response -> {
-            jobManagerHolder.getJobManager().expandJobs(Metadata.ALL, true, ActionListener.wrap(jobs -> {
+        ActionListener<GetJobsStatsAction.Response> jobStatsListener = ActionListener.wrap(
+            response -> jobManagerHolder.getJobManager().expandJobs(Metadata.ALL, true, ActionListener.wrap(jobs -> {
                 addJobsUsage(response, jobs.results(), jobsUsage);
                 GetDatafeedsStatsAction.Request datafeedStatsRequest = new GetDatafeedsStatsAction.Request(Metadata.ALL);
                 client.execute(GetDatafeedsStatsAction.INSTANCE, datafeedStatsRequest, datafeedStatsListener);
-            }, listener::onFailure));
-        }, listener::onFailure);
+            }, listener::onFailure)),
+            listener::onFailure
+        );
 
         // Step 0. Kick off the chain of callbacks by requesting jobs stats
         client.execute(GetJobsStatsAction.INSTANCE, jobStatsRequest, jobStatsListener);
@@ -229,7 +245,7 @@ public class MachineLearningUsageTransportAction extends XPackUsageFeatureTransp
             jobCountByState.computeIfAbsent(jobState, js -> Counter.newCounter()).addAndGet(1);
             detectorStatsByState.computeIfAbsent(jobState, js -> new StatsAccumulator()).add(detectorsCount);
             modelSizeStatsByState.computeIfAbsent(jobState, js -> new StatsAccumulator()).add(modelSize);
-            forecastStatsByState.merge(jobState, jobStats.getForecastStats(), (f1, f2) -> f1.merge(f2));
+            forecastStatsByState.merge(jobState, jobStats.getForecastStats(), ForecastStats::merge);
             createdByByState.computeIfAbsent(jobState, js -> new HashMap<>())
                 .compute(jobCreatedBy(job), (k, v) -> (v == null) ? 1L : (v + 1));
         }
@@ -346,9 +362,37 @@ public class MachineLearningUsageTransportAction extends XPackUsageFeatureTransp
     }
 
     private static void updateStats(Map<String, Long> statsMap, Long value) {
-        statsMap.compute("sum", (k, v) -> v + value);
-        statsMap.compute("min", (k, v) -> Math.min(v, value));
-        statsMap.compute("max", (k, v) -> Math.max(v, value));
+        statsMap.computeIfPresent("sum", (k, v) -> v + value);
+        statsMap.computeIfPresent("min", (k, v) -> Math.min(v, value));
+        statsMap.computeIfPresent("max", (k, v) -> Math.max(v, value));
+    }
+
+    private void addDeploymentStats(GetDeploymentStatsAction.Response response, Map<String, Object> inferenceUsage) {
+        StatsAccumulator modelSizes = new StatsAccumulator();
+        double avgTimeSum = 0.0;
+        StatsAccumulator nodeDistribution = new StatsAccumulator();
+        for (var stats : response.getStats().results()) {
+            modelSizes.add(stats.getModelSize().getBytes());
+            for (var nodeStats : stats.getNodeStats()) {
+                long nodeInferenceCount = nodeStats.getInferenceCount().orElse(0L);
+                avgTimeSum += nodeStats.getAvgInferenceTime().orElse(0.0) * nodeInferenceCount;
+                nodeDistribution.add(nodeInferenceCount);
+            }
+        }
+
+        inferenceUsage.put(
+            "deployments",
+            Map.of(
+                "count",
+                response.getStats().count(),
+                "time_ms",
+                Map.of(StatsAccumulator.Fields.AVG, nodeDistribution.getTotal() == 0.0 ? 0.0 : avgTimeSum / nodeDistribution.getTotal()),
+                "model_sizes_bytes",
+                modelSizes.asMap(),
+                "inference_counts",
+                nodeDistribution.asMap()
+            )
+        );
     }
 
     private void addTrainedModelStats(GetTrainedModelsAction.Response response, Map<String, Object> inferenceUsage) {
@@ -359,8 +403,7 @@ public class MachineLearningUsageTransportAction extends XPackUsageFeatureTransp
         StatsAccumulator estimatedOperations = new StatsAccumulator();
         StatsAccumulator estimatedMemoryUsageBytes = new StatsAccumulator();
         int createdByAnalyticsCount = 0;
-        int regressionCount = 0;
-        int classificationCount = 0;
+        Map<String, Counter> inferenceConfigCounts = new LinkedHashMap<>();
         int prepackagedCount = 0;
         for (TrainedModelConfig trainedModelConfig : trainedModelConfigs) {
             if (trainedModelConfig.getTags().contains("prepackaged")) {
@@ -368,10 +411,8 @@ public class MachineLearningUsageTransportAction extends XPackUsageFeatureTransp
                 continue;
             }
             InferenceConfig inferenceConfig = trainedModelConfig.getInferenceConfig();
-            if (inferenceConfig instanceof RegressionConfig) {
-                regressionCount++;
-            } else if (inferenceConfig instanceof ClassificationConfig) {
-                classificationCount++;
+            if (inferenceConfig != null) {
+                inferenceConfigCounts.computeIfAbsent(inferenceConfig.getName(), s -> Counter.newCounter()).addAndGet(1);
             }
             if (trainedModelConfig.getMetadata() != null && trainedModelConfig.getMetadata().containsKey("analytics_config")) {
                 createdByAnalyticsCount++;
@@ -382,8 +423,7 @@ public class MachineLearningUsageTransportAction extends XPackUsageFeatureTransp
 
         Map<String, Object> counts = new HashMap<>();
         counts.put("total", trainedModelConfigs.size());
-        counts.put("classification", classificationCount);
-        counts.put("regression", regressionCount);
+        inferenceConfigCounts.forEach((configName, count) -> counts.put(configName, count.get()));
         counts.put("prepackaged", prepackagedCount);
         counts.put("other", trainedModelConfigs.size() - createdByAnalyticsCount - prepackagedCount);
 

+ 89 - 17
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MachineLearningInfoTransportActionTests.java

@@ -25,12 +25,15 @@ import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.io.stream.BytesStreamOutput;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.transport.TransportAddress;
+import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.env.Environment;
 import org.elasticsearch.env.TestEnvironment;
 import org.elasticsearch.ingest.IngestStats;
 import org.elasticsearch.license.MockLicenseState;
 import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.threadpool.TestThreadPool;
+import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.TransportService;
 import org.elasticsearch.xcontent.ToXContent;
 import org.elasticsearch.xcontent.XContentBuilder;
@@ -44,6 +47,7 @@ import org.elasticsearch.xpack.core.ml.MachineLearningField;
 import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsAction;
 import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction;
 import org.elasticsearch.xpack.core.ml.action.GetDatafeedsStatsAction;
+import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
 import org.elasticsearch.xpack.core.ml.action.GetJobsStatsAction;
 import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
 import org.elasticsearch.xpack.core.ml.datafeed.DatafeedConfig;
@@ -54,7 +58,10 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
 import org.elasticsearch.xpack.core.ml.dataframe.stats.common.MemoryUsage;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfigTests;
+import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationState;
+import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationStatus;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
 import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig;
 import org.elasticsearch.xpack.core.ml.job.config.DataDescription;
@@ -68,6 +75,7 @@ import org.elasticsearch.xpack.core.watcher.support.xcontent.XContentSource;
 import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
 import org.elasticsearch.xpack.ml.job.JobManager;
 import org.elasticsearch.xpack.ml.job.JobManagerHolder;
+import org.junit.After;
 import org.junit.Before;
 
 import java.time.Instant;
@@ -82,6 +90,7 @@ import java.util.Set;
 import java.util.stream.Collectors;
 import java.util.stream.IntStream;
 
+import static org.hamcrest.Matchers.closeTo;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.notNullValue;
 import static org.hamcrest.Matchers.nullValue;
@@ -104,12 +113,14 @@ public class MachineLearningInfoTransportActionTests extends ESTestCase {
 
     @Before
     public void init() {
+        ThreadPool threadpool = new TestThreadPool("test");
         commonSettings = Settings.builder()
             .put(Environment.PATH_HOME_SETTING.getKey(), createTempDir().toAbsolutePath())
             .put(MachineLearningField.AUTODETECT_PROCESS.getKey(), false)
             .build();
         clusterService = mock(ClusterService.class);
         client = mock(Client.class);
+        when(client.threadPool()).thenReturn(threadpool);
         jobManager = mock(JobManager.class);
         jobManagerHolder = new JobManagerHolder(jobManager);
         licenseState = mock(MockLicenseState.class);
@@ -120,6 +131,12 @@ public class MachineLearningInfoTransportActionTests extends ESTestCase {
         givenDataFrameAnalytics(Collections.emptyList(), Collections.emptyList());
         givenProcessorStats(Collections.emptyList());
         givenTrainedModels(Collections.emptyList());
+        givenDeploymentStats(new GetDeploymentStatsAction.Response(List.of(), List.of(), List.of(), 0L));
+    }
+
+    @After
+    public void close() {
+        client.threadPool().shutdown();
     }
 
     private MachineLearningUsageTransportAction newUsageAction(Settings settings) {
@@ -295,15 +312,18 @@ public class MachineLearningInfoTransportActionTests extends ESTestCase {
             .setEstimatedHeapMemory(100)
             .setEstimatedOperations(200)
             .setMetadata(Collections.singletonMap("analytics_config", "anything"))
+            .setInferenceConfig(ClassificationConfig.EMPTY_PARAMS)
             .build();
         TrainedModelConfig trainedModel2 = TrainedModelConfigTests.createTestInstance("model_2")
             .setEstimatedHeapMemory(200)
             .setEstimatedOperations(400)
             .setMetadata(Collections.singletonMap("analytics_config", "anything"))
+            .setInferenceConfig(RegressionConfig.EMPTY_PARAMS)
             .build();
         TrainedModelConfig trainedModel3 = TrainedModelConfigTests.createTestInstance("model_3")
             .setEstimatedHeapMemory(300)
             .setEstimatedOperations(600)
+            .setInferenceConfig(new NerConfig(null, null, null, null))
             .build();
         TrainedModelConfig trainedModel4 = TrainedModelConfigTests.createTestInstance("model_4")
             .setTags(Collections.singletonList("prepackaged"))
@@ -312,16 +332,52 @@ public class MachineLearningInfoTransportActionTests extends ESTestCase {
             .build();
         givenTrainedModels(Arrays.asList(trainedModel1, trainedModel2, trainedModel3, trainedModel4));
 
-        Map<String, Integer> trainedModelsCountByAnalysis = new HashMap<>();
-        trainedModelsCountByAnalysis.put("classification", 0);
-        trainedModelsCountByAnalysis.put("regression", 0);
-        for (TrainedModelConfig trainedModel : Arrays.asList(trainedModel1, trainedModel2, trainedModel3)) {
-            if (trainedModel.getInferenceConfig() instanceof ClassificationConfig) {
-                trainedModelsCountByAnalysis.put("classification", trainedModelsCountByAnalysis.get("classification") + 1);
-            } else if (trainedModel.getInferenceConfig() instanceof RegressionConfig) {
-                trainedModelsCountByAnalysis.put("regression", trainedModelsCountByAnalysis.get("regression") + 1);
-            }
-        }
+        Map<String, Integer> trainedModelsCountByAnalysis = Map.of("classification", 1, "regression", 1, "ner", 1);
+
+        givenDeploymentStats(
+            new GetDeploymentStatsAction.Response(
+                List.of(),
+                List.of(),
+                List.of(
+                    new GetDeploymentStatsAction.Response.AllocationStats(
+                        "model_3",
+                        ByteSizeValue.ofMb(100),
+                        null,
+                        null,
+                        null,
+                        Instant.now(),
+                        List.of()
+                    ).setState(AllocationState.STOPPING),
+                    new GetDeploymentStatsAction.Response.AllocationStats(
+                        "model_4",
+                        ByteSizeValue.ofMb(200),
+                        2,
+                        2,
+                        1000,
+                        Instant.now(),
+                        List.of(
+                            GetDeploymentStatsAction.Response.AllocationStats.NodeStats.forStartedState(
+                                new DiscoveryNode("foo", new TransportAddress(TransportAddress.META_ADDRESS, 2), Version.CURRENT),
+                                5,
+                                42.0,
+                                0,
+                                Instant.now(),
+                                Instant.now()
+                            ),
+                            GetDeploymentStatsAction.Response.AllocationStats.NodeStats.forStartedState(
+                                new DiscoveryNode("bar", new TransportAddress(TransportAddress.META_ADDRESS, 3), Version.CURRENT),
+                                4,
+                                50.0,
+                                0,
+                                Instant.now(),
+                                Instant.now()
+                            )
+                        )
+                    ).setState(AllocationState.STARTED).setAllocationStatus(new AllocationStatus(2, 2))
+                ),
+                2
+            )
+        );
 
         var usageAction = newUsageAction(settings.build());
         PlainActionFuture<XPackUsageFeatureResponse> future = new PlainActionFuture<>();
@@ -415,13 +471,8 @@ public class MachineLearningInfoTransportActionTests extends ESTestCase {
             assertThat(source.getValue("inference.trained_models.estimated_operations.total"), equalTo(1200.0));
             assertThat(source.getValue("inference.trained_models.estimated_operations.avg"), equalTo(400.0));
             assertThat(source.getValue("inference.trained_models.count.total"), equalTo(4));
-            assertThat(
-                source.getValue("inference.trained_models.count.classification"),
-                equalTo(trainedModelsCountByAnalysis.get("classification"))
-            );
-            assertThat(
-                source.getValue("inference.trained_models.count.regression"),
-                equalTo(trainedModelsCountByAnalysis.get("regression"))
+            trainedModelsCountByAnalysis.forEach(
+                (name, count) -> assertThat(source.getValue("inference.trained_models.count." + name), equalTo(count))
             );
             assertThat(source.getValue("inference.trained_models.count.prepackaged"), equalTo(1));
             assertThat(source.getValue("inference.trained_models.count.other"), equalTo(1));
@@ -436,6 +487,17 @@ public class MachineLearningInfoTransportActionTests extends ESTestCase {
             assertThat(source.getValue("inference.ingest_processors._all.num_failures.sum"), equalTo(1));
             assertThat(source.getValue("inference.ingest_processors._all.num_failures.min"), equalTo(0));
             assertThat(source.getValue("inference.ingest_processors._all.num_failures.max"), equalTo(1));
+            assertThat(source.getValue("inference.deployments.count"), equalTo(2));
+            assertThat(source.getValue("inference.deployments.inference_counts.total"), equalTo(9.0));
+            assertThat(source.getValue("inference.deployments.inference_counts.min"), equalTo(4.0));
+            assertThat(source.getValue("inference.deployments.inference_counts.total"), equalTo(9.0));
+            assertThat(source.getValue("inference.deployments.inference_counts.max"), equalTo(5.0));
+            assertThat(source.getValue("inference.deployments.inference_counts.avg"), equalTo(4.5));
+            assertThat(source.getValue("inference.deployments.model_sizes_bytes.total"), equalTo(3.145728E8));
+            assertThat(source.getValue("inference.deployments.model_sizes_bytes.min"), equalTo(1.048576E8));
+            assertThat(source.getValue("inference.deployments.model_sizes_bytes.max"), equalTo(2.097152E8));
+            assertThat(source.getValue("inference.deployments.model_sizes_bytes.avg"), equalTo(1.572864E8));
+            assertThat(source.getValue("inference.deployments.time_ms.avg"), closeTo(45.55555555555556, 1e-10));
         }
     }
 
@@ -692,6 +754,16 @@ public class MachineLearningInfoTransportActionTests extends ESTestCase {
         }).when(client).execute(same(GetTrainedModelsAction.INSTANCE), any(), any());
     }
 
+    private void givenDeploymentStats(GetDeploymentStatsAction.Response deploymentStats) {
+        doAnswer(invocationOnMock -> {
+            @SuppressWarnings("unchecked")
+            ActionListener<GetDeploymentStatsAction.Response> listener = (ActionListener<
+                GetDeploymentStatsAction.Response>) invocationOnMock.getArguments()[2];
+            listener.onResponse(deploymentStats);
+            return Void.TYPE;
+        }).when(client).execute(same(GetDeploymentStatsAction.INSTANCE), any(), any());
+    }
+
     private static Detector buildMinDetector(String fieldName) {
         Detector.Builder detectorBuilder = new Detector.Builder();
         detectorBuilder.setFunction("min");