|
@@ -17,20 +17,25 @@ import org.elasticsearch.action.FailedNodeException;
|
|
|
import org.elasticsearch.action.TaskOperationFailure;
|
|
|
import org.elasticsearch.action.support.ActionFilters;
|
|
|
import org.elasticsearch.action.support.tasks.TransportTasksAction;
|
|
|
+import org.elasticsearch.client.internal.Client;
|
|
|
+import org.elasticsearch.client.internal.OriginSettingClient;
|
|
|
import org.elasticsearch.cluster.ClusterState;
|
|
|
import org.elasticsearch.cluster.node.DiscoveryNode;
|
|
|
import org.elasticsearch.cluster.node.DiscoveryNodes;
|
|
|
import org.elasticsearch.cluster.service.ClusterService;
|
|
|
import org.elasticsearch.common.util.concurrent.EsExecutors;
|
|
|
+import org.elasticsearch.common.xcontent.XContentHelper;
|
|
|
import org.elasticsearch.discovery.MasterNotDiscoveredException;
|
|
|
+import org.elasticsearch.inference.TaskType;
|
|
|
import org.elasticsearch.ingest.IngestMetadata;
|
|
|
-import org.elasticsearch.ingest.IngestService;
|
|
|
import org.elasticsearch.injection.guice.Inject;
|
|
|
import org.elasticsearch.rest.RestStatus;
|
|
|
import org.elasticsearch.tasks.CancellableTask;
|
|
|
import org.elasticsearch.tasks.Task;
|
|
|
import org.elasticsearch.transport.TransportResponseHandler;
|
|
|
import org.elasticsearch.transport.TransportService;
|
|
|
+import org.elasticsearch.xcontent.XContentType;
|
|
|
+import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
|
|
|
import org.elasticsearch.xpack.core.ml.action.StopTrainedModelDeploymentAction;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentMetadata;
|
|
@@ -47,6 +52,7 @@ import java.util.Optional;
|
|
|
import java.util.Set;
|
|
|
|
|
|
import static org.elasticsearch.core.Strings.format;
|
|
|
+import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
|
|
|
import static org.elasticsearch.xpack.ml.action.TransportDeleteTrainedModelAction.getModelAliases;
|
|
|
|
|
|
/**
|
|
@@ -63,7 +69,7 @@ public class TransportStopTrainedModelDeploymentAction extends TransportTasksAct
|
|
|
|
|
|
private static final Logger logger = LogManager.getLogger(TransportStopTrainedModelDeploymentAction.class);
|
|
|
|
|
|
- private final IngestService ingestService;
|
|
|
+ private final OriginSettingClient client;
|
|
|
private final TrainedModelAssignmentClusterService trainedModelAssignmentClusterService;
|
|
|
private final InferenceAuditor auditor;
|
|
|
|
|
@@ -72,7 +78,7 @@ public class TransportStopTrainedModelDeploymentAction extends TransportTasksAct
|
|
|
ClusterService clusterService,
|
|
|
TransportService transportService,
|
|
|
ActionFilters actionFilters,
|
|
|
- IngestService ingestService,
|
|
|
+ Client client,
|
|
|
TrainedModelAssignmentClusterService trainedModelAssignmentClusterService,
|
|
|
InferenceAuditor auditor
|
|
|
) {
|
|
@@ -85,7 +91,7 @@ public class TransportStopTrainedModelDeploymentAction extends TransportTasksAct
|
|
|
StopTrainedModelDeploymentAction.Response::new,
|
|
|
EsExecutors.DIRECT_EXECUTOR_SERVICE
|
|
|
);
|
|
|
- this.ingestService = ingestService;
|
|
|
+ this.client = new OriginSettingClient(client, ML_ORIGIN);
|
|
|
this.trainedModelAssignmentClusterService = trainedModelAssignmentClusterService;
|
|
|
this.auditor = Objects.requireNonNull(auditor);
|
|
|
}
|
|
@@ -154,21 +160,84 @@ public class TransportStopTrainedModelDeploymentAction extends TransportTasksAct
|
|
|
|
|
|
// NOTE, should only run on Master node
|
|
|
assert clusterService.localNode().isMasterNode();
|
|
|
+
|
|
|
+ if (request.isForce() == false) {
|
|
|
+ checkIfUsedByInferenceEndpoint(
|
|
|
+ request.getId(),
|
|
|
+ ActionListener.wrap(canStop -> stopDeployment(task, request, maybeAssignment.get(), listener), listener::onFailure)
|
|
|
+ );
|
|
|
+ } else {
|
|
|
+ stopDeployment(task, request, maybeAssignment.get(), listener);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ private void stopDeployment(
|
|
|
+ Task task,
|
|
|
+ StopTrainedModelDeploymentAction.Request request,
|
|
|
+ TrainedModelAssignment assignment,
|
|
|
+ ActionListener<StopTrainedModelDeploymentAction.Response> listener
|
|
|
+ ) {
|
|
|
trainedModelAssignmentClusterService.setModelAssignmentToStopping(
|
|
|
request.getId(),
|
|
|
- ActionListener.wrap(
|
|
|
- setToStopping -> normalUndeploy(task, request.getId(), maybeAssignment.get(), request, listener),
|
|
|
- failure -> {
|
|
|
- if (ExceptionsHelper.unwrapCause(failure) instanceof ResourceNotFoundException) {
|
|
|
- listener.onResponse(new StopTrainedModelDeploymentAction.Response(true));
|
|
|
- return;
|
|
|
- }
|
|
|
- listener.onFailure(failure);
|
|
|
+ ActionListener.wrap(setToStopping -> normalUndeploy(task, request.getId(), assignment, request, listener), failure -> {
|
|
|
+ if (ExceptionsHelper.unwrapCause(failure) instanceof ResourceNotFoundException) {
|
|
|
+ listener.onResponse(new StopTrainedModelDeploymentAction.Response(true));
|
|
|
+ return;
|
|
|
}
|
|
|
- )
|
|
|
+ listener.onFailure(failure);
|
|
|
+ })
|
|
|
);
|
|
|
}
|
|
|
|
|
|
+ private void checkIfUsedByInferenceEndpoint(String deploymentId, ActionListener<Boolean> listener) {
|
|
|
+
|
|
|
+ GetInferenceModelAction.Request getAllEndpoints = new GetInferenceModelAction.Request("*", TaskType.ANY);
|
|
|
+ client.execute(GetInferenceModelAction.INSTANCE, getAllEndpoints, listener.delegateFailureAndWrap((l, response) -> {
|
|
|
+ // filter by the ml node services
|
|
|
+ var mlNodeEndpoints = response.getEndpoints()
|
|
|
+ .stream()
|
|
|
+ .filter(model -> model.getService().equals("elasticsearch") || model.getService().equals("elser"))
|
|
|
+ .toList();
|
|
|
+
|
|
|
+ var endpointOwnsDeployment = mlNodeEndpoints.stream()
|
|
|
+ .filter(model -> model.getInferenceEntityId().equals(deploymentId))
|
|
|
+ .findFirst();
|
|
|
+ if (endpointOwnsDeployment.isPresent()) {
|
|
|
+ l.onFailure(
|
|
|
+ new ElasticsearchStatusException(
|
|
|
+ "Cannot stop deployment [{}] as it was created by inference endpoint [{}]",
|
|
|
+ RestStatus.CONFLICT,
|
|
|
+ deploymentId,
|
|
|
+ endpointOwnsDeployment.get().getInferenceEntityId()
|
|
|
+ )
|
|
|
+ );
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ // The inference endpoint may have been created by attaching to an existing deployment.
|
|
|
+ for (var endpoint : mlNodeEndpoints) {
|
|
|
+ var serviceSettingsXContent = XContentHelper.toXContent(endpoint.getServiceSettings(), XContentType.JSON, false);
|
|
|
+ var settingsMap = XContentHelper.convertToMap(serviceSettingsXContent, false, XContentType.JSON).v2();
|
|
|
+ // Endpoints with the deployment_id setting are attached to an existing deployment.
|
|
|
+ var deploymentIdFromSettings = (String) settingsMap.get("deployment_id");
|
|
|
+ if (deploymentIdFromSettings != null && deploymentIdFromSettings.equals(deploymentId)) {
|
|
|
+ // The endpoint was created to use this deployment
|
|
|
+ l.onFailure(
|
|
|
+ new ElasticsearchStatusException(
|
|
|
+ "Cannot stop deployment [{}] as it is used by inference endpoint [{}]",
|
|
|
+ RestStatus.CONFLICT,
|
|
|
+ deploymentId,
|
|
|
+ endpoint.getInferenceEntityId()
|
|
|
+ )
|
|
|
+ );
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ l.onResponse(true);
|
|
|
+ }));
|
|
|
+ }
|
|
|
+
|
|
|
private void redirectToMasterNode(
|
|
|
DiscoveryNode masterNode,
|
|
|
StopTrainedModelDeploymentAction.Request request,
|