Browse Source

[ML] Add deployment threading details and memory usage to telemetry (#113099) (#113516)

Adds deployment threading options and a new memory section reporting
the memory usage for each of the ml features
# Conflicts:
#	server/src/main/java/org/elasticsearch/TransportVersions.java
David Kyle 1 year ago
parent
commit
cc3caa228d

+ 7 - 1
docs/reference/rest-api/usage.asciidoc

@@ -195,7 +195,13 @@ GET /_xpack/usage
         }
       }
     },
-    "node_count" : 1
+    "node_count" : 1,
+    "memory": {
+      anomaly_detectors_memory_bytes: 0,
+      data_frame_analytics_memory_bytes: 0,
+      pytorch_inference_memory_bytes: 0,
+      total_used_memory_bytes: 0
+    }
   },
   "inference": {
     "available" : true,

+ 54 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MachineLearningFeatureSetUsage.java

@@ -31,11 +31,13 @@ public class MachineLearningFeatureSetUsage extends XPackFeatureSet.Usage {
     public static final String NODE_COUNT = "node_count";
     public static final String DATA_FRAME_ANALYTICS_JOBS_FIELD = "data_frame_analytics_jobs";
     public static final String INFERENCE_FIELD = "inference";
+    public static final String MEMORY_FIELD = "memory";
 
     private final Map<String, Object> jobsUsage;
     private final Map<String, Object> datafeedsUsage;
     private final Map<String, Object> analyticsUsage;
     private final Map<String, Object> inferenceUsage;
+    private final Map<String, Object> memoryUsage;
     private final int nodeCount;
 
     public MachineLearningFeatureSetUsage(
@@ -45,6 +47,7 @@ public class MachineLearningFeatureSetUsage extends XPackFeatureSet.Usage {
         Map<String, Object> datafeedsUsage,
         Map<String, Object> analyticsUsage,
         Map<String, Object> inferenceUsage,
+        Map<String, Object> memoryUsage,
         int nodeCount
     ) {
         super(XPackField.MACHINE_LEARNING, available, enabled);
@@ -52,6 +55,7 @@ public class MachineLearningFeatureSetUsage extends XPackFeatureSet.Usage {
         this.datafeedsUsage = Objects.requireNonNull(datafeedsUsage);
         this.analyticsUsage = Objects.requireNonNull(analyticsUsage);
         this.inferenceUsage = Objects.requireNonNull(inferenceUsage);
+        this.memoryUsage = Objects.requireNonNull(memoryUsage);
         this.nodeCount = nodeCount;
     }
 
@@ -62,6 +66,11 @@ public class MachineLearningFeatureSetUsage extends XPackFeatureSet.Usage {
         this.analyticsUsage = in.readGenericMap();
         this.inferenceUsage = in.readGenericMap();
         this.nodeCount = in.readInt();
+        if (in.getTransportVersion().onOrAfter(TransportVersions.ML_TELEMETRY_MEMORY_ADDED)) {
+            this.memoryUsage = in.readGenericMap();
+        } else {
+            this.memoryUsage = Map.of();
+        }
     }
 
     @Override
@@ -77,6 +86,9 @@ public class MachineLearningFeatureSetUsage extends XPackFeatureSet.Usage {
         out.writeGenericMap(analyticsUsage);
         out.writeGenericMap(inferenceUsage);
         out.writeInt(nodeCount);
+        if (out.getTransportVersion().onOrAfter(TransportVersions.ML_TELEMETRY_MEMORY_ADDED)) {
+            out.writeGenericMap(memoryUsage);
+        }
     }
 
     @Override
@@ -86,9 +98,51 @@ public class MachineLearningFeatureSetUsage extends XPackFeatureSet.Usage {
         builder.field(DATAFEEDS_FIELD, datafeedsUsage);
         builder.field(DATA_FRAME_ANALYTICS_JOBS_FIELD, analyticsUsage);
         builder.field(INFERENCE_FIELD, inferenceUsage);
+        builder.field(MEMORY_FIELD, memoryUsage);
         if (nodeCount >= 0) {
             builder.field(NODE_COUNT, nodeCount);
         }
     }
 
+    public Map<String, Object> getJobsUsage() {
+        return jobsUsage;
+    }
+
+    public Map<String, Object> getDatafeedsUsage() {
+        return datafeedsUsage;
+    }
+
+    public Map<String, Object> getAnalyticsUsage() {
+        return analyticsUsage;
+    }
+
+    public Map<String, Object> getInferenceUsage() {
+        return inferenceUsage;
+    }
+
+    public Map<String, Object> getMemoryUsage() {
+        return memoryUsage;
+    }
+
+    public int getNodeCount() {
+        return nodeCount;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        MachineLearningFeatureSetUsage that = (MachineLearningFeatureSetUsage) o;
+        return nodeCount == that.nodeCount
+            && Objects.equals(jobsUsage, that.jobsUsage)
+            && Objects.equals(datafeedsUsage, that.datafeedsUsage)
+            && Objects.equals(analyticsUsage, that.analyticsUsage)
+            && Objects.equals(inferenceUsage, that.inferenceUsage)
+            && Objects.equals(memoryUsage, that.memoryUsage);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(jobsUsage, datafeedsUsage, analyticsUsage, inferenceUsage, memoryUsage, nodeCount);
+    }
 }

+ 75 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/MachineLearningFeatureSetUsageTests.java

@@ -0,0 +1,75 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml;
+
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.TransportVersions;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.core.Tuple;
+
+import java.io.IOException;
+import java.util.Collections;
+
+public class MachineLearningFeatureSetUsageTests extends AbstractBWCWireSerializationTestCase<MachineLearningFeatureSetUsage> {
+    @Override
+    protected Writeable.Reader<MachineLearningFeatureSetUsage> instanceReader() {
+        return MachineLearningFeatureSetUsage::new;
+    }
+
+    @Override
+    protected MachineLearningFeatureSetUsage createTestInstance() {
+        boolean enabled = randomBoolean();
+
+        if (enabled == false) {
+            return new MachineLearningFeatureSetUsage(
+                randomBoolean(),
+                enabled,
+                Collections.emptyMap(),
+                Collections.emptyMap(),
+                Collections.emptyMap(),
+                Collections.emptyMap(),
+                Collections.emptyMap(),
+                0
+            );
+        } else {
+            return new MachineLearningFeatureSetUsage(
+                randomBoolean(),
+                enabled,
+                randomMap(0, 4, () -> new Tuple<>(randomAlphaOfLength(4), randomAlphaOfLength(4))),
+                randomMap(0, 4, () -> new Tuple<>(randomAlphaOfLength(4), randomAlphaOfLength(4))),
+                randomMap(0, 4, () -> new Tuple<>(randomAlphaOfLength(4), randomAlphaOfLength(4))),
+                randomMap(0, 4, () -> new Tuple<>(randomAlphaOfLength(4), randomAlphaOfLength(4))),
+                randomMap(0, 4, () -> new Tuple<>(randomAlphaOfLength(4), randomAlphaOfLength(4))),
+                randomIntBetween(1, 10)
+            );
+        }
+    }
+
+    @Override
+    protected MachineLearningFeatureSetUsage mutateInstance(MachineLearningFeatureSetUsage instance) throws IOException {
+        return null;
+    }
+
+    @Override
+    protected MachineLearningFeatureSetUsage mutateInstanceForVersion(MachineLearningFeatureSetUsage instance, TransportVersion version) {
+        if (version.before(TransportVersions.ML_TELEMETRY_MEMORY_ADDED)) {
+            return new MachineLearningFeatureSetUsage(
+                instance.available(),
+                instance.enabled(),
+                instance.getJobsUsage(),
+                instance.getDatafeedsUsage(),
+                instance.getAnalyticsUsage(),
+                instance.getInferenceUsage(),
+                Collections.emptyMap(),
+                instance.getNodeCount()
+            );
+        }
+
+        return instance;
+    }
+}

+ 22 - 0
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java

@@ -1120,6 +1120,28 @@ public class PyTorchModelIT extends PyTorchModelRestTestCase {
         }
     }
 
+    @SuppressWarnings("unchecked")
+    public void testDeploymentThreadsIncludedInUsage() throws IOException {
+        String modelId = "deployment_threads_in_usage";
+        createPassThroughModel(modelId);
+        putModelDefinition(modelId);
+        putVocabulary(List.of("these", "are", "my", "words"), modelId);
+        startDeployment(modelId);
+
+        Request request = new Request("GET", "/_xpack/usage");
+        var usage = entityAsMap(client().performRequest(request).getEntity());
+
+        var ml = (Map<String, Object>) usage.get("ml");
+        assertNotNull(usage.toString(), ml);
+        var inference = (Map<String, Object>) ml.get("inference");
+        var deployments = (Map<String, Object>) inference.get("deployments");
+        var deploymentStats = (List<Map<String, Object>>) deployments.get("stats_by_model");
+        for (var stat : deploymentStats) {
+            assertThat(stat.toString(), (Integer) stat.get("num_threads"), greaterThanOrEqualTo(1));
+            assertThat(stat.toString(), (Integer) stat.get("num_allocations"), greaterThanOrEqualTo(1));
+        }
+    }
+
     private void putModelDefinition(String modelId) throws IOException {
         putModelDefinition(modelId, BASE_64_ENCODED_MODEL, RAW_MODEL_SIZE);
     }

+ 35 - 0
x-pack/plugin/ml/qa/single-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlUsageIT.java

@@ -0,0 +1,35 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.integration;
+
+import org.elasticsearch.client.Request;
+import org.elasticsearch.test.rest.ESRestTestCase;
+
+import java.io.IOException;
+import java.util.Map;
+
+import static org.hamcrest.Matchers.greaterThanOrEqualTo;
+
+// Test the phone home/telemetry data
+public class MlUsageIT extends ESRestTestCase {
+
+    @SuppressWarnings("unchecked")
+    public void testMLUsage() throws IOException {
+        Request request = new Request("GET", "/_xpack/usage");
+        var usage = entityAsMap(client().performRequest(request).getEntity());
+
+        var ml = (Map<String, Object>) usage.get("ml");
+        assertNotNull(usage.toString(), ml);
+        var memoryUsage = (Map<String, Object>) ml.get("memory");
+        assertNotNull(ml.toString(), memoryUsage);
+        assertThat(memoryUsage.toString(), (Integer) memoryUsage.get("anomaly_detectors_memory_bytes"), greaterThanOrEqualTo(0));
+        assertThat(memoryUsage.toString(), (Integer) memoryUsage.get("data_frame_analytics_memory_bytes"), greaterThanOrEqualTo(0));
+        assertThat(memoryUsage.toString(), (Integer) memoryUsage.get("pytorch_inference_memory_bytes"), greaterThanOrEqualTo(0));
+        assertThat(memoryUsage.toString(), (Integer) memoryUsage.get("total_used_memory_bytes"), greaterThanOrEqualTo(0));
+    }
+}

+ 71 - 26
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearningUsageTransportAction.java

@@ -39,6 +39,7 @@ import org.elasticsearch.xpack.core.ml.action.GetDatafeedsStatsAction;
 import org.elasticsearch.xpack.core.ml.action.GetJobsStatsAction;
 import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
 import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction;
+import org.elasticsearch.xpack.core.ml.action.MlMemoryAction;
 import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState;
 import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
 import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
@@ -65,6 +66,7 @@ import java.util.Locale;
 import java.util.Map;
 import java.util.Objects;
 import java.util.TreeMap;
+import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.Function;
 import java.util.stream.Collectors;
 
@@ -72,16 +74,20 @@ import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
 
 public class MachineLearningUsageTransportAction extends XPackUsageFeatureTransportAction {
 
-    private static class ModelStats {
+    private static class DeploymentStats {
 
         private final String modelId;
         private final String taskType;
         private final StatsAccumulator inferenceCounts = new StatsAccumulator();
         private Instant lastAccess;
+        private final int numThreads;
+        private final int numAllocations;
 
-        ModelStats(String modelId, String taskType) {
+        DeploymentStats(String modelId, String taskType, int numThreads, int numAllocations) {
             this.modelId = modelId;
             this.taskType = taskType;
+            this.numThreads = numThreads;
+            this.numAllocations = numAllocations;
         }
 
         void update(AssignmentStats.NodeStats stats) {
@@ -95,6 +101,8 @@ public class MachineLearningUsageTransportAction extends XPackUsageFeatureTransp
             Map<String, Object> result = new HashMap<>();
             result.put("model_id", modelId);
             result.put("task_type", taskType);
+            result.put("num_allocations", numAllocations);
+            result.put("num_threads", numThreads);
             result.put("inference_counts", inferenceCounts.asMap());
             if (lastAccess != null) {
                 result.put("last_access", lastAccess.toString());
@@ -158,6 +166,7 @@ public class MachineLearningUsageTransportAction extends XPackUsageFeatureTransp
                 Collections.emptyMap(),
                 Collections.emptyMap(),
                 Collections.emptyMap(),
+                Collections.emptyMap(),
                 0
             );
             listener.onResponse(new XPackUsageFeatureResponse(usage));
@@ -167,11 +176,14 @@ public class MachineLearningUsageTransportAction extends XPackUsageFeatureTransp
         Map<String, Object> jobsUsage = new LinkedHashMap<>();
         Map<String, Object> datafeedsUsage = new LinkedHashMap<>();
         Map<String, Object> analyticsUsage = new LinkedHashMap<>();
+        AtomicReference<Map<String, Object>> inferenceUsage = new AtomicReference<>(Map.of());
+
         int nodeCount = mlNodeCount(state);
 
-        // Step 5. return final ML usage
-        ActionListener<Map<String, Object>> inferenceUsageListener = ActionListener.wrap(
-            inferenceUsage -> listener.onResponse(
+        // Step 6. return final ML usage
+        ActionListener<MlMemoryAction.Response> memoryUsageListener = ActionListener.wrap(memoryResponse -> {
+            var memoryUsage = extractMemoryUsage(memoryResponse);
+            listener.onResponse(
                 new XPackUsageFeatureResponse(
                     new MachineLearningFeatureSetUsage(
                         MachineLearningField.ML_API_FEATURE.checkWithoutTracking(licenseState),
@@ -179,28 +191,38 @@ public class MachineLearningUsageTransportAction extends XPackUsageFeatureTransp
                         jobsUsage,
                         datafeedsUsage,
                         analyticsUsage,
-                        inferenceUsage,
+                        inferenceUsage.get(),
+                        memoryUsage,
                         nodeCount
                     )
                 )
-            ),
-            e -> {
-                logger.warn("Failed to get inference usage to include in ML usage", e);
-                listener.onResponse(
-                    new XPackUsageFeatureResponse(
-                        new MachineLearningFeatureSetUsage(
-                            MachineLearningField.ML_API_FEATURE.checkWithoutTracking(licenseState),
-                            enabled,
-                            jobsUsage,
-                            datafeedsUsage,
-                            analyticsUsage,
-                            Collections.emptyMap(),
-                            nodeCount
-                        )
+            );
+        }, e -> {
+            logger.warn("Failed to get memory usage to include in ML usage", e);
+            listener.onResponse(
+                new XPackUsageFeatureResponse(
+                    new MachineLearningFeatureSetUsage(
+                        MachineLearningField.ML_API_FEATURE.checkWithoutTracking(licenseState),
+                        enabled,
+                        jobsUsage,
+                        datafeedsUsage,
+                        analyticsUsage,
+                        inferenceUsage.get(),
+                        Collections.emptyMap(),
+                        nodeCount
                     )
-                );
-            }
-        );
+                )
+            );
+        });
+
+        // Step 5. Get
+        ActionListener<Map<String, Object>> inferenceUsageListener = ActionListener.wrap(inference -> {
+            inferenceUsage.set(inference);
+            client.execute(MlMemoryAction.INSTANCE, new MlMemoryAction.Request("_all"), memoryUsageListener);
+        }, e -> {
+            logger.warn("Failed to get inference usage to include in ML usage", e);
+            client.execute(MlMemoryAction.INSTANCE, new MlMemoryAction.Request("_all"), memoryUsageListener);
+        });
 
         // Step 4. Extract usage from data frame analytics configs and then get inference usage
         ActionListener<GetDataFrameAnalyticsAction.Response> dataframeAnalyticsListener = ActionListener.wrap(response -> {
@@ -464,7 +486,7 @@ public class MachineLearningUsageTransportAction extends XPackUsageFeatureTransp
         int deploymentsCount = 0;
         double avgTimeSum = 0.0;
         StatsAccumulator nodeDistribution = new StatsAccumulator();
-        Map<String, ModelStats> statsByModel = new TreeMap<>();
+        Map<String, DeploymentStats> statsByModel = new TreeMap<>();
         for (var stats : statsResponse.getResources().results()) {
             AssignmentStats deploymentStats = stats.getDeploymentStats();
             if (deploymentStats == null) {
@@ -478,7 +500,15 @@ public class MachineLearningUsageTransportAction extends XPackUsageFeatureTransp
             String modelId = deploymentStats.getModelId();
             String taskType = taskTypes.get(deploymentStats.getModelId());
             String mapKey = modelId + ":" + taskType;
-            ModelStats modelStats = statsByModel.computeIfAbsent(mapKey, key -> new ModelStats(modelId, taskType));
+            DeploymentStats modelStats = statsByModel.computeIfAbsent(
+                mapKey,
+                key -> new DeploymentStats(
+                    modelId,
+                    taskType,
+                    deploymentStats.getThreadsPerAllocation(),
+                    deploymentStats.getNumberOfAllocations()
+                )
+            );
             for (var nodeStats : deploymentStats.getNodeStats()) {
                 long nodeInferenceCount = nodeStats.getInferenceCount().orElse(0L);
                 avgTimeSum += nodeStats.getAvgInferenceTime().orElse(0.0) * nodeInferenceCount;
@@ -499,7 +529,7 @@ public class MachineLearningUsageTransportAction extends XPackUsageFeatureTransp
                 "inference_counts",
                 nodeDistribution.asMap(),
                 "stats_by_model",
-                statsByModel.values().stream().map(ModelStats::asMap).collect(Collectors.toList())
+                statsByModel.values().stream().map(DeploymentStats::asMap).collect(Collectors.toList())
             )
         );
     }
@@ -590,6 +620,21 @@ public class MachineLearningUsageTransportAction extends XPackUsageFeatureTransp
         inferenceUsage.put("ingest_processors", Collections.singletonMap(MachineLearningFeatureSetUsage.ALL, ingestUsage));
     }
 
+    private static Map<String, Object> extractMemoryUsage(MlMemoryAction.Response memoryResponse) {
+        var adMem = memoryResponse.getNodes().stream().mapToLong(mem -> mem.getMlAnomalyDetectors().getBytes()).sum();
+        var dfaMem = memoryResponse.getNodes().stream().mapToLong(mem -> mem.getMlDataFrameAnalytics().getBytes()).sum();
+        var pytorchMem = memoryResponse.getNodes().stream().mapToLong(mem -> mem.getMlNativeInference().getBytes()).sum();
+        var nativeOverheadMem = memoryResponse.getNodes().stream().mapToLong(mem -> mem.getMlNativeCodeOverhead().getBytes()).sum();
+        long totalUsedMem = adMem + dfaMem + pytorchMem + nativeOverheadMem;
+
+        var memoryUsage = new LinkedHashMap<String, Object>();
+        memoryUsage.put("anomaly_detectors_memory_bytes", adMem);
+        memoryUsage.put("data_frame_analytics_memory_bytes", dfaMem);
+        memoryUsage.put("pytorch_inference_memory_bytes", pytorchMem);
+        memoryUsage.put("total_used_memory_bytes", totalUsedMem);
+        return memoryUsage;
+    }
+
     private static Map<String, Object> getMinMaxSumAsLongsFromStats(StatsAccumulator stats) {
         Map<String, Object> asMap = Maps.newMapWithExpectedSize(3);
         asMap.put("sum", Double.valueOf(stats.getTotal()).longValue());

+ 74 - 2
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MachineLearningInfoTransportActionTests.java

@@ -10,9 +10,11 @@ import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.support.ActionFilters;
 import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.client.internal.Client;
+import org.elasticsearch.cluster.ClusterName;
 import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
 import org.elasticsearch.cluster.metadata.Metadata;
+import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.node.DiscoveryNodeRole;
 import org.elasticsearch.cluster.node.DiscoveryNodeUtils;
 import org.elasticsearch.cluster.node.DiscoveryNodes;
@@ -46,6 +48,7 @@ import org.elasticsearch.xpack.core.ml.action.GetDatafeedsStatsAction;
 import org.elasticsearch.xpack.core.ml.action.GetJobsStatsAction;
 import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
 import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction;
+import org.elasticsearch.xpack.core.ml.action.MlMemoryAction;
 import org.elasticsearch.xpack.core.ml.datafeed.DatafeedConfig;
 import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState;
 import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
@@ -134,6 +137,27 @@ public class MachineLearningInfoTransportActionTests extends ESTestCase {
                 new QueryPage<>(Collections.emptyList(), 0, GetTrainedModelsStatsAction.Response.RESULTS_FIELD)
             )
         );
+        givenMlMemory(
+            new MlMemoryAction.Response(
+                new ClusterName("cluster_foo"),
+                List.of(
+                    new MlMemoryAction.Response.MlMemoryStats(
+                        mock(DiscoveryNode.class),
+                        ByteSizeValue.ofBytes(100L),
+                        ByteSizeValue.ofBytes(1L),
+                        ByteSizeValue.ofBytes(1L),
+                        ByteSizeValue.ofBytes(1L),
+                        ByteSizeValue.ofBytes(20L),
+                        ByteSizeValue.ofBytes(30L),
+                        ByteSizeValue.ofBytes(40L),
+                        ByteSizeValue.ofBytes(1L),
+                        ByteSizeValue.ofBytes(1L),
+                        ByteSizeValue.ofBytes(1L)
+                    )
+                ),
+                List.of()
+            )
+        );
     }
 
     @After
@@ -343,6 +367,8 @@ public class MachineLearningInfoTransportActionTests extends ESTestCase {
             assertThat(source.getValue("inference.deployments.inference_counts.avg"), equalTo(4.0));
             assertThat(source.getValue("inference.deployments.stats_by_model.0.model_id"), equalTo("model_3"));
             assertThat(source.getValue("inference.deployments.stats_by_model.0.task_type"), equalTo("ner"));
+            assertThat(source.getValue("inference.deployments.stats_by_model.0.num_allocations"), equalTo(8));
+            assertThat(source.getValue("inference.deployments.stats_by_model.0.num_threads"), equalTo(1));
             assertThat(source.getValue("inference.deployments.stats_by_model.0.last_access"), equalTo(lastAccess(3).toString()));
             assertThat(source.getValue("inference.deployments.stats_by_model.0.inference_counts.total"), equalTo(3.0));
             assertThat(source.getValue("inference.deployments.stats_by_model.0.inference_counts.min"), equalTo(3.0));
@@ -350,6 +376,8 @@ public class MachineLearningInfoTransportActionTests extends ESTestCase {
             assertThat(source.getValue("inference.deployments.stats_by_model.0.inference_counts.avg"), equalTo(3.0));
             assertThat(source.getValue("inference.deployments.stats_by_model.1.model_id"), equalTo("model_4"));
             assertThat(source.getValue("inference.deployments.stats_by_model.1.task_type"), equalTo("text_expansion"));
+            assertThat(source.getValue("inference.deployments.stats_by_model.1.num_allocations"), equalTo(2));
+            assertThat(source.getValue("inference.deployments.stats_by_model.1.num_threads"), equalTo(2));
             assertThat(source.getValue("inference.deployments.stats_by_model.1.last_access"), equalTo(lastAccess(44).toString()));
             assertThat(source.getValue("inference.deployments.stats_by_model.1.inference_counts.total"), equalTo(9.0));
             assertThat(source.getValue("inference.deployments.stats_by_model.1.inference_counts.min"), equalTo(4.0));
@@ -360,6 +388,11 @@ public class MachineLearningInfoTransportActionTests extends ESTestCase {
             assertThat(source.getValue("inference.deployments.model_sizes_bytes.max"), equalTo(1000.0));
             assertThat(source.getValue("inference.deployments.model_sizes_bytes.avg"), equalTo(650.0));
             assertThat(source.getValue("inference.deployments.time_ms.avg"), closeTo(44.0, 1e-10));
+
+            assertThat(source.getValue("memory.anomaly_detectors_memory_bytes"), equalTo(20));
+            assertThat(source.getValue("memory.data_frame_analytics_memory_bytes"), equalTo(30));
+            assertThat(source.getValue("memory.pytorch_inference_memory_bytes"), equalTo(40));
+            assertThat(source.getValue("memory.total_used_memory_bytes"), equalTo(91));
         }
     }
 
@@ -566,6 +599,8 @@ public class MachineLearningInfoTransportActionTests extends ESTestCase {
         Job closed1 = buildJob("closed1", Arrays.asList(buildMinDetector("foo"), buildMinDetector("bar"), buildMinDetector("foobar")));
         GetJobsStatsAction.Response.JobStats closed1JobStats = buildJobStats("closed1", JobState.CLOSED, 300L, 0);
         givenJobs(Arrays.asList(opened1, closed1), Arrays.asList(opened1JobStats, opened2JobStats, closed1JobStats));
+        MlMemoryAction.Response memory = new MlMemoryAction.Response(new ClusterName("foo"), List.of(), List.of());
+        givenMlMemory(memory);
 
         var usageAction = newUsageAction(settings.build(), true, true, true);
         PlainActionFuture<XPackUsageFeatureResponse> future = new PlainActionFuture<>();
@@ -590,6 +625,11 @@ public class MachineLearningInfoTransportActionTests extends ESTestCase {
         assertThat(source.getValue("jobs._all.model_size.avg"), equalTo(200.0));
         assertThat(source.getValue("jobs._all.created_by.a_cool_module"), equalTo(1));
         assertThat(source.getValue("jobs._all.created_by.unknown"), equalTo(1));
+
+        assertThat(source.getValue("memory.anomaly_detectors_memory_bytes"), equalTo(0));
+        assertThat(source.getValue("memory.data_frame_analytics_memory_bytes"), equalTo(0));
+        assertThat(source.getValue("memory.pytorch_inference_memory_bytes"), equalTo(0));
+        assertThat(source.getValue("memory.total_used_memory_bytes"), equalTo(0));
     }
 
     public void testUsageDisabledML() throws Exception {
@@ -802,6 +842,15 @@ public class MachineLearningInfoTransportActionTests extends ESTestCase {
         }).when(client).execute(same(GetTrainedModelsStatsAction.INSTANCE), any(), any());
     }
 
+    private void givenMlMemory(MlMemoryAction.Response memoryUsage) {
+        doAnswer(invocationOnMock -> {
+            @SuppressWarnings("unchecked")
+            ActionListener<MlMemoryAction.Response> listener = (ActionListener<MlMemoryAction.Response>) invocationOnMock.getArguments()[2];
+            listener.onResponse(memoryUsage);
+            return Void.TYPE;
+        }).when(client).execute(same(MlMemoryAction.INSTANCE), any(), any());
+    }
+
     private static Detector buildMinDetector(String fieldName) {
         Detector.Builder detectorBuilder = new Detector.Builder();
         detectorBuilder.setFunction("min");
@@ -1004,8 +1053,8 @@ public class MachineLearningInfoTransportActionTests extends ESTestCase {
                             new AssignmentStats(
                                 "deployment_3",
                                 "model_3",
-                                null,
-                                null,
+                                1,
+                                8,
                                 null,
                                 null,
                                 null,
@@ -1111,6 +1160,29 @@ public class MachineLearningInfoTransportActionTests extends ESTestCase {
                 )
             )
         );
+
+        givenMlMemory(
+            new MlMemoryAction.Response(
+                new ClusterName("cluster_foo"),
+                List.of(
+                    new MlMemoryAction.Response.MlMemoryStats(
+                        mock(DiscoveryNode.class),
+                        ByteSizeValue.ofBytes(100L),
+                        ByteSizeValue.ofBytes(1L),
+                        ByteSizeValue.ofBytes(1L),
+                        ByteSizeValue.ofBytes(1L),
+                        ByteSizeValue.ofBytes(20L),
+                        ByteSizeValue.ofBytes(30L),
+                        ByteSizeValue.ofBytes(40L),
+                        ByteSizeValue.ofBytes(1L),
+                        ByteSizeValue.ofBytes(1L),
+                        ByteSizeValue.ofBytes(1L)
+                    )
+                ),
+                List.of()
+            )
+        );
+
         return expectedDfaCountByAnalysis;
     }