Pārlūkot izejas kodu

[ML] Fix for inference modelId trained model deployment id collision (#105303)

* Fix for inference modelId trained model deployment id collision

* Add check for model already downloaded before put trained model
Max Hniebergall 1 gadu atpakaļ
vecāks
revīzija
89bf949555

+ 11 - 0
server/src/main/java/org/elasticsearch/inference/InferenceService.java

@@ -142,6 +142,17 @@ public interface InferenceService extends Closeable {
         listener.onResponse(true);
     }
 
+    /**
+     * Checks if the modelId has been downloaded to the local Elasticsearch cluster using the trained models API
+     * The default action does nothing except acknowledge the request (false).
+     * Any internal services should Override this method.
+     * @param model
+     * @param listener The listener
+     */
+    default void isModelDownloaded(Model model, ActionListener<Boolean> listener) {
+        listener.onResponse(false);
+    };
+
     /**
      * Optionally test the new model configuration in the inference service.
      * This function should be called when the model is first created, the

+ 2 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelAction.java

@@ -30,6 +30,8 @@ public class PutTrainedModelAction extends ActionType<PutTrainedModelAction.Resp
     public static final String DEFER_DEFINITION_DECOMPRESSION = "defer_definition_decompression";
     public static final PutTrainedModelAction INSTANCE = new PutTrainedModelAction();
     public static final String NAME = "cluster:admin/xpack/ml/inference/put";
+    public static final String MODEL_ALREADY_EXISTS_ERROR_MESSAGE_FRAGMENT =
+        "the model id is the same as the deployment id of a current model deployment";
 
     private PutTrainedModelAction() {
         super(NAME);

+ 2 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/ExceptionsHelper.java

@@ -59,7 +59,8 @@ public class ExceptionsHelper {
 
     public static ResourceNotFoundException missingTrainedModel(String modelId, Exception cause) {
         return new ResourceNotFoundException(
-            "No known trained model with model_id [{}], you may need to create it or load it into the cluster with eland",
+            "Failure due to [{}]. No known trained model with model_id [{}], "
+                + "you may need to create it or load it into the cluster with eland",
             cause,
             modelId
         );

+ 25 - 0
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java

@@ -124,6 +124,31 @@ public class InferenceBaseRestTest extends ESRestTestCase {
         return entityAsMap(response);
     }
 
+    protected Map<String, Object> putE5TrainedModels() throws IOException {
+        var request = new Request("PUT", "_ml/trained_models/.multilingual-e5-small?wait_for_completion=true");
+
+        String body = """
+                {
+                    "input": {
+                    "field_names": ["text_field"]
+                    }
+                }
+            """;
+
+        request.setJsonEntity(body);
+        var response = client().performRequest(request);
+        assertOkOrCreated(response);
+        return entityAsMap(response);
+    }
+
+    protected Map<String, Object> deployE5TrainedModels() throws IOException {
+        var request = new Request("POST", "_ml/trained_models/.multilingual-e5-small/deployment/_start?wait_for=fully_allocated");
+
+        var response = client().performRequest(request);
+        assertOkOrCreated(response);
+        return entityAsMap(response);
+    }
+
     protected Map<String, Object> getModel(String modelId) throws IOException {
         var endpoint = Strings.format("_inference/%s", modelId);
         return getAllModelInternal(endpoint);

+ 15 - 20
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/TextEmbeddingCrudIT.java

@@ -24,29 +24,18 @@ import static org.hamcrest.Matchers.containsString;
 public class TextEmbeddingCrudIT extends InferenceBaseRestTest {
 
     public void testPutE5Small_withNoModelVariant() throws IOException {
-        // Model downloaded automatically & test infer with no model variant
         {
             String inferenceEntityId = randomAlphaOfLength(10).toLowerCase();
-            putTextEmbeddingModel(inferenceEntityId, TaskType.TEXT_EMBEDDING, noModelIdVariantJsonEntity());
-            var models = getTrainedModel("_all");
-            assertThat(models.toString(), containsString("deployment_id=" + inferenceEntityId));
-
-            Map<String, Object> results = inferOnMockService(
-                inferenceEntityId,
-                TaskType.TEXT_EMBEDDING,
-                List.of("hello world", "this is the second document")
+            expectThrows(
+                org.elasticsearch.client.ResponseException.class,
+                () -> putTextEmbeddingModel(inferenceEntityId, noModelIdVariantJsonEntity())
             );
-            assertTrue(((List) ((Map) ((List) results.get("text_embedding")).get(0)).get("embedding")).size() > 1);
-            // there exists embeddings
-            assertTrue(((List) results.get("text_embedding")).size() == 2);
-            // there are two sets of embeddings
-            deleteTextEmbeddingModel(inferenceEntityId);
         }
     }
 
     public void testPutE5Small_withPlatformAgnosticVariant() throws IOException {
         String inferenceEntityId = randomAlphaOfLength(10).toLowerCase();
-        putTextEmbeddingModel(inferenceEntityId, TaskType.TEXT_EMBEDDING, platformAgnosticModelVariantJsonEntity());
+        putTextEmbeddingModel(inferenceEntityId, platformAgnosticModelVariantJsonEntity());
         var models = getTrainedModel("_all");
         assertThat(models.toString(), containsString("deployment_id=" + inferenceEntityId));
 
@@ -65,7 +54,7 @@ public class TextEmbeddingCrudIT extends InferenceBaseRestTest {
     public void testPutE5Small_withPlatformSpecificVariant() throws IOException {
         String inferenceEntityId = randomAlphaOfLength(10).toLowerCase();
         if ("linux-x86_64".equals(Platforms.PLATFORM_NAME)) {
-            putTextEmbeddingModel(inferenceEntityId, TaskType.TEXT_EMBEDDING, platformSpecificModelVariantJsonEntity());
+            putTextEmbeddingModel(inferenceEntityId, platformSpecificModelVariantJsonEntity());
             var models = getTrainedModel("_all");
             assertThat(models.toString(), containsString("deployment_id=" + inferenceEntityId));
 
@@ -82,7 +71,7 @@ public class TextEmbeddingCrudIT extends InferenceBaseRestTest {
         } else {
             expectThrows(
                 org.elasticsearch.client.ResponseException.class,
-                () -> putTextEmbeddingModel(inferenceEntityId, TaskType.TEXT_EMBEDDING, platformSpecificModelVariantJsonEntity())
+                () -> putTextEmbeddingModel(inferenceEntityId, platformSpecificModelVariantJsonEntity())
             );
         }
     }
@@ -91,9 +80,15 @@ public class TextEmbeddingCrudIT extends InferenceBaseRestTest {
         String inferenceEntityId = randomAlphaOfLength(10).toLowerCase();
         expectThrows(
             org.elasticsearch.client.ResponseException.class,
-            () -> putTextEmbeddingModel(inferenceEntityId, TaskType.TEXT_EMBEDDING, fakeModelVariantJsonEntity())
+            () -> putTextEmbeddingModel(inferenceEntityId, fakeModelVariantJsonEntity())
         );
+    }
 
+    public void testPutE5WithTrainedModelAndInference() throws IOException {
+        putE5TrainedModels();
+        deployE5TrainedModels();
+        putTextEmbeddingModel("an-e5-deployment", platformAgnosticModelVariantJsonEntity());
+        getTrainedModel("an-e5-deployment");
     }
 
     private Map<String, Object> deleteTextEmbeddingModel(String inferenceEntityId) throws IOException {
@@ -104,8 +99,8 @@ public class TextEmbeddingCrudIT extends InferenceBaseRestTest {
         return entityAsMap(response);
     }
 
-    private Map<String, Object> putTextEmbeddingModel(String inferenceEntityId, TaskType taskType, String jsonEntity) throws IOException {
-        var endpoint = Strings.format("_inference/%s/%s", taskType, inferenceEntityId);
+    private Map<String, Object> putTextEmbeddingModel(String inferenceEntityId, String jsonEntity) throws IOException {
+        var endpoint = Strings.format("_inference/%s/%s", TaskType.TEXT_EMBEDDING, inferenceEntityId);
         var request = new Request("PUT", endpoint);
 
         request.setJsonEntity(jsonEntity);

+ 23 - 15
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java

@@ -210,23 +210,31 @@ public class TransportPutInferenceModelAction extends TransportMasterNodeAction<
     }
 
     private void putAndStartModel(InferenceService service, Model model, ActionListener<PutInferenceModelAction.Response> finalListener) {
-        SubscribableListener.<Boolean>newForked((listener1) -> { service.putModel(model, listener1); }).<
-            PutInferenceModelAction.Response>andThen((listener2, modelDidPut) -> {
-                if (modelDidPut) {
-                    if (skipValidationAndStart) {
-                        listener2.onResponse(new PutInferenceModelAction.Response(model.getConfigurations()));
-                    } else {
-                        service.start(
-                            model,
-                            listener2.delegateFailureAndWrap(
-                                (l3, ok) -> l3.onResponse(new PutInferenceModelAction.Response(model.getConfigurations()))
-                            )
-                        );
-                    }
+        SubscribableListener.<Boolean>newForked(listener -> {
+            var errorCatchingListener = ActionListener.<Boolean>wrap(listener::onResponse, e -> { listener.onResponse(false); });
+            service.isModelDownloaded(model, errorCatchingListener);
+        }).<Boolean>andThen((listener, isDownloaded) -> {
+            if (isDownloaded == false) {
+                service.putModel(model, listener);
+            } else {
+                listener.onResponse(true);
+            }
+        }).<PutInferenceModelAction.Response>andThen((listener, modelDidPut) -> {
+            if (modelDidPut) {
+                if (skipValidationAndStart) {
+                    listener.onResponse(new PutInferenceModelAction.Response(model.getConfigurations()));
                 } else {
-                    logger.warn("Failed to put model [{}]", model.getInferenceEntityId());
+                    service.start(
+                        model,
+                        listener.delegateFailureAndWrap(
+                            (l3, ok) -> l3.onResponse(new PutInferenceModelAction.Response(model.getConfigurations()))
+                        )
+                    );
                 }
-            }).addListener(finalListener);
+            } else {
+                logger.warn("Failed to put model [{}]", model.getInferenceEntityId());
+            }
+        }).addListener(finalListener);
     }
 
     private Map<String, Object> requestToMap(PutInferenceModelAction.Request request) throws IOException {

+ 26 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalService.java

@@ -31,6 +31,7 @@ import org.elasticsearch.xpack.core.ClientHelper;
 import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults;
 import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
 import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
+import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
 import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
 import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
 import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
@@ -347,6 +348,31 @@ public class ElserInternalService implements InferenceService {
         }
     }
 
+    @Override
+    public void isModelDownloaded(Model model, ActionListener<Boolean> listener) {
+        ActionListener<GetTrainedModelsAction.Response> getModelsResponseListener = listener.delegateFailure((delegate, response) -> {
+            if (response.getResources().count() < 1) {
+                delegate.onResponse(Boolean.FALSE);
+            } else {
+                delegate.onResponse(Boolean.TRUE);
+            }
+        });
+
+        if (model instanceof ElserInternalModel elserModel) {
+            String modelId = elserModel.getServiceSettings().getModelId();
+            GetTrainedModelsAction.Request getRequest = new GetTrainedModelsAction.Request(modelId);
+            executeAsyncWithOrigin(client, INFERENCE_ORIGIN, GetTrainedModelsAction.INSTANCE, getRequest, getModelsResponseListener);
+        } else {
+            listener.onFailure(
+                new IllegalArgumentException(
+                    "Can not download model automatically for ["
+                        + model.getConfigurations().getInferenceEntityId()
+                        + "] you may need to download it through the trained models API or with eland."
+                )
+            );
+        }
+    }
+
     private static ElserMlNodeTaskSettings taskSettingsFromMap(TaskType taskType, Map<String, Object> config) {
         if (taskType != TaskType.SPARSE_EMBEDDING) {
             throw new ElasticsearchStatusException(TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), RestStatus.BAD_REQUEST);

+ 34 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/textembedding/TextEmbeddingInternalService.java

@@ -314,8 +314,13 @@ public class TextEmbeddingInternalService implements InferenceService {
                 INFERENCE_ORIGIN,
                 PutTrainedModelAction.INSTANCE,
                 putRequest,
-                listener.delegateFailure((l, r) -> {
-                    l.onResponse(Boolean.TRUE);
+                ActionListener.wrap(response -> listener.onResponse(Boolean.TRUE), e -> {
+                    if (e instanceof ElasticsearchStatusException esException
+                        && esException.getMessage().contains(PutTrainedModelAction.MODEL_ALREADY_EXISTS_ERROR_MESSAGE_FRAGMENT)) {
+                        listener.onResponse(Boolean.TRUE);
+                    } else {
+                        listener.onFailure(e);
+                    }
                 })
             );
         } else if (model instanceof CustomElandModel elandModel) {
@@ -333,6 +338,33 @@ public class TextEmbeddingInternalService implements InferenceService {
         }
     }
 
+    @Override
+    public void isModelDownloaded(Model model, ActionListener<Boolean> listener) {
+        ActionListener<GetTrainedModelsAction.Response> getModelsResponseListener = listener.delegateFailure((delegate, response) -> {
+            if (response.getResources().count() < 1) {
+                delegate.onResponse(Boolean.FALSE);
+            } else {
+                delegate.onResponse(Boolean.TRUE);
+            }
+        });
+
+        if (model instanceof TextEmbeddingModel == false) {
+            listener.onFailure(notTextEmbeddingModelException(model));
+        } else if (model.getServiceSettings() instanceof InternalServiceSettings internalServiceSettings) {
+            String modelId = internalServiceSettings.getModelId();
+            GetTrainedModelsAction.Request getRequest = new GetTrainedModelsAction.Request(modelId);
+            executeAsyncWithOrigin(client, INFERENCE_ORIGIN, GetTrainedModelsAction.INSTANCE, getRequest, getModelsResponseListener);
+        } else {
+            listener.onFailure(
+                new IllegalArgumentException(
+                    "Unable to determine supported model for ["
+                        + model.getConfigurations().getInferenceEntityId()
+                        + "] please verify the request and submit a bug report if necessary."
+                )
+            );
+        }
+    }
+
     private static IllegalStateException notTextEmbeddingModelException(Model model) {
         return new IllegalStateException(
             "Error starting model, [" + model.getConfigurations().getInferenceEntityId() + "] is not a text embedding model"

+ 2 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java

@@ -81,6 +81,7 @@ import java.util.concurrent.atomic.AtomicReference;
 
 import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
 import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
+import static org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction.MODEL_ALREADY_EXISTS_ERROR_MESSAGE_FRAGMENT;
 
 public class TransportPutTrainedModelAction extends TransportMasterNodeAction<Request, Response> {
 
@@ -230,7 +231,7 @@ public class TransportPutTrainedModelAction extends TransportMasterNodeAction<Re
         if (TrainedModelAssignmentMetadata.fromState(state).hasDeployment(trainedModelConfig.getModelId())) {
             finalResponseListener.onFailure(
                 ExceptionsHelper.badRequestException(
-                    "Cannot create model [{}] the id is the same as an current model deployment",
+                    "Cannot create model [{}] " + MODEL_ALREADY_EXISTS_ERROR_MESSAGE_FRAGMENT,
                     config.getModelId()
                 )
             );

+ 1 - 1
x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/3rd_party_deployment.yml

@@ -688,7 +688,7 @@ setup:
   - match: { assignment.assignment_state: started }
 
   - do:
-      catch: /Cannot create model \[test_model_deployment\] the id is the same as an current model deployment/
+      catch: /Cannot create model \[test_model_deployment\] the model id is the same as the deployment id of a current model deployment/
       ml.put_trained_model:
         model_id: test_model_deployment
         body: >