Browse Source

[ML] Check for model deployment in inference endpoints before stopping (#129325)

David Kyle 3 months ago
parent
commit
816caf70fc

+ 6 - 0
docs/changelog/129325.yaml

@@ -0,0 +1,6 @@
+pr: 129325
+summary: Check for model deployment in inference endpoints before stopping
+area: Machine Learning
+type: bug
+issues:
+ - 128549

+ 31 - 2
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CreateFromDeploymentIT.java

@@ -75,7 +75,7 @@ public class CreateFromDeploymentIT extends InferenceBaseRestTest {
         var deploymentStats = stats.get(0).get("deployment_stats");
         assertNotNull(stats.toString(), deploymentStats);
 
-        stopMlNodeDeployment(deploymentId);
+        forceStopMlNodeDeployment(deploymentId);
     }
 
     public void testAttachWithModelId() throws IOException {
@@ -146,7 +146,7 @@ public class CreateFromDeploymentIT extends InferenceBaseRestTest {
             )
         );
 
-        stopMlNodeDeployment(deploymentId);
+        forceStopMlNodeDeployment(deploymentId);
     }
 
     public void testModelIdDoesNotMatch() throws IOException {
@@ -229,6 +229,29 @@ public class CreateFromDeploymentIT extends InferenceBaseRestTest {
         );
     }
 
+    public void testStoppingDeploymentAttachedToInferenceEndpoint() throws IOException {
+        var modelId = "try_stop_attach_to_deployment";
+        var deploymentId = "test_stop_attach_to_deployment";
+
+        CustomElandModelIT.createMlNodeTextExpansionModel(modelId, client());
+        var response = startMlNodeDeploymemnt(modelId, deploymentId);
+        assertStatusOkOrCreated(response);
+
+        var inferenceId = "test_stop_inference_on_existing_deployment";
+        putModel(inferenceId, endpointConfig(deploymentId), TaskType.SPARSE_EMBEDDING);
+
+        var stopShouldNotSucceed = expectThrows(ResponseException.class, () -> stopMlNodeDeployment(deploymentId));
+        assertThat(
+            stopShouldNotSucceed.getMessage(),
+            containsString(
+                Strings.format("Cannot stop deployment [%s] as it is used by inference endpoint [%s]", deploymentId, inferenceId)
+            )
+        );
+
+        // Force stop will stop the deployment
+        forceStopMlNodeDeployment(deploymentId);
+    }
+
     private String endpointConfig(String deploymentId) {
         return Strings.format("""
             {
@@ -292,6 +315,12 @@ public class CreateFromDeploymentIT extends InferenceBaseRestTest {
     }
 
     protected void stopMlNodeDeployment(String deploymentId) throws IOException {
+        String endpoint = "/_ml/trained_models/" + deploymentId + "/deployment/_stop";
+        Request request = new Request("POST", endpoint);
+        client().performRequest(request);
+    }
+
+    protected void forceStopMlNodeDeployment(String deploymentId) throws IOException {
         String endpoint = "/_ml/trained_models/" + deploymentId + "/deployment/_stop";
         Request request = new Request("POST", endpoint);
         request.addParameter("force", "true");

+ 44 - 0
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CustomElandModelIT.java

@@ -8,6 +8,7 @@
 package org.elasticsearch.xpack.inference;
 
 import org.elasticsearch.client.Request;
+import org.elasticsearch.client.ResponseException;
 import org.elasticsearch.client.RestClient;
 import org.elasticsearch.core.Strings;
 import org.elasticsearch.inference.TaskType;
@@ -18,6 +19,8 @@ import java.util.Base64;
 import java.util.List;
 import java.util.stream.Collectors;
 
+import static org.hamcrest.Matchers.containsString;
+
 public class CustomElandModelIT extends InferenceBaseRestTest {
 
     // The model definition is taken from org.elasticsearch.xpack.ml.integration.TextExpansionQueryIT
@@ -92,6 +95,47 @@ public class CustomElandModelIT extends InferenceBaseRestTest {
         assertNotNull(results.get("sparse_embedding"));
     }
 
+    public void testCannotStopDeployment() throws IOException {
+        String modelId = "custom-model-that-cannot-be-stopped";
+
+        createTextExpansionModel(modelId, client());
+        putModelDefinition(modelId, BASE_64_ENCODED_MODEL, RAW_MODEL_SIZE, client());
+        putVocabulary(
+            List.of("these", "are", "my", "words", "the", "washing", "machine", "is", "leaking", "octopus", "comforter", "smells"),
+            modelId,
+            client()
+        );
+
+        var inferenceConfig = """
+            {
+              "service": "elasticsearch",
+              "service_settings": {
+                "model_id": "custom-model-that-cannot-be-stopped",
+                "num_allocations": 1,
+                "num_threads": 1
+              }
+            }
+            """;
+
+        var inferenceId = "sparse-inf";
+        putModel(inferenceId, inferenceConfig, TaskType.SPARSE_EMBEDDING);
+        infer(inferenceId, List.of("washing", "machine"));
+
+        // Stopping the deployment using the ML trained models API should fail
+        // because the deployment was created by the inference endpoint API
+        String stopEndpoint = org.elasticsearch.common.Strings.format("_ml/trained_models/%s/deployment/_stop?error_trace", inferenceId);
+        Request stopRequest = new Request("POST", stopEndpoint);
+        var e = expectThrows(ResponseException.class, () -> client().performRequest(stopRequest));
+        assertThat(
+            e.getMessage(),
+            containsString("Cannot stop deployment [sparse-inf] as it was created by inference endpoint [sparse-inf]")
+        );
+
+        // Force stop works
+        String forceStopEndpoint = org.elasticsearch.common.Strings.format("_ml/trained_models/%s/deployment/_stop?force", inferenceId);
+        assertStatusOkOrCreated(client().performRequest(new Request("POST", forceStopEndpoint)));
+    }
+
     static void createTextExpansionModel(String modelId, RestClient client) throws IOException {
         // with_special_tokens: false for this test with limited vocab
         Request request = new Request("PUT", "/_ml/trained_models/" + modelId);

+ 82 - 13
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStopTrainedModelDeploymentAction.java

@@ -17,20 +17,25 @@ import org.elasticsearch.action.FailedNodeException;
 import org.elasticsearch.action.TaskOperationFailure;
 import org.elasticsearch.action.support.ActionFilters;
 import org.elasticsearch.action.support.tasks.TransportTasksAction;
+import org.elasticsearch.client.internal.Client;
+import org.elasticsearch.client.internal.OriginSettingClient;
 import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.node.DiscoveryNodes;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.util.concurrent.EsExecutors;
+import org.elasticsearch.common.xcontent.XContentHelper;
 import org.elasticsearch.discovery.MasterNotDiscoveredException;
+import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.ingest.IngestMetadata;
-import org.elasticsearch.ingest.IngestService;
 import org.elasticsearch.injection.guice.Inject;
 import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.tasks.CancellableTask;
 import org.elasticsearch.tasks.Task;
 import org.elasticsearch.transport.TransportResponseHandler;
 import org.elasticsearch.transport.TransportService;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
 import org.elasticsearch.xpack.core.ml.action.StopTrainedModelDeploymentAction;
 import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
 import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentMetadata;
@@ -47,6 +52,7 @@ import java.util.Optional;
 import java.util.Set;
 
 import static org.elasticsearch.core.Strings.format;
+import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
 import static org.elasticsearch.xpack.ml.action.TransportDeleteTrainedModelAction.getModelAliases;
 
 /**
@@ -63,7 +69,7 @@ public class TransportStopTrainedModelDeploymentAction extends TransportTasksAct
 
     private static final Logger logger = LogManager.getLogger(TransportStopTrainedModelDeploymentAction.class);
 
-    private final IngestService ingestService;
+    private final OriginSettingClient client;
     private final TrainedModelAssignmentClusterService trainedModelAssignmentClusterService;
     private final InferenceAuditor auditor;
 
@@ -72,7 +78,7 @@ public class TransportStopTrainedModelDeploymentAction extends TransportTasksAct
         ClusterService clusterService,
         TransportService transportService,
         ActionFilters actionFilters,
-        IngestService ingestService,
+        Client client,
         TrainedModelAssignmentClusterService trainedModelAssignmentClusterService,
         InferenceAuditor auditor
     ) {
@@ -85,7 +91,7 @@ public class TransportStopTrainedModelDeploymentAction extends TransportTasksAct
             StopTrainedModelDeploymentAction.Response::new,
             EsExecutors.DIRECT_EXECUTOR_SERVICE
         );
-        this.ingestService = ingestService;
+        this.client = new OriginSettingClient(client, ML_ORIGIN);
         this.trainedModelAssignmentClusterService = trainedModelAssignmentClusterService;
         this.auditor = Objects.requireNonNull(auditor);
     }
@@ -154,21 +160,84 @@ public class TransportStopTrainedModelDeploymentAction extends TransportTasksAct
 
         // NOTE, should only run on Master node
         assert clusterService.localNode().isMasterNode();
+
+        if (request.isForce() == false) {
+            checkIfUsedByInferenceEndpoint(
+                request.getId(),
+                ActionListener.wrap(canStop -> stopDeployment(task, request, maybeAssignment.get(), listener), listener::onFailure)
+            );
+        } else {
+            stopDeployment(task, request, maybeAssignment.get(), listener);
+        }
+    }
+
+    private void stopDeployment(
+        Task task,
+        StopTrainedModelDeploymentAction.Request request,
+        TrainedModelAssignment assignment,
+        ActionListener<StopTrainedModelDeploymentAction.Response> listener
+    ) {
         trainedModelAssignmentClusterService.setModelAssignmentToStopping(
             request.getId(),
-            ActionListener.wrap(
-                setToStopping -> normalUndeploy(task, request.getId(), maybeAssignment.get(), request, listener),
-                failure -> {
-                    if (ExceptionsHelper.unwrapCause(failure) instanceof ResourceNotFoundException) {
-                        listener.onResponse(new StopTrainedModelDeploymentAction.Response(true));
-                        return;
-                    }
-                    listener.onFailure(failure);
+            ActionListener.wrap(setToStopping -> normalUndeploy(task, request.getId(), assignment, request, listener), failure -> {
+                if (ExceptionsHelper.unwrapCause(failure) instanceof ResourceNotFoundException) {
+                    listener.onResponse(new StopTrainedModelDeploymentAction.Response(true));
+                    return;
                 }
-            )
+                listener.onFailure(failure);
+            })
         );
     }
 
+    private void checkIfUsedByInferenceEndpoint(String deploymentId, ActionListener<Boolean> listener) {
+
+        GetInferenceModelAction.Request getAllEndpoints = new GetInferenceModelAction.Request("*", TaskType.ANY);
+        client.execute(GetInferenceModelAction.INSTANCE, getAllEndpoints, listener.delegateFailureAndWrap((l, response) -> {
+            // filter by the ml node services
+            var mlNodeEndpoints = response.getEndpoints()
+                .stream()
+                .filter(model -> model.getService().equals("elasticsearch") || model.getService().equals("elser"))
+                .toList();
+
+            var endpointOwnsDeployment = mlNodeEndpoints.stream()
+                .filter(model -> model.getInferenceEntityId().equals(deploymentId))
+                .findFirst();
+            if (endpointOwnsDeployment.isPresent()) {
+                l.onFailure(
+                    new ElasticsearchStatusException(
+                        "Cannot stop deployment [{}] as it was created by inference endpoint [{}]",
+                        RestStatus.CONFLICT,
+                        deploymentId,
+                        endpointOwnsDeployment.get().getInferenceEntityId()
+                    )
+                );
+                return;
+            }
+
+            // The inference endpoint may have been created by attaching to an existing deployment.
+            for (var endpoint : mlNodeEndpoints) {
+                var serviceSettingsXContent = XContentHelper.toXContent(endpoint.getServiceSettings(), XContentType.JSON, false);
+                var settingsMap = XContentHelper.convertToMap(serviceSettingsXContent, false, XContentType.JSON).v2();
+                // Endpoints with the deployment_id setting are attached to an existing deployment.
+                var deploymentIdFromSettings = (String) settingsMap.get("deployment_id");
+                if (deploymentIdFromSettings != null && deploymentIdFromSettings.equals(deploymentId)) {
+                    // The endpoint was created to use this deployment
+                    l.onFailure(
+                        new ElasticsearchStatusException(
+                            "Cannot stop deployment [{}] as it is used by inference endpoint [{}]",
+                            RestStatus.CONFLICT,
+                            deploymentId,
+                            endpoint.getInferenceEntityId()
+                        )
+                    );
+                    return;
+                }
+            }
+
+            l.onResponse(true);
+        }));
+    }
+
     private void redirectToMasterNode(
         DiscoveryNode masterNode,
         StopTrainedModelDeploymentAction.Request request,