1
0
Эх сурвалжийг харах

[ML] Undeploy elser when inference model deleted (#104230)

* Added stop top InferenceService interface and Elser

* New integration tests

* undeploy ELSER deployment when _inf ELSER model deleted

* Update docs/changelog/104230.yaml

* Added check for platform architecture in integration test

* improvements from PR comments
Max Hniebergall 1 жил өмнө
parent
commit
31e89890de

+ 5 - 0
docs/changelog/104230.yaml

@@ -0,0 +1,5 @@
+pr: 104230
+summary: Undeploy elser when inference model deleted
+area: Machine Learning
+type: bug
+issues: []

+ 10 - 0
server/src/main/java/org/elasticsearch/inference/InferenceService.java

@@ -87,6 +87,16 @@ public interface InferenceService extends Closeable {
      */
     void start(Model model, ActionListener<Boolean> listener);
 
+    /**
+     * Stop the model deployment.
+     * The default action does nothing except acknowledge the request (true).
+     * @param modelId The ID of the model to be stopped
+     * @param listener The listener
+     */
+    default void stop(String modelId, ActionListener<Boolean> listener) {
+        listener.onResponse(true);
+    }
+
     /**
      * Optionally test the new model configuration in the inference service.
      * This function should be called when the model is first created, the

+ 36 - 0
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java

@@ -15,6 +15,7 @@ import org.elasticsearch.common.settings.SecureString;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
 import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.plugins.Platforms;
 import org.elasticsearch.test.cluster.ElasticsearchCluster;
 import org.elasticsearch.test.cluster.local.distribution.DistributionType;
 import org.elasticsearch.test.rest.ESRestTestCase;
@@ -64,6 +65,25 @@ public class InferenceBaseRestTest extends ESRestTestCase {
             """;
     }
 
+    protected Map<String, Object> downloadElserBlocking() throws IOException {
+        String endpoint = "_ml/trained_models/.elser_model_2?wait_for_completion=true";
+        if ("linux-x86_64".equals(Platforms.PLATFORM_NAME)) {
+            endpoint = "_ml/trained_models/.elser_model_2_linux-x86_64?wait_for_completion=true";
+        }
+        String body = """
+            {
+                "input": {
+                "field_names": ["text_field"]
+                }
+            }
+            """;
+        var request = new Request("PUT", endpoint);
+        request.setJsonEntity(body);
+        var response = client().performRequest(request);
+        assertOkOrCreated(response);
+        return entityAsMap(response);
+    }
+
     protected Map<String, Object> putModel(String modelId, String modelConfig, TaskType taskType) throws IOException {
         String endpoint = Strings.format("_inference/%s/%s", taskType, modelId);
         var request = new Request("PUT", endpoint);
@@ -73,6 +93,14 @@ public class InferenceBaseRestTest extends ESRestTestCase {
         return entityAsMap(response);
     }
 
+    protected Map<String, Object> deleteModel(String modelId, TaskType taskType) throws IOException {
+        var endpoint = Strings.format("_inference/%s/%s", taskType, modelId);
+        var request = new Request("DELETE", endpoint);
+        var response = client().performRequest(request);
+        assertOkOrCreated(response);
+        return entityAsMap(response);
+    }
+
     protected Map<String, Object> getModels(String modelId, TaskType taskType) throws IOException {
         var endpoint = Strings.format("_inference/%s/%s", taskType, modelId);
         var request = new Request("GET", endpoint);
@@ -89,6 +117,14 @@ public class InferenceBaseRestTest extends ESRestTestCase {
         return entityAsMap(response);
     }
 
+    protected Map<String, Object> getTrainedModel(String modelId) throws IOException {
+        var endpoint = Strings.format("_ml/trained_models/%s/_stats", modelId);
+        var request = new Request("GET", endpoint);
+        var response = client().performRequest(request);
+        assertOkOrCreated(response);
+        return entityAsMap(response);
+    }
+
     protected Map<String, Object> inferOnMockService(String modelId, TaskType taskType, List<String> input) throws IOException {
         var endpoint = Strings.format("_inference/%s/%s", taskType, modelId);
         var request = new Request("POST", endpoint);

+ 37 - 0
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java

@@ -3,6 +3,8 @@
  * or more contributor license agreements. Licensed under the Elastic License
  * 2.0; you may not use this file except in compliance with the Elastic License
  * 2.0.
+ *
+ * this file has been contributed to by a Generative AI
  */
 
 package org.elasticsearch.xpack.inference;
@@ -16,9 +18,44 @@ import java.util.Map;
 
 import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.hasSize;
+import static org.hamcrest.Matchers.not;
 
 public class InferenceCrudIT extends InferenceBaseRestTest {
 
+    public void testElserCrud() throws IOException {
+
+        String elserConfig = """
+            {
+              "service": "elser",
+              "service_settings": {
+                "num_allocations": 1,
+                "num_threads": 1
+              },
+              "task_settings": {}
+            }
+            """;
+
+        // ELSER not downloaded case
+        {
+            String modelId = randomAlphaOfLength(10).toLowerCase();
+            expectThrows(ResponseException.class, () -> putModel(modelId, elserConfig, TaskType.SPARSE_EMBEDDING));
+        }
+
+        downloadElserBlocking();
+
+        // Happy case
+        {
+            String modelId = randomAlphaOfLength(10).toLowerCase();
+            putModel(modelId, elserConfig, TaskType.SPARSE_EMBEDDING);
+            var models = getModels(modelId, TaskType.SPARSE_EMBEDDING);
+            assertThat(models.get("models").toString(), containsString("model_id=" + modelId));
+            deleteModel(modelId, TaskType.SPARSE_EMBEDDING);
+            expectThrows(ResponseException.class, () -> getModels(modelId, TaskType.SPARSE_EMBEDDING));
+            models = getTrainedModel("_all");
+            assertThat(models.toString(), not(containsString("deployment_id=" + modelId)));
+        }
+    }
+
     @SuppressWarnings("unchecked")
     public void testGet() throws IOException {
         for (int i = 0; i < 5; i++) {

+ 31 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceModelAction.java

@@ -7,8 +7,12 @@
 
 package org.elasticsearch.xpack.inference.action;
 
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.support.ActionFilters;
+import org.elasticsearch.action.support.SubscribableListener;
 import org.elasticsearch.action.support.master.AcknowledgedResponse;
 import org.elasticsearch.action.support.master.AcknowledgedTransportMasterNodeAction;
 import org.elasticsearch.cluster.ClusterState;
@@ -18,6 +22,8 @@ import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.inject.Inject;
 import org.elasticsearch.common.util.concurrent.EsExecutors;
+import org.elasticsearch.inference.InferenceServiceRegistry;
+import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.tasks.Task;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.TransportService;
@@ -26,7 +32,10 @@ import org.elasticsearch.xpack.inference.registry.ModelRegistry;
 
 public class TransportDeleteInferenceModelAction extends AcknowledgedTransportMasterNodeAction<DeleteInferenceModelAction.Request> {
 
+    private static final Logger logger = LogManager.getLogger(TransportPutInferenceModelAction.class);
+
     private final ModelRegistry modelRegistry;
+    private final InferenceServiceRegistry serviceRegistry;
 
     @Inject
     public TransportDeleteInferenceModelAction(
@@ -35,7 +44,8 @@ public class TransportDeleteInferenceModelAction extends AcknowledgedTransportMa
         ThreadPool threadPool,
         ActionFilters actionFilters,
         IndexNameExpressionResolver indexNameExpressionResolver,
-        ModelRegistry modelRegistry
+        ModelRegistry modelRegistry,
+        InferenceServiceRegistry serviceRegistry
     ) {
         super(
             DeleteInferenceModelAction.NAME,
@@ -48,6 +58,7 @@ public class TransportDeleteInferenceModelAction extends AcknowledgedTransportMa
             EsExecutors.DIRECT_EXECUTOR_SERVICE
         );
         this.modelRegistry = modelRegistry;
+        this.serviceRegistry = serviceRegistry;
     }
 
     @Override
@@ -57,11 +68,29 @@ public class TransportDeleteInferenceModelAction extends AcknowledgedTransportMa
         ClusterState state,
         ActionListener<AcknowledgedResponse> listener
     ) {
-        modelRegistry.deleteModel(request.getModelId(), listener.delegateFailureAndWrap((l, r) -> l.onResponse(AcknowledgedResponse.TRUE)));
+        SubscribableListener.<ModelRegistry.UnparsedModel>newForked(modelConfigListener -> {
+            modelRegistry.getModel(request.getModelId(), modelConfigListener);
+        }).<Boolean>andThen((l1, unparsedModel) -> {
+            var service = serviceRegistry.getService(unparsedModel.service());
+            if (service.isPresent()) {
+                service.get().stop(request.getModelId(), l1);
+            } else {
+                l1.onFailure(new ElasticsearchStatusException("No service found for model " + request.getModelId(), RestStatus.NOT_FOUND));
+            }
+        }).<Boolean>andThen((l2, didStop) -> {
+            if (didStop) {
+                modelRegistry.deleteModel(request.getModelId(), l2);
+            } else {
+                l2.onFailure(
+                    new ElasticsearchStatusException("Failed to stop model " + request.getModelId(), RestStatus.INTERNAL_SERVER_ERROR)
+                );
+            }
+        }).addListener(listener.delegateFailure((l3, didDeleteModel) -> listener.onResponse(AcknowledgedResponse.of(didDeleteModel))));
     }
 
     @Override
     protected ClusterBlockException checkBlock(DeleteInferenceModelAction.Request request, ClusterState state) {
         return state.blocks().globalBlockedException(ClusterBlockLevel.WRITE);
     }
+
 }

+ 12 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeService.java

@@ -3,6 +3,8 @@
  * or more contributor license agreements. Licensed under the Elastic License
  * 2.0; you may not use this file except in compliance with the Elastic License
  * 2.0.
+ *
+ * this file was contributed to by a generative AI
  */
 
 package org.elasticsearch.xpack.inference.services.elser;
@@ -24,6 +26,7 @@ import org.elasticsearch.xpack.core.ClientHelper;
 import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
 import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
 import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
+import org.elasticsearch.xpack.core.ml.action.StopTrainedModelDeploymentAction;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigUpdate;
 
 import java.io.IOException;
@@ -161,6 +164,15 @@ public class ElserMlNodeService implements InferenceService {
         );
     }
 
+    @Override
+    public void stop(String modelId, ActionListener<Boolean> listener) {
+        client.execute(
+            StopTrainedModelDeploymentAction.INSTANCE,
+            new StopTrainedModelDeploymentAction.Request(modelId),
+            listener.delegateFailureAndWrap((delegatedResponseListener, response) -> delegatedResponseListener.onResponse(Boolean.TRUE))
+        );
+    }
+
     @Override
     public void infer(Model model, List<String> input, Map<String, Object> taskSettings, ActionListener<InferenceServiceResults> listener) {
         // No task settings to override with requestTaskSettings