Browse Source

[ML] Include start params in _stats for non-started model deployments (#89091)

Adds the missing start parameters to the _stats API response
for non-started deployments.
Dimitris Athanasiou 3 years ago
parent
commit
77aa8c03e1

+ 5 - 0
docs/changelog/89091.yaml

@@ -0,0 +1,5 @@
+pr: 89091
+summary: Include start params in `_stats` for non-started model deployments
+area: Machine Learning
+type: bug
+issues: []

+ 65 - 0
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java

@@ -813,6 +813,71 @@ public class PyTorchModelIT extends ESRestTestCase {
         assertThat(EntityUtils.toString(response.getEntity()), not(containsString("deployment_stats")));
     }
 
+    @SuppressWarnings("unchecked")
+    public void testStartDeployment_GivenNoProcessorsLeft_AndLazyStartEnabled() throws Exception {
+        // We start 2 models. The first needs so many allocations it won't possibly
+        // get them all. This would leave no space to allocate the second model at all.
+
+        // Enable lazy starting so that the deployments start even if they cannot get fully allocated.
+        // The setting is cleared in the cleanup method of these tests.
+        Request loggingSettings = new Request("PUT", "_cluster/settings");
+        loggingSettings.setJsonEntity("""
+            {"persistent" : {
+                    "xpack.ml.max_lazy_ml_nodes": 5
+                }}""");
+        client().performRequest(loggingSettings);
+
+        String modelId1 = "model_1";
+        createTrainedModel(modelId1);
+        putModelDefinition(modelId1);
+        putVocabulary(List.of("these", "are", "my", "words"), modelId1);
+
+        String modelId2 = "model_2";
+        createTrainedModel(modelId2);
+        putModelDefinition(modelId2);
+        putVocabulary(List.of("these", "are", "my", "words"), modelId2);
+
+        startDeployment(modelId1, AllocationStatus.State.STARTED.toString(), 100, 1);
+
+        {
+            Request request = new Request(
+                "POST",
+                "/_ml/trained_models/"
+                    + modelId2
+                    + "/deployment/_start?timeout=40s&wait_for=starting&"
+                    + "number_of_allocations=4&threads_per_allocation=2&queue_capacity=500&cache_size=100Kb"
+            );
+            client().performRequest(request);
+        }
+
+        // Check second model did not get any allocations
+        assertAllocationCount(modelId2, 0);
+
+        // Verify stats shows model is starting and deployment settings are present
+        {
+            Response statsResponse = getTrainedModelStats(modelId2);
+            var responseMap = entityAsMap(statsResponse);
+            List<Map<String, Object>> stats = (List<Map<String, Object>>) responseMap.get("trained_model_stats");
+            assertThat(stats, hasSize(1));
+            String statusState = (String) XContentMapValues.extractValue("deployment_stats.allocation_status.state", stats.get(0));
+            assertThat(statusState, equalTo("starting"));
+            int numberOfAllocations = (int) XContentMapValues.extractValue("deployment_stats.number_of_allocations", stats.get(0));
+            assertThat(numberOfAllocations, equalTo(4));
+            int threadsPerAllocation = (int) XContentMapValues.extractValue("deployment_stats.threads_per_allocation", stats.get(0));
+            assertThat(threadsPerAllocation, equalTo(2));
+            int queueCapacity = (int) XContentMapValues.extractValue("deployment_stats.queue_capacity", stats.get(0));
+            assertThat(queueCapacity, equalTo(500));
+            ByteSizeValue cacheSize = ByteSizeValue.parseBytesSizeValue(
+                (String) XContentMapValues.extractValue("deployment_stats.cache_size", stats.get(0)),
+                "cache_size)"
+            );
+            assertThat(cacheSize, equalTo(ByteSizeValue.ofKb(100)));
+        }
+
+        stopDeployment(modelId1);
+        stopDeployment(modelId2);
+    }
+
     @SuppressWarnings("unchecked")
     private void assertAllocationCount(String modelId, int expectedAllocationCount) throws IOException {
         Response response = getTrainedModelStats(modelId);

+ 11 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsAction.java

@@ -269,7 +269,17 @@ public class TransportGetDeploymentStatsAction extends TransportTasksAction<
 
                 nodeStats.sort(Comparator.comparing(n -> n.getNode().getId()));
 
-                updatedAssignmentStats.add(new AssignmentStats(modelId, null, null, null, null, assignment.getStartTime(), nodeStats));
+                updatedAssignmentStats.add(
+                    new AssignmentStats(
+                        modelId,
+                        assignment.getTaskParams().getThreadsPerAllocation(),
+                        assignment.getTaskParams().getNumberOfAllocations(),
+                        assignment.getTaskParams().getQueueCapacity(),
+                        assignment.getTaskParams().getCacheSize().orElse(null),
+                        assignment.getStartTime(),
+                        nodeStats
+                    )
+                );
             }
         }