Quellcode durchsuchen

[ML] Gracefully shutdown model deployment when node is removed from assignment routing (#134673) (#135437)

* Initial fix for graceful shutdown for unreferenced nodes

* Update docs/changelog/134673.yaml

* [CI] Auto commit changes from spotless

* Fixing cluster change test and flaky tests

* [CI] Auto commit changes from spotless

* Addressing feedback

---------

Co-authored-by: elasticsearchmachine <infra-root+elasticsearchmachine@elastic.co>
Jonathan Buttner vor 1 Woche
Ursprung
Commit
a5d4b21601

+ 6 - 0
docs/changelog/134673.yaml

@@ -0,0 +1,6 @@
+pr: 134673
+summary: Gracefully shutdown model deployment when node is removed from assignment
+  routing
+area: Machine Learning
+type: bug
+issues: []

+ 16 - 5
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java

@@ -523,7 +523,7 @@ public class TrainedModelAssignmentNodeService implements ClusterStateListener {
         if (task == null) {
             logger.debug(
                 () -> format(
-                    "[%s] Unable to gracefully stop deployment for shutting down node %s because task does not exit",
+                    "[%s] Unable to gracefully stop deployment for shutting down node %s because task does not exist",
                     deploymentId,
                     currentNode
                 )
@@ -547,7 +547,7 @@ public class TrainedModelAssignmentNodeService implements ClusterStateListener {
             routingStateListener
         );
 
-        stopDeploymentAfterCompletingPendingWorkAsync(task, notifyDeploymentOfStopped);
+        stopDeploymentAfterCompletingPendingWorkAsync(task, NODE_IS_SHUTTING_DOWN, notifyDeploymentOfStopped);
     }
 
     private ActionListener<Void> updateRoutingStateToStoppedListener(
@@ -573,11 +573,18 @@ public class TrainedModelAssignmentNodeService implements ClusterStateListener {
         // This model is not routed to the current node at all
         TrainedModelDeploymentTask task = deploymentIdToTask.remove(deploymentId);
         if (task == null) {
+            logger.debug(
+                () -> format(
+                    "[%s] Unable to stop unreferenced deployment for node %s because task does not exist",
+                    deploymentId,
+                    currentNode
+                )
+            );
             return;
         }
 
         logger.debug(() -> format("[%s] Stopping unreferenced deployment for node %s", deploymentId, currentNode));
-        stopDeploymentAsync(
+        stopDeploymentAfterCompletingPendingWorkAsync(
             task,
             NODE_NO_LONGER_REFERENCED,
             ActionListener.wrap(
@@ -614,8 +621,12 @@ public class TrainedModelAssignmentNodeService implements ClusterStateListener {
         });
     }
 
-    private void stopDeploymentAfterCompletingPendingWorkAsync(TrainedModelDeploymentTask task, ActionListener<Void> listener) {
-        stopDeploymentHelper(task, NODE_IS_SHUTTING_DOWN, deploymentManager::stopAfterCompletingPendingWork, listener);
+    private void stopDeploymentAfterCompletingPendingWorkAsync(
+        TrainedModelDeploymentTask task,
+        String reason,
+        ActionListener<Void> listener
+    ) {
+        stopDeploymentHelper(task, reason, deploymentManager::stopAfterCompletingPendingWork, listener);
     }
 
     private void updateNumberOfAllocations(TrainedModelAssignmentMetadata assignments) {

+ 52 - 11
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeServiceTests.java

@@ -380,8 +380,7 @@ public class TrainedModelAssignmentNodeServiceTests extends ESTestCase {
         verifyNoMoreInteractions(deploymentManager, trainedModelAssignmentService);
     }
 
-    public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNode_CallsStopAfterCompletingPendingWork()
-        throws InterruptedException {
+    public void testClusterChanged_WhenAssignmentIsRoutedToShuttingDownNode_CallsStopAfterCompletingPendingWork() throws Exception {
         final TrainedModelAssignmentNodeService trainedModelAssignmentNodeService = createService();
         final DiscoveryNodes nodes = DiscoveryNodes.builder().localNodeId(NODE_ID).add(DiscoveryNodeUtils.create(NODE_ID, NODE_ID)).build();
         String modelOne = "model-1";
@@ -430,9 +429,11 @@ public class TrainedModelAssignmentNodeServiceTests extends ESTestCase {
             fail("Failed waiting for the stop process call to complete");
         }
 
-        verify(deploymentManager, times(1)).stopAfterCompletingPendingWork(stopParamsCapture.capture());
-        assertThat(stopParamsCapture.getValue().getModelId(), equalTo(modelOne));
-        assertThat(stopParamsCapture.getValue().getDeploymentId(), equalTo(deploymentOne));
+        assertBusy(() -> {
+            verify(deploymentManager, times(1)).stopAfterCompletingPendingWork(stopParamsCapture.capture());
+            assertThat(stopParamsCapture.getValue().getModelId(), equalTo(modelOne));
+            assertThat(stopParamsCapture.getValue().getDeploymentId(), equalTo(deploymentOne));
+        });
         verify(trainedModelAssignmentService, times(1)).updateModelAssignmentState(
             any(UpdateTrainedModelAssignmentRoutingInfoAction.Request.class),
             any()
@@ -440,7 +441,7 @@ public class TrainedModelAssignmentNodeServiceTests extends ESTestCase {
         verifyNoMoreInteractions(deploymentManager, trainedModelAssignmentService);
     }
 
-    public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNode_ButOtherAllocationIsNotReady_DoesNotCallStop() {
+    public void testClusterChanged_WhenAssignmentIsRoutedToShuttingDownNode_ButOtherAllocationIsNotReady_DoesNotCallStop() {
         final TrainedModelAssignmentNodeService trainedModelAssignmentNodeService = createService();
         String node2 = "test-node-2";
         final DiscoveryNodes nodes = DiscoveryNodes.builder()
@@ -488,7 +489,7 @@ public class TrainedModelAssignmentNodeServiceTests extends ESTestCase {
         verifyNoMoreInteractions(deploymentManager, trainedModelAssignmentService);
     }
 
-    public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNodeButAlreadyRemoved_DoesNotCallStop() {
+    public void testClusterChanged_WhenAssignmentIsRoutedToShuttingDownNodeButAlreadyRemoved_DoesNotCallStop() {
         final TrainedModelAssignmentNodeService trainedModelAssignmentNodeService = createService();
         final DiscoveryNodes nodes = DiscoveryNodes.builder().localNodeId(NODE_ID).add(DiscoveryNodeUtils.create(NODE_ID, NODE_ID)).build();
         String modelOne = "model-1";
@@ -529,7 +530,7 @@ public class TrainedModelAssignmentNodeServiceTests extends ESTestCase {
         verifyNoMoreInteractions(deploymentManager, trainedModelAssignmentService);
     }
 
-    public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNodeWithStartingState_DoesNotStopTheDeployment() {
+    public void testClusterChanged_WhenAssignmentIsRoutedToShuttingDownNodeWithStartingState_DoesNotStopTheDeployment() {
         final TrainedModelAssignmentNodeService trainedModelAssignmentNodeService = createService();
         final DiscoveryNodes nodes = DiscoveryNodes.builder().localNodeId(NODE_ID).add(DiscoveryNodeUtils.create(NODE_ID, NODE_ID)).build();
         String modelOne = "model-1";
@@ -571,7 +572,46 @@ public class TrainedModelAssignmentNodeServiceTests extends ESTestCase {
         verifyNoMoreInteractions(deploymentManager, trainedModelAssignmentService);
     }
 
-    public void testClusterChanged_WhenAssigmentIsStopping_DoesNotAddModelToBeLoaded() throws InterruptedException {
+    public void testClusterChanged_WhenNodeDoesNotExistInAssignmentRoutingTable_DoesGracefullyStopTheDeployment() throws Exception {
+        final TrainedModelAssignmentNodeService trainedModelAssignmentNodeService = createService();
+        final DiscoveryNodes nodes = DiscoveryNodes.builder().localNodeId(NODE_ID).add(DiscoveryNodeUtils.create(NODE_ID, NODE_ID)).build();
+        String modelOne = "model-1";
+        String deploymentOne = "deployment-1";
+
+        var taskParams = newParams(deploymentOne, modelOne);
+
+        ClusterChangedEvent event = new ClusterChangedEvent(
+            "testClusterChanged",
+            ClusterState.builder(new ClusterName("testClusterChanged"))
+                .nodes(nodes)
+                .metadata(
+                    Metadata.builder()
+                        .putCustom(
+                            TrainedModelAssignmentMetadata.NAME,
+                            TrainedModelAssignmentMetadata.Builder.empty()
+                                .addNewAssignment(deploymentOne, TrainedModelAssignment.Builder.empty(taskParams, null))
+                                .build()
+                        )
+                        .build()
+                )
+                .build(),
+            ClusterState.EMPTY_STATE
+        );
+
+        trainedModelAssignmentNodeService.prepareModelToLoad(taskParams);
+        trainedModelAssignmentNodeService.clusterChanged(event);
+
+        assertBusy(() -> verify(deploymentManager, times(1)).stopAfterCompletingPendingWork(any()));
+        // This still shouldn't trigger a cluster state update because the routing entry wasn't in the table so we won't add a new routing
+        // entry for stopping
+        verify(trainedModelAssignmentService, never()).updateModelAssignmentState(
+            any(UpdateTrainedModelAssignmentRoutingInfoAction.Request.class),
+            any()
+        );
+        verifyNoMoreInteractions(deploymentManager, trainedModelAssignmentService);
+    }
+
+    public void testClusterChanged_WhenAssignmentIsStopping_DoesNotAddModelToBeLoaded() throws InterruptedException {
         final TrainedModelAssignmentNodeService trainedModelAssignmentNodeService = createService();
         final DiscoveryNodes nodes = DiscoveryNodes.builder().localNodeId(NODE_ID).add(DiscoveryNodeUtils.create(NODE_ID, NODE_ID)).build();
         String modelOne = "model-1";
@@ -603,7 +643,6 @@ public class TrainedModelAssignmentNodeServiceTests extends ESTestCase {
             ClusterState.EMPTY_STATE
         );
 
-        // trainedModelAssignmentNodeService.prepareModelToLoad(taskParams);
         trainedModelAssignmentNodeService.clusterChanged(event);
         loadQueuedModels(trainedModelAssignmentNodeService);
 
@@ -724,7 +763,9 @@ public class TrainedModelAssignmentNodeServiceTests extends ESTestCase {
 
         assertBusy(() -> {
             ArgumentCaptor<TrainedModelDeploymentTask> stoppedTaskCapture = ArgumentCaptor.forClass(TrainedModelDeploymentTask.class);
-            verify(deploymentManager, times(1)).stopDeployment(stoppedTaskCapture.capture());
+            // deployment-2 was originally started on node NODE_ID but in the latest cluster event it is no longer on that node so we will
+            // gracefully stop it
+            verify(deploymentManager, times(1)).stopAfterCompletingPendingWork(stoppedTaskCapture.capture());
             assertThat(stoppedTaskCapture.getAllValues().get(0).getDeploymentId(), equalTo(deploymentTwo));
         });
         ArgumentCaptor<TrainedModelDeploymentTask> startTaskCapture = ArgumentCaptor.forClass(TrainedModelDeploymentTask.class);