浏览代码

[ML] Wait for model process to stop in stop deployment (#83644)

David Kyle 3 年之前
父节点
当前提交
2200fa783f

+ 5 - 0
docs/changelog/83644.yaml

@@ -0,0 +1,5 @@
+pr: 83644
+summary: Wait for model process to be stop in stop deployment
+area: Machine Learning
+type: bug
+issues: []

+ 6 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java

@@ -23,6 +23,7 @@ import org.elasticsearch.transport.TransportService;
 import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
 import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
+import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationState;
 import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocation;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationMetadata;
@@ -88,6 +89,11 @@ public class TransportInferTrainedModelDeploymentAction extends TransportTasksAc
             }, listener::onFailure));
             return;
         }
+        if (allocation.getAllocationState() == AllocationState.STOPPING) {
+            String message = "Trained model [" + deploymentId + "] is STOPPING";
+            listener.onFailure(ExceptionsHelper.conflictStatusException(message));
+            return;
+        }
         String[] randomRunningNode = allocation.getStartedNodes();
         if (randomRunningNode.length == 0) {
             String message = "Trained model [" + deploymentId + "] is not allocated to any nodes";

+ 23 - 47
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStopTrainedModelDeploymentAction.java

@@ -39,7 +39,6 @@ import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocati
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationClusterService;
 import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationMetadata;
-import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationService;
 import org.elasticsearch.xpack.ml.inference.deployment.TrainedModelDeploymentTask;
 
 import java.util.Collections;
@@ -66,7 +65,6 @@ public class TransportStopTrainedModelDeploymentAction extends TransportTasksAct
 
     private final Client client;
     private final IngestService ingestService;
-    private final TrainedModelAllocationService trainedModelAllocationService;
     private final TrainedModelAllocationClusterService trainedModelAllocationClusterService;
 
     @Inject
@@ -76,7 +74,6 @@ public class TransportStopTrainedModelDeploymentAction extends TransportTasksAct
         ActionFilters actionFilters,
         Client client,
         IngestService ingestService,
-        TrainedModelAllocationService trainedModelAllocationService,
         TrainedModelAllocationClusterService trainedModelAllocationClusterService
     ) {
         super(
@@ -91,7 +88,6 @@ public class TransportStopTrainedModelDeploymentAction extends TransportTasksAct
         );
         this.client = new OriginSettingClient(client, ML_ORIGIN);
         this.ingestService = ingestService;
-        this.trainedModelAllocationService = trainedModelAllocationService;
         this.trainedModelAllocationClusterService = trainedModelAllocationClusterService;
     }
 
@@ -150,6 +146,7 @@ public class TransportStopTrainedModelDeploymentAction extends TransportTasksAct
             }
 
             // NOTE, should only run on Master node
+            assert clusterService.localNode().isMasterNode();
             trainedModelAllocationClusterService.setModelAllocationToStopping(
                 modelId,
                 ActionListener.wrap(
@@ -196,30 +193,25 @@ public class TransportStopTrainedModelDeploymentAction extends TransportTasksAct
     ) {
         request.setNodes(modelAllocation.getNodeRoutingTable().keySet().toArray(String[]::new));
         ActionListener<StopTrainedModelDeploymentAction.Response> finalListener = ActionListener.wrap(r -> {
-            waitForTaskRemoved(modelId, modelAllocation, request, r, ActionListener.wrap(waited -> {
-                trainedModelAllocationService.deleteModelAllocation(
-                    modelId,
-                    ActionListener.wrap(deleted -> listener.onResponse(r), deletionFailed -> {
-                        logger.error(
-                            () -> new ParameterizedMessage(
-                                "[{}] failed to delete model allocation after nodes unallocated the deployment",
-                                modelId
-                            ),
+            assert clusterService.localNode().isMasterNode();
+            trainedModelAllocationClusterService.removeModelAllocation(
+                modelId,
+                ActionListener.wrap(deleted -> listener.onResponse(r), deletionFailed -> {
+                    logger.error(
+                        () -> new ParameterizedMessage(
+                            "[{}] failed to delete model allocation after nodes unallocated the deployment",
+                            modelId
+                        ),
+                        deletionFailed
+                    );
+                    listener.onFailure(
+                        ExceptionsHelper.serverError(
+                            "failed to delete model allocation after nodes unallocated the deployment. Attempt to stop again",
                             deletionFailed
-                        );
-                        listener.onFailure(
-                            ExceptionsHelper.serverError(
-                                "failed to delete model allocation after nodes unallocated the deployment. Attempt to stop again",
-                                deletionFailed
-                            )
-                        );
-                    })
-                );
-            },
-                // TODO should we attempt to delete the deployment here?
-                listener::onFailure
-            ));
-
+                        )
+                    );
+                })
+            );
         }, e -> {
             if (ExceptionsHelper.unwrapCause(e) instanceof FailedNodeException) {
                 // A node has dropped out of the cluster since we started executing the requests.
@@ -235,24 +227,6 @@ public class TransportStopTrainedModelDeploymentAction extends TransportTasksAct
         super.doExecute(task, request, finalListener);
     }
 
-    void waitForTaskRemoved(
-        String modelId,
-        TrainedModelAllocation trainedModelAllocation,
-        StopTrainedModelDeploymentAction.Request request,
-        StopTrainedModelDeploymentAction.Response response,
-        ActionListener<StopTrainedModelDeploymentAction.Response> listener
-    ) {
-        final Set<String> nodesOfConcern = trainedModelAllocation.getNodeRoutingTable().keySet();
-        client.admin()
-            .cluster()
-            .prepareListTasks(nodesOfConcern.toArray(String[]::new))
-            .setDetailed(true)
-            .setWaitForCompletion(true)
-            .setActions(modelId)
-            .setTimeout(request.getTimeout())
-            .execute(ActionListener.wrap(complete -> listener.onResponse(response), listener::onFailure));
-    }
-
     @Override
     protected StopTrainedModelDeploymentAction.Response newResponse(
         StopTrainedModelDeploymentAction.Request request,
@@ -275,7 +249,9 @@ public class TransportStopTrainedModelDeploymentAction extends TransportTasksAct
         TrainedModelDeploymentTask task,
         ActionListener<StopTrainedModelDeploymentAction.Response> listener
     ) {
-        task.stop("undeploy_trained_model (api)");
-        listener.onResponse(new StopTrainedModelDeploymentAction.Response(true));
+        task.stop(
+            "undeploy_trained_model (api)",
+            ActionListener.wrap(r -> listener.onResponse(new StopTrainedModelDeploymentAction.Response(true)), listener::onFailure)
+        );
     }
 }

+ 6 - 13
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeService.java

@@ -135,7 +135,8 @@ public class TrainedModelAllocationNodeService implements ClusterStateListener {
         if (stopped) {
             return;
         }
-        task.stopWithoutNotification(reason);
+        task.markAsStopped(reason);
+
         threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(() -> {
             try {
                 deploymentManager.stopDeployment(task);
@@ -204,20 +205,12 @@ public class TrainedModelAllocationNodeService implements ClusterStateListener {
         loadingModels.addAll(loadingToRetry);
     }
 
-    public void stopDeploymentAndNotify(TrainedModelDeploymentTask task, String reason) {
+    public void stopDeploymentAndNotify(TrainedModelDeploymentTask task, String reason, ActionListener<AcknowledgedResponse> listener) {
         ActionListener<Void> notifyDeploymentOfStopped = ActionListener.wrap(
-            _void -> updateStoredState(
-                task.getModelId(),
-                new RoutingStateAndReason(RoutingState.STOPPED, reason),
-                ActionListener.wrap(s -> {}, failure -> {})
-            ),
+            _void -> updateStoredState(task.getModelId(), new RoutingStateAndReason(RoutingState.STOPPED, reason), listener),
             failed -> { // if we failed to stop the process, something strange is going on, but we should still notify of stop
                 logger.warn(() -> new ParameterizedMessage("[{}] failed to stop due to error", task.getModelId()), failed);
-                updateStoredState(
-                    task.getModelId(),
-                    new RoutingStateAndReason(RoutingState.STOPPED, reason),
-                    ActionListener.wrap(s -> {}, failure -> {})
-                );
+                updateStoredState(task.getModelId(), new RoutingStateAndReason(RoutingState.STOPPED, reason), listener);
             }
         );
         updateStoredState(
@@ -309,7 +302,7 @@ public class TrainedModelAllocationNodeService implements ClusterStateListener {
                     && isResetMode == false) {
                     prepareModelToLoad(trainedModelAllocation.getTaskParams());
                 }
-                // This mode is not routed to the current node at all
+                // This model is not routed to the current node at all
                 if (routingStateAndReason == null) {
                     TrainedModelDeploymentTask task = modelIdToTask.remove(trainedModelAllocation.getTaskParams().getModelId());
                     if (task != null) {

+ 13 - 8
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java

@@ -9,9 +9,11 @@ package org.elasticsearch.xpack.ml.inference.deployment;
 
 import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
+import org.apache.logging.log4j.message.ParameterizedMessage;
 import org.apache.lucene.util.SetOnce;
 import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.support.master.AcknowledgedResponse;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.license.LicensedFeature;
 import org.elasticsearch.license.XPackLicenseState;
@@ -80,15 +82,11 @@ public class TrainedModelDeploymentTask extends CancellableTask implements Start
         return params;
     }
 
-    public void stop(String reason) {
-        logger.debug("[{}] Stopping due to reason [{}]", getModelId(), reason);
-        licensedFeature.stopTracking(licenseState, "model-" + params.getModelId());
-        stopped = true;
-        stoppedReasonHolder.trySet(reason);
-        trainedModelAllocationNodeService.stopDeploymentAndNotify(this, reason);
+    public void stop(String reason, ActionListener<AcknowledgedResponse> listener) {
+        trainedModelAllocationNodeService.stopDeploymentAndNotify(this, reason, listener);
     }
 
-    public void stopWithoutNotification(String reason) {
+    public void markAsStopped(String reason) {
         licensedFeature.stopTracking(licenseState, "model-" + params.getModelId());
         logger.debug("[{}] Stopping due to reason [{}]", getModelId(), reason);
         stoppedReasonHolder.trySet(reason);
@@ -106,7 +104,14 @@ public class TrainedModelDeploymentTask extends CancellableTask implements Start
     @Override
     protected void onCancelled() {
         String reason = getReasonCancelled();
-        stop(reason);
+        logger.info("[{}] task cancelled due to reason [{}]", getModelId(), reason);
+        stop(
+            reason,
+            ActionListener.wrap(
+                acknowledgedResponse -> {},
+                e -> logger.error(new ParameterizedMessage("[{}] error stopping the model after task cancellation", getModelId()), e)
+            )
+        );
     }
 
     public void infer(Map<String, Object> doc, InferenceConfigUpdate update, TimeValue timeout, ActionListener<InferenceResults> listener) {

+ 1 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeServiceTests.java

@@ -196,7 +196,7 @@ public class TrainedModelAllocationNodeServiceTests extends ESTestCase {
         // Only one model should be loaded, the other should be stopped
         trainedModelAllocationNodeService.prepareModelToLoad(newParams(modelToLoad));
         trainedModelAllocationNodeService.prepareModelToLoad(newParams(stoppedModelToLoad));
-        trainedModelAllocationNodeService.getTask(stoppedModelToLoad).stop("testing");
+        trainedModelAllocationNodeService.getTask(stoppedModelToLoad).stop("testing", ActionListener.wrap(r -> {}, e -> {}));
         trainedModelAllocationNodeService.loadQueuedModels();
 
         assertBusy(() -> {

+ 17 - 4
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTaskTests.java

@@ -7,6 +7,7 @@
 
 package org.elasticsearch.xpack.ml.inference.deployment;
 
+import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.license.LicensedFeature;
 import org.elasticsearch.license.XPackLicenseState;
 import org.elasticsearch.tasks.TaskId;
@@ -14,12 +15,15 @@ import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfig;
 import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationNodeService;
+import org.mockito.ArgumentCaptor;
 
 import java.util.Map;
 import java.util.function.Consumer;
 
 import static org.elasticsearch.xpack.core.ml.MlTasks.TRAINED_MODEL_ALLOCATION_TASK_ACTION;
 import static org.elasticsearch.xpack.core.ml.MlTasks.TRAINED_MODEL_ALLOCATION_TASK_TYPE;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
@@ -29,6 +33,15 @@ public class TrainedModelDeploymentTaskTests extends ESTestCase {
     void assertTrackingComplete(Consumer<TrainedModelDeploymentTask> method, String modelId) {
         XPackLicenseState licenseState = mock(XPackLicenseState.class);
         LicensedFeature.Persistent feature = mock(LicensedFeature.Persistent.class);
+        TrainedModelAllocationNodeService nodeService = mock(TrainedModelAllocationNodeService.class);
+
+        ArgumentCaptor<TrainedModelDeploymentTask> taskCaptor = ArgumentCaptor.forClass(TrainedModelDeploymentTask.class);
+        ArgumentCaptor<String> reasonCaptur = ArgumentCaptor.forClass(String.class);
+        doAnswer(invocation -> {
+            taskCaptor.getValue().markAsStopped(reasonCaptur.getValue());
+            return null;
+        }).when(nodeService).stopDeploymentAndNotify(taskCaptor.capture(), reasonCaptur.capture(), any());
+
         TrainedModelDeploymentTask task = new TrainedModelDeploymentTask(
             0,
             TRAINED_MODEL_ALLOCATION_TASK_TYPE,
@@ -42,7 +55,7 @@ public class TrainedModelDeploymentTaskTests extends ESTestCase {
                 randomInt(5),
                 randomInt(5)
             ),
-            mock(TrainedModelAllocationNodeService.class),
+            nodeService,
             licenseState,
             feature
         );
@@ -53,12 +66,12 @@ public class TrainedModelDeploymentTaskTests extends ESTestCase {
         verify(feature, times(1)).stopTracking(licenseState, "model-" + modelId);
     }
 
-    public void testOnStopWithoutNotification() {
-        assertTrackingComplete(t -> t.stopWithoutNotification("foo"), randomAlphaOfLength(10));
+    public void testMarkAsStopped() {
+        assertTrackingComplete(t -> t.markAsStopped("foo"), randomAlphaOfLength(10));
     }
 
     public void testOnStop() {
-        assertTrackingComplete(t -> t.stop("foo"), randomAlphaOfLength(10));
+        assertTrackingComplete(t -> t.stop("foo", ActionListener.wrap(r -> {}, e -> {})), randomAlphaOfLength(10));
     }
 
     public void testCancelled() {