Browse Source

[ML] allow model_aliases to be used with Pytorch trained models (#91296)

This adds model_alias support for native pytorch models.

Model aliases can be used in `_infer` or within the inference processor. This way the alias can be atomically changed without down time to another deployed model. 

Restrictions:
 - Model alias changes need to be done between two models of the same kind (e.g. pytorch -> pytorch)
 - Model alias change is not allowed between a model that is deployed to a model that is not
 - Model alias change is not allowed between a model that deployed AND allocated to a model that is deployed but NOT allocated (not assigned to any nodes).
 - A deployment cannot be stopped (without supplying the `force` parameter) when the model has a model alias that is used by a pipeline.


closes: https://github.com/elastic/elasticsearch/issues/90960
Benjamin Trent 3 years ago
parent
commit
2e8bf33b0a
15 changed files with 227 additions and 61 deletions
  1. 5 0
      docs/changelog/91296.yaml
  2. 1 1
      docs/reference/ml/trained-models/apis/infer-trained-model-deployment.asciidoc
  3. 6 6
      docs/reference/ml/trained-models/apis/infer-trained-model.asciidoc
  4. 12 5
      docs/reference/ml/trained-models/apis/put-trained-models-aliases.asciidoc
  5. 3 3
      docs/reference/ml/trained-models/apis/stop-trained-model-deployment.asciidoc
  6. 5 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferTrainedModelDeploymentAction.java
  7. 68 0
      x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java
  8. 1 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteTrainedModelAction.java
  9. 15 7
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java
  10. 10 4
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java
  11. 69 4
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAliasAction.java
  12. 27 9
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStopTrainedModelDeploymentAction.java
  13. 5 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ModelAliasMetadata.java
  14. 0 0
      x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/3rd_party_deployment.yml
  15. 0 20
      x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/inference_crud.yml

+ 5 - 0
docs/changelog/91296.yaml

@@ -0,0 +1,5 @@
+pr: 91296
+summary: Allow `model_aliases` to be used with Pytorch trained models
+area: Machine Learning
+type: enhancement
+issues: []

+ 1 - 1
docs/reference/ml/trained-models/apis/infer-trained-model-deployment.asciidoc

@@ -31,7 +31,7 @@ deprecated::[8.3.0,Replaced by <<infer-trained-model>>.]
 
 
 `<model_id>`::
 `<model_id>`::
 (Required, string)
 (Required, string)
-include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id-or-alias]
 
 
 [[infer-trained-model-deployment-query-params]]
 [[infer-trained-model-deployment-query-params]]
 == {api-query-parms-title}
 == {api-query-parms-title}

+ 6 - 6
docs/reference/ml/trained-models/apis/infer-trained-model.asciidoc

@@ -31,7 +31,7 @@ beta::[]
 
 
 `<model_id>`::
 `<model_id>`::
 (Required, string)
 (Required, string)
-include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id-or-alias]
 
 
 [[infer-trained-model-query-params]]
 [[infer-trained-model-query-params]]
 == {api-query-parms-title}
 == {api-query-parms-title}
@@ -629,7 +629,7 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenizati
 
 
 The response depends on the kind of model.
 The response depends on the kind of model.
 
 
-For example, for {lang-ident} the response is the predicted language and the 
+For example, for {lang-ident} the response is the predicted language and the
 score:
 score:
 
 
 [source,console]
 [source,console]
@@ -658,7 +658,7 @@ Here are the results predicting english with a high probability.
 // NOTCONSOLE
 // NOTCONSOLE
 
 
 
 
-When it is a text classification model, the response is the score and predicted 
+When it is a text classification model, the response is the score and predicted
 classification.
 classification.
 
 
 For example:
 For example:
@@ -822,8 +822,8 @@ The API returns a response similar to the following:
 ----
 ----
 // NOTCONSOLE
 // NOTCONSOLE
 
 
-Text similarity models require at least two sequences of text to compare. It's 
-possible to provide multiple strings of text to compare to another text 
+Text similarity models require at least two sequences of text to compare. It's
+possible to provide multiple strings of text to compare to another text
 sequence:
 sequence:
 
 
 [source,console]
 [source,console]
@@ -840,7 +840,7 @@ POST _ml/trained_models/cross-encoder__ms-marco-tinybert-l-2-v2/_infer
 --------------------------------------------------
 --------------------------------------------------
 // TEST[skip:TBD]
 // TEST[skip:TBD]
 
 
-The response contains the prediction for every string that is compared to the 
+The response contains the prediction for every string that is compared to the
 text provided in the `text_similarity`.`text` field:
 text provided in the `text_similarity`.`text` field:
 
 
 [source,console-result]
 [source,console-result]

+ 12 - 5
docs/reference/ml/trained-models/apis/put-trained-models-aliases.asciidoc

@@ -20,7 +20,7 @@ A trained model alias is a logical name used to reference a single trained model
 [[ml-put-trained-models-aliases-prereq]]
 [[ml-put-trained-models-aliases-prereq]]
 == {api-prereq-title}
 == {api-prereq-title}
 
 
-Requires the `manage_ml` cluster privilege. This privilege is included in the 
+Requires the `manage_ml` cluster privilege. This privilege is included in the
 `machine_learning_admin` built-in role.
 `machine_learning_admin` built-in role.
 
 
 
 
@@ -34,11 +34,18 @@ and processors.
 An alias must be unique and refer to only a single trained model. However,
 An alias must be unique and refer to only a single trained model. However,
 you can have multiple aliases for each trained model.
 you can have multiple aliases for each trained model.
 
 
-If you use this API to update an alias such that it references a different
-trained model ID and the model uses a different type of {dfanalytics}, an error
-occurs. For example, this situation occurs if you have a trained model for
+API Restrictions:
++
+--
+* You are not allowed to update an alias such that it references a different
+trained model ID and the model uses a different type of {dfanalytics}. For example,
+this situation occurs if you have a trained model for
 {reganalysis} and a trained model for {classanalysis}; you cannot reassign an
 {reganalysis} and a trained model for {classanalysis}; you cannot reassign an
 alias from one type of trained model to another.
 alias from one type of trained model to another.
+* You cannot update an alias from a `pytorch` model and a {dfanalytics} model.
+* You cannot update the alias from a deployed `pytorch` model to one
+not currently deployed.
+--
 
 
 If you use this API to update an alias and there are very few input fields in
 If you use this API to update an alias and there are very few input fields in
 common between the old and new trained models for the model alias, the API
 common between the old and new trained models for the model alias, the API
@@ -62,7 +69,7 @@ The identifier for the trained model that the alias refers to.
 (Optional, boolean)
 (Optional, boolean)
 Specifies whether the alias gets reassigned to the specified trained model if it
 Specifies whether the alias gets reassigned to the specified trained model if it
 is already assigned to a different model. If the alias is already assigned and
 is already assigned to a different model. If the alias is already assigned and
-this parameter is `false`, the API returns an error. Defaults to `false`. 
+this parameter is `false`, the API returns an error. Defaults to `false`.
 
 
 [[ml-put-trained-models-aliases-example]]
 [[ml-put-trained-models-aliases-example]]
 == {api-examples-title}
 == {api-examples-title}

+ 3 - 3
docs/reference/ml/trained-models/apis/stop-trained-model-deployment.asciidoc

@@ -18,7 +18,7 @@ beta::[]
 [[stop-trained-model-deployment-prereq]]
 [[stop-trained-model-deployment-prereq]]
 == {api-prereq-title}
 == {api-prereq-title}
 
 
-Requires the `manage_ml` cluster privilege. This privilege is included in the 
+Requires the `manage_ml` cluster privilege. This privilege is included in the
 `machine_learning_admin` built-in role.
 `machine_learning_admin` built-in role.
 
 
 [[stop-trained-model-deployment-desc]]
 [[stop-trained-model-deployment-desc]]
@@ -42,8 +42,8 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=allow-no-match-deployments]
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=allow-no-match-deployments]
 
 
 `force`::
 `force`::
-(Optional, Boolean) If true, the deployment is stopped even if it is referenced
-by ingest pipelines. You can't use these pipelines until you restart the model
+(Optional, Boolean) If true, the deployment is stopped even if it or one of its model aliases
+is referenced by ingest pipelines. You can't use these pipelines until you restart the model
 deployment.
 deployment.
 
 
 ////
 ////

+ 5 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferTrainedModelDeploymentAction.java

@@ -86,7 +86,7 @@ public class InferTrainedModelDeploymentAction extends ActionType<InferTrainedMo
             return builder;
             return builder;
         }
         }
 
 
-        private final String modelId;
+        private String modelId;
         private final List<Map<String, Object>> docs;
         private final List<Map<String, Object>> docs;
         private final InferenceConfigUpdate update;
         private final InferenceConfigUpdate update;
         private final TimeValue inferenceTimeout;
         private final TimeValue inferenceTimeout;
@@ -165,6 +165,10 @@ public class InferTrainedModelDeploymentAction extends ActionType<InferTrainedMo
             return inferenceTimeout == null ? DEFAULT_TIMEOUT : inferenceTimeout;
             return inferenceTimeout == null ? DEFAULT_TIMEOUT : inferenceTimeout;
         }
         }
 
 
+        public void setModelId(String modelId) {
+            this.modelId = modelId;
+        }
+
         /**
         /**
          * This is always null as we want the inference call to handle the timeout, not the tasks framework
          * This is always null as we want the inference call to handle the timeout, not the tasks framework
          * @return null
          * @return null

+ 68 - 0
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java

@@ -585,6 +585,74 @@ public class PyTorchModelIT extends PyTorchModelRestTestCase {
         stopDeployment(modelId, true);
         stopDeployment(modelId, true);
     }
     }
 
 
+    public void testStopWithModelAliasUsedDeploymentByIngestProcessor() throws IOException {
+        String modelId = "test_stop_model_alias_used_deployment_by_ingest_processor";
+        String modelAlias = "used_model_alias";
+        createPassThroughModel(modelId);
+        putModelDefinition(modelId);
+        putVocabulary(List.of("these", "are", "my", "words"), modelId);
+        startDeployment(modelId);
+        client().performRequest(new Request("PUT", formatted("_ml/trained_models/%s/model_aliases/%s", modelId, modelAlias)));
+
+        client().performRequest(putPipeline("my_pipeline", formatted("""
+            {
+              "processors": [
+                {
+                  "inference": {
+                    "model_id": "%s"
+                  }
+                }
+              ]
+            }""", modelAlias)));
+        ResponseException ex = expectThrows(ResponseException.class, () -> stopDeployment(modelId));
+        assertThat(ex.getResponse().getStatusLine().getStatusCode(), equalTo(409));
+        assertThat(
+            EntityUtils.toString(ex.getResponse().getEntity()),
+            containsString(
+                "Cannot stop deployment for model [test_stop_model_alias_used_deployment_by_ingest_processor] as it has a "
+                    + "model_alias [used_model_alias] that is still referenced"
+                    + " by ingest processors; use force to stop the deployment"
+            )
+        );
+        stopDeployment(modelId, true);
+    }
+
+    public void testInferenceProcessorWithModelAlias() throws IOException {
+        String modelId = "test_model_alias_infer";
+        String modelAlias = "pytorch_model_alias";
+        createPassThroughModel(modelId);
+        putModelDefinition(modelId);
+        putVocabulary(List.of("these", "are", "my", "words"), modelId);
+        startDeployment(modelId);
+        client().performRequest(new Request("PUT", formatted("_ml/trained_models/%s/model_aliases/%s", modelId, modelAlias)));
+
+        String source = formatted("""
+            {
+              "pipeline": {
+                "processors": [
+                  {
+                    "inference": {
+                      "model_id": "%s"
+                    }
+                  }
+                ]
+              },
+              "docs": [
+                {"_source": {"input": "my words"}}]
+            }
+            """, modelAlias);
+
+        String response = EntityUtils.toString(client().performRequest(simulateRequest(source)).getEntity());
+        assertThat(
+            response,
+            allOf(
+                containsString("\"ml\":{\"inference\":{\"predicted_value\":[[1.0,1.0]]"),
+                containsString(modelId),
+                not(containsString("warning"))
+            )
+        );
+    }
+
     public void testPipelineWithBadProcessor() throws IOException {
     public void testPipelineWithBadProcessor() throws IOException {
         String model = "deployed";
         String model = "deployed";
         createPassThroughModel(model);
         createPassThroughModel(model);

+ 1 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteTrainedModelAction.java

@@ -186,7 +186,7 @@ public class TransportDeleteTrainedModelAction extends AcknowledgedTransportMast
         return allReferencedModelKeys;
         return allReferencedModelKeys;
     }
     }
 
 
-    private static List<String> getModelAliases(ClusterState clusterState, String modelId) {
+    static List<String> getModelAliases(ClusterState clusterState, String modelId) {
         final ModelAliasMetadata currentMetadata = ModelAliasMetadata.fromState(clusterState);
         final ModelAliasMetadata currentMetadata = ModelAliasMetadata.fromState(clusterState);
         final List<String> modelAliases = new ArrayList<>();
         final List<String> modelAliases = new ArrayList<>();
         for (Map.Entry<String, ModelAliasMetadata.ModelAliasEntry> modelAliasEntry : currentMetadata.modelAliases().entrySet()) {
         for (Map.Entry<String, ModelAliasMetadata.ModelAliasEntry> modelAliasEntry : currentMetadata.modelAliases().entrySet()) {

+ 15 - 7
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java

@@ -29,12 +29,14 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
 import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentState;
 import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentState;
 import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
 import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata;
 import org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentMetadata;
 import org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentMetadata;
 import org.elasticsearch.xpack.ml.inference.deployment.NlpInferenceInput;
 import org.elasticsearch.xpack.ml.inference.deployment.NlpInferenceInput;
 import org.elasticsearch.xpack.ml.inference.deployment.TrainedModelDeploymentTask;
 import org.elasticsearch.xpack.ml.inference.deployment.TrainedModelDeploymentTask;
 import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
 import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
 
 
 import java.util.List;
 import java.util.List;
+import java.util.Optional;
 
 
 import static org.elasticsearch.core.Strings.format;
 import static org.elasticsearch.core.Strings.format;
 
 
@@ -75,14 +77,18 @@ public class TransportInferTrainedModelDeploymentAction extends TransportTasksAc
         ActionListener<InferTrainedModelDeploymentAction.Response> listener
         ActionListener<InferTrainedModelDeploymentAction.Response> listener
     ) {
     ) {
         TaskId taskId = new TaskId(clusterService.localNode().getId(), task.getId());
         TaskId taskId = new TaskId(clusterService.localNode().getId(), task.getId());
-        final String modelId = request.getModelId();
+        // Update the requests model ID if it's an alias
+        Optional.ofNullable(ModelAliasMetadata.fromState(clusterService.state()).getModelId(request.getModelId()))
+            .ifPresent(request::setModelId);
         // We need to check whether there is at least an assigned task here, otherwise we cannot redirect to the
         // We need to check whether there is at least an assigned task here, otherwise we cannot redirect to the
         // node running the job task.
         // node running the job task.
-        TrainedModelAssignment assignment = TrainedModelAssignmentMetadata.assignmentForModelId(clusterService.state(), modelId)
-            .orElse(null);
+        TrainedModelAssignment assignment = TrainedModelAssignmentMetadata.assignmentForModelId(
+            clusterService.state(),
+            request.getModelId()
+        ).orElse(null);
         if (assignment == null) {
         if (assignment == null) {
             // If there is no assignment, verify the model even exists so that we can provide a nicer error message
             // If there is no assignment, verify the model even exists so that we can provide a nicer error message
-            provider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.empty(), taskId, ActionListener.wrap(config -> {
+            provider.getTrainedModel(request.getModelId(), GetTrainedModelsAction.Includes.empty(), taskId, ActionListener.wrap(config -> {
                 if (config.getModelType() != TrainedModelType.PYTORCH) {
                 if (config.getModelType() != TrainedModelType.PYTORCH) {
                     listener.onFailure(
                     listener.onFailure(
                         ExceptionsHelper.badRequestException(
                         ExceptionsHelper.badRequestException(
@@ -93,13 +99,13 @@ public class TransportInferTrainedModelDeploymentAction extends TransportTasksAc
                     );
                     );
                     return;
                     return;
                 }
                 }
-                String message = "Trained model [" + modelId + "] is not deployed";
+                String message = "Trained model [" + request.getModelId() + "] is not deployed";
                 listener.onFailure(ExceptionsHelper.conflictStatusException(message));
                 listener.onFailure(ExceptionsHelper.conflictStatusException(message));
             }, listener::onFailure));
             }, listener::onFailure));
             return;
             return;
         }
         }
         if (assignment.getAssignmentState() == AssignmentState.STOPPING) {
         if (assignment.getAssignmentState() == AssignmentState.STOPPING) {
-            String message = "Trained model [" + modelId + "] is STOPPING";
+            String message = "Trained model [" + request.getModelId() + "] is STOPPING";
             listener.onFailure(ExceptionsHelper.conflictStatusException(message));
             listener.onFailure(ExceptionsHelper.conflictStatusException(message));
             return;
             return;
         }
         }
@@ -114,7 +120,9 @@ public class TransportInferTrainedModelDeploymentAction extends TransportTasksAc
             }, listener::onFailure));
             }, listener::onFailure));
         }, () -> {
         }, () -> {
             logger.trace(() -> format("[%s] model not allocated to any node [%s]", assignment.getModelId()));
             logger.trace(() -> format("[%s] model not allocated to any node [%s]", assignment.getModelId()));
-            listener.onFailure(ExceptionsHelper.conflictStatusException("Trained model [" + modelId + "] is not allocated to any nodes"));
+            listener.onFailure(
+                ExceptionsHelper.conflictStatusException("Trained model [" + request.getModelId() + "] is not allocated to any nodes")
+            );
         });
         });
     }
     }
 
 

+ 10 - 4
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java

@@ -31,6 +31,7 @@ import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
 import org.elasticsearch.xpack.ml.MachineLearning;
 import org.elasticsearch.xpack.ml.MachineLearning;
+import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata;
 import org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentMetadata;
 import org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentMetadata;
 import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel;
 import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel;
 import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
 import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
@@ -39,6 +40,7 @@ import org.elasticsearch.xpack.ml.utils.TypedChainTaskExecutor;
 
 
 import java.util.Collections;
 import java.util.Collections;
 import java.util.Map;
 import java.util.Map;
+import java.util.Optional;
 
 
 import static org.elasticsearch.core.Strings.format;
 import static org.elasticsearch.core.Strings.format;
 import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
 import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
@@ -128,8 +130,11 @@ public class TransportInternalInferModelAction extends HandledTransportAction<Re
         TaskId parentTaskId,
         TaskId parentTaskId,
         ActionListener<Response> listener
         ActionListener<Response> listener
     ) {
     ) {
-        if (isAllocatedModel(request.getModelId())) {
-            inferAgainstAllocatedModel(request, responseBuilder, parentTaskId, listener);
+        String concreteModelId = Optional.ofNullable(ModelAliasMetadata.fromState(clusterService.state()).getModelId(request.getModelId()))
+            .orElse(request.getModelId());
+        if (isAllocatedModel(concreteModelId)) {
+            // It is important to use the resolved model ID here as the alias could change between transport calls.
+            inferAgainstAllocatedModel(request, concreteModelId, responseBuilder, parentTaskId, listener);
         } else {
         } else {
             getModelAndInfer(request, responseBuilder, parentTaskId, (CancellableTask) task, listener);
             getModelAndInfer(request, responseBuilder, parentTaskId, (CancellableTask) task, listener);
         }
         }
@@ -176,6 +181,7 @@ public class TransportInternalInferModelAction extends HandledTransportAction<Re
 
 
     private void inferAgainstAllocatedModel(
     private void inferAgainstAllocatedModel(
         Request request,
         Request request,
+        String concreteModelId,
         Response.Builder responseBuilder,
         Response.Builder responseBuilder,
         TaskId parentTaskId,
         TaskId parentTaskId,
         ActionListener<Response> listener
         ActionListener<Response> listener
@@ -191,7 +197,7 @@ public class TransportInternalInferModelAction extends HandledTransportAction<Re
             .forEach(
             .forEach(
                 stringObjectMap -> typedChainTaskExecutor.add(
                 stringObjectMap -> typedChainTaskExecutor.add(
                     chainedTask -> inferSingleDocAgainstAllocatedModel(
                     chainedTask -> inferSingleDocAgainstAllocatedModel(
-                        request.getModelId(),
+                        concreteModelId,
                         request.getTimeout(),
                         request.getTimeout(),
                         request.getUpdate(),
                         request.getUpdate(),
                         stringObjectMap,
                         stringObjectMap,
@@ -204,7 +210,7 @@ public class TransportInternalInferModelAction extends HandledTransportAction<Re
         typedChainTaskExecutor.execute(
         typedChainTaskExecutor.execute(
             ActionListener.wrap(
             ActionListener.wrap(
                 inferenceResults -> listener.onResponse(
                 inferenceResults -> listener.onResponse(
-                    responseBuilder.setInferenceResults(inferenceResults).setModelId(request.getModelId()).build()
+                    responseBuilder.setInferenceResults(inferenceResults).setModelId(concreteModelId).build()
                 ),
                 ),
                 listener::onFailure
                 listener::onFailure
             )
             )

+ 69 - 4
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAliasAction.java

@@ -37,15 +37,20 @@ import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
 import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAliasAction;
 import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAliasAction;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
+import org.elasticsearch.xpack.core.ml.inference.assignment.AllocationStatus;
+import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
 import org.elasticsearch.xpack.core.ml.job.messages.Messages;
 import org.elasticsearch.xpack.core.ml.job.messages.Messages;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata;
 import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata;
+import org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentMetadata;
 import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
 import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
 import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
 import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
 
 
 import java.util.HashMap;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.HashSet;
 import java.util.Map;
 import java.util.Map;
+import java.util.Objects;
+import java.util.Optional;
 import java.util.Set;
 import java.util.Set;
 import java.util.function.Predicate;
 import java.util.function.Predicate;
 
 
@@ -141,10 +146,6 @@ public class TransportPutTrainedModelAliasAction extends AcknowledgedTransportMa
                 listener.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING));
                 listener.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING));
                 return;
                 return;
             }
             }
-            if (newModel.getModelType() == TrainedModelType.PYTORCH) {
-                listener.onFailure(ExceptionsHelper.badRequestException("model_alias is not supported on pytorch models"));
-                return;
-            }
             // if old model is null, none of these validations matter
             // if old model is null, none of these validations matter
             // we should still allow reassignment even if the old model was some how deleted and the alias still refers to it
             // we should still allow reassignment even if the old model was some how deleted and the alias still refers to it
             if (oldModel != null) {
             if (oldModel != null) {
@@ -166,6 +167,70 @@ public class TransportPutTrainedModelAliasAction extends AcknowledgedTransportMa
                     }
                     }
                 }
                 }
 
 
+                if (Objects.equals(newModel.getModelType(), oldModel.getModelType()) == false) {
+                    listener.onFailure(
+                        ExceptionsHelper.badRequestException(
+                            "cannot reassign model_alias [{}] to model [{}] with type [{}] from model [{}] with type [{}]",
+                            request.getModelAlias(),
+                            newModel.getModelId(),
+                            Optional.ofNullable(newModel.getModelType()).orElse(TrainedModelType.TREE_ENSEMBLE).toString(),
+                            oldModel.getModelId(),
+                            Optional.ofNullable(oldModel.getModelType()).orElse(TrainedModelType.TREE_ENSEMBLE).toString()
+                        )
+                    );
+                    return;
+                }
+
+                // If we are reassigning Pytorch models, we need to validate assignments are acceptable.
+                if (newModel.getModelType() == TrainedModelType.PYTORCH) {
+                    Optional<TrainedModelAssignment> oldAssignment = TrainedModelAssignmentMetadata.assignmentForModelId(state, oldModelId);
+                    Optional<TrainedModelAssignment> newAssignment = TrainedModelAssignmentMetadata.assignmentForModelId(
+                        state,
+                        newModel.getModelId()
+                    );
+                    // Old model is currently deployed
+                    if (oldAssignment.isPresent()) {
+                        // disallow changing the model alias from a deployed model to an undeployed model
+                        if (newAssignment.isEmpty()) {
+                            listener.onFailure(
+                                ExceptionsHelper.badRequestException(
+                                    "cannot reassign model_alias [{}] to model [{}] from model [{}] as it is not yet deployed",
+                                    request.getModelAlias(),
+                                    newModel.getModelId(),
+                                    oldModel.getModelId()
+                                )
+                            );
+                            return;
+                        } else {
+                            Optional<AllocationStatus> oldAllocationStatus = oldAssignment.map(
+                                TrainedModelAssignment::calculateAllocationStatus
+                            ).get();
+                            // Old model is deployed and its allocation status is NOT "stopping" or "starting"
+                            if (oldAllocationStatus.isPresent()
+                                && oldAllocationStatus.get()
+                                    .calculateState()
+                                    .isAnyOf(AllocationStatus.State.FULLY_ALLOCATED, AllocationStatus.State.STARTED)) {
+                                Optional<AllocationStatus> newAllocationStatus = newAssignment.map(
+                                    TrainedModelAssignment::calculateAllocationStatus
+                                ).get();
+                                if (newAllocationStatus.isEmpty()
+                                    || newAllocationStatus.get().calculateState().equals(AllocationStatus.State.STARTING)) {
+                                    listener.onFailure(
+                                        ExceptionsHelper.badRequestException(
+                                            "cannot reassign model_alias [{}] to model [{}] "
+                                                + " from model [{}] as it is not yet allocated to any nodes",
+                                            request.getModelAlias(),
+                                            newModel.getModelId(),
+                                            oldModel.getModelId()
+                                        )
+                                    );
+                                    return;
+                                }
+                            }
+                        }
+                    }
+                }
+
                 Set<String> oldInputFields = new HashSet<>(oldModel.getInput().getFieldNames());
                 Set<String> oldInputFields = new HashSet<>(oldModel.getInput().getFieldNames());
                 Set<String> newInputFields = new HashSet<>(newModel.getInput().getFieldNames());
                 Set<String> newInputFields = new HashSet<>(newModel.getInput().getFieldNames());
                 // TODO should we fail in this case???
                 // TODO should we fail in this case???

+ 27 - 9
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStopTrainedModelDeploymentAction.java

@@ -50,6 +50,7 @@ import java.util.Set;
 
 
 import static org.elasticsearch.core.Strings.format;
 import static org.elasticsearch.core.Strings.format;
 import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
 import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
+import static org.elasticsearch.xpack.ml.action.TransportDeleteTrainedModelAction.getModelAliases;
 import static org.elasticsearch.xpack.ml.action.TransportDeleteTrainedModelAction.getReferencedModelKeys;
 import static org.elasticsearch.xpack.ml.action.TransportDeleteTrainedModelAction.getReferencedModelKeys;
 
 
 /**
 /**
@@ -138,15 +139,32 @@ public class TransportStopTrainedModelDeploymentAction extends TransportTasksAct
             IngestMetadata currentIngestMetadata = state.metadata().custom(IngestMetadata.TYPE);
             IngestMetadata currentIngestMetadata = state.metadata().custom(IngestMetadata.TYPE);
             Set<String> referencedModels = getReferencedModelKeys(currentIngestMetadata, ingestService);
             Set<String> referencedModels = getReferencedModelKeys(currentIngestMetadata, ingestService);
 
 
-            if (request.isForce() == false && referencedModels.contains(modelId)) {
-                listener.onFailure(
-                    new ElasticsearchStatusException(
-                        "Cannot stop deployment for model [{}] as it is referenced by ingest processors; use force to stop the deployment",
-                        RestStatus.CONFLICT,
-                        modelId
-                    )
-                );
-                return;
+            if (request.isForce() == false) {
+                if (referencedModels.contains(modelId)) {
+                    listener.onFailure(
+                        new ElasticsearchStatusException(
+                            "Cannot stop deployment for model [{}] as it is referenced by ingest processors; "
+                                + "use force to stop the deployment",
+                            RestStatus.CONFLICT,
+                            modelId
+                        )
+                    );
+                    return;
+                }
+                List<String> modelAliases = getModelAliases(state, modelId);
+                Optional<String> referencedModelAlias = modelAliases.stream().filter(referencedModels::contains).findFirst();
+                if (referencedModelAlias.isPresent()) {
+                    listener.onFailure(
+                        new ElasticsearchStatusException(
+                            "Cannot stop deployment for model [{}] as it has a model_alias [{}] that is still referenced"
+                                + " by ingest processors; use force to stop the deployment",
+                            RestStatus.CONFLICT,
+                            modelId,
+                            referencedModelAlias.get()
+                        )
+                    );
+                    return;
+                }
             }
             }
 
 
             // NOTE, should only run on Master node
             // NOTE, should only run on Master node

+ 5 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ModelAliasMetadata.java

@@ -227,5 +227,10 @@ public class ModelAliasMetadata implements Metadata.Custom {
         public int hashCode() {
         public int hashCode() {
             return Objects.hash(modelId);
             return Objects.hash(modelId);
         }
         }
+
+        @Override
+        public String toString() {
+            return "ModelAliasEntry{modelId='" + modelId + "'}";
+        }
     }
     }
 }
 }

File diff suppressed because it is too large
+ 0 - 0
x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/3rd_party_deployment.yml


+ 0 - 20
x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/inference_crud.yml

@@ -998,26 +998,6 @@ setup:
   - match: { trained_model_configs.0.metadata.model_aliases.0: "regression-model" }
   - match: { trained_model_configs.0.metadata.model_aliases.0: "regression-model" }
   - match: { trained_model_configs.0.metadata.model_aliases.1: "regression-model-again" }
   - match: { trained_model_configs.0.metadata.model_aliases.1: "regression-model-again" }
 ---
 ---
-"Test put model model aliases with nlp model":
-
-  - do:
-      ml.put_trained_model:
-        model_id: my-nlp-model
-        body: >
-          {
-            "description": "distilbert-base-uncased-finetuned-sst-2-english.pt",
-            "model_type": "pytorch",
-            "inference_config": {
-              "ner": {
-              }
-            }
-          }
-  - do:
-      catch: /model_alias is not supported on pytorch models/
-      ml.put_trained_model_alias:
-        model_alias: "nlp-model"
-        model_id: "my-nlp-model"
----
 "Test update model alias with model id referring to missing model":
 "Test update model alias with model id referring to missing model":
   - do:
   - do:
       catch: missing
       catch: missing

Some files were not shown because too many files changed in this diff