Browse Source

Add stats by model to the machine learning usage stats. (#101915)

* Add inference counts by NLP model to the machine learning usage stats.

* Update docs/changelog/101915.yaml

* Add inference_counts_by_model to yamlRestTest.

* Strip leading dot from internal model IDs.

* Add last access and task type to the stats by model.

* Change stats_by_model for map to list

* Simplify code.

* Fix style
Jan Kuipers 1 year ago
parent
commit
2e95b992b2

+ 5 - 0
docs/changelog/101915.yaml

@@ -0,0 +1,5 @@
+pr: 101915
+summary: Add inference counts by model to the machine learning usage stats
+area: Machine Learning
+type: enhancement
+issues: []

+ 1 - 0
docs/reference/rest-api/usage.asciidoc

@@ -183,6 +183,7 @@ GET /_xpack/usage
           "avg": 0.0,
           "max": 0.0
         },
+        "stats_by_model": [],
         "model_sizes_bytes": {
           "total": 0.0,
           "min": 0.0,

+ 52 - 3
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearningUsageTransportAction.java

@@ -55,12 +55,14 @@ import org.elasticsearch.xpack.core.ml.stats.StatsAccumulator;
 import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
 import org.elasticsearch.xpack.ml.job.JobManagerHolder;
 
+import java.time.Instant;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Locale;
 import java.util.Map;
+import java.util.TreeMap;
 import java.util.function.Function;
 import java.util.stream.Collectors;
 
@@ -68,6 +70,37 @@ import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
 
 public class MachineLearningUsageTransportAction extends XPackUsageFeatureTransportAction {
 
+    private static class ModelStats {
+
+        private final String modelId;
+        private final String taskType;
+        private final StatsAccumulator inferenceCounts = new StatsAccumulator();
+        private Instant lastAccess;
+
+        ModelStats(String modelId, String taskType) {
+            this.modelId = modelId;
+            this.taskType = taskType;
+        }
+
+        void update(AssignmentStats.NodeStats stats) {
+            inferenceCounts.add(stats.getInferenceCount().orElse(0L));
+            if (stats.getLastAccess() != null && (lastAccess == null || stats.getLastAccess().isAfter(lastAccess))) {
+                lastAccess = stats.getLastAccess();
+            }
+        }
+
+        Map<String, Object> asMap() {
+            Map<String, Object> result = new HashMap<>();
+            result.put("model_id", modelId);
+            result.put("task_type", taskType);
+            result.put("inference_counts", inferenceCounts.asMap());
+            if (lastAccess != null) {
+                result.put("last_access", lastAccess.toString());
+            }
+            return result;
+        }
+    }
+
     private static final Logger logger = LogManager.getLogger(MachineLearningUsageTransportAction.class);
 
     private final Client client;
@@ -399,7 +432,7 @@ public class MachineLearningUsageTransportAction extends XPackUsageFeatureTransp
                     Map<String, Object> inferenceUsage = new LinkedHashMap<>();
                     addInferenceIngestUsage(getStatsResponse, inferenceUsage);
                     addTrainedModelStats(getModelsResponse, getStatsResponse, inferenceUsage);
-                    addDeploymentStats(getStatsResponse, inferenceUsage);
+                    addDeploymentStats(getModelsResponse, getStatsResponse, inferenceUsage);
                     listener.onResponse(inferenceUsage);
                 }, listener::onFailure));
             }, listener::onFailure));
@@ -408,11 +441,20 @@ public class MachineLearningUsageTransportAction extends XPackUsageFeatureTransp
         }
     }
 
-    private static void addDeploymentStats(GetTrainedModelsStatsAction.Response statsResponse, Map<String, Object> inferenceUsage) {
+    private static void addDeploymentStats(
+        GetTrainedModelsAction.Response modelsResponse,
+        GetTrainedModelsStatsAction.Response statsResponse,
+        Map<String, Object> inferenceUsage
+    ) {
+        Map<String, String> taskTypes = modelsResponse.getResources()
+            .results()
+            .stream()
+            .collect(Collectors.toMap(TrainedModelConfig::getModelId, cfg -> cfg.getInferenceConfig().getName()));
         StatsAccumulator modelSizes = new StatsAccumulator();
         int deploymentsCount = 0;
         double avgTimeSum = 0.0;
         StatsAccumulator nodeDistribution = new StatsAccumulator();
+        Map<String, ModelStats> statsByModel = new TreeMap<>();
         for (var stats : statsResponse.getResources().results()) {
             AssignmentStats deploymentStats = stats.getDeploymentStats();
             if (deploymentStats == null) {
@@ -423,10 +465,15 @@ public class MachineLearningUsageTransportAction extends XPackUsageFeatureTransp
             if (modelSizeStats != null) {
                 modelSizes.add(modelSizeStats.getModelSizeBytes());
             }
+            String modelId = deploymentStats.getModelId();
+            String taskType = taskTypes.get(deploymentStats.getModelId());
+            String mapKey = modelId + ":" + taskType;
+            ModelStats modelStats = statsByModel.computeIfAbsent(mapKey, key -> new ModelStats(modelId, taskType));
             for (var nodeStats : deploymentStats.getNodeStats()) {
                 long nodeInferenceCount = nodeStats.getInferenceCount().orElse(0L);
                 avgTimeSum += nodeStats.getAvgInferenceTime().orElse(0.0) * nodeInferenceCount;
                 nodeDistribution.add(nodeInferenceCount);
+                modelStats.update(nodeStats);
             }
         }
 
@@ -440,7 +487,9 @@ public class MachineLearningUsageTransportAction extends XPackUsageFeatureTransp
                 "model_sizes_bytes",
                 modelSizes.asMap(),
                 "inference_counts",
-                nodeDistribution.asMap()
+                nodeDistribution.asMap(),
+                "stats_by_model",
+                statsByModel.values().stream().map(ModelStats::asMap).collect(Collectors.toList())
             )
         );
     }

+ 67 - 15
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MachineLearningInfoTransportActionTests.java

@@ -61,6 +61,7 @@ import org.elasticsearch.xpack.core.ml.inference.assignment.Priority;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModelSizeStats;
 import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig;
 import org.elasticsearch.xpack.core.ml.job.config.DataDescription;
@@ -336,16 +337,29 @@ public class MachineLearningInfoTransportActionTests extends ESTestCase {
             assertThat(source.getValue("inference.ingest_processors._all.num_failures.min"), equalTo(100));
             assertThat(source.getValue("inference.ingest_processors._all.num_failures.max"), equalTo(500));
             assertThat(source.getValue("inference.deployments.count"), equalTo(2));
-            assertThat(source.getValue("inference.deployments.inference_counts.total"), equalTo(9.0));
-            assertThat(source.getValue("inference.deployments.inference_counts.min"), equalTo(4.0));
-            assertThat(source.getValue("inference.deployments.inference_counts.total"), equalTo(9.0));
+            assertThat(source.getValue("inference.deployments.inference_counts.total"), equalTo(12.0));
+            assertThat(source.getValue("inference.deployments.inference_counts.min"), equalTo(3.0));
             assertThat(source.getValue("inference.deployments.inference_counts.max"), equalTo(5.0));
-            assertThat(source.getValue("inference.deployments.inference_counts.avg"), equalTo(4.5));
+            assertThat(source.getValue("inference.deployments.inference_counts.avg"), equalTo(4.0));
+            assertThat(source.getValue("inference.deployments.stats_by_model.0.model_id"), equalTo("model_3"));
+            assertThat(source.getValue("inference.deployments.stats_by_model.0.task_type"), equalTo("ner"));
+            assertThat(source.getValue("inference.deployments.stats_by_model.0.last_access"), equalTo(lastAccess(3).toString()));
+            assertThat(source.getValue("inference.deployments.stats_by_model.0.inference_counts.total"), equalTo(3.0));
+            assertThat(source.getValue("inference.deployments.stats_by_model.0.inference_counts.min"), equalTo(3.0));
+            assertThat(source.getValue("inference.deployments.stats_by_model.0.inference_counts.max"), equalTo(3.0));
+            assertThat(source.getValue("inference.deployments.stats_by_model.0.inference_counts.avg"), equalTo(3.0));
+            assertThat(source.getValue("inference.deployments.stats_by_model.1.model_id"), equalTo("model_4"));
+            assertThat(source.getValue("inference.deployments.stats_by_model.1.task_type"), equalTo("text_expansion"));
+            assertThat(source.getValue("inference.deployments.stats_by_model.1.last_access"), equalTo(lastAccess(44).toString()));
+            assertThat(source.getValue("inference.deployments.stats_by_model.1.inference_counts.total"), equalTo(9.0));
+            assertThat(source.getValue("inference.deployments.stats_by_model.1.inference_counts.min"), equalTo(4.0));
+            assertThat(source.getValue("inference.deployments.stats_by_model.1.inference_counts.max"), equalTo(5.0));
+            assertThat(source.getValue("inference.deployments.stats_by_model.1.inference_counts.avg"), equalTo(4.5));
             assertThat(source.getValue("inference.deployments.model_sizes_bytes.total"), equalTo(1300.0));
             assertThat(source.getValue("inference.deployments.model_sizes_bytes.min"), equalTo(300.0));
             assertThat(source.getValue("inference.deployments.model_sizes_bytes.max"), equalTo(1000.0));
             assertThat(source.getValue("inference.deployments.model_sizes_bytes.avg"), equalTo(650.0));
-            assertThat(source.getValue("inference.deployments.time_ms.avg"), closeTo(45.55555555555556, 1e-10));
+            assertThat(source.getValue("inference.deployments.time_ms.avg"), closeTo(44.0, 1e-10));
         }
     }
 
@@ -421,16 +435,29 @@ public class MachineLearningInfoTransportActionTests extends ESTestCase {
             assertThat(source.getValue("inference.ingest_processors._all.num_failures.min"), equalTo(100));
             assertThat(source.getValue("inference.ingest_processors._all.num_failures.max"), equalTo(500));
             assertThat(source.getValue("inference.deployments.count"), equalTo(2));
-            assertThat(source.getValue("inference.deployments.inference_counts.total"), equalTo(9.0));
-            assertThat(source.getValue("inference.deployments.inference_counts.min"), equalTo(4.0));
-            assertThat(source.getValue("inference.deployments.inference_counts.total"), equalTo(9.0));
+            assertThat(source.getValue("inference.deployments.inference_counts.total"), equalTo(12.0));
+            assertThat(source.getValue("inference.deployments.inference_counts.min"), equalTo(3.0));
             assertThat(source.getValue("inference.deployments.inference_counts.max"), equalTo(5.0));
-            assertThat(source.getValue("inference.deployments.inference_counts.avg"), equalTo(4.5));
+            assertThat(source.getValue("inference.deployments.inference_counts.avg"), equalTo(4.0));
+            assertThat(source.getValue("inference.deployments.stats_by_model.0.model_id"), equalTo("model_3"));
+            assertThat(source.getValue("inference.deployments.stats_by_model.0.task_type"), equalTo("ner"));
+            assertThat(source.getValue("inference.deployments.stats_by_model.0.last_access"), equalTo(lastAccess(3).toString()));
+            assertThat(source.getValue("inference.deployments.stats_by_model.0.inference_counts.total"), equalTo(3.0));
+            assertThat(source.getValue("inference.deployments.stats_by_model.0.inference_counts.min"), equalTo(3.0));
+            assertThat(source.getValue("inference.deployments.stats_by_model.0.inference_counts.max"), equalTo(3.0));
+            assertThat(source.getValue("inference.deployments.stats_by_model.0.inference_counts.avg"), equalTo(3.0));
+            assertThat(source.getValue("inference.deployments.stats_by_model.1.model_id"), equalTo("model_4"));
+            assertThat(source.getValue("inference.deployments.stats_by_model.1.task_type"), equalTo("text_expansion"));
+            assertThat(source.getValue("inference.deployments.stats_by_model.1.last_access"), equalTo(lastAccess(44).toString()));
+            assertThat(source.getValue("inference.deployments.stats_by_model.1.inference_counts.total"), equalTo(9.0));
+            assertThat(source.getValue("inference.deployments.stats_by_model.1.inference_counts.min"), equalTo(4.0));
+            assertThat(source.getValue("inference.deployments.stats_by_model.1.inference_counts.max"), equalTo(5.0));
+            assertThat(source.getValue("inference.deployments.stats_by_model.1.inference_counts.avg"), equalTo(4.5));
             assertThat(source.getValue("inference.deployments.model_sizes_bytes.total"), equalTo(1300.0));
             assertThat(source.getValue("inference.deployments.model_sizes_bytes.min"), equalTo(300.0));
             assertThat(source.getValue("inference.deployments.model_sizes_bytes.max"), equalTo(1000.0));
             assertThat(source.getValue("inference.deployments.model_sizes_bytes.avg"), equalTo(650.0));
-            assertThat(source.getValue("inference.deployments.time_ms.avg"), closeTo(45.55555555555556, 1e-10));
+            assertThat(source.getValue("inference.deployments.time_ms.avg"), closeTo(44.0, 1e-10));
         }
     }
 
@@ -898,6 +925,7 @@ public class MachineLearningInfoTransportActionTests extends ESTestCase {
             .setTags(Collections.singletonList("prepackaged"))
             .setModelSize(1000)
             .setEstimatedOperations(2000)
+            .setInferenceConfig(new TextExpansionConfig(null, null, null))
             .build();
         givenTrainedModels(Arrays.asList(trainedModel1, trainedModel2, trainedModel3, trainedModel4));
 
@@ -981,7 +1009,27 @@ public class MachineLearningInfoTransportActionTests extends ESTestCase {
                                 null,
                                 null,
                                 Instant.now(),
-                                List.of(),
+                                List.of(
+                                    AssignmentStats.NodeStats.forStartedState(
+                                        DiscoveryNodeUtils.create("foo", new TransportAddress(TransportAddress.META_ADDRESS, 2)),
+                                        3,
+                                        41.0,
+                                        41.0,
+                                        0,
+                                        1,
+                                        3L,
+                                        2,
+                                        3,
+                                        lastAccess(3),
+                                        Instant.now(),
+                                        randomIntBetween(1, 16),
+                                        randomIntBetween(1, 16),
+                                        1L,
+                                        2L,
+                                        33.0,
+                                        1L
+                                    )
+                                ),
                                 Priority.NORMAL
                             ).setState(AssignmentState.STOPPING)
                         ),
@@ -1016,14 +1064,14 @@ public class MachineLearningInfoTransportActionTests extends ESTestCase {
                                     AssignmentStats.NodeStats.forStartedState(
                                         DiscoveryNodeUtils.create("foo", new TransportAddress(TransportAddress.META_ADDRESS, 2)),
                                         5,
-                                        42.0,
-                                        42.0,
+                                        41.0,
+                                        41.0,
                                         0,
                                         1,
                                         3L,
                                         2,
                                         3,
-                                        Instant.now(),
+                                        lastAccess(4),
                                         Instant.now(),
                                         randomIntBetween(1, 16),
                                         randomIntBetween(1, 16),
@@ -1042,7 +1090,7 @@ public class MachineLearningInfoTransportActionTests extends ESTestCase {
                                         1L,
                                         2,
                                         3,
-                                        Instant.now(),
+                                        lastAccess(44),
                                         Instant.now(),
                                         randomIntBetween(1, 16),
                                         randomIntBetween(1, 16),
@@ -1063,4 +1111,8 @@ public class MachineLearningInfoTransportActionTests extends ESTestCase {
         );
         return expectedDfaCountByAnalysis;
     }
+
+    private static Instant lastAccess(int i) {
+        return Instant.ofEpochSecond(1_000_000_000 + i);
+    }
 }