|  | @@ -39,7 +39,6 @@ import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocati
 | 
	
		
			
				|  |  |  import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 | 
	
		
			
				|  |  |  import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationClusterService;
 | 
	
		
			
				|  |  |  import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationMetadata;
 | 
	
		
			
				|  |  | -import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationService;
 | 
	
		
			
				|  |  |  import org.elasticsearch.xpack.ml.inference.deployment.TrainedModelDeploymentTask;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  import java.util.Collections;
 | 
	
	
		
			
				|  | @@ -66,7 +65,6 @@ public class TransportStopTrainedModelDeploymentAction extends TransportTasksAct
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      private final Client client;
 | 
	
		
			
				|  |  |      private final IngestService ingestService;
 | 
	
		
			
				|  |  | -    private final TrainedModelAllocationService trainedModelAllocationService;
 | 
	
		
			
				|  |  |      private final TrainedModelAllocationClusterService trainedModelAllocationClusterService;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      @Inject
 | 
	
	
		
			
				|  | @@ -76,7 +74,6 @@ public class TransportStopTrainedModelDeploymentAction extends TransportTasksAct
 | 
	
		
			
				|  |  |          ActionFilters actionFilters,
 | 
	
		
			
				|  |  |          Client client,
 | 
	
		
			
				|  |  |          IngestService ingestService,
 | 
	
		
			
				|  |  | -        TrainedModelAllocationService trainedModelAllocationService,
 | 
	
		
			
				|  |  |          TrainedModelAllocationClusterService trainedModelAllocationClusterService
 | 
	
		
			
				|  |  |      ) {
 | 
	
		
			
				|  |  |          super(
 | 
	
	
		
			
				|  | @@ -91,7 +88,6 @@ public class TransportStopTrainedModelDeploymentAction extends TransportTasksAct
 | 
	
		
			
				|  |  |          );
 | 
	
		
			
				|  |  |          this.client = new OriginSettingClient(client, ML_ORIGIN);
 | 
	
		
			
				|  |  |          this.ingestService = ingestService;
 | 
	
		
			
				|  |  | -        this.trainedModelAllocationService = trainedModelAllocationService;
 | 
	
		
			
				|  |  |          this.trainedModelAllocationClusterService = trainedModelAllocationClusterService;
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -150,6 +146,7 @@ public class TransportStopTrainedModelDeploymentAction extends TransportTasksAct
 | 
	
		
			
				|  |  |              }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |              // NOTE, should only run on Master node
 | 
	
		
			
				|  |  | +            assert clusterService.localNode().isMasterNode();
 | 
	
		
			
				|  |  |              trainedModelAllocationClusterService.setModelAllocationToStopping(
 | 
	
		
			
				|  |  |                  modelId,
 | 
	
		
			
				|  |  |                  ActionListener.wrap(
 | 
	
	
		
			
				|  | @@ -196,30 +193,25 @@ public class TransportStopTrainedModelDeploymentAction extends TransportTasksAct
 | 
	
		
			
				|  |  |      ) {
 | 
	
		
			
				|  |  |          request.setNodes(modelAllocation.getNodeRoutingTable().keySet().toArray(String[]::new));
 | 
	
		
			
				|  |  |          ActionListener<StopTrainedModelDeploymentAction.Response> finalListener = ActionListener.wrap(r -> {
 | 
	
		
			
				|  |  | -            waitForTaskRemoved(modelId, modelAllocation, request, r, ActionListener.wrap(waited -> {
 | 
	
		
			
				|  |  | -                trainedModelAllocationService.deleteModelAllocation(
 | 
	
		
			
				|  |  | -                    modelId,
 | 
	
		
			
				|  |  | -                    ActionListener.wrap(deleted -> listener.onResponse(r), deletionFailed -> {
 | 
	
		
			
				|  |  | -                        logger.error(
 | 
	
		
			
				|  |  | -                            () -> new ParameterizedMessage(
 | 
	
		
			
				|  |  | -                                "[{}] failed to delete model allocation after nodes unallocated the deployment",
 | 
	
		
			
				|  |  | -                                modelId
 | 
	
		
			
				|  |  | -                            ),
 | 
	
		
			
				|  |  | +            assert clusterService.localNode().isMasterNode();
 | 
	
		
			
				|  |  | +            trainedModelAllocationClusterService.removeModelAllocation(
 | 
	
		
			
				|  |  | +                modelId,
 | 
	
		
			
				|  |  | +                ActionListener.wrap(deleted -> listener.onResponse(r), deletionFailed -> {
 | 
	
		
			
				|  |  | +                    logger.error(
 | 
	
		
			
				|  |  | +                        () -> new ParameterizedMessage(
 | 
	
		
			
				|  |  | +                            "[{}] failed to delete model allocation after nodes unallocated the deployment",
 | 
	
		
			
				|  |  | +                            modelId
 | 
	
		
			
				|  |  | +                        ),
 | 
	
		
			
				|  |  | +                        deletionFailed
 | 
	
		
			
				|  |  | +                    );
 | 
	
		
			
				|  |  | +                    listener.onFailure(
 | 
	
		
			
				|  |  | +                        ExceptionsHelper.serverError(
 | 
	
		
			
				|  |  | +                            "failed to delete model allocation after nodes unallocated the deployment. Attempt to stop again",
 | 
	
		
			
				|  |  |                              deletionFailed
 | 
	
		
			
				|  |  | -                        );
 | 
	
		
			
				|  |  | -                        listener.onFailure(
 | 
	
		
			
				|  |  | -                            ExceptionsHelper.serverError(
 | 
	
		
			
				|  |  | -                                "failed to delete model allocation after nodes unallocated the deployment. Attempt to stop again",
 | 
	
		
			
				|  |  | -                                deletionFailed
 | 
	
		
			
				|  |  | -                            )
 | 
	
		
			
				|  |  | -                        );
 | 
	
		
			
				|  |  | -                    })
 | 
	
		
			
				|  |  | -                );
 | 
	
		
			
				|  |  | -            },
 | 
	
		
			
				|  |  | -                // TODO should we attempt to delete the deployment here?
 | 
	
		
			
				|  |  | -                listener::onFailure
 | 
	
		
			
				|  |  | -            ));
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | +                        )
 | 
	
		
			
				|  |  | +                    );
 | 
	
		
			
				|  |  | +                })
 | 
	
		
			
				|  |  | +            );
 | 
	
		
			
				|  |  |          }, e -> {
 | 
	
		
			
				|  |  |              if (ExceptionsHelper.unwrapCause(e) instanceof FailedNodeException) {
 | 
	
		
			
				|  |  |                  // A node has dropped out of the cluster since we started executing the requests.
 | 
	
	
		
			
				|  | @@ -235,24 +227,6 @@ public class TransportStopTrainedModelDeploymentAction extends TransportTasksAct
 | 
	
		
			
				|  |  |          super.doExecute(task, request, finalListener);
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    void waitForTaskRemoved(
 | 
	
		
			
				|  |  | -        String modelId,
 | 
	
		
			
				|  |  | -        TrainedModelAllocation trainedModelAllocation,
 | 
	
		
			
				|  |  | -        StopTrainedModelDeploymentAction.Request request,
 | 
	
		
			
				|  |  | -        StopTrainedModelDeploymentAction.Response response,
 | 
	
		
			
				|  |  | -        ActionListener<StopTrainedModelDeploymentAction.Response> listener
 | 
	
		
			
				|  |  | -    ) {
 | 
	
		
			
				|  |  | -        final Set<String> nodesOfConcern = trainedModelAllocation.getNodeRoutingTable().keySet();
 | 
	
		
			
				|  |  | -        client.admin()
 | 
	
		
			
				|  |  | -            .cluster()
 | 
	
		
			
				|  |  | -            .prepareListTasks(nodesOfConcern.toArray(String[]::new))
 | 
	
		
			
				|  |  | -            .setDetailed(true)
 | 
	
		
			
				|  |  | -            .setWaitForCompletion(true)
 | 
	
		
			
				|  |  | -            .setActions(modelId)
 | 
	
		
			
				|  |  | -            .setTimeout(request.getTimeout())
 | 
	
		
			
				|  |  | -            .execute(ActionListener.wrap(complete -> listener.onResponse(response), listener::onFailure));
 | 
	
		
			
				|  |  | -    }
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  |      @Override
 | 
	
		
			
				|  |  |      protected StopTrainedModelDeploymentAction.Response newResponse(
 | 
	
		
			
				|  |  |          StopTrainedModelDeploymentAction.Request request,
 | 
	
	
		
			
				|  | @@ -275,7 +249,9 @@ public class TransportStopTrainedModelDeploymentAction extends TransportTasksAct
 | 
	
		
			
				|  |  |          TrainedModelDeploymentTask task,
 | 
	
		
			
				|  |  |          ActionListener<StopTrainedModelDeploymentAction.Response> listener
 | 
	
		
			
				|  |  |      ) {
 | 
	
		
			
				|  |  | -        task.stop("undeploy_trained_model (api)");
 | 
	
		
			
				|  |  | -        listener.onResponse(new StopTrainedModelDeploymentAction.Response(true));
 | 
	
		
			
				|  |  | +        task.stop(
 | 
	
		
			
				|  |  | +            "undeploy_trained_model (api)",
 | 
	
		
			
				|  |  | +            ActionListener.wrap(r -> listener.onResponse(new StopTrainedModelDeploymentAction.Response(true)), listener::onFailure)
 | 
	
		
			
				|  |  | +        );
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  }
 |