浏览代码

[Inference API] Semantic text delete inference (#110487)

* Prevent inference endpoints from being deleted if they are referenced by a semantic text field

* Update docs/changelog/110399.yaml

* fix tests

* remove erroneous loging

* Apply suggestions from code review

Co-authored-by: David Kyle <david.kyle@elastic.co>
Co-authored-by: Carlos Delgado <6339205+carlosdelest@users.noreply.github.com>

* Fix serialization problem

* Update error messages

* Update Delete response to include new fields

* Refactor Delete Transport Action to return the error message on dry run

* Fix tests including disabling failing yaml tests

* Fix YAML tests

* move work off of transport thread onto utility threadpool

* clean up semantic text indexes after IT

* improvements from review

---------

Co-authored-by: David Kyle <david.kyle@elastic.co>
Co-authored-by: Carlos Delgado <6339205+carlosdelest@users.noreply.github.com>
Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
Max Hniebergall 1 年之前
父节点
当前提交
320b88ae37

+ 6 - 0
docs/changelog/110399.yaml

@@ -0,0 +1,6 @@
+pr: 110399
+summary: "[Inference API] Prevent inference endpoints from being deleted if they are\
+  \ referenced by semantic text"
+area: Machine Learning
+type: enhancement
+issues: []

+ 1 - 0
server/src/main/java/org/elasticsearch/TransportVersions.java

@@ -209,6 +209,7 @@ public class TransportVersions {
     public static final TransportVersion ML_INFERENCE_GOOGLE_VERTEX_AI_RERANKING_ADDED = def(8_700_00_0);
     public static final TransportVersion VERSIONED_MASTER_NODE_REQUESTS = def(8_701_00_0);
     public static final TransportVersion ML_INFERENCE_AMAZON_BEDROCK_ADDED = def(8_702_00_0);
+    public static final TransportVersion ML_INFERENCE_DONT_DELETE_WHEN_SEMANTIC_TEXT_EXISTS = def(8_703_00_0);
 
     /*
      * STOP! READ THIS FIRST! No, really,

+ 27 - 8
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/DeleteInferenceEndpointAction.java

@@ -11,8 +11,10 @@ import org.elasticsearch.TransportVersions;
 import org.elasticsearch.action.ActionType;
 import org.elasticsearch.action.support.master.AcknowledgedRequest;
 import org.elasticsearch.action.support.master.AcknowledgedResponse;
+import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.core.Nullable;
 import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.xcontent.XContentBuilder;
 
@@ -105,10 +107,16 @@ public class DeleteInferenceEndpointAction extends ActionType<DeleteInferenceEnd
 
         private final String PIPELINE_IDS = "pipelines";
         Set<String> pipelineIds;
+        private final String REFERENCED_INDEXES = "indexes";
+        Set<String> indexes;
+        private final String DRY_RUN_MESSAGE = "error_message"; // error message only returned in response for dry_run
+        String dryRunMessage;
 
-        public Response(boolean acknowledged, Set<String> pipelineIds) {
+        public Response(boolean acknowledged, Set<String> pipelineIds, Set<String> semanticTextIndexes, @Nullable String dryRunMessage) {
             super(acknowledged);
             this.pipelineIds = pipelineIds;
+            this.indexes = semanticTextIndexes;
+            this.dryRunMessage = dryRunMessage;
         }
 
         public Response(StreamInput in) throws IOException {
@@ -118,6 +126,15 @@ public class DeleteInferenceEndpointAction extends ActionType<DeleteInferenceEnd
             } else {
                 pipelineIds = Set.of();
             }
+
+            if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_DONT_DELETE_WHEN_SEMANTIC_TEXT_EXISTS)) {
+                indexes = in.readCollectionAsSet(StreamInput::readString);
+                dryRunMessage = in.readOptionalString();
+            } else {
+                indexes = Set.of();
+                dryRunMessage = null;
+            }
+
         }
 
         @Override
@@ -126,23 +143,25 @@ public class DeleteInferenceEndpointAction extends ActionType<DeleteInferenceEnd
             if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_ENHANCE_DELETE_ENDPOINT)) {
                 out.writeCollection(pipelineIds, StreamOutput::writeString);
             }
+            if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_DONT_DELETE_WHEN_SEMANTIC_TEXT_EXISTS)) {
+                out.writeCollection(indexes, StreamOutput::writeString);
+                out.writeOptionalString(dryRunMessage);
+            }
         }
 
         @Override
         protected void addCustomFields(XContentBuilder builder, Params params) throws IOException {
             super.addCustomFields(builder, params);
             builder.field(PIPELINE_IDS, pipelineIds);
+            builder.field(REFERENCED_INDEXES, indexes);
+            if (dryRunMessage != null) {
+                builder.field(DRY_RUN_MESSAGE, dryRunMessage);
+            }
         }
 
         @Override
         public String toString() {
-            StringBuilder returnable = new StringBuilder();
-            returnable.append("acknowledged: ").append(this.acknowledged);
-            returnable.append(", pipelineIdsByEndpoint: ");
-            for (String entry : pipelineIds) {
-                returnable.append(entry).append(", ");
-            }
-            return returnable.toString();
+            return Strings.toString(this);
         }
     }
 }

+ 50 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/SemanticTextInfoExtractor.java

@@ -0,0 +1,50 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ *
+ * this file was contributed to by a Generative AI
+ */
+
+package org.elasticsearch.xpack.core.ml.utils;
+
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.elasticsearch.cluster.metadata.IndexMetadata;
+import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
+import org.elasticsearch.cluster.metadata.Metadata;
+import org.elasticsearch.transport.Transports;
+
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+
+public class SemanticTextInfoExtractor {
+    private static final Logger logger = LogManager.getLogger(SemanticTextInfoExtractor.class);
+
+    public static Set<String> extractIndexesReferencingInferenceEndpoints(Metadata metadata, Set<String> endpointIds) {
+        assert Transports.assertNotTransportThread("non-trivial nested loops over cluster state structures");
+        assert endpointIds.isEmpty() == false;
+        assert metadata != null;
+
+        Set<String> referenceIndices = new HashSet<>();
+
+        Map<String, IndexMetadata> indices = metadata.indices();
+
+        indices.forEach((indexName, indexMetadata) -> {
+            if (indexMetadata.getInferenceFields() != null) {
+                Map<String, InferenceFieldMetadata> inferenceFields = indexMetadata.getInferenceFields();
+                if (inferenceFields.entrySet()
+                    .stream()
+                    .anyMatch(
+                        entry -> entry.getValue().getInferenceId() != null && endpointIds.contains(entry.getValue().getInferenceId())
+                    )) {
+                    referenceIndices.add(indexName);
+                }
+            }
+        });
+
+        return referenceIndices;
+    }
+}

+ 19 - 0
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java

@@ -126,6 +126,25 @@ public class InferenceBaseRestTest extends ESRestTestCase {
         assertOkOrCreated(response);
     }
 
+    protected void putSemanticText(String endpointId, String indexName) throws IOException {
+        var request = new Request("PUT", Strings.format("%s", indexName));
+        String body = Strings.format("""
+            {
+                "mappings": {
+                "properties": {
+                    "inference_field": {
+                        "type": "semantic_text",
+                            "inference_id": "%s"
+                    }
+                }
+                }
+            }
+            """, endpointId);
+        request.setJsonEntity(body);
+        var response = client().performRequest(request);
+        assertOkOrCreated(response);
+    }
+
     protected Map<String, Object> putModel(String modelId, String modelConfig, TaskType taskType) throws IOException {
         String endpoint = Strings.format("_inference/%s/%s", taskType, modelId);
         return putRequest(endpoint, modelConfig);

+ 83 - 7
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java

@@ -16,6 +16,7 @@ import org.elasticsearch.inference.TaskType;
 
 import java.io.IOException;
 import java.util.List;
+import java.util.Set;
 
 import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.hasSize;
@@ -124,14 +125,15 @@ public class InferenceCrudIT extends InferenceBaseRestTest {
         putPipeline(pipelineId, endpointId);
 
         {
+            var errorString = new StringBuilder().append("Inference endpoint ")
+                .append(endpointId)
+                .append(" is referenced by pipelines: ")
+                .append(Set.of(pipelineId))
+                .append(". ")
+                .append("Ensure that no pipelines are using this inference endpoint, ")
+                .append("or use force to ignore this warning and delete the inference endpoint.");
             var e = expectThrows(ResponseException.class, () -> deleteModel(endpointId));
-            assertThat(
-                e.getMessage(),
-                containsString(
-                    "Inference endpoint endpoint_referenced_by_pipeline is referenced by pipelines and cannot be deleted. "
-                        + "Use `force` to delete it anyway, or use `dry_run` to list the pipelines that reference it."
-                )
-            );
+            assertThat(e.getMessage(), containsString(errorString.toString()));
         }
         {
             var response = deleteModel(endpointId, "dry_run=true");
@@ -146,4 +148,78 @@ public class InferenceCrudIT extends InferenceBaseRestTest {
         }
         deletePipeline(pipelineId);
     }
+
+    public void testDeleteEndpointWhileReferencedBySemanticText() throws IOException {
+        String endpointId = "endpoint_referenced_by_semantic_text";
+        putModel(endpointId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
+        String indexName = randomAlphaOfLength(10).toLowerCase();
+        putSemanticText(endpointId, indexName);
+        {
+
+            var errorString = new StringBuilder().append(" Inference endpoint ")
+                .append(endpointId)
+                .append(" is being used in the mapping for indexes: ")
+                .append(Set.of(indexName))
+                .append(". ")
+                .append("Ensure that no index mappings are using this inference endpoint, ")
+                .append("or use force to ignore this warning and delete the inference endpoint.");
+            var e = expectThrows(ResponseException.class, () -> deleteModel(endpointId));
+            assertThat(e.getMessage(), containsString(errorString.toString()));
+        }
+        {
+            var response = deleteModel(endpointId, "dry_run=true");
+            var entityString = EntityUtils.toString(response.getEntity());
+            assertThat(entityString, containsString("\"acknowledged\":false"));
+            assertThat(entityString, containsString(indexName));
+        }
+        {
+            var response = deleteModel(endpointId, "force=true");
+            var entityString = EntityUtils.toString(response.getEntity());
+            assertThat(entityString, containsString("\"acknowledged\":true"));
+        }
+        deleteIndex(indexName);
+    }
+
+    public void testDeleteEndpointWhileReferencedBySemanticTextAndPipeline() throws IOException {
+        String endpointId = "endpoint_referenced_by_semantic_text";
+        putModel(endpointId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
+        String indexName = randomAlphaOfLength(10).toLowerCase();
+        putSemanticText(endpointId, indexName);
+        var pipelineId = "pipeline_referencing_model";
+        putPipeline(pipelineId, endpointId);
+        {
+
+            var errorString = new StringBuilder().append("Inference endpoint ")
+                .append(endpointId)
+                .append(" is referenced by pipelines: ")
+                .append(Set.of(pipelineId))
+                .append(". ")
+                .append("Ensure that no pipelines are using this inference endpoint, ")
+                .append("or use force to ignore this warning and delete the inference endpoint.")
+                .append(" Inference endpoint ")
+                .append(endpointId)
+                .append(" is being used in the mapping for indexes: ")
+                .append(Set.of(indexName))
+                .append(". ")
+                .append("Ensure that no index mappings are using this inference endpoint, ")
+                .append("or use force to ignore this warning and delete the inference endpoint.");
+
+            var e = expectThrows(ResponseException.class, () -> deleteModel(endpointId));
+            assertThat(e.getMessage(), containsString(errorString.toString()));
+        }
+        {
+            var response = deleteModel(endpointId, "dry_run=true");
+            var entityString = EntityUtils.toString(response.getEntity());
+            assertThat(entityString, containsString("\"acknowledged\":false"));
+            assertThat(entityString, containsString(indexName));
+            assertThat(entityString, containsString(pipelineId));
+        }
+        {
+            var response = deleteModel(endpointId, "force=true");
+            var entityString = EntityUtils.toString(response.getEntity());
+            assertThat(entityString, containsString("\"acknowledged\":true"));
+        }
+        deletePipeline(pipelineId);
+        deleteIndex(indexName);
+    }
 }

+ 94 - 44
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java

@@ -3,6 +3,8 @@
  * or more contributor license agreements. Licensed under the Elastic License
  * 2.0; you may not use this file except in compliance with the Elastic License
  * 2.0.
+ *
+ * this file was contributed to by a Generative AI
  */
 
 package org.elasticsearch.xpack.inference.action;
@@ -11,6 +13,7 @@ import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.ActionRunnable;
 import org.elasticsearch.action.support.ActionFilters;
 import org.elasticsearch.action.support.SubscribableListener;
 import org.elasticsearch.action.support.master.TransportMasterNodeAction;
@@ -18,12 +21,10 @@ import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.cluster.block.ClusterBlockException;
 import org.elasticsearch.cluster.block.ClusterBlockLevel;
 import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
-import org.elasticsearch.cluster.metadata.Metadata;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.inject.Inject;
 import org.elasticsearch.common.util.concurrent.EsExecutors;
 import org.elasticsearch.inference.InferenceServiceRegistry;
-import org.elasticsearch.ingest.IngestMetadata;
 import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.tasks.Task;
 import org.elasticsearch.threadpool.ThreadPool;
@@ -34,6 +35,10 @@ import org.elasticsearch.xpack.inference.common.InferenceExceptions;
 import org.elasticsearch.xpack.inference.registry.ModelRegistry;
 
 import java.util.Set;
+import java.util.concurrent.Executor;
+
+import static org.elasticsearch.xpack.core.ml.utils.SemanticTextInfoExtractor.extractIndexesReferencingInferenceEndpoints;
+import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME;
 
 public class TransportDeleteInferenceEndpointAction extends TransportMasterNodeAction<
     DeleteInferenceEndpointAction.Request,
@@ -42,6 +47,7 @@ public class TransportDeleteInferenceEndpointAction extends TransportMasterNodeA
     private final ModelRegistry modelRegistry;
     private final InferenceServiceRegistry serviceRegistry;
     private static final Logger logger = LogManager.getLogger(TransportDeleteInferenceEndpointAction.class);
+    private final Executor executor;
 
     @Inject
     public TransportDeleteInferenceEndpointAction(
@@ -66,6 +72,7 @@ public class TransportDeleteInferenceEndpointAction extends TransportMasterNodeA
         );
         this.modelRegistry = modelRegistry;
         this.serviceRegistry = serviceRegistry;
+        this.executor = threadPool.executor(UTILITY_THREAD_POOL_NAME);
     }
 
     @Override
@@ -74,6 +81,15 @@ public class TransportDeleteInferenceEndpointAction extends TransportMasterNodeA
         DeleteInferenceEndpointAction.Request request,
         ClusterState state,
         ActionListener<DeleteInferenceEndpointAction.Response> masterListener
+    ) {
+        // workaround for https://github.com/elastic/elasticsearch/issues/97916 - TODO remove this when we can
+        executor.execute(ActionRunnable.wrap(masterListener, l -> doExecuteForked(request, state, l)));
+    }
+
+    private void doExecuteForked(
+        DeleteInferenceEndpointAction.Request request,
+        ClusterState state,
+        ActionListener<DeleteInferenceEndpointAction.Response> masterListener
     ) {
         SubscribableListener.<ModelRegistry.UnparsedModel>newForked(modelConfigListener -> {
             // Get the model from the registry
@@ -89,17 +105,15 @@ public class TransportDeleteInferenceEndpointAction extends TransportMasterNodeA
             }
 
             if (request.isDryRun()) {
-                masterListener.onResponse(
-                    new DeleteInferenceEndpointAction.Response(
-                        false,
-                        InferenceProcessorInfoExtractor.pipelineIdsForResource(state, Set.of(request.getInferenceEndpointId()))
-                    )
-                );
+                handleDryRun(request, state, masterListener);
                 return;
-            } else if (request.isForceDelete() == false
-                && endpointIsReferencedInPipelines(state, request.getInferenceEndpointId(), listener)) {
+            } else if (request.isForceDelete() == false) {
+                var errorString = endpointIsReferencedInPipelinesOrIndexes(state, request.getInferenceEndpointId());
+                if (errorString != null) {
+                    listener.onFailure(new ElasticsearchStatusException(errorString, RestStatus.CONFLICT));
                     return;
                 }
+            }
 
             var service = serviceRegistry.getService(unparsedModel.service());
             if (service.isPresent()) {
@@ -126,47 +140,83 @@ public class TransportDeleteInferenceEndpointAction extends TransportMasterNodeA
         })
             .addListener(
                 masterListener.delegateFailure(
-                    (l3, didDeleteModel) -> masterListener.onResponse(new DeleteInferenceEndpointAction.Response(didDeleteModel, Set.of()))
+                    (l3, didDeleteModel) -> masterListener.onResponse(
+                        new DeleteInferenceEndpointAction.Response(didDeleteModel, Set.of(), Set.of(), null)
+                    )
                 )
             );
     }
 
-    private static boolean endpointIsReferencedInPipelines(
-        final ClusterState state,
-        final String inferenceEndpointId,
-        ActionListener<Boolean> listener
+    private static void handleDryRun(
+        DeleteInferenceEndpointAction.Request request,
+        ClusterState state,
+        ActionListener<DeleteInferenceEndpointAction.Response> masterListener
     ) {
-        Metadata metadata = state.getMetadata();
-        if (metadata == null) {
-            listener.onFailure(
-                new ElasticsearchStatusException(
-                    " Could not determine if the endpoint is referenced in a pipeline as cluster state metadata was unexpectedly null. "
-                        + "Use `force` to delete it anyway",
-                    RestStatus.INTERNAL_SERVER_ERROR
-                )
-            );
-            // Unsure why the ClusterState metadata would ever be null, but in this case it seems safer to assume the endpoint is referenced
-            return true;
+        Set<String> pipelines = InferenceProcessorInfoExtractor.pipelineIdsForResource(state, Set.of(request.getInferenceEndpointId()));
+
+        Set<String> indexesReferencedBySemanticText = extractIndexesReferencingInferenceEndpoints(
+            state.getMetadata(),
+            Set.of(request.getInferenceEndpointId())
+        );
+
+        masterListener.onResponse(
+            new DeleteInferenceEndpointAction.Response(
+                false,
+                pipelines,
+                indexesReferencedBySemanticText,
+                buildErrorString(request.getInferenceEndpointId(), pipelines, indexesReferencedBySemanticText)
+            )
+        );
+    }
+
+    private static String endpointIsReferencedInPipelinesOrIndexes(final ClusterState state, final String inferenceEndpointId) {
+
+        var pipelines = endpointIsReferencedInPipelines(state, inferenceEndpointId);
+        var indexes = endpointIsReferencedInIndex(state, inferenceEndpointId);
+
+        if (pipelines.isEmpty() == false || indexes.isEmpty() == false) {
+            return buildErrorString(inferenceEndpointId, pipelines, indexes);
         }
-        IngestMetadata ingestMetadata = metadata.custom(IngestMetadata.TYPE);
-        if (ingestMetadata == null) {
-            logger.debug("No ingest metadata found in cluster state while attempting to delete inference endpoint");
-        } else {
-            Set<String> modelIdsReferencedByPipelines = InferenceProcessorInfoExtractor.getModelIdsFromInferenceProcessors(ingestMetadata);
-            if (modelIdsReferencedByPipelines.contains(inferenceEndpointId)) {
-                listener.onFailure(
-                    new ElasticsearchStatusException(
-                        "Inference endpoint "
-                            + inferenceEndpointId
-                            + " is referenced by pipelines and cannot be deleted. "
-                            + "Use `force` to delete it anyway, or use `dry_run` to list the pipelines that reference it.",
-                        RestStatus.CONFLICT
-                    )
-                );
-                return true;
-            }
+        return null;
+    }
+
+    private static String buildErrorString(String inferenceEndpointId, Set<String> pipelines, Set<String> indexes) {
+        StringBuilder errorString = new StringBuilder();
+
+        if (pipelines.isEmpty() == false) {
+            errorString.append("Inference endpoint ")
+                .append(inferenceEndpointId)
+                .append(" is referenced by pipelines: ")
+                .append(pipelines)
+                .append(". ")
+                .append("Ensure that no pipelines are using this inference endpoint, ")
+                .append("or use force to ignore this warning and delete the inference endpoint.");
         }
-        return false;
+
+        if (indexes.isEmpty() == false) {
+            errorString.append(" Inference endpoint ")
+                .append(inferenceEndpointId)
+                .append(" is being used in the mapping for indexes: ")
+                .append(indexes)
+                .append(". ")
+                .append("Ensure that no index mappings are using this inference endpoint, ")
+                .append("or use force to ignore this warning and delete the inference endpoint.");
+        }
+
+        return errorString.toString();
+    }
+
+    private static Set<String> endpointIsReferencedInIndex(final ClusterState state, final String inferenceEndpointId) {
+        Set<String> indexes = extractIndexesReferencingInferenceEndpoints(state.getMetadata(), Set.of(inferenceEndpointId));
+        return indexes;
+    }
+
+    private static Set<String> endpointIsReferencedInPipelines(final ClusterState state, final String inferenceEndpointId) {
+        Set<String> modelIdsReferencedByPipelines = InferenceProcessorInfoExtractor.pipelineIdsForResource(
+            state,
+            Set.of(inferenceEndpointId)
+        );
+        return modelIdsReferencedByPipelines;
     }
 
     @Override

+ 3 - 0
x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/50_semantic_text_query_inference_endpoint_changes.yml

@@ -81,6 +81,7 @@ setup:
   - do:
       inference.delete:
         inference_id: sparse-inference-id
+        force: true
 
   - do:
       inference.put:
@@ -119,6 +120,7 @@ setup:
   - do:
       inference.delete:
         inference_id: dense-inference-id
+        force: true
 
   - do:
       inference.put:
@@ -155,6 +157,7 @@ setup:
   - do:
       inference.delete:
         inference_id: dense-inference-id
+        force: true
 
   - do:
       inference.put: