Browse Source

[ML] adds new trained model alias API to simplify trained model updates and deployments (#68922)

A `model_alias` allows trained models to be referred by a user defined moniker. 

This not only improves the readability and simplicity of numerous API calls, but it allows for simpler deployment and upgrade procedures for trained models. 

Previously, if you referenced a model ID directly within an ingest pipeline, when you have a new model that performs better than an earlier referenced model, you have to update the pipeline itself. If this model was used in numerous pipelines, ALL those pipelines would have to be updated. 

When using a `model_alias` in an ingest pipeline, only that `model_alias` needs to be updated. Then, the underlying referenced model will change in place for all ingest pipelines automatically. 

An additional benefit is that the model referenced is not changed until it is fully loaded into cache, this way throughput is not hampered by changing models.
Benjamin Trent 4 years ago
parent
commit
26eef892df
40 changed files with 1962 additions and 219 deletions
  1. 1 1
      docs/reference/ml/df-analytics/apis/get-trained-models-stats.asciidoc
  2. 1 1
      docs/reference/ml/df-analytics/apis/get-trained-models.asciidoc
  3. 1 0
      docs/reference/ml/df-analytics/apis/index.asciidoc
  4. 2 1
      docs/reference/ml/df-analytics/apis/ml-df-analytics-apis.asciidoc
  5. 89 0
      docs/reference/ml/df-analytics/apis/put-trained-models-aliases.asciidoc
  6. 4 0
      docs/reference/ml/ml-shared.asciidoc
  7. 6 0
      server/src/main/java/org/elasticsearch/common/util/set/Sets.java
  8. 7 7
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsAction.java
  9. 26 4
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelAction.java
  10. 119 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelAliasAction.java
  11. 26 13
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java
  12. 5 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java
  13. 1 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelActionResponseTests.java
  14. 68 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelAliasActionRequestTests.java
  15. 25 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/integration/MlRestTestStateCleaner.java
  16. 4 0
      x-pack/plugin/ml/qa/ml-with-security/build.gradle
  17. 84 0
      x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java
  18. 3 0
      x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlNativeIntegTestCase.java
  19. 204 56
      x-pack/plugin/ml/qa/single-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/InferenceProcessorIT.java
  20. 12 4
      x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java
  21. 17 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java
  22. 59 6
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteTrainedModelAction.java
  23. 11 3
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java
  24. 35 12
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java
  25. 3 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java
  26. 8 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java
  27. 209 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAliasAction.java
  28. 222 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ModelAliasMetadata.java
  29. 6 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java
  30. 232 53
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java
  31. 77 13
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java
  32. 57 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestPutTrainedModelAliasAction.java
  33. 3 2
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsActionTests.java
  34. 44 6
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java
  35. 140 9
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java
  36. 1 0
      x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java
  37. 1 1
      x-pack/plugin/src/test/java/org/elasticsearch/xpack/test/rest/AbstractXPackRestTest.java
  38. 40 0
      x-pack/plugin/src/test/resources/rest-api-spec/api/ml.put_trained_model_alias.json
  39. 23 0
      x-pack/plugin/src/test/resources/rest-api-spec/test/ml/get_trained_model_stats.yml
  40. 86 24
      x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml

+ 1 - 1
docs/reference/ml/df-analytics/apis/get-trained-models-stats.asciidoc

@@ -48,7 +48,7 @@ request by using a comma-separated list of model IDs or a wildcard expression.
 
 `<model_id>`::
 (Optional, string)
-include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id-or-alias]
 
 
 [[ml-get-trained-models-stats-query-params]]

+ 1 - 1
docs/reference/ml/df-analytics/apis/get-trained-models.asciidoc

@@ -50,7 +50,7 @@ using a comma-separated list of model IDs or a wildcard expression.
 
 `<model_id>`::
 (Optional, string)
-include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id-or-alias]
 
 
 [[ml-get-trained-models-query-params]]

+ 1 - 0
docs/reference/ml/df-analytics/apis/index.asciidoc

@@ -2,6 +2,7 @@ include::ml-df-analytics-apis.asciidoc[leveloffset=+1]
 //CREATE
 include::put-dfanalytics.asciidoc[leveloffset=+2]
 include::put-trained-models.asciidoc[leveloffset=+2]
+include::put-trained-models-aliases.asciidoc[leveloffset=+2]
 //UPDATE
 include::update-dfanalytics.asciidoc[leveloffset=+2]
 //DELETE

+ 2 - 1
docs/reference/ml/df-analytics/apis/ml-df-analytics-apis.asciidoc

@@ -22,8 +22,9 @@ You can use the following APIs to perform {infer} operations.
 * <<get-trained-models>>
 * <<get-trained-models-stats>>
 * <<delete-trained-models>>
+* <<put-trained-models-aliases>>
 
-You can deploy a trained model to make predictions in an ingest pipeline or in 
+You can deploy a trained model to make predictions in an ingest pipeline or in
 an aggregation. Refer to the following documentation to learn more.
 
 * <<inference-processor,{infer-cap} processor>>

+ 89 - 0
docs/reference/ml/df-analytics/apis/put-trained-models-aliases.asciidoc

@@ -0,0 +1,89 @@
+[role="xpack"]
+[testenv="platinum"]
+[[put-trained-models-aliases]]
+= Put Trained Models Aliases API
+[subs="attributes"]
+++++
+<titleabbrev>Put Trained Models Aliases</titleabbrev>
+++++
+
+Creates a trained models alias. These model aliases can be used instead of the trained model ID
+when referencing the model in the stack. Model aliases must be unique, and a trained model can have
+more than one model alias referring to it. But a model alias can only refer to a single trained model.
+
+beta::[]
+
+[[ml-put-trained-models-aliases-request]]
+== {api-request-title}
+
+`PUT _ml/trained_models/<model_id>/model_aliases/<model_alias>`
+
+
+[[ml-put-trained-models-aliases-prereq]]
+== {api-prereq-title}
+
+If the {es} {security-features} are enabled, you must have the following
+built-in roles and privileges:
+
+* `machine_learning_admin`
+
+For more information, see <<built-in-roles>>, <<security-privileges>>, and
+{ml-docs-setup-privileges}.
+
+[[ml-put-trained-models-aliases-desc]]
+== {api-description-title}
+
+This API creates a new model alias to refer to trained models, or updates an existing
+trained model's alias.
+
+When updating an existing model alias to a new model ID, this API will return a error if the models
+are of different inference types. Example, if attempting to put the model alias
+`flights-delay-prediction` from a regression model to a classification model, the API will error.
+
+The API will return a warning if there are very few input fields in common between the old
+and new models for the model alias.
+
+[[ml-put-trained-models-aliases-path-params]]
+== {api-path-parms-title}
+
+`model_id`::
+(Required, string)
+The trained model ID to which the model alias should refer.
+
+`model_alias`::
+(Required, string)
+The model alias to create or update. The model_alias cannot end in numbers.
+
+[[ml-put-trained-models-aliases-query-params]]
+== {api-query-parms-title}
+
+`reassign`::
+(Optional, boolean)
+Should the `model_alias` get reassigned to the provided `model_id` if it is already
+assigned to a model. Defaults to false. The API will return an error if the `model_alias`
+is already assigned to a model but this parameter is `false`.
+
+[[ml-put-trained-models-aliases-example]]
+== {api-examples-title}
+
+[[ml-put-trained-models-aliases-example-new-alias]]
+=== Creating a new model alias
+
+The following example shows how to create a new model alias for a trained model ID.
+
+[source,console]
+--------------------------------------------------
+PUT _ml/trained_models/flight-delay-prediction-1574775339910/model_aliases/flight_delay_model
+--------------------------------------------------
+// TEST[skip:setup kibana sample data]
+
+[[ml-put-trained-models-aliases-example-put-alias]]
+=== Updating an existing model alias
+
+The following example shows how to reassign an existing model alias for a trained model ID.
+
+[source,console]
+--------------------------------------------------
+PUT _ml/trained_models/flight-delay-prediction-1580004349800/model_aliases/flight_delay_model?reassign=true
+--------------------------------------------------
+// TEST[skip:setup kibana sample data]

+ 4 - 0
docs/reference/ml/ml-shared.asciidoc

@@ -1149,6 +1149,10 @@ tag::model-id[]
 The unique identifier of the trained model.
 end::model-id[]
 
+tag::model-id-or-alias[]
+The unique identifier of the trained model or a model alias.
+end::model-id-or-alias[]
+
 tag::model-memory-limit[]
 The approximate maximum amount of memory resources that are required for
 analytical processing. Once this limit is approached, data pruning becomes

+ 6 - 0
server/src/main/java/org/elasticsearch/common/util/set/Sets.java

@@ -62,6 +62,12 @@ public final class Sets {
         return left.stream().noneMatch(right::contains);
     }
 
+    public static <T> boolean haveNonEmptyIntersection(Set<T> left, Set<T> right) {
+        Objects.requireNonNull(left);
+        Objects.requireNonNull(right);
+        return left.stream().anyMatch(right::contains);
+    }
+
     /**
      * The relative complement, or difference, of the specified left and right set. Namely, the resulting set contains all the elements that
      * are in the left set but not in the right set. Neither input is mutated by this operation, an entirely new set is returned.

+ 7 - 7
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsAction.java

@@ -171,7 +171,7 @@ public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStat
         public static class Builder {
 
             private long totalModelCount;
-            private Set<String> expandedIds;
+            private Map<String, Set<String>> expandedIdsWithAliases;
             private Map<String, IngestStats> ingestStatsMap;
             private Map<String, InferenceStats> inferenceStatsMap;
 
@@ -180,13 +180,13 @@ public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStat
                 return this;
             }
 
-            public Builder setExpandedIds(Set<String> expandedIds) {
-                this.expandedIds = expandedIds;
+            public Builder setExpandedIdsWithAliases(Map<String, Set<String>> expandedIdsWithAliases) {
+                this.expandedIdsWithAliases = expandedIdsWithAliases;
                 return this;
             }
 
-            public Set<String> getExpandedIds() {
-                return this.expandedIds;
+            public Map<String, Set<String>> getExpandedIdsWithAliases() {
+                return this.expandedIdsWithAliases;
             }
 
             public Builder setIngestStatsByModelId(Map<String, IngestStats> ingestStatsByModelId) {
@@ -200,8 +200,8 @@ public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStat
             }
 
             public Response build() {
-                List<TrainedModelStats> trainedModelStats = new ArrayList<>(expandedIds.size());
-                expandedIds.forEach(id -> {
+                List<TrainedModelStats> trainedModelStats = new ArrayList<>(expandedIdsWithAliases.size());
+                expandedIdsWithAliases.keySet().forEach(id -> {
                     IngestStats ingestStats = ingestStatsMap.get(id);
                     InferenceStats inferenceStats = inferenceStatsMap.get(id);
                     trainedModelStats.add(new TrainedModelStats(

+ 26 - 4
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelAction.java

@@ -143,18 +143,25 @@ public class InternalInferModelAction extends ActionType<InternalInferModelActio
     public static class Response extends ActionResponse {
 
         private final List<InferenceResults> inferenceResults;
+        private final String modelId;
         private final boolean isLicensed;
 
-        public Response(List<InferenceResults> inferenceResults, boolean isLicensed) {
+        public Response(List<InferenceResults> inferenceResults, String modelId, boolean isLicensed) {
             super();
             this.inferenceResults = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(inferenceResults, "inferenceResults"));
             this.isLicensed = isLicensed;
+            this.modelId = modelId;
         }
 
         public Response(StreamInput in) throws IOException {
             super(in);
             this.inferenceResults = Collections.unmodifiableList(in.readNamedWriteableList(InferenceResults.class));
             this.isLicensed = in.readBoolean();
+            if (in.getVersion().onOrAfter(Version.V_8_0_0)) {
+                this.modelId = in.readOptionalString();
+            } else {
+                this.modelId = null;
+            }
         }
 
         public List<InferenceResults> getInferenceResults() {
@@ -165,10 +172,17 @@ public class InternalInferModelAction extends ActionType<InternalInferModelActio
             return isLicensed;
         }
 
+        public String getModelId() {
+            return modelId;
+        }
+
         @Override
         public void writeTo(StreamOutput out) throws IOException {
             out.writeNamedWriteableList(inferenceResults);
             out.writeBoolean(isLicensed);
+            if (out.getVersion().onOrAfter(Version.V_8_0_0)) {
+                out.writeOptionalString(modelId);
+            }
         }
 
         @Override
@@ -176,12 +190,14 @@ public class InternalInferModelAction extends ActionType<InternalInferModelActio
             if (this == o) return true;
             if (o == null || getClass() != o.getClass()) return false;
             InternalInferModelAction.Response that = (InternalInferModelAction.Response) o;
-            return isLicensed == that.isLicensed && Objects.equals(inferenceResults, that.inferenceResults);
+            return isLicensed == that.isLicensed
+                && Objects.equals(inferenceResults, that.inferenceResults)
+                && Objects.equals(modelId, that.modelId);
         }
 
         @Override
         public int hashCode() {
-            return Objects.hash(inferenceResults, isLicensed);
+            return Objects.hash(inferenceResults, isLicensed, modelId);
         }
 
         public static Builder builder() {
@@ -190,6 +206,7 @@ public class InternalInferModelAction extends ActionType<InternalInferModelActio
 
         public static class Builder {
             private List<InferenceResults> inferenceResults;
+            private String modelId;
             private boolean isLicensed;
 
             public Builder setInferenceResults(List<InferenceResults> inferenceResults) {
@@ -202,8 +219,13 @@ public class InternalInferModelAction extends ActionType<InternalInferModelActio
                 return this;
             }
 
+            public Builder setModelId(String modelId) {
+                this.modelId = modelId;
+                return this;
+            }
+
             public Response build() {
-                return new Response(inferenceResults, isLicensed);
+                return new Response(inferenceResults, modelId, isLicensed);
             }
         }
 

+ 119 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelAliasAction.java

@@ -0,0 +1,119 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.action;
+
+import org.elasticsearch.action.ActionRequestValidationException;
+import org.elasticsearch.action.ActionType;
+import org.elasticsearch.action.support.master.AcknowledgedRequest;
+import org.elasticsearch.action.support.master.AcknowledgedResponse;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
+import org.elasticsearch.xpack.core.ml.job.messages.Messages;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+
+import java.io.IOException;
+import java.util.Locale;
+import java.util.Objects;
+import java.util.regex.Pattern;
+
+import static org.elasticsearch.action.ValidateActions.addValidationError;
+import static org.elasticsearch.xpack.core.ml.job.messages.Messages.INVALID_MODEL_ALIAS;
+
+public class PutTrainedModelAliasAction extends ActionType<AcknowledgedResponse> {
+
+    // NOTE this is similar to our valid ID check. The difference here is that model_aliases cannot end in numbers
+    // This is to protect our automatic model naming conventions from hitting weird model_alias conflicts
+    private static final Pattern VALID_MODEL_ALIAS_CHAR_PATTERN = Pattern.compile("[a-z0-9](?:[a-z0-9_\\-\\.]*[a-z])?");
+
+    public static final PutTrainedModelAliasAction INSTANCE = new PutTrainedModelAliasAction();
+    public static final String NAME = "cluster:admin/xpack/ml/inference/model_aliases/put";
+
+    private PutTrainedModelAliasAction() {
+        super(NAME, AcknowledgedResponse::readFrom);
+    }
+
+    public static class Request extends AcknowledgedRequest<Request> {
+
+        public static final String MODEL_ALIAS = "model_alias";
+        public static final String REASSIGN = "reassign";
+
+        private final String modelAlias;
+        private final String modelId;
+        private final boolean reassign;
+
+        public Request(String modelAlias, String modelId, boolean reassign) {
+            this.modelAlias = ExceptionsHelper.requireNonNull(modelAlias, MODEL_ALIAS);
+            this.modelId = ExceptionsHelper.requireNonNull(modelId, TrainedModelConfig.MODEL_ID);
+            this.reassign = reassign;
+        }
+
+        public Request(StreamInput in) throws IOException {
+            super(in);
+            this.modelAlias = in.readString();
+            this.modelId = in.readString();
+            this.reassign = in.readBoolean();
+        }
+
+        public String getModelAlias() {
+            return modelAlias;
+        }
+
+        public String getModelId() {
+            return modelId;
+        }
+
+        public boolean isReassign() {
+            return reassign;
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) throws  IOException {
+            super.writeTo(out);
+            out.writeString(modelAlias);
+            out.writeString(modelId);
+            out.writeBoolean(reassign);
+        }
+
+        @Override
+        public ActionRequestValidationException validate() {
+            ActionRequestValidationException validationException = null;
+            if (modelAlias.equals(modelId)) {
+                validationException = addValidationError(
+                    String.format(
+                        Locale.ROOT,
+                        "model_alias [%s] cannot equal model_id [%s]",
+                        modelAlias,
+                        modelId
+                    ),
+                    validationException
+                );
+            }
+            if (VALID_MODEL_ALIAS_CHAR_PATTERN.matcher(modelAlias).matches() == false) {
+                validationException = addValidationError(Messages.getMessage(INVALID_MODEL_ALIAS, modelAlias), validationException);
+            }
+            return validationException;
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            Request request = (Request) o;
+            return Objects.equals(modelAlias, request.modelAlias)
+                && Objects.equals(modelId, request.modelId)
+                && Objects.equals(reassign, request.reassign);
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(modelAlias, modelId, reassign);
+        }
+
+    }
+}

+ 26 - 13
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java

@@ -43,6 +43,7 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
+import java.util.Set;
 import java.util.stream.Collectors;
 
 import static org.elasticsearch.action.ValidateActions.addValidationError;
@@ -58,6 +59,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
     public static final String TOTAL_FEATURE_IMPORTANCE = "total_feature_importance";
     public static final String FEATURE_IMPORTANCE_BASELINE = "feature_importance_baseline";
     public static final String HYPERPARAMETERS = "hyperparameters";
+    public static final String MODEL_ALIASES = "model_aliases";
 
     private static final String ESTIMATED_HEAP_MEMORY_USAGE_HUMAN = "estimated_heap_memory_usage";
 
@@ -471,34 +473,41 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
             if (totalFeatureImportance == null) {
                 return this;
             }
-            if (this.metadata == null) {
-                this.metadata = new HashMap<>();
-            }
-            this.metadata.put(TOTAL_FEATURE_IMPORTANCE,
-                totalFeatureImportance.stream().map(TotalFeatureImportance::asMap).collect(Collectors.toList()));
-            return this;
+            return addToMetadata(
+                TOTAL_FEATURE_IMPORTANCE,
+                totalFeatureImportance.stream().map(TotalFeatureImportance::asMap).collect(Collectors.toList())
+            );
         }
 
         public Builder setBaselineFeatureImportance(FeatureImportanceBaseline featureImportanceBaseline) {
             if (featureImportanceBaseline == null) {
                 return this;
             }
-            if (this.metadata == null) {
-                this.metadata = new HashMap<>();
-            }
-            this.metadata.put(FEATURE_IMPORTANCE_BASELINE, featureImportanceBaseline.asMap());
-            return this;
+            return addToMetadata(FEATURE_IMPORTANCE_BASELINE, featureImportanceBaseline.asMap());
         }
 
         public Builder setHyperparameters(List<Hyperparameters> hyperparameters) {
             if (hyperparameters == null) {
                 return this;
             }
+            return addToMetadata(
+                HYPERPARAMETERS,
+                hyperparameters.stream().map(Hyperparameters::asMap).collect(Collectors.toList())
+            );
+        }
+
+        public Builder setModelAliases(Set<String> modelAliases) {
+            if (modelAliases == null || modelAliases.isEmpty()) {
+                return this;
+            }
+            return addToMetadata(MODEL_ALIASES, modelAliases.stream().sorted().collect(Collectors.toList()));
+        }
+
+        private Builder addToMetadata(String fieldName, Object value) {
             if (this.metadata == null) {
                 this.metadata = new HashMap<>();
             }
-            this.metadata.put(HYPERPARAMETERS,
-            hyperparameters.stream().map(Hyperparameters::asMap).collect(Collectors.toList()));
+            this.metadata.put(fieldName, value);
             return this;
         }
 
@@ -663,6 +672,10 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
                         metadata.get(TOTAL_FEATURE_IMPORTANCE),
                         METADATA.getPreferredName() + "." + TOTAL_FEATURE_IMPORTANCE,
                         validationException);
+                    validationException = checkIllegalSetting(
+                        metadata.get(MODEL_ALIASES),
+                        METADATA.getPreferredName() + "." + MODEL_ALIASES,
+                        validationException);
                 }
             }
 

+ 5 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java

@@ -120,6 +120,11 @@ public final class Messages {
     public static final String INFERENCE_TAGS_AND_MODEL_IDS_UNIQUE = "The provided tags {0} must not match existing model_ids.";
     public static final String INFERENCE_MODEL_ID_AND_TAGS_UNIQUE = "The provided model_id {0} must not match existing tags.";
 
+    public static final String INVALID_MODEL_ALIAS = "Invalid model_alias; ''{0}'' can contain lowercase alphanumeric (a-z and 0-9), " +
+        "hyphens or underscores; must start with alphanumeric and cannot end with numbers";
+    public static final String TRAINED_MODEL_INPUTS_DIFFER_SIGNIFICANTLY =
+        "The input fields for new model [{0}] and for old model [{1}] differ significantly, model results may change drastically.";
+
     public static final String JOB_AUDIT_DATAFEED_DATA_SEEN_AGAIN = "Datafeed has started retrieving data again";
     public static final String JOB_AUDIT_CREATED = "Job created";
     public static final String JOB_AUDIT_UPDATED = "Job updated: {0}";

+ 1 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelActionResponseTests.java

@@ -29,6 +29,7 @@ public class InternalInferModelActionResponseTests extends AbstractWireSerializi
             Stream.generate(() -> randomInferenceResult(resultType))
             .limit(randomIntBetween(0, 10))
             .collect(Collectors.toList()),
+            randomAlphaOfLength(10),
             randomBoolean());
     }
 

+ 68 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelAliasActionRequestTests.java

@@ -0,0 +1,68 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+package org.elasticsearch.xpack.core.ml.action;
+
+import org.elasticsearch.action.ActionRequestValidationException;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAliasAction.Request;
+import org.junit.Before;
+
+import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.not;
+import static org.hamcrest.Matchers.nullValue;
+
+public class PutTrainedModelAliasActionRequestTests extends AbstractWireSerializingTestCase<Request> {
+
+    private String modelAlias;
+
+    @Before
+    public void setupModelAlias() {
+        modelAlias = randomAlphaOfLength(10);
+    }
+
+    @Override
+    protected Request createTestInstance() {
+        return new Request(
+            modelAlias,
+            randomAlphaOfLength(10),
+            randomBoolean()
+        );
+    }
+
+    @Override
+    protected Writeable.Reader<Request> instanceReader() {
+        return Request::new;
+    }
+
+    public void testCtor() {
+        expectThrows(Exception.class, () -> new Request(null, randomAlphaOfLength(10), randomBoolean()));
+        expectThrows(Exception.class, () -> new Request(randomAlphaOfLength(10), null, randomBoolean()));
+    }
+
+    public void testValidate() {
+
+        { // model_alias equal to  model Id
+            ActionRequestValidationException ex = new Request("foo", "foo", randomBoolean()).validate();
+            assertThat(ex, not(nullValue()));
+            assertThat(ex.getMessage(), containsString("model_alias [foo] cannot equal model_id [foo]"));
+        }
+        { // model_alias cannot end in numbers
+            String modelAlias = randomAlphaOfLength(10) + randomIntBetween(0, Integer.MAX_VALUE);
+            ActionRequestValidationException ex = new Request(modelAlias, "foo", randomBoolean()).validate();
+            assertThat(ex, not(nullValue()));
+            assertThat(
+                ex.getMessage(),
+                containsString(
+                    "can contain lowercase alphanumeric (a-z and 0-9), hyphens or underscores; "
+                        + "must start with alphanumeric and cannot end with numbers"
+                )
+            );
+        }
+    }
+
+}

+ 25 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/integration/MlRestTestStateCleaner.java

@@ -14,11 +14,14 @@ import org.elasticsearch.common.xcontent.support.XContentMapValues;
 import org.elasticsearch.test.rest.ESRestTestCase;
 
 import java.io.IOException;
+import java.util.Collections;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
 
 public class MlRestTestStateCleaner {
 
+    private static final Set<String> NOT_DELETED_TRAINED_MODELS = Collections.singleton("lang_ident_model_1");
     private final Logger logger;
     private final RestClient adminClient;
 
@@ -28,12 +31,34 @@ public class MlRestTestStateCleaner {
     }
 
     public void clearMlMetadata() throws IOException {
+        deleteAllTrainedModels();
         deleteAllDatafeeds();
         deleteAllJobs();
         deleteAllDataFrameAnalytics();
         // indices will be deleted by the ESRestTestCase class
     }
 
+    @SuppressWarnings("unchecked")
+    private void deleteAllTrainedModels() throws IOException {
+        final Request getTrainedModels = new Request("GET", "/_ml/trained_models");
+        getTrainedModels.addParameter("size", "10000");
+        final Response trainedModelsResponse = adminClient.performRequest(getTrainedModels);
+        final List<Map<String, Object>> models = (List<Map<String, Object>>) XContentMapValues.extractValue(
+            "trained_model_configs",
+            ESRestTestCase.entityAsMap(trainedModelsResponse)
+        );
+        if (models == null || models.isEmpty()) {
+            return;
+        }
+        for (Map<String, Object> model : models) {
+            String modelId = (String) model.get("model_id");
+            if (NOT_DELETED_TRAINED_MODELS.contains(modelId)) {
+                continue;
+            }
+            adminClient.performRequest(new Request("DELETE", "/_ml/trained_models/" + modelId));
+        }
+    }
+
     @SuppressWarnings("unchecked")
     private void deleteAllDatafeeds() throws IOException {
         final Request datafeedsRequest = new Request("GET", "/_ml/datafeeds");

+ 4 - 0
x-pack/plugin/ml/qa/ml-with-security/build.gradle

@@ -145,6 +145,10 @@ tasks.named("yamlRestTest").configure {
     'ml/inference_crud/Test put ensemble with tree where tree has out of bounds feature_names index',
     'ml/inference_crud/Test put model with empty input.field_names',
     'ml/inference_crud/Test PUT model where target type and inference config mismatch',
+    'ml/inference_crud/Test update model alias with model id referring to missing model',
+    'ml/inference_crud/Test update model alias with bad alias',
+    'ml/inference_crud/Test update model alias where alias exists but old model id is different inference type',
+    'ml/inference_crud/Test update model alias where alias exists but reassign is false',
     'ml/inference_processor/Test create processor with missing mandatory fields',
     'ml/inference_stats_crud/Test get stats given missing trained model',
     'ml/inference_stats_crud/Test get stats given expression without matches and allow_no_match is false',

+ 84 - 0
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java

@@ -48,6 +48,7 @@ import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.greaterThan;
 import static org.hamcrest.Matchers.hasSize;
 import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.not;
 
 /**
  * This is a {@link ESRestTestCase} because the cleanup code in {@link ExternalTestCluster#ensureEstimatedStats()} causes problems
@@ -201,6 +202,84 @@ public class InferenceIngestIT extends ESRestTestCase {
         }, 30, TimeUnit.SECONDS);
     }
 
+    public void testPipelineIngestWithModelAliases() throws Exception {
+        String regressionModelId = "test_regression_1";
+        putModel(regressionModelId, REGRESSION_CONFIG);
+        String regressionModelId2 = "test_regression_2";
+        putModel(regressionModelId2, REGRESSION_CONFIG);
+        String modelAlias = "test_regression";
+        putModelAlias(modelAlias, regressionModelId);
+
+        client().performRequest(putPipeline("simple_regression_pipeline", pipelineDefinition(modelAlias, "regression")));
+
+        for (int i = 0; i < 10; i++) {
+            client().performRequest(indexRequest("index_for_inference_test", "simple_regression_pipeline", generateSourceDoc()));
+        }
+        putModelAlias(modelAlias, regressionModelId2);
+        // Need to assert busy as loading the model and then switching the model alias can take time
+        assertBusy(() -> {
+            String source = "{\n" +
+                "  \"docs\": [\n" +
+                "    {\"_source\": {\n" +
+                "      \"col1\": \"female\",\n" +
+                "      \"col2\": \"M\",\n" +
+                "      \"col3\": \"none\",\n" +
+                "      \"col4\": 10\n" +
+                "    }}]\n" +
+                "}";
+            Request request = new Request("POST", "_ingest/pipeline/simple_regression_pipeline/_simulate");
+            request.setJsonEntity(source);
+            Response response = client().performRequest(request);
+            String responseString = EntityUtils.toString(response.getEntity());
+            assertThat(responseString, containsString("\"model_id\":\"test_regression_2\""));
+        }, 30, TimeUnit.SECONDS);
+
+        for (int i = 0; i < 10; i++) {
+            client().performRequest(indexRequest("index_for_inference_test", "simple_regression_pipeline", generateSourceDoc()));
+        }
+
+        client().performRequest(new Request("DELETE", "_ingest/pipeline/simple_regression_pipeline"));
+
+        client().performRequest(new Request("POST", "index_for_inference_test/_refresh"));
+
+        Response searchResponse = client().performRequest(searchRequest("index_for_inference_test",
+            QueryBuilders.boolQuery()
+                .filter(
+                    QueryBuilders.existsQuery("ml.inference.regression.predicted_value"))));
+        // Verify we have 20 documents that contain a predicted value for regression
+        assertThat(EntityUtils.toString(searchResponse.getEntity()), containsString("\"value\":20"));
+
+
+        // Since this is a multi-node cluster, the model could be loaded and cached on one ingest node but not the other
+        // Consequently, we should only verify that some of the documents refer to the first regression model
+        // and some refer to the second.
+        searchResponse = client().performRequest(searchRequest("index_for_inference_test",
+            QueryBuilders.boolQuery()
+                .filter(
+                    QueryBuilders.termQuery("ml.inference.regression.model_id.keyword", regressionModelId))));
+        assertThat(EntityUtils.toString(searchResponse.getEntity()), not(containsString("\"value\":0")));
+
+        searchResponse = client().performRequest(searchRequest("index_for_inference_test",
+            QueryBuilders.boolQuery()
+                .filter(
+                    QueryBuilders.termQuery("ml.inference.regression.model_id.keyword", regressionModelId2))));
+        assertThat(EntityUtils.toString(searchResponse.getEntity()), not(containsString("\"value\":0")));
+
+        assertBusy(() -> {
+            try (XContentParser parser = createParser(JsonXContent.jsonXContent, client().performRequest(new Request("GET",
+                "_ml/trained_models/" + modelAlias + "/_stats")).getEntity().getContent())) {
+                GetTrainedModelsStatsResponse response = GetTrainedModelsStatsResponse.fromXContent(parser);
+                assertThat(response.toString(), response.getTrainedModelStats(), hasSize(1));
+                TrainedModelStats trainedModelStats = response.getTrainedModelStats().get(0);
+                assertThat(trainedModelStats.getModelId(), equalTo(regressionModelId2));
+                assertThat(trainedModelStats.getInferenceStats(), is(notNullValue()));
+            } catch (ResponseException ex) {
+                //this could just mean shard failures.
+                fail(ex.getMessage());
+            }
+        });
+    }
+
     public void assertStatsWithCacheMisses(String modelId, long inferenceCount) throws IOException {
         Response statsResponse = client().performRequest(new Request("GET",
             "_ml/trained_models/" + modelId + "/_stats"));
@@ -629,4 +708,9 @@ public class InferenceIngestIT extends ESRestTestCase {
         client().performRequest(request);
     }
 
+    private void putModelAlias(String modelAlias, String newModel) throws IOException {
+        Request request = new Request("PUT", "_ml/trained_models/" + newModel + "/model_aliases/" + modelAlias + "?reassign=true");
+        client().performRequest(request);
+    }
+
 }

+ 3 - 0
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlNativeIntegTestCase.java

@@ -79,6 +79,7 @@ import org.elasticsearch.xpack.datastreams.DataStreamsPlugin;
 import org.elasticsearch.xpack.ilm.IndexLifecycle;
 import org.elasticsearch.xpack.ml.LocalStateMachineLearning;
 import org.elasticsearch.xpack.ml.autoscaling.MlScalingReason;
+import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata;
 
 import java.io.IOException;
 import java.io.UncheckedIOException;
@@ -257,6 +258,8 @@ abstract class MlNativeIntegTestCase extends ESIntegTestCase {
         if (cluster() != null && cluster().size() > 0) {
             List<NamedWriteableRegistry.Entry> entries = new ArrayList<>(ClusterModule.getNamedWriteables());
             entries.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedWriteables());
+            entries.add(new NamedWriteableRegistry.Entry(Metadata.Custom.class, ModelAliasMetadata.NAME, ModelAliasMetadata::new));
+            entries.add(new NamedWriteableRegistry.Entry(NamedDiff.class, ModelAliasMetadata.NAME, ModelAliasMetadata::readDiffFrom));
             entries.add(new NamedWriteableRegistry.Entry(Metadata.Custom.class, "ml", MlMetadata::new));
             entries.add(new NamedWriteableRegistry.Entry(Metadata.Custom.class, IndexLifecycleMetadata.TYPE, IndexLifecycleMetadata::new));
             entries.add(new NamedWriteableRegistry.Entry(LifecycleType.class, TimeseriesLifecycleType.TYPE,

+ 204 - 56
x-pack/plugin/ml/qa/single-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/InferenceProcessorIT.java

@@ -13,19 +13,26 @@ import org.elasticsearch.client.Response;
 import org.elasticsearch.client.ResponseException;
 import org.elasticsearch.common.xcontent.support.XContentMapValues;
 import org.elasticsearch.test.rest.ESRestTestCase;
+import org.junit.After;
 import org.junit.Before;
 
 import java.io.IOException;
+import java.util.Collections;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
 
+import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.hasSize;
 
 
 public class InferenceProcessorIT extends ESRestTestCase {
+    private static final Set<String> NOT_DELETED_TRAINED_MODELS = Collections.singleton("lang_ident_model_1");
 
     private static final String MODEL_ID = "a-perfect-regression-model";
+    private final Set<String> createdPipelines = new HashSet<>();
 
     @Before
     public void enableLogging() throws IOException {
@@ -36,8 +43,39 @@ public class InferenceProcessorIT extends ESRestTestCase {
         assertThat(client().performRequest(setTrace).getStatusLine().getStatusCode(), equalTo(200));
     }
 
-    private void putRegressionModel() throws IOException {
+    @SuppressWarnings("unchecked")
+    @After
+    public void cleanup() throws Exception {
+        for (String createdPipeline : createdPipelines) {
+            deletePipeline(createdPipeline);
+        }
+        createdPipelines.clear();
+        waitForStats();
+        final Request getTrainedModels = new Request("GET", "/_ml/trained_models");
+        getTrainedModels.addParameter("size", "10000");
+        final Response trainedModelsResponse = adminClient().performRequest(getTrainedModels);
+        final List<Map<String, Object>> models = (List<Map<String, Object>>) XContentMapValues.extractValue(
+            "trained_model_configs",
+            ESRestTestCase.entityAsMap(trainedModelsResponse)
+        );
+        if (models == null || models.isEmpty()) {
+            return;
+        }
+        for (Map<String, Object> model : models) {
+            String modelId = (String) model.get("model_id");
+            if (NOT_DELETED_TRAINED_MODELS.contains(modelId)) {
+                continue;
+            }
+            adminClient().performRequest(new Request("DELETE", "/_ml/trained_models/" + modelId));
+        }
+    }
+
+    private void putModelAlias(String modelAlias, String newModel) throws IOException {
+        Request request = new Request("PUT", "_ml/trained_models/" + newModel + "/model_aliases/" + modelAlias + "?reassign=true");
+        client().performRequest(request);
+    }
 
+    private void putRegressionModel() throws IOException {
         Request model = new Request("PUT", "_ml/trained_models/" + MODEL_ID);
         model.setJsonEntity(
             "  {\n" +
@@ -66,24 +104,9 @@ public class InferenceProcessorIT extends ESRestTestCase {
     @SuppressWarnings("unchecked")
     public void testCreateAndDeletePipelineWithInferenceProcessor() throws Exception {
         putRegressionModel();
-
-        Request putPipeline = new Request("PUT", "_ingest/pipeline/regression-model-pipeline");
-        putPipeline.setJsonEntity(
-                "          {\n" +
-                "            \"processors\": [\n" +
-                "              {\n" +
-                "                \"inference\" : {\n" +
-                "                  \"model_id\" : \"" + MODEL_ID + "\",\n" +
-                "                  \"inference_config\": {\"regression\": {}},\n" +
-                "                  \"target_field\": \"regression_field\",\n" +
-                "                  \"field_map\": {}\n" +
-                "                }\n" +
-                "              }\n" +
-                "            ]\n" +
-                "          }"
-        );
-
-        assertThat(client().performRequest(putPipeline).getStatusLine().getStatusCode(), equalTo(200));
+        String pipelineId = "regression-model-pipeline";
+        createdPipelines.add(pipelineId);
+        putPipeline(MODEL_ID, pipelineId);
 
         Map<String, Object> statsAsMap = getStats();
         List<Integer> pipelineCount =
@@ -100,8 +123,8 @@ public class InferenceProcessorIT extends ESRestTestCase {
         // using the model will ensure it is loaded and stats will be written before it is deleted
         infer("regression-model-pipeline");
 
-        Request deletePipeline = new Request("DELETE", "_ingest/pipeline/regression-model-pipeline");
-        assertThat(client().performRequest(deletePipeline).getStatusLine().getStatusCode(), equalTo(200));
+        deletePipeline(pipelineId);
+        createdPipelines.remove(pipelineId);
 
         // check stats are updated
         assertBusy(() -> {
@@ -129,9 +152,100 @@ public class InferenceProcessorIT extends ESRestTestCase {
         });
     }
 
+    @SuppressWarnings("unchecked")
+    public void testCreateAndDeletePipelineWithInferenceProcessorByName() throws Exception {
+        putRegressionModel();
+
+        putModelAlias("regression_first", MODEL_ID);
+        putModelAlias("regression_second", MODEL_ID);
+        createdPipelines.add("first_pipeline");
+        putPipeline("regression_first", "first_pipeline");
+        createdPipelines.add("second_pipeline");
+        putPipeline("regression_second", "second_pipeline");
+
+        Map<String, Object> statsAsMap = getStats();
+        List<Integer> pipelineCount =
+            (List<Integer>)XContentMapValues.extractValue("trained_model_stats.pipeline_count", statsAsMap);
+        assertThat(pipelineCount.get(0), equalTo(2));
+
+        List<Map<String, Object>> counts =
+            (List<Map<String, Object>>)XContentMapValues.extractValue("trained_model_stats.ingest.total", statsAsMap);
+        assertThat(counts.get(0).get("count"), equalTo(0));
+        assertThat(counts.get(0).get("time_in_millis"), equalTo(0));
+        assertThat(counts.get(0).get("current"), equalTo(0));
+        assertThat(counts.get(0).get("failed"), equalTo(0));
+
+        // using the model will ensure it is loaded and stats will be written before it is deleted
+        infer("first_pipeline");
+        deletePipeline("first_pipeline");
+        createdPipelines.remove("first_pipeline");
+
+        infer("second_pipeline");
+        deletePipeline("second_pipeline");
+        createdPipelines.remove("second_pipeline");
+
+        // check stats are updated
+        assertBusy(() -> {
+            Map<String, Object> updatedStatsMap = null;
+            try {
+                updatedStatsMap = getStats();
+            } catch (ResponseException e) {
+                // the search may fail because the index is not ready yet in which case retry
+                if (e.getMessage().contains("search_phase_execution_exception")) {
+                    fail("search failed- retry");
+                } else {
+                    throw e;
+                }
+            }
+
+            List<Integer> updatedPipelineCount =
+                (List<Integer>) XContentMapValues.extractValue("trained_model_stats.pipeline_count", updatedStatsMap);
+            assertThat(updatedPipelineCount.get(0), equalTo(0));
+
+            List<Map<String, Object>> inferenceStats =
+                (List<Map<String, Object>>) XContentMapValues.extractValue("trained_model_stats.inference_stats", updatedStatsMap);
+            assertNotNull(inferenceStats);
+            assertThat(inferenceStats, hasSize(1));
+            assertThat(inferenceStats.toString(), inferenceStats.get(0).get("inference_count"), equalTo(2));
+        });
+    }
+
+    public void testDeleteModelWhileAliasReferencedByPipeline() throws Exception {
+        putRegressionModel();
+        putModelAlias("regression_first", MODEL_ID);
+        createdPipelines.add("first_pipeline");
+        putPipeline("regression_first", "first_pipeline");
+        Exception ex = expectThrows(Exception.class,
+            () -> client().performRequest(new Request("DELETE", "_ml/trained_models/" + MODEL_ID)));
+        assertThat(ex.getMessage(),
+            containsString("Cannot delete model ["
+                + MODEL_ID
+                + "] as it has a model_alias [regression_first] that is still referenced by ingest processors"));
+        infer("first_pipeline");
+        deletePipeline("first_pipeline");
+        waitForStats();
+    }
+
+    public void testDeleteModelWhileReferencedByPipeline() throws Exception {
+        putRegressionModel();
+        createdPipelines.add("first_pipeline");
+        putPipeline(MODEL_ID, "first_pipeline");
+        Exception ex = expectThrows(Exception.class,
+            () -> client().performRequest(new Request("DELETE", "_ml/trained_models/" + MODEL_ID)));
+        assertThat(ex.getMessage(),
+            containsString("Cannot delete model ["
+                + MODEL_ID
+                + "] as it is still referenced by ingest processors"));
+        infer("first_pipeline");
+        deletePipeline("first_pipeline");
+        waitForStats();
+    }
+
+    @SuppressWarnings("unchecked")
     public void testCreateProcessorWithDeprecatedFields() throws Exception {
         putRegressionModel();
 
+        createdPipelines.add("regression-model-deprecated-pipeline");
         Request putPipeline = new Request("PUT", "_ingest/pipeline/regression-model-deprecated-pipeline");
         putPipeline.setJsonEntity(
                 "{\n" +
@@ -155,14 +269,35 @@ public class InferenceProcessorIT extends ESRestTestCase {
         // using the model will ensure it is loaded and stats will be written before it is deleted
         infer("regression-model-deprecated-pipeline");
 
-        Request deletePipeline = new Request("DELETE", "_ingest/pipeline/regression-model-deprecated-pipeline");
-        Response deleteResponse = client().performRequest(deletePipeline);
-        assertThat(deleteResponse.getStatusLine().getStatusCode(), equalTo(200));
+        deletePipeline("regression-model-deprecated-pipeline");
+        createdPipelines.remove("regression-model-deprecated-pipeline");
+        waitForStats();
+        assertBusy(() -> {
+            Map<String, Object> updatedStatsMap = null;
+            try {
+                updatedStatsMap = getStats();
+            } catch (ResponseException e) {
+                // the search may fail because the index is not ready yet in which case retry
+                if (e.getMessage().contains("search_phase_execution_exception")) {
+                    fail("search failed- retry");
+                } else {
+                    throw e;
+                }
+            }
+
+            List<Integer> updatedPipelineCount =
+                (List<Integer>) XContentMapValues.extractValue("trained_model_stats.pipeline_count", updatedStatsMap);
+            assertThat(updatedPipelineCount.get(0), equalTo(0));
 
-        waitForStatsDoc();
+            List<Map<String, Object>> inferenceStats =
+                (List<Map<String, Object>>) XContentMapValues.extractValue("trained_model_stats.inference_stats", updatedStatsMap);
+            assertNotNull(inferenceStats);
+            assertThat(inferenceStats, hasSize(1));
+            assertThat(inferenceStats.get(0).get("inference_count"), equalTo(1));
+        });
     }
 
-    public void infer(String pipelineId) throws IOException {
+    private void infer(String pipelineId) throws IOException {
         Request putDoc = new Request("POST", "any_index/_doc?pipeline=" + pipelineId);
         putDoc.setJsonEntity("{\"field1\": 1, \"field2\": 2}");
 
@@ -170,43 +305,56 @@ public class InferenceProcessorIT extends ESRestTestCase {
         assertThat(response.getStatusLine().getStatusCode(), equalTo(201));
     }
 
-    @SuppressWarnings("unchecked")
-    public void waitForStatsDoc() throws Exception {
-        assertBusy( () -> {
-            Request searchForStats = new Request("GET", ".ml-stats-*/_search?rest_total_hits_as_int");
-            searchForStats.setJsonEntity(
-                    "{\n" +
-                    "  \"query\": {\n" +
-                    "    \"bool\": {\n" +
-                    "      \"filter\": [\n" +
-                    "        {\n" +
-                    "          \"term\": {\n" +
-                    "            \"type\": \"inference_stats\"\n" +
-                    "          }\n" +
-                    "        },\n" +
-                    "        {\n" +
-                    "          \"term\": {\n" +
-                    "            \"model_id\": \"" + MODEL_ID + "\"\n" +
-                    "          }\n" +
-                    "        }\n" +
-                    "      ]\n" +
-                    "    }\n" +
-                    "  }\n" +
-                    "}"
-            );
+    private void putPipeline(String modelId, String pipelineName) throws IOException {
+        Request putPipeline = new Request("PUT", "_ingest/pipeline/" + pipelineName);
+        putPipeline.setJsonEntity(
+            "          {\n" +
+                "            \"processors\": [\n" +
+                "              {\n" +
+                "                \"inference\" : {\n" +
+                "                  \"model_id\" : \"" + modelId + "\",\n" +
+                "                  \"inference_config\": {\"regression\": {}},\n" +
+                "                  \"target_field\": \"regression_field\",\n" +
+                "                  \"field_map\": {}\n" +
+                "                }\n" +
+                "              }\n" +
+                "            ]\n" +
+                "          }"
+        );
 
-            try {
-                Response searchResponse = client().performRequest(searchForStats);
+        assertThat(client().performRequest(putPipeline).getStatusLine().getStatusCode(), equalTo(200));
+    }
 
-                Map<String, Object> responseAsMap = entityAsMap(searchResponse);
-                Map<String, Object> hits = (Map<String, Object>)responseAsMap.get("hits");
-                assertThat(responseAsMap.toString(), hits.get("total"), equalTo(1));
+    private void deletePipeline(String pipelineId) throws IOException {
+        try {
+            Request deletePipeline = new Request("DELETE", "_ingest/pipeline/" + pipelineId);
+            assertThat(client().performRequest(deletePipeline).getStatusLine().getStatusCode(), equalTo(200));
+        } catch (ResponseException ex) {
+            if (ex.getResponse().getStatusLine().getStatusCode() != 404) {
+                throw ex;
+            }
+        }
+    }
+
+    @SuppressWarnings("unchecked")
+    private void waitForStats() throws Exception {
+        assertBusy(() -> {
+            Map<String, Object> updatedStatsMap = null;
+            try {
+                ensureGreen(".ml-stats-*");
+                updatedStatsMap = getStats();
             } catch (ResponseException e) {
                 // the search may fail because the index is not ready yet in which case retry
-                if (e.getMessage().contains("search_phase_execution_exception") == false) {
+                if (e.getMessage().contains("search_phase_execution_exception")) {
+                    fail("search failed- retry");
+                } else {
                     throw e;
                 }
             }
+
+            List<Map<String, Object>> inferenceStats =
+                (List<Map<String, Object>>) XContentMapValues.extractValue("trained_model_stats.inference_stats", updatedStatsMap);
+            assertNotNull(inferenceStats);
         });
     }
 

+ 12 - 4
x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java

@@ -36,6 +36,7 @@ import org.elasticsearch.xpack.ml.dataframe.process.results.TrainedModelDefiniti
 import org.elasticsearch.xpack.ml.extractor.DocValueField;
 import org.elasticsearch.xpack.ml.extractor.ExtractedField;
 import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
+import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata;
 import org.elasticsearch.xpack.ml.inference.modelsize.MlModelSizeNamedXContentProvider;
 import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfo;
 import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfoTests;
@@ -102,11 +103,18 @@ public class ChunkedTrainedModelPersisterIT extends MlSingleNodeTestCase {
             .collect(Collectors.toList()));
         persister.createAndIndexInferenceModelMetadata(modelMetadata);
 
-        PlainActionFuture<Tuple<Long, Set<String>>> getIdsFuture = new PlainActionFuture<>();
-        trainedModelProvider.expandIds(modelId + "*", false, PageParams.defaultParams(), Collections.emptySet(), getIdsFuture);
-        Tuple<Long, Set<String>> ids = getIdsFuture.actionGet();
+        PlainActionFuture<Tuple<Long, Map<String, Set<String>>>> getIdsFuture = new PlainActionFuture<>();
+        trainedModelProvider.expandIds(
+            modelId + "*",
+            false,
+            PageParams.defaultParams(),
+            Collections.emptySet(),
+            ModelAliasMetadata.EMPTY,
+            getIdsFuture
+        );
+        Tuple<Long, Map<String, Set<String>>> ids = getIdsFuture.actionGet();
         assertThat(ids.v1(), equalTo(1L));
-        String inferenceModelId = ids.v2().iterator().next();
+        String inferenceModelId = ids.v2().keySet().iterator().next();
 
         PlainActionFuture<TrainedModelConfig> getTrainedModelFuture = new PlainActionFuture<>();
         trainedModelProvider.getTrainedModel(inferenceModelId, GetTrainedModelsAction.Includes.all(), getTrainedModelFuture);

+ 17 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java

@@ -25,6 +25,7 @@ import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.node.DiscoveryNodeRole;
 import org.elasticsearch.cluster.node.DiscoveryNodes;
 import org.elasticsearch.cluster.service.ClusterService;
+import org.elasticsearch.common.ParseField;
 import org.elasticsearch.common.breaker.CircuitBreaker;
 import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.settings.ClusterSettings;
@@ -141,6 +142,7 @@ import org.elasticsearch.xpack.core.ml.action.UpdateFilterAction;
 import org.elasticsearch.xpack.core.ml.action.UpdateJobAction;
 import org.elasticsearch.xpack.core.ml.action.UpdateModelSnapshotAction;
 import org.elasticsearch.xpack.core.ml.action.UpdateProcessAction;
+import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAliasAction;
 import org.elasticsearch.xpack.core.ml.action.UpgradeJobModelSnapshotAction;
 import org.elasticsearch.xpack.core.ml.action.ValidateDetectorAction;
 import org.elasticsearch.xpack.core.ml.action.ValidateJobConfigAction;
@@ -219,6 +221,7 @@ import org.elasticsearch.xpack.ml.action.TransportUpdateFilterAction;
 import org.elasticsearch.xpack.ml.action.TransportUpdateJobAction;
 import org.elasticsearch.xpack.ml.action.TransportUpdateModelSnapshotAction;
 import org.elasticsearch.xpack.ml.action.TransportUpdateProcessAction;
+import org.elasticsearch.xpack.ml.action.TransportPutTrainedModelAliasAction;
 import org.elasticsearch.xpack.ml.action.TransportUpgradeJobModelSnapshotAction;
 import org.elasticsearch.xpack.ml.action.TransportValidateDetectorAction;
 import org.elasticsearch.xpack.ml.action.TransportValidateJobConfigAction;
@@ -239,6 +242,7 @@ import org.elasticsearch.xpack.ml.dataframe.process.NativeAnalyticsProcessFactor
 import org.elasticsearch.xpack.ml.dataframe.process.NativeMemoryUsageEstimationProcessFactory;
 import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
 import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult;
+import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata;
 import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService;
 import org.elasticsearch.xpack.ml.inference.aggs.InferencePipelineAggregationBuilder;
 import org.elasticsearch.xpack.ml.inference.aggs.InternalInferenceAggregation;
@@ -316,6 +320,7 @@ import org.elasticsearch.xpack.ml.rest.inference.RestDeleteTrainedModelAction;
 import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsAction;
 import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsStatsAction;
 import org.elasticsearch.xpack.ml.rest.inference.RestPutTrainedModelAction;
+import org.elasticsearch.xpack.ml.rest.inference.RestPutTrainedModelAliasAction;
 import org.elasticsearch.xpack.ml.rest.job.RestCloseJobAction;
 import org.elasticsearch.xpack.ml.rest.job.RestDeleteForecastAction;
 import org.elasticsearch.xpack.ml.rest.job.RestDeleteJobAction;
@@ -933,6 +938,7 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
             new RestGetTrainedModelsStatsAction(),
             new RestPutTrainedModelAction(),
             new RestUpgradeJobModelSnapshotAction(),
+            new RestPutTrainedModelAliasAction(),
             // CAT Handlers
             new RestCatJobsAction(),
             new RestCatTrainedModelsAction(),
@@ -1016,7 +1022,8 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
                 new ActionHandler<>(GetTrainedModelsStatsAction.INSTANCE, TransportGetTrainedModelsStatsAction.class),
                 new ActionHandler<>(PutTrainedModelAction.INSTANCE, TransportPutTrainedModelAction.class),
                 new ActionHandler<>(UpgradeJobModelSnapshotAction.INSTANCE, TransportUpgradeJobModelSnapshotAction.class),
-            usageAction,
+                new ActionHandler<>(PutTrainedModelAliasAction.INSTANCE, TransportPutTrainedModelAliasAction.class),
+                usageAction,
                 infoAction);
     }
 
@@ -1121,6 +1128,13 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
         namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers());
         namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
         namedXContent.addAll(new MlModelSizeNamedXContentProvider().getNamedXContentParsers());
+        namedXContent.add(
+            new NamedXContentRegistry.Entry(
+                Metadata.Custom.class,
+                new ParseField(ModelAliasMetadata.NAME),
+                ModelAliasMetadata::fromXContent
+            )
+        );
         return namedXContent;
     }
 
@@ -1131,6 +1145,8 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
         // Custom metadata
         namedWriteables.add(new NamedWriteableRegistry.Entry(Metadata.Custom.class, "ml", MlMetadata::new));
         namedWriteables.add(new NamedWriteableRegistry.Entry(NamedDiff.class, "ml", MlMetadata.MlMetadataDiff::new));
+        namedWriteables.add(new NamedWriteableRegistry.Entry(Metadata.Custom.class, ModelAliasMetadata.NAME, ModelAliasMetadata::new));
+        namedWriteables.add(new NamedWriteableRegistry.Entry(NamedDiff.class, ModelAliasMetadata.NAME, ModelAliasMetadata::readDiffFrom));
 
         // Persistent tasks params
         namedWriteables.add(new NamedWriteableRegistry.Entry(PersistentTaskParams.class, MlTasks.DATAFEED_TASK_NAME,

+ 59 - 6
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteTrainedModelAction.java

@@ -14,10 +14,12 @@ import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.support.ActionFilters;
 import org.elasticsearch.action.support.master.AcknowledgedResponse;
 import org.elasticsearch.action.support.master.AcknowledgedTransportMasterNodeAction;
+import org.elasticsearch.cluster.AckedClusterStateUpdateTask;
 import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.cluster.block.ClusterBlockException;
 import org.elasticsearch.cluster.block.ClusterBlockLevel;
 import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
+import org.elasticsearch.cluster.metadata.Metadata;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.inject.Inject;
 import org.elasticsearch.ingest.IngestMetadata;
@@ -29,11 +31,15 @@ import org.elasticsearch.tasks.Task;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.TransportService;
 import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction;
+import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata;
 import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
 import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
 import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
 
+import java.util.ArrayList;
+import java.util.HashMap;
 import java.util.HashSet;
+import java.util.List;
 import java.util.Map;
 import java.util.Objects;
 import java.util.Set;
@@ -80,13 +86,60 @@ public class TransportDeleteTrainedModelAction
             return;
         }
 
-        trainedModelProvider.deleteTrainedModel(request.getId(), ActionListener.wrap(
-            r -> {
-                auditor.info(request.getId(), "trained model deleted");
-                listener.onResponse(AcknowledgedResponse.TRUE);
-            },
+        final ModelAliasMetadata currentMetadata = ModelAliasMetadata.fromState(state);
+        final List<String> modelAliases = new ArrayList<>();
+        for (Map.Entry<String, ModelAliasMetadata.ModelAliasEntry> modelAliasEntry : currentMetadata.modelAliases().entrySet()) {
+            if (modelAliasEntry.getValue().getModelId().equals(id)) {
+                modelAliases.add(modelAliasEntry.getKey());
+            }
+        }
+        for (String modelAlias : modelAliases) {
+            if (referencedModels.contains(modelAlias)) {
+                listener.onFailure(new ElasticsearchStatusException(
+                    "Cannot delete model [{}] as it has a model_alias [{}] that is still referenced by ingest processors",
+                    RestStatus.CONFLICT,
+                    id,
+                    modelAlias));
+                return;
+            }
+        }
+
+        ActionListener<AcknowledgedResponse> nameDeletionListener = ActionListener.wrap(
+            ack -> trainedModelProvider.deleteTrainedModel(request.getId(), ActionListener.wrap(
+                    r -> {
+                        auditor.info(request.getId(), "trained model deleted");
+                        listener.onResponse(AcknowledgedResponse.TRUE);
+                    },
+                    listener::onFailure
+            )),
+
             listener::onFailure
-        ));
+        );
+
+        // No reason to update cluster state, simply delete the model
+        if (modelAliases.isEmpty()) {
+            nameDeletionListener.onResponse(AcknowledgedResponse.of(true));
+            return;
+        }
+
+        clusterService.submitStateUpdateTask("delete-trained-model-alias", new AckedClusterStateUpdateTask(request, nameDeletionListener) {
+            @Override
+            public ClusterState execute(final ClusterState currentState) {
+                final ClusterState.Builder builder = ClusterState.builder(currentState);
+                final ModelAliasMetadata currentMetadata = ModelAliasMetadata.fromState(currentState);
+                if (currentMetadata.modelAliases().isEmpty()) {
+                    return currentState;
+                }
+                final Map<String, ModelAliasMetadata.ModelAliasEntry> newMetadata = new HashMap<>(currentMetadata.modelAliases());
+                logger.info("[{}] delete model model_aliases {}", request.getId(), modelAliases);
+                modelAliases.forEach(newMetadata::remove);
+                final ModelAliasMetadata modelAliasMetadata = new ModelAliasMetadata(newMetadata);
+                builder.metadata(Metadata.builder(currentState.getMetadata())
+                    .putCustom(ModelAliasMetadata.NAME, modelAliasMetadata)
+                    .build());
+                return builder.build();
+            }
+        });
     }
 
     private Set<String> getReferencedModelKeys(IngestMetadata ingestMetadata) {

+ 11 - 3
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java

@@ -9,6 +9,7 @@ package org.elasticsearch.xpack.ml.action;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.support.ActionFilters;
 import org.elasticsearch.action.support.HandledTransportAction;
+import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.collect.Tuple;
 import org.elasticsearch.common.inject.Inject;
 import org.elasticsearch.tasks.Task;
@@ -18,22 +19,27 @@ import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction.Request;
 import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction.Response;
 import org.elasticsearch.xpack.core.ml.job.messages.Messages;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata;
 import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
 
 import java.util.Collections;
 import java.util.HashSet;
+import java.util.Map;
 import java.util.Set;
 
 
 public class TransportGetTrainedModelsAction extends HandledTransportAction<Request, Response> {
 
     private final TrainedModelProvider provider;
+    private final ClusterService clusterService;
     @Inject
     public TransportGetTrainedModelsAction(TransportService transportService,
                                            ActionFilters actionFilters,
+                                           ClusterService clusterService,
                                            TrainedModelProvider trainedModelProvider) {
         super(GetTrainedModelsAction.NAME, transportService, actionFilters, GetTrainedModelsAction.Request::new);
         this.provider = trainedModelProvider;
+        this.clusterService = clusterService;
     }
 
     @Override
@@ -41,7 +47,7 @@ public class TransportGetTrainedModelsAction extends HandledTransportAction<Requ
 
         Response.Builder responseBuilder = Response.builder();
 
-        ActionListener<Tuple<Long, Set<String>>> idExpansionListener = ActionListener.wrap(
+        ActionListener<Tuple<Long, Map<String, Set<String>>>> idExpansionListener = ActionListener.wrap(
             totalAndIds -> {
                 responseBuilder.setTotalCount(totalAndIds.v1());
 
@@ -58,8 +64,10 @@ public class TransportGetTrainedModelsAction extends HandledTransportAction<Requ
                 }
 
                 if (request.getIncludes().isIncludeModelDefinition()) {
+                    Map.Entry<String, Set<String>> modelIdAndAliases = totalAndIds.v2().entrySet().iterator().next();
                     provider.getTrainedModel(
-                        totalAndIds.v2().iterator().next(),
+                        modelIdAndAliases.getKey(),
+                        modelIdAndAliases.getValue(),
                         request.getIncludes(),
                         ActionListener.wrap(
                             config -> listener.onResponse(responseBuilder.setModels(Collections.singletonList(config)).build()),
@@ -80,11 +88,11 @@ public class TransportGetTrainedModelsAction extends HandledTransportAction<Requ
             },
             listener::onFailure
         );
-
         provider.expandIds(request.getResourceId(),
             request.isAllowNoResources(),
             request.getPageParams(),
             new HashSet<>(request.getTags()),
+            ModelAliasMetadata.fromState(clusterService.state()),
             idExpansionListener);
     }
 

+ 35 - 12
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java

@@ -20,6 +20,7 @@ import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.collect.Tuple;
 import org.elasticsearch.common.inject.Inject;
 import org.elasticsearch.common.metrics.CounterMetric;
+import org.elasticsearch.common.util.set.Sets;
 import org.elasticsearch.ingest.IngestMetadata;
 import org.elasticsearch.ingest.IngestService;
 import org.elasticsearch.ingest.IngestStats;
@@ -28,6 +29,7 @@ import org.elasticsearch.tasks.Task;
 import org.elasticsearch.transport.TransportService;
 import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
+import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata;
 import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
 import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
 
@@ -42,6 +44,7 @@ import java.util.Map;
 import java.util.Set;
 import java.util.function.Function;
 import java.util.stream.Collectors;
+import java.util.stream.Stream;
 
 import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
 import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
@@ -73,6 +76,7 @@ public class TransportGetTrainedModelsStatsAction extends HandledTransportAction
                              GetTrainedModelsStatsAction.Request request,
                              ActionListener<GetTrainedModelsStatsAction.Response> listener) {
 
+        final ModelAliasMetadata currentMetadata = ModelAliasMetadata.fromState(clusterService.state());
         GetTrainedModelsStatsAction.Response.Builder responseBuilder = new GetTrainedModelsStatsAction.Response.Builder();
 
         ActionListener<List<InferenceStats>> inferenceStatsListener = ActionListener.wrap(
@@ -84,20 +88,30 @@ public class TransportGetTrainedModelsStatsAction extends HandledTransportAction
 
         ActionListener<NodesStatsResponse> nodesStatsListener = ActionListener.wrap(
             nodesStatsResponse -> {
+                Set<String> allPossiblePipelineReferences = responseBuilder.getExpandedIdsWithAliases()
+                    .entrySet()
+                    .stream()
+                    .flatMap(entry -> Stream.concat(entry.getValue().stream(), Stream.of(entry.getKey())))
+                    .collect(Collectors.toSet());
+                Map<String, Set<String>> pipelineIdsByModelIdsOrAliases = pipelineIdsByModelIdsOrAliases(clusterService.state(),
+                    ingestService,
+                    allPossiblePipelineReferences);
                 Map<String, IngestStats> modelIdIngestStats = inferenceIngestStatsByModelId(nodesStatsResponse,
-                    pipelineIdsByModelIds(clusterService.state(),
-                        ingestService,
-                        responseBuilder.getExpandedIds()));
+                    currentMetadata,
+                    pipelineIdsByModelIdsOrAliases
+                );
                 responseBuilder.setIngestStatsByModelId(modelIdIngestStats);
-                trainedModelProvider.getInferenceStats(responseBuilder.getExpandedIds().toArray(new String[0]), inferenceStatsListener);
+                trainedModelProvider.getInferenceStats(
+                    responseBuilder.getExpandedIdsWithAliases().keySet().toArray(new String[0]),
+                    inferenceStatsListener
+                );
             },
             listener::onFailure
         );
 
-        ActionListener<Tuple<Long, Set<String>>> idsListener = ActionListener.wrap(
+        ActionListener<Tuple<Long, Map<String, Set<String>>>> idsListener = ActionListener.wrap(
             tuple -> {
-                responseBuilder.setExpandedIds(tuple.v2())
-                    .setTotalModelCount(tuple.v1());
+                responseBuilder.setExpandedIdsWithAliases(tuple.v2()).setTotalModelCount(tuple.v1());
                 String[] ingestNodes = ingestNodes(clusterService.state());
                 NodesStatsRequest nodesStatsRequest = new NodesStatsRequest(ingestNodes).clear()
                     .addMetric(NodesStatsRequest.Metric.INGEST.metricName());
@@ -105,27 +119,36 @@ public class TransportGetTrainedModelsStatsAction extends HandledTransportAction
             },
             listener::onFailure
         );
-
         trainedModelProvider.expandIds(request.getResourceId(),
             request.isAllowNoResources(),
             request.getPageParams(),
             Collections.emptySet(),
+            currentMetadata,
             idsListener);
     }
 
     static Map<String, IngestStats> inferenceIngestStatsByModelId(NodesStatsResponse response,
+                                                                  ModelAliasMetadata currentMetadata,
                                                                   Map<String, Set<String>> modelIdToPipelineId) {
 
         Map<String, IngestStats> ingestStatsMap = new HashMap<>();
-
-        modelIdToPipelineId.forEach((modelId, pipelineIds) -> {
+        Map<String, Set<String>> trueModelIdToPipelines = modelIdToPipelineId.entrySet()
+            .stream()
+            .collect(Collectors.toMap(
+                entry -> {
+                    String maybeModelId = currentMetadata.getModelId(entry.getKey());
+                    return maybeModelId == null ? entry.getKey() : maybeModelId;
+                },
+                Map.Entry::getValue,
+                Sets::union
+            ));
+        trueModelIdToPipelines.forEach((modelId, pipelineIds) -> {
             List<IngestStats> collectedStats = response.getNodes()
                 .stream()
                 .map(nodeStats -> ingestStatsForPipelineIds(nodeStats, pipelineIds))
                 .collect(Collectors.toList());
             ingestStatsMap.put(modelId, mergeStats(collectedStats));
         });
-
         return ingestStatsMap;
     }
 
@@ -139,7 +162,7 @@ public class TransportGetTrainedModelsStatsAction extends HandledTransportAction
         return ingestNodes;
     }
 
-    static Map<String, Set<String>> pipelineIdsByModelIds(ClusterState state, IngestService ingestService, Set<String> modelIds) {
+    static Map<String, Set<String>> pipelineIdsByModelIdsOrAliases(ClusterState state, IngestService ingestService, Set<String> modelIds) {
         IngestMetadata ingestMetadata = state.metadata().custom(IngestMetadata.TYPE);
         Map<String, Set<String>> pipelineIdsByModelIds = new HashMap<>();
         if (ingestMetadata == null) {

+ 3 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java

@@ -69,7 +69,9 @@ public class TransportInternalInferModelAction extends HandledTransportAction<Re
                 typedChainTaskExecutor.execute(ActionListener.wrap(
                     inferenceResultsInterfaces -> {
                         model.release();
-                        listener.onResponse(responseBuilder.setInferenceResults(inferenceResultsInterfaces).build());
+                        listener.onResponse(responseBuilder.setInferenceResults(inferenceResultsInterfaces)
+                            .setModelId(model.getModelId())
+                            .build());
                     },
                     e -> {
                         model.release();

+ 8 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java

@@ -38,6 +38,7 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
 import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
 import org.elasticsearch.xpack.core.ml.job.messages.Messages;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata;
 import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
 
 import java.io.IOException;
@@ -123,6 +124,13 @@ public class TransportPutTrainedModelAction extends TransportMasterNodeAction<Re
             .setEstimatedHeapMemory(request.getTrainedModelConfig().getModelDefinition().ramBytesUsed())
             .setEstimatedOperations(request.getTrainedModelConfig().getModelDefinition().getTrainedModel().estimatedNumOperations())
             .build();
+        if (ModelAliasMetadata.fromState(state).getModelId(trainedModelConfig.getModelId()) != null) {
+            listener.onFailure(ExceptionsHelper.badRequestException(
+                "requested model_id [{}] is the same as an existing model_alias. Model model_aliases and ids must be unique",
+                request.getTrainedModelConfig().getModelId()
+            ));
+            return;
+        }
 
         ActionListener<Void> tagsModelIdCheckListener = ActionListener.wrap(
             r -> trainedModelProvider.storeTrainedModel(trainedModelConfig, ActionListener.wrap(

+ 209 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAliasAction.java

@@ -0,0 +1,209 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.action;
+
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.support.ActionFilters;
+import org.elasticsearch.action.support.master.AcknowledgedResponse;
+import org.elasticsearch.action.support.master.AcknowledgedTransportMasterNodeAction;
+import org.elasticsearch.cluster.AckedClusterStateUpdateTask;
+import org.elasticsearch.cluster.ClusterState;
+import org.elasticsearch.cluster.block.ClusterBlockException;
+import org.elasticsearch.cluster.block.ClusterBlockLevel;
+import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
+import org.elasticsearch.cluster.metadata.Metadata;
+import org.elasticsearch.cluster.service.ClusterService;
+import org.elasticsearch.common.inject.Inject;
+import org.elasticsearch.common.logging.HeaderWarning;
+import org.elasticsearch.common.util.set.Sets;
+import org.elasticsearch.license.LicenseUtils;
+import org.elasticsearch.license.XPackLicenseState;
+import org.elasticsearch.tasks.Task;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.transport.TransportService;
+import org.elasticsearch.xpack.core.XPackField;
+import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
+import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAliasAction;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
+import org.elasticsearch.xpack.core.ml.job.messages.Messages;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata;
+import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
+import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
+
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+import java.util.function.Predicate;
+
+import static org.elasticsearch.xpack.core.ml.job.messages.Messages.TRAINED_MODEL_INPUTS_DIFFER_SIGNIFICANTLY;
+
+public class TransportPutTrainedModelAliasAction extends AcknowledgedTransportMasterNodeAction<PutTrainedModelAliasAction.Request> {
+
+    private static final Logger logger = LogManager.getLogger(TransportPutTrainedModelAliasAction.class);
+
+    private final XPackLicenseState licenseState;
+    private final TrainedModelProvider trainedModelProvider;
+    private final InferenceAuditor auditor;
+
+    @Inject
+    public TransportPutTrainedModelAliasAction(
+        TransportService transportService,
+        TrainedModelProvider trainedModelProvider,
+        ClusterService clusterService,
+        ThreadPool threadPool,
+        XPackLicenseState licenseState,
+        ActionFilters actionFilters,
+        InferenceAuditor auditor,
+        IndexNameExpressionResolver indexNameExpressionResolver) {
+        super(
+            PutTrainedModelAliasAction.NAME,
+            transportService,
+            clusterService,
+            threadPool,
+            actionFilters,
+            PutTrainedModelAliasAction.Request::new,
+            indexNameExpressionResolver,
+            ThreadPool.Names.SAME
+        );
+        this.licenseState = licenseState;
+        this.trainedModelProvider = trainedModelProvider;
+        this.auditor = auditor;
+    }
+
+    @Override
+    protected void masterOperation(
+        Task task,
+        PutTrainedModelAliasAction.Request request,
+        ClusterState state,
+        ActionListener<AcknowledgedResponse> listener
+    ) throws Exception {
+        final boolean mlSupported = licenseState.checkFeature(XPackLicenseState.Feature.MACHINE_LEARNING);
+        final Predicate<TrainedModelConfig> isLicensed = (model) -> mlSupported || licenseState.isAllowedByLicense(model.getLicenseLevel());
+        final String oldModelId = ModelAliasMetadata.fromState(state).getModelId(request.getModelAlias());
+
+        if (oldModelId != null && (request.isReassign() == false)) {
+            listener.onFailure(ExceptionsHelper.badRequestException(
+                "cannot assign model_alias [{}] to model_id [{}] as model_alias already refers to [{}]. "
+                    +
+                    "Set parameter [reassign] to [true] if model_alias should be reassigned.",
+                request.getModelAlias(),
+                request.getModelId(),
+                oldModelId));
+            return;
+        }
+        Set<String> modelIds = new HashSet<>();
+        modelIds.add(request.getModelAlias());
+        modelIds.add(request.getModelId());
+        if (oldModelId != null) {
+            modelIds.add(oldModelId);
+        }
+        trainedModelProvider.getTrainedModels(modelIds, GetTrainedModelsAction.Includes.empty(), true, ActionListener.wrap(
+            models -> {
+                TrainedModelConfig newModel = null;
+                TrainedModelConfig oldModel = null;
+                for (TrainedModelConfig config : models) {
+                    if (config.getModelId().equals(request.getModelId())) {
+                        newModel = config;
+                    }
+                    if (config.getModelId().equals(oldModelId)) {
+                        oldModel = config;
+                    }
+                    if (config.getModelId().equals(request.getModelAlias())) {
+                        listener.onFailure(
+                            ExceptionsHelper.badRequestException("model_alias cannot be the same as an existing trained model_id")
+                        );
+                        return;
+                    }
+                }
+                if (newModel == null) {
+                    listener.onFailure(
+                        ExceptionsHelper.missingTrainedModel(request.getModelId())
+                    );
+                    return;
+                }
+                if (isLicensed.test(newModel) == false) {
+                    listener.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING));
+                    return;
+                }
+                // if old model is null, none of these validations matter
+                // we should still allow reassignment even if the old model was some how deleted and the alias still refers to it
+                if (oldModel != null) {
+                    // validate inference configs are the same type. Moving an alias from regression -> classification seems dangerous
+                    if (newModel.getInferenceConfig() != null && oldModel.getInferenceConfig() != null) {
+                        if (newModel.getInferenceConfig().getName().equals(oldModel.getInferenceConfig().getName()) == false) {
+                            listener.onFailure(
+                                ExceptionsHelper.badRequestException(
+                                    "cannot reassign model_alias [{}] to model [{}] "
+                                    + "with inference config type [{}] from model [{}] with type [{}]",
+                                    request.getModelAlias(),
+                                    newModel.getModelId(),
+                                    newModel.getInferenceConfig().getName(),
+                                    oldModel.getModelId(),
+                                    oldModel.getInferenceConfig().getName()
+                                )
+                            );
+                            return;
+                        }
+                    }
+
+                    Set<String> oldInputFields = new HashSet<>(oldModel.getInput().getFieldNames());
+                    Set<String> newInputFields = new HashSet<>(newModel.getInput().getFieldNames());
+                    // TODO should we fail in this case???
+                    if (Sets.difference(oldInputFields, newInputFields).size() > (oldInputFields.size() / 2)
+                    || Sets.intersection(newInputFields, oldInputFields).size() < (oldInputFields.size() / 2)) {
+                        String warning =  Messages.getMessage(
+                            TRAINED_MODEL_INPUTS_DIFFER_SIGNIFICANTLY,
+                            request.getModelId(),
+                            oldModelId);
+                        auditor.warning(oldModelId, warning);
+                        logger.warn("[{}] {}", oldModelId, warning);
+                        HeaderWarning.addWarning(warning);
+                    }
+                }
+                clusterService.submitStateUpdateTask("update-model-alias", new AckedClusterStateUpdateTask(request, listener) {
+                    @Override
+                    public ClusterState execute(final ClusterState currentState) {
+                        return updateModelAlias(currentState, request);
+                    }
+                });
+
+            },
+            listener::onFailure
+        ));
+    }
+
+    static ClusterState updateModelAlias(final ClusterState currentState, final PutTrainedModelAliasAction.Request request) {
+        final ClusterState.Builder builder = ClusterState.builder(currentState);
+        final ModelAliasMetadata currentMetadata = ModelAliasMetadata.fromState(currentState);
+        String currentModelId = currentMetadata.getModelId(request.getModelAlias());
+        final Map<String, ModelAliasMetadata.ModelAliasEntry> newMetadata = new HashMap<>(currentMetadata.modelAliases());
+        if (currentModelId == null) {
+            logger.info("creating new model_alias [{}] for model [{}]", request.getModelAlias(), request.getModelId());
+        } else {
+            logger.info(
+                "updating model_alias [{}] to refer to model [{}] from model [{}]",
+                request.getModelAlias(),
+                request.getModelId(),
+                currentModelId
+            );
+        }
+        newMetadata.put(request.getModelAlias(), new ModelAliasMetadata.ModelAliasEntry(request.getModelId()));
+        final ModelAliasMetadata modelAliasMetadata = new ModelAliasMetadata(newMetadata);
+        builder.metadata(Metadata.builder(currentState.getMetadata()).putCustom(ModelAliasMetadata.NAME, modelAliasMetadata).build());
+        return builder.build();
+    }
+
+    @Override
+    protected ClusterBlockException checkBlock(PutTrainedModelAliasAction.Request request, ClusterState state) {
+        return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE);
+    }
+}

+ 222 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ModelAliasMetadata.java

@@ -0,0 +1,222 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.inference;
+
+import org.elasticsearch.Version;
+import org.elasticsearch.cluster.AbstractDiffable;
+import org.elasticsearch.cluster.ClusterState;
+import org.elasticsearch.cluster.Diff;
+import org.elasticsearch.cluster.DiffableUtils;
+import org.elasticsearch.cluster.NamedDiff;
+import org.elasticsearch.cluster.metadata.Metadata;
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ToXContentObject;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.EnumSet;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Objects;
+
+/**
+ * Custom {@link Metadata} implementation for storing a map of model aliases that point to model IDs
+ */
+public class ModelAliasMetadata implements Metadata.Custom {
+
+    public static final String NAME = "trained_model_alias";
+
+    public static final ModelAliasMetadata EMPTY = new ModelAliasMetadata(new HashMap<>());
+
+    public static ModelAliasMetadata fromState(ClusterState cs) {
+        ModelAliasMetadata modelAliasMetadata = cs.metadata().custom(NAME);
+        return modelAliasMetadata == null ? EMPTY : modelAliasMetadata;
+    }
+
+    public static NamedDiff<Metadata.Custom> readDiffFrom(StreamInput in) throws IOException {
+        return new ModelAliasMetadataDiff(in);
+    }
+
+    private static final ParseField MODEL_ALIASES = new ParseField("model_aliases");
+    private static final ParseField MODEL_ID = new ParseField("model_id");
+
+    @SuppressWarnings("unchecked")
+    private static final ConstructingObjectParser<ModelAliasMetadata, Void> PARSER = new ConstructingObjectParser<>(
+        NAME,
+        // to protect BWC serialization
+        true,
+        args -> new ModelAliasMetadata((Map<String, ModelAliasEntry>)args[0])
+    );
+
+    static {
+        PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> {
+            Map<String, ModelAliasEntry> modelAliases = new HashMap<>();
+            while (p.nextToken() != XContentParser.Token.END_OBJECT) {
+                String modelAlias = p.currentName();
+                modelAliases.put(modelAlias, ModelAliasEntry.fromXContent(p));
+            }
+            return modelAliases;
+        }, MODEL_ALIASES);
+    }
+
+    public static ModelAliasMetadata fromXContent(XContentParser parser) {
+        return PARSER.apply(parser, null);
+    }
+
+    private final Map<String, ModelAliasEntry> modelAliases;
+
+    public ModelAliasMetadata(Map<String, ModelAliasEntry> modelAliases) {
+        this.modelAliases = Collections.unmodifiableMap(modelAliases);
+    }
+
+    public ModelAliasMetadata(StreamInput in) throws IOException {
+        this.modelAliases = Collections.unmodifiableMap(in.readMap(StreamInput::readString, ModelAliasEntry::new));
+    }
+
+    public Map<String, ModelAliasEntry> modelAliases() {
+        return modelAliases;
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject(MODEL_ALIASES.getPreferredName());
+        for (Map.Entry<String, ModelAliasEntry> modelAliasEntry : modelAliases.entrySet()) {
+            builder.field(modelAliasEntry.getKey(), modelAliasEntry.getValue());
+        }
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public Diff<Metadata.Custom> diff(Metadata.Custom previousState) {
+        return new ModelAliasMetadataDiff((ModelAliasMetadata) previousState, this);
+    }
+
+    @Override
+    public EnumSet<Metadata.XContentContext> context() {
+        return Metadata.ALL_CONTEXTS;
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME;
+    }
+
+    @Override
+    public Version getMinimalSupportedVersion() {
+        // TODO change after backport
+        return Version.V_8_0_0;
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeMap(this.modelAliases, StreamOutput::writeString, (stream, val) -> val.writeTo(stream));
+    }
+
+    public String getModelId(String modelAlias) {
+        ModelAliasEntry entry = this.modelAliases.get(modelAlias);
+        if (entry == null) {
+            return null;
+        }
+        return entry.modelId;
+    }
+
+    static class ModelAliasMetadataDiff implements NamedDiff<Metadata.Custom> {
+
+        final Diff<Map<String, ModelAliasEntry>> modelAliasesDiff;
+
+        ModelAliasMetadataDiff(ModelAliasMetadata before, ModelAliasMetadata after) {
+            this.modelAliasesDiff = DiffableUtils.diff(before.modelAliases, after.modelAliases, DiffableUtils.getStringKeySerializer());
+        }
+
+        ModelAliasMetadataDiff(StreamInput in) throws IOException {
+            this.modelAliasesDiff = DiffableUtils.readJdkMapDiff(in, DiffableUtils.getStringKeySerializer(),
+                ModelAliasEntry::new, ModelAliasEntry::readDiffFrom);
+        }
+
+        @Override
+        public Metadata.Custom apply(Metadata.Custom part) {
+            return new ModelAliasMetadata(modelAliasesDiff.apply(((ModelAliasMetadata) part).modelAliases));
+        }
+
+        @Override
+        public String getWriteableName() {
+            return NAME;
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            modelAliasesDiff.writeTo(out);
+        }
+    }
+
+    public static class ModelAliasEntry extends AbstractDiffable<ModelAliasEntry> implements ToXContentObject {
+        private static final ConstructingObjectParser<ModelAliasEntry, Void> PARSER = new ConstructingObjectParser<>(
+            "model_alias_metadata_alias_entry",
+            // to protect BWC serialization
+            true,
+            args -> new ModelAliasEntry((String)args[0])
+        );
+        static {
+            PARSER.declareString(ConstructingObjectParser.constructorArg(), MODEL_ID);
+        }
+
+        private static Diff<ModelAliasEntry> readDiffFrom(StreamInput in) throws IOException {
+            return readDiffFrom(ModelAliasEntry::new, in);
+        }
+
+        private static ModelAliasEntry fromXContent(XContentParser parser) {
+            return PARSER.apply(parser, null);
+        }
+
+        private final String modelId;
+
+        public ModelAliasEntry(String modelId) {
+            this.modelId = modelId;
+        }
+
+        ModelAliasEntry(StreamInput in) throws IOException {
+            this.modelId = in.readString();
+        }
+
+        public String getModelId() {
+            return modelId;
+        }
+
+        @Override
+        public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+            builder.startObject();
+            builder.field(MODEL_ID.getPreferredName(), modelId);
+            builder.endObject();
+            return builder;
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            out.writeString(modelId);
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            ModelAliasEntry modelAliasEntry = (ModelAliasEntry) o;
+            return Objects.equals(modelId, modelAliasEntry.modelId);
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(modelId);
+        }
+    }
+}

+ 6 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java

@@ -157,7 +157,12 @@ public class InferenceProcessor extends AbstractProcessor {
             throw new ElasticsearchStatusException("Unexpected empty inference response", RestStatus.INTERNAL_SERVER_ERROR);
         }
         assert response.getInferenceResults().size() == 1;
-        InferenceResults.writeResult(response.getInferenceResults().get(0), ingestDocument, targetField, modelId);
+        InferenceResults.writeResult(
+            response.getInferenceResults().get(0),
+            ingestDocument,
+            targetField,
+            response.getModelId() != null ? response.getModelId() : modelId
+        );
     }
 
     @Override

+ 232 - 53
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java

@@ -19,6 +19,7 @@ import org.elasticsearch.common.breaker.CircuitBreaker;
 import org.elasticsearch.common.breaker.CircuitBreakingException;
 import org.elasticsearch.common.cache.Cache;
 import org.elasticsearch.common.cache.CacheBuilder;
+import org.elasticsearch.common.cache.CacheLoader;
 import org.elasticsearch.common.cache.RemovalNotification;
 import org.elasticsearch.common.settings.Setting;
 import org.elasticsearch.common.settings.Settings;
@@ -37,12 +38,14 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceDefinition;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.ml.MachineLearning;
+import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata;
 import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService;
 import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
 import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
 import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
 
 import java.util.ArrayDeque;
+import java.util.Collections;
 import java.util.EnumSet;
 import java.util.HashMap;
 import java.util.HashSet;
@@ -50,6 +53,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.Queue;
 import java.util.Set;
+import java.util.concurrent.ExecutionException;
 import java.util.concurrent.TimeUnit;
 
 /**
@@ -108,11 +112,14 @@ public class ModelLoadingService implements ClusterStateListener {
         }
     }
 
-
     private static final Logger logger = LogManager.getLogger(ModelLoadingService.class);
     private final TrainedModelStatsService modelStatsService;
     private final Cache<String, ModelAndConsumer> localModelCache;
+    // Referenced models can be model aliases or IDs
     private final Set<String> referencedModels = new HashSet<>();
+    private final Map<String, String> modelAliasToId = new HashMap<>();
+    private final Map<String, Set<String>> modelIdToModelAliases = new HashMap<>();
+    private final Map<String, Set<String>> modelIdToUpdatedModelAliases = new HashMap<>();
     private final Map<String, Queue<ActionListener<LocalModel>>> loadingListeners = new HashMap<>();
     private final TrainedModelProvider provider;
     private final Set<String> shouldNotAudit;
@@ -148,8 +155,13 @@ public class ModelLoadingService implements ClusterStateListener {
         this.trainedModelCircuitBreaker = ExceptionsHelper.requireNonNull(trainedModelCircuitBreaker, "trainedModelCircuitBreaker");
     }
 
+    // for testing
+    String getModelId(String modelIdOrAlias) {
+        return modelAliasToId.getOrDefault(modelIdOrAlias, modelIdOrAlias);
+    }
+
     boolean isModelCached(String modelId) {
-        return localModelCache.get(modelId) != null;
+        return localModelCache.get(modelAliasToId.getOrDefault(modelId, modelId)) != null;
     }
 
     /**
@@ -195,11 +207,12 @@ public class ModelLoadingService implements ClusterStateListener {
      * The main difference being that models for search are always cached whereas pipeline models
      * are only cached if they are referenced by an ingest pipeline
      *
-     * @param modelId             the model to get
+     * @param modelIdOrAlias       the model id or model alias to get
      * @param consumer            which feature is requesting the model
      * @param modelActionListener the listener to alert when the model has been retrieved.
      */
-    private void getModel(String modelId, Consumer consumer, ActionListener<LocalModel> modelActionListener) {
+    private void getModel(String modelIdOrAlias, Consumer consumer, ActionListener<LocalModel> modelActionListener) {
+        final String modelId = modelAliasToId.getOrDefault(modelIdOrAlias, modelIdOrAlias);
         ModelAndConsumer cachedModel = localModelCache.get(modelId);
         if (cachedModel != null) {
             cachedModel.consumers.add(consumer);
@@ -210,12 +223,16 @@ public class ModelLoadingService implements ClusterStateListener {
                 return;
             }
             modelActionListener.onResponse(cachedModel.model);
-            logger.trace(() -> new ParameterizedMessage("[{}] loaded from cache", modelId));
+            logger.trace(() -> new ParameterizedMessage("[{}] (model_alias [{}]) loaded from cache", modelId, modelIdOrAlias));
             return;
         }
 
-        if (loadModelIfNecessary(modelId, consumer, modelActionListener)) {
-            logger.trace(() -> new ParameterizedMessage("[{}] is loading or loaded, added new listener to queue", modelId));
+        if (loadModelIfNecessary(modelIdOrAlias, consumer, modelActionListener)) {
+            logger.trace(() -> new ParameterizedMessage(
+                "[{}] (model_alias [{}]) is loading or loaded, added new listener to queue",
+                modelId,
+                modelIdOrAlias
+            ));
         }
     }
 
@@ -224,14 +241,15 @@ public class ModelLoadingService implements ClusterStateListener {
      * else if the model is CURRENTLY being loaded the listener is added to be notified when it is loaded
      * else the model load is initiated.
      *
-     * @param modelId The model to get
+     * @param modelIdOrAlias The model to get
      * @param consumer The model consumer
      * @param modelActionListener The listener
      * @return If the model is cached or currently being loaded true is returned. If a new load is started
      * false is returned to indicate a new load event
      */
-    private boolean loadModelIfNecessary(String modelId, Consumer consumer, ActionListener<LocalModel> modelActionListener) {
+    private boolean loadModelIfNecessary(String modelIdOrAlias, Consumer consumer, ActionListener<LocalModel> modelActionListener) {
         synchronized (loadingListeners) {
+            final String modelId = modelAliasToId.getOrDefault(modelIdOrAlias, modelIdOrAlias);
             ModelAndConsumer cachedModel = localModelCache.get(modelId);
             if (cachedModel != null) {
                 cachedModel.consumers.add(consumer);
@@ -257,13 +275,21 @@ public class ModelLoadingService implements ClusterStateListener {
             if (Consumer.PIPELINE == consumer && referencedModels.contains(modelId) == false) {
                 // The model is requested by a pipeline but not referenced by any ingest pipelines.
                 // This means it is a simulate call and the model should not be cached
+                logger.trace(() -> new ParameterizedMessage(
+                    "[{}] (model_alias [{}]) not actively loading, eager loading without cache",
+                    modelId,
+                    modelIdOrAlias
+                ));
                 loadWithoutCaching(modelId, modelActionListener);
             } else {
-                logger.trace(() -> new ParameterizedMessage("[{}] attempting to load and cache", modelId));
+                logger.trace(() -> new ParameterizedMessage(
+                    "[{}] (model_alias [{}]) attempting to load and cache",
+                    modelId,
+                    modelIdOrAlias
+                ));
                 loadingListeners.put(modelId, addFluently(new ArrayDeque<>(), modelActionListener));
                 loadModel(modelId, consumer);
             }
-
             return false;
         } // synchronized (loadingListeners)
     }
@@ -304,7 +330,6 @@ public class ModelLoadingService implements ClusterStateListener {
     private void loadWithoutCaching(String modelId, ActionListener<LocalModel> modelActionListener) {
         // If we the model is not loaded and we did not kick off a new loading attempt, this means that we may be getting called
         // by a simulated pipeline
-        logger.trace(() -> new ParameterizedMessage("[{}] not actively loading, eager loading without cache", modelId));
         provider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.empty(), ActionListener.wrap(
             trainedModelConfig -> {
                 // Verify we can pull the model into memory without causing OOM
@@ -377,34 +402,41 @@ public class ModelLoadingService implements ClusterStateListener {
             trainedModelConfig.getLicenseLevel(),
             modelStatsService,
             trainedModelCircuitBreaker);
-        boolean modelAcquired = false;
+        final ModelAndConsumerLoader modelAndConsumerLoader = new ModelAndConsumerLoader(new ModelAndConsumer(loadedModel, consumer));
         synchronized (loadingListeners) {
-            listeners = loadingListeners.remove(modelId);
-            // if there are no listeners, simply release and leave
-            if (listeners == null) {
-                loadedModel.release();
-                return;
-            }
-
+            populateNewModelAlias(modelId);
             // If the model is referenced, that means it is currently in a pipeline somewhere
             // Also, if the consume is a search consumer, we should always cache it
-            if (referencedModels.contains(modelId) || consumer.equals(Consumer.SEARCH)) {
-                // temporarily increase the reference count before adding to
-                // the cache in case the model is evicted before the listeners
-                // are called in which case acquire() would throw.
-                loadedModel.acquire();
-                localModelCache.put(modelId, new ModelAndConsumer(loadedModel, consumer));
+            if (referencedModels.contains(modelId)
+                || Sets.haveNonEmptyIntersection(modelIdToModelAliases.getOrDefault(modelId, new HashSet<>()), referencedModels)
+                || consumer.equals(Consumer.SEARCH)) {
+                try {
+                    // The local model may already be in cache. If it is, we don't bother adding it to cache.
+                    // If it isn't, we flip an `isLoaded` flag, and increment the model counter to make sure if it is evicted
+                    // between now and when the listeners access it, the circuit breaker reflects actual usage.
+                    localModelCache.computeIfAbsent(modelId, modelAndConsumerLoader);
+                } catch (ExecutionException ee) {
+                    logger.warn(() -> new ParameterizedMessage("[{}] threw when attempting add to cache", modelId), ee);
+                }
                 shouldNotAudit.remove(modelId);
-                modelAcquired = true;
+            }
+            listeners = loadingListeners.remove(modelId);
+            // if there are no listeners, we should just exit
+            if (listeners == null) {
+                // If we newly added it into cache, release the model so that the circuit breaker can still accurately keep track
+                // of memory
+                if(modelAndConsumerLoader.isLoaded()) {
+                    loadedModel.release();
+                }
+                return;
             }
         } // synchronized (loadingListeners)
         for (ActionListener<LocalModel> listener = listeners.poll(); listener != null; listener = listeners.poll()) {
             loadedModel.acquire();
             listener.onResponse(loadedModel);
         }
-        // account for the acquire in the synchronized block above
-        // We cannot simply utilize the same conditionals as `referencedModels` could have changed once we exited the synchronized block
-        if (modelAcquired) {
+        // account for the acquire in the synchronized block above if the model was loaded into the cache
+        if (modelAndConsumerLoader.isLoaded()) {
             loadedModel.release();
         }
     }
@@ -413,6 +445,7 @@ public class ModelLoadingService implements ClusterStateListener {
         Queue<ActionListener<LocalModel>> listeners;
         synchronized (loadingListeners) {
             listeners = loadingListeners.remove(modelId);
+            populateNewModelAlias(modelId);
             if (listeners == null) {
                 return;
             }
@@ -424,6 +457,20 @@ public class ModelLoadingService implements ClusterStateListener {
         }
     }
 
+    private void populateNewModelAlias(String modelId) {
+        Set<String> newModelAliases = modelIdToUpdatedModelAliases.remove(modelId);
+        if (newModelAliases != null && newModelAliases.isEmpty() == false) {
+            logger.trace(() -> new ParameterizedMessage(
+                "[{}] model is now loaded, setting new model_aliases {}",
+                modelId,
+                newModelAliases
+            ));
+            for (String modelAlias: newModelAliases) {
+                modelAliasToId.put(modelAlias, modelId);
+            }
+        }
+    }
+
     private void cacheEvictionListener(RemovalNotification<String, ModelAndConsumer> notification) {
         try {
             if (notification.getRemovalReason() == RemovalNotification.RemovalReason.EVICTED) {
@@ -438,12 +485,15 @@ public class ModelLoadingService implements ClusterStateListener {
                     INFERENCE_MODEL_CACHE_TTL.getKey());
                 auditIfNecessary(notification.getKey(), msg);
             }
-
-            logger.trace(() -> new ParameterizedMessage("Persisting stats for evicted model [{}]",
-                notification.getValue().model.getModelId()));
+            String modelId = modelAliasToId.getOrDefault(notification.getKey(), notification.getKey());
+            logger.trace(() -> new ParameterizedMessage(
+                "Persisting stats for evicted model [{}] (model_aliases {})",
+                modelId,
+                modelIdToModelAliases.getOrDefault(modelId, new HashSet<>())
+            ));
 
             // If the model is no longer referenced, flush the stats to persist as soon as possible
-            notification.getValue().model.persistStats(referencedModels.contains(notification.getKey()) == false);
+            notification.getValue().model.persistStats(referencedModels.contains(modelId) == false);
         } finally {
             notification.getValue().model.release();
         }
@@ -451,46 +501,112 @@ public class ModelLoadingService implements ClusterStateListener {
 
     @Override
     public void clusterChanged(ClusterChangedEvent event) {
-        // If ingest data has not changed or if the current node is not an ingest node, don't bother caching models
-        if (event.changedCustomMetadataSet().contains(IngestMetadata.TYPE) == false ||
-            event.state().nodes().getLocalNode().isIngestNode() == false) {
+        final boolean prefetchModels = event.state().nodes().getLocalNode().isIngestNode();
+        // If we are not prefetching models and there were no model alias changes, don't bother handling the changes
+        if ((prefetchModels == false)
+            && (event.changedCustomMetadataSet().contains(IngestMetadata.TYPE) == false)
+            && (event.changedCustomMetadataSet().contains(ModelAliasMetadata.NAME) == false)) {
             return;
         }
 
         ClusterState state = event.state();
         IngestMetadata currentIngestMetadata = state.metadata().custom(IngestMetadata.TYPE);
-        Set<String> allReferencedModelKeys = getReferencedModelKeys(currentIngestMetadata);
-        if (allReferencedModelKeys.equals(referencedModels)) {
-            return;
-        }
-        Set<String> referencedModelsBeforeClusterState = null;
+        Set<String> allReferencedModelKeys = event.changedCustomMetadataSet().contains(IngestMetadata.TYPE) ?
+            getReferencedModelKeys(currentIngestMetadata) :
+            new HashSet<>(referencedModels);
+        Set<String> referencedModelsBeforeClusterState;
         Set<String> loadingModelBeforeClusterState = null;
-        Set<String> removedModels = null;
+        Set<String> removedModels;
+        Map<String, Set<String>> addedModelViaAliases = new HashMap<>();
+        Map<String, Set<String>> oldIdToAliases;
         synchronized (loadingListeners) {
+            oldIdToAliases = new HashMap<>(modelIdToModelAliases);
+            Map<String, String> changedAliases = gatherLazyChangedAliasesAndUpdateModelAliases(
+                event,
+                prefetchModels,
+                allReferencedModelKeys
+            );
+
+            // if we are not prefetching, exit now.
+            if (prefetchModels == false) {
+                return;
+            }
+
             referencedModelsBeforeClusterState = new HashSet<>(referencedModels);
             if (logger.isTraceEnabled()) {
                 loadingModelBeforeClusterState = new HashSet<>(loadingListeners.keySet());
             }
             removedModels = Sets.difference(referencedModelsBeforeClusterState, allReferencedModelKeys);
 
-            // Remove all cached models that are not referenced by any processors
-            // and are not used in search
-            removedModels.forEach(modelId -> {
-                ModelAndConsumer modelAndConsumer = localModelCache.get(modelId);
-                if (modelAndConsumer != null && modelAndConsumer.consumers.contains(Consumer.SEARCH) == false) {
-                    localModelCache.invalidate(modelId);
-                }
-            });
             // Remove the models that are no longer referenced
             referencedModels.removeAll(removedModels);
             shouldNotAudit.removeAll(removedModels);
 
+            // Remove all cached models that are not referenced by any processors
+            // and are not used in search
+            for (String modelAliasOrId : removedModels) {
+                String modelId = changedAliases.getOrDefault(modelAliasOrId, modelAliasToId.getOrDefault(modelAliasOrId, modelAliasOrId));
+                // If the "old" model_alias is referenced, we don't want to invalidate. This way the model that now has the model_alias
+                // can be loaded in first
+                boolean oldModelAliasesNotReferenced = Sets.haveEmptyIntersection(referencedModels,
+                    oldIdToAliases.getOrDefault(modelId, Collections.emptySet()));
+                // If the model itself is referenced, we shouldn't evict.
+                boolean modelIsNotReferenced = referencedModels.contains(modelId) == false;
+                // If a model_alias change causes it to NOW be referenced, we shouldn't attempt to evict it
+                boolean newModelAliasesNotReferenced = Sets.haveEmptyIntersection(referencedModels,
+                    modelIdToModelAliases.getOrDefault(modelId, Collections.emptySet()));
+                if (oldModelAliasesNotReferenced && newModelAliasesNotReferenced && modelIsNotReferenced) {
+                    ModelAndConsumer modelAndConsumer = localModelCache.get(modelId);
+                    if (modelAndConsumer != null && modelAndConsumer.consumers.contains(Consumer.SEARCH) == false) {
+                        logger.trace("[{} ({})] invalidated from cache", modelId, modelAliasOrId);
+                        localModelCache.invalidate(modelId);
+                    }
+                }
+            }
             // Remove all that are still referenced, i.e. the intersection of allReferencedModelKeys and referencedModels
             allReferencedModelKeys.removeAll(referencedModels);
+            for (String newlyReferencedModel : allReferencedModelKeys) {
+                // check if the model_alias has changed in this round
+                String modelId = changedAliases.getOrDefault(
+                    newlyReferencedModel,
+                    // If the model_alias hasn't changed, get the model id IF it is a model_alias, otherwise we assume it is an id
+                    modelAliasToId.getOrDefault(
+                        newlyReferencedModel,
+                        newlyReferencedModel
+                    )
+                );
+                // Verify that it isn't an old model id but just a new model_alias
+                if (referencedModels.contains(modelId) == false) {
+                    addedModelViaAliases.computeIfAbsent(modelId, k -> new HashSet<>()).add(newlyReferencedModel);
+                }
+            }
+            // For any previously referenced model, the model_alias COULD have changed, so it is actually a NEWLY referenced model
+            for (Map.Entry<String, String> modelAliasAndId : changedAliases.entrySet()) {
+                String modelAlias = modelAliasAndId.getKey();
+                String modelId = modelAliasAndId.getValue();
+                if (referencedModels.contains(modelAlias)) {
+                    // we need to load the underlying model since its model_alias is referenced
+                    addedModelViaAliases.computeIfAbsent(modelId, k -> new HashSet<>()).add(modelAlias);
+                    // If we are in cache, keep the old translation for now, it will be updated later
+                    String oldModelId = modelAliasToId.get(modelAlias);
+                    if (oldModelId != null && localModelCache.get(oldModelId) != null) {
+                        modelIdToUpdatedModelAliases.computeIfAbsent(modelId, k -> new HashSet<>()).add(modelAlias);
+                    } else {
+                        // If we are not cached, might as well add the translation right away as new callers will have to load
+                        // from disk anyways.
+                        modelAliasToId.put(modelAlias, modelId);
+                    }
+                } else {
+                    // Add model_alias and id here, since the model_alias wasn't previously referenced,
+                    // no reason to wait on updating the model_alias -> model_id mapping
+                    modelAliasToId.put(modelAlias, modelId);
+                }
+            }
+            // Gather ALL currently referenced model ids
             referencedModels.addAll(allReferencedModelKeys);
 
             // Populate loadingListeners key so we know that we are currently loading the model
-            for (String modelId : allReferencedModelKeys) {
+            for (String modelId : addedModelViaAliases.keySet()) {
                 loadingListeners.computeIfAbsent(modelId, (s) -> new ArrayDeque<>());
             }
         } // synchronized (loadingListeners)
@@ -503,9 +619,51 @@ public class ModelLoadingService implements ClusterStateListener {
                 logger.trace("cluster state event changed referenced models: before {} after {}", referencedModelsBeforeClusterState,
                     referencedModels);
             }
+            if (oldIdToAliases.equals(modelIdToModelAliases) == false) {
+                logger.trace("model id to alias mappings changed. before {} after {}. Model alias to IDs {}",
+                    oldIdToAliases,
+                    modelIdToModelAliases,
+                    modelAliasToId);
+            }
+            if (addedModelViaAliases.isEmpty() == false) {
+                logger.trace("adding new models via model_aliases and ids: {}", addedModelViaAliases);
+            }
+            if (modelIdToUpdatedModelAliases.isEmpty() == false) {
+                logger.trace("delayed model aliases to update {}", modelIdToModelAliases);
+            }
         }
         removedModels.forEach(this::auditUnreferencedModel);
-        loadModelsForPipeline(allReferencedModelKeys);
+        loadModelsForPipeline(addedModelViaAliases.keySet());
+    }
+
+    private Map<String, String> gatherLazyChangedAliasesAndUpdateModelAliases(ClusterChangedEvent event,
+                                                                              boolean prefetchModels,
+                                                                              Set<String> allReferencedModelKeys) {
+        Map<String, String> changedAliases = new HashMap<>();
+        if (event.changedCustomMetadataSet().contains(ModelAliasMetadata.NAME)) {
+            final Map<java.lang.String, ModelAliasMetadata.ModelAliasEntry> modelAliasesToIds = new HashMap<>(
+                ModelAliasMetadata.fromState(event.state()).modelAliases()
+            );
+            modelIdToModelAliases.clear();
+            for (Map.Entry<java.lang.String, ModelAliasMetadata.ModelAliasEntry> aliasToId : modelAliasesToIds.entrySet()) {
+                modelIdToModelAliases.computeIfAbsent(aliasToId.getValue().getModelId(), k -> new HashSet<>()).add(aliasToId.getKey());
+                java.lang.String modelId = modelAliasToId.get(aliasToId.getKey());
+                if (modelId != null
+                    && modelId.equals(aliasToId.getValue().getModelId()) == false) {
+                    if (prefetchModels && allReferencedModelKeys.contains(aliasToId.getKey())) {
+                        changedAliases.put(aliasToId.getKey(), aliasToId.getValue().getModelId());
+                    } else {
+                        modelAliasToId.put(aliasToId.getKey(), aliasToId.getValue().getModelId());
+                    }
+                }
+                if (modelId == null) {
+                    modelAliasToId.put(aliasToId.getKey(), aliasToId.getValue().getModelId());
+                }
+            }
+            Set<java.lang.String> removedAliases = Sets.difference(modelAliasToId.keySet(), modelAliasesToIds.keySet());
+            modelAliasToId.keySet().removeAll(removedAliases);
+        }
+        return changedAliases;
     }
 
     private void auditIfNecessary(String modelId, MessageSupplier msg) {
@@ -600,4 +758,25 @@ public class ModelLoadingService implements ClusterStateListener {
             });
         }
     }
+
+    private static class ModelAndConsumerLoader implements CacheLoader<String, ModelAndConsumer> {
+
+        private boolean loaded;
+        private final ModelAndConsumer modelAndConsumer;
+
+        ModelAndConsumerLoader(ModelAndConsumer modelAndConsumer) {
+            this.modelAndConsumer = modelAndConsumer;
+        }
+
+        boolean isLoaded() {
+            return loaded;
+        }
+
+        @Override
+        public ModelAndConsumer load(String key) throws Exception {
+            loaded = true;
+            modelAndConsumer.model.acquire();
+            return modelAndConsumer;
+        }
+    }
 }

+ 77 - 13
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java

@@ -80,6 +80,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TrainedMo
 import org.elasticsearch.xpack.core.ml.job.messages.Messages;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
+import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata;
 
 import java.io.IOException;
 import java.io.InputStream;
@@ -97,6 +98,7 @@ import java.util.Map;
 import java.util.Objects;
 import java.util.Set;
 import java.util.TreeSet;
+import java.util.function.Function;
 import java.util.stream.Collectors;
 
 import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
@@ -442,6 +444,13 @@ public class TrainedModelProvider {
     public void getTrainedModel(final String modelId,
                                 final GetTrainedModelsAction.Includes includes,
                                 final ActionListener<TrainedModelConfig> finalListener) {
+        getTrainedModel(modelId, Collections.emptySet(), includes, finalListener);
+    }
+
+    public void getTrainedModel(final String modelId,
+                                final Set<String> modelAliases,
+                                final GetTrainedModelsAction.Includes includes,
+                                final ActionListener<TrainedModelConfig> finalListener) {
 
         if (MODELS_STORED_AS_RESOURCE.contains(modelId)) {
             try {
@@ -455,6 +464,7 @@ public class TrainedModelProvider {
 
         ActionListener<TrainedModelConfig.Builder> getTrainedModelListener = ActionListener.wrap(
             modelBuilder -> {
+                modelBuilder.setModelAliases(modelAliases);
                 if ((includes.isIncludeFeatureImportanceBaseline() || includes.isIncludeTotalFeatureImportance()
                   || includes.isIncludeHyperparameters()) == false) {
                     finalListener.onResponse(modelBuilder.build());
@@ -570,6 +580,18 @@ public class TrainedModelProvider {
             multiSearchResponseActionListener);
     }
 
+    public void getTrainedModels(Set<String> modelIds,
+                                 GetTrainedModelsAction.Includes includes,
+                                 boolean allowNoResources,
+                                 final ActionListener<List<TrainedModelConfig>> finalListener) {
+        getTrainedModels(
+            modelIds.stream().collect(Collectors.toMap(Function.identity(), _k -> Collections.emptySet())),
+            includes,
+            allowNoResources,
+            finalListener
+        );
+    }
+
     /**
      * Gets all the provided trained config model objects
      *
@@ -577,11 +599,15 @@ public class TrainedModelProvider {
      * This does no expansion on the ids.
      * It assumes that there are fewer than 10k.
      */
-    public void getTrainedModels(Set<String> modelIds,
+    public void getTrainedModels(Map<String, Set<String>> modelIds,
                                  GetTrainedModelsAction.Includes includes,
                                  boolean allowNoResources,
                                  final ActionListener<List<TrainedModelConfig>> finalListener) {
-        QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds(modelIds.toArray(new String[0])));
+        QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(
+            QueryBuilders
+                .idsQuery()
+                .addIds(modelIds.keySet().toArray(new String[0]))
+        );
 
         SearchRequest searchRequest = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN)
             .addSort(TrainedModelConfig.MODEL_ID.getPreferredName(), SortOrder.ASC)
@@ -590,8 +616,8 @@ public class TrainedModelProvider {
             .setSize(modelIds.size())
             .request();
         List<TrainedModelConfig.Builder> configs = new ArrayList<>(modelIds.size());
-        Set<String> modelsInIndex = Sets.difference(modelIds, MODELS_STORED_AS_RESOURCE);
-        Set<String> modelsAsResource = Sets.intersection(MODELS_STORED_AS_RESOURCE, modelIds);
+        Set<String> modelsInIndex = Sets.difference(modelIds.keySet(), MODELS_STORED_AS_RESOURCE);
+        Set<String> modelsAsResource = Sets.intersection(MODELS_STORED_AS_RESOURCE, modelIds.keySet());
         for(String modelId : modelsAsResource) {
             try {
                 configs.add(loadModelFromResource(modelId, true));
@@ -613,12 +639,12 @@ public class TrainedModelProvider {
                 if ((includes.isIncludeFeatureImportanceBaseline() || includes.isIncludeTotalFeatureImportance()
                   || includes.isIncludeHyperparameters()) == false) {
                     finalListener.onResponse(modelBuilders.stream()
-                        .map(TrainedModelConfig.Builder::build)
+                        .map(b -> b.setModelAliases(modelIds.get(b.getModelId())).build())
                         .sorted(Comparator.comparing(TrainedModelConfig::getModelId))
                         .collect(Collectors.toList()));
                     return;
                 }
-                this.getTrainedModelMetadata(modelIds, ActionListener.wrap(
+                this.getTrainedModelMetadata(modelIds.keySet(), ActionListener.wrap(
                     metadata ->
                         finalListener.onResponse(modelBuilders.stream()
                             .map(builder -> {
@@ -633,9 +659,8 @@ public class TrainedModelProvider {
                                     if (includes.isIncludeHyperparameters()) {
                                         builder.setHyperparameters(modelMetadata.getHyperparameters());
                                     }
-
                                 }
-                                return builder.build();
+                                return builder.setModelAliases(modelIds.get(builder.getModelId())).build();
                             })
                             .sorted(Comparator.comparing(TrainedModelConfig::getModelId))
                             .collect(Collectors.toList())),
@@ -679,7 +704,7 @@ public class TrainedModelProvider {
                 // We previously expanded the IDs.
                 // If the config has gone missing between then and now we should throw if allowNoResources is false
                 // Otherwise, treat it as if it was never expanded to begin with.
-                Set<String> missingConfigs = Sets.difference(modelIds, observedIds);
+                Set<String> missingConfigs = Sets.difference(modelIds.keySet(), observedIds);
                 if (missingConfigs.isEmpty() == false && allowNoResources == false) {
                     getTrainedModelListener.onFailure(new ResourceNotFoundException(Messages.INFERENCE_NOT_FOUND_MULTIPLE, missingConfigs));
                     return;
@@ -729,8 +754,23 @@ public class TrainedModelProvider {
                           boolean allowNoResources,
                           PageParams pageParams,
                           Set<String> tags,
-                          ActionListener<Tuple<Long, Set<String>>> idsListener) {
+                          ModelAliasMetadata modelAliasMetadata,
+                          ActionListener<Tuple<Long, Map<String, Set<String>>>> idsListener) {
         String[] tokens = Strings.tokenizeToStringArray(idExpression, ",");
+        Set<String> expandedIdsFromAliases = new HashSet<>();
+        if (Strings.isAllOrWildcard(tokens) == false) {
+            for (String token : tokens) {
+                if (Regex.isSimpleMatchPattern(token)) {
+                    for (String modelAlias : modelAliasMetadata.modelAliases().keySet()) {
+                        if (Regex.simpleMatch(token, modelAlias)) {
+                            expandedIdsFromAliases.add(modelAliasMetadata.getModelId(modelAlias));
+                        }
+                    }
+                } else if (modelAliasMetadata.getModelId(token) != null) {
+                    expandedIdsFromAliases.add(modelAliasMetadata.getModelId(token));
+                }
+            }
+        }
         Set<String> matchedResourceIds = matchedResourceIds(tokens);
         Set<String> foundResourceIds;
         if (tags.isEmpty()) {
@@ -744,12 +784,17 @@ public class TrainedModelProvider {
                 }
             }
         }
+        expandedIdsFromAliases.addAll(Arrays.asList(tokens));
+
+        // We need to include the translated model alias, and ANY tokens that were not translated
+        String[] tokensForQuery = expandedIdsFromAliases.toArray(new String[0]);
+
         SearchSourceBuilder sourceBuilder = new SearchSourceBuilder()
             .sort(SortBuilders.fieldSort(TrainedModelConfig.MODEL_ID.getPreferredName())
                 // If there are no resources, there might be no mapping for the id field.
                 // This makes sure we don't get an error if that happens.
                 .unmappedType("long"))
-            .query(buildExpandIdsQuery(tokens, tags))
+            .query(buildExpandIdsQuery(tokensForQuery, tags))
             // We "buffer" the from and size to take into account models stored as resources.
             // This is so we handle the edge cases when the model that is stored as a resource is at the start/end of
             // a page.
@@ -785,9 +830,28 @@ public class TrainedModelProvider {
                             foundFromDocs.add(idValue.toString());
                         }
                     }
-                    Set<String> allFoundIds = collectIds(pageParams, foundResourceIds, foundFromDocs);
+                    Map<String, Set<String>> allFoundIds = collectIds(pageParams, foundResourceIds, foundFromDocs)
+                        .stream()
+                        .collect(Collectors.toMap(Function.identity(), k -> new HashSet<>()));
+
+                    // We technically have matched on model tokens and any reversed referenced aliases
+                    // We may end up with "over matching" on the aliases (matching on an alias that was not provided)
+                    // But the expanded ID matcher does not care.
+                    Set<String> matchedTokens = new HashSet<>(allFoundIds.keySet());
+
+                    // We should gather ALL model aliases referenced by the given model IDs
+                    // This way the callers have access to them
+                    modelAliasMetadata.modelAliases().forEach((alias, modelIdEntry) -> {
+                        final String modelId = modelIdEntry.getModelId();
+                        if (allFoundIds.containsKey(modelId)) {
+                            allFoundIds.get(modelId).add(alias);
+                            matchedTokens.add(alias);
+                        }
+                    });
+
+                    // Reverse lookup to see what model aliases were matched by their found trained model IDs
                     ExpandedIdsMatcher requiredMatches = new ExpandedIdsMatcher(tokens, allowNoResources);
-                    requiredMatches.filterMatchedIds(allFoundIds);
+                    requiredMatches.filterMatchedIds(matchedTokens);
                     if (requiredMatches.hasUnmatchedIds()) {
                         idsListener.onFailure(ExceptionsHelper.missingTrainedModel(requiredMatches.unmatchedIdsString()));
                     } else {

+ 57 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestPutTrainedModelAliasAction.java

@@ -0,0 +1,57 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+package org.elasticsearch.xpack.ml.rest.inference;
+
+import static java.util.Collections.singletonList;
+import static org.elasticsearch.rest.RestRequest.Method.PUT;
+
+import java.io.IOException;
+import java.util.List;
+
+import org.elasticsearch.client.node.NodeClient;
+import org.elasticsearch.rest.BaseRestHandler;
+import org.elasticsearch.rest.RestRequest;
+import org.elasticsearch.rest.action.RestToXContentListener;
+import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAliasAction;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
+import org.elasticsearch.xpack.ml.MachineLearning;
+
+public class RestPutTrainedModelAliasAction extends BaseRestHandler {
+
+    @Override
+    public List<Route> routes() {
+        return singletonList(
+            new Route(
+                PUT,
+                MachineLearning.BASE_PATH
+                    + "trained_models/{"
+                    + TrainedModelConfig.MODEL_ID.getPreferredName()
+                    + "}/model_aliases/{"
+                    + PutTrainedModelAliasAction.Request.MODEL_ALIAS
+                    + "}"
+
+            )
+        );
+    }
+
+    @Override
+    public String getName() {
+        return "ml_put_trained_model_alias_action";
+    }
+
+    @Override
+    protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException {
+        String modelAlias = restRequest.param(PutTrainedModelAliasAction.Request.MODEL_ALIAS);
+        String modelId = restRequest.param(TrainedModelConfig.MODEL_ID.getPreferredName());
+        boolean reassign = restRequest.paramAsBoolean(PutTrainedModelAliasAction.Request.REASSIGN, false);
+        return channel -> client.execute(
+            PutTrainedModelAliasAction.INSTANCE,
+            new PutTrainedModelAliasAction.Request(modelAlias, modelId, reassign),
+            new RestToXContentListener<>(channel)
+        );
+    }
+}

+ 3 - 2
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsActionTests.java

@@ -35,6 +35,7 @@ import org.elasticsearch.plugins.IngestPlugin;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
+import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata;
 import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
 import org.junit.Before;
 
@@ -129,7 +130,6 @@ public class TransportGetTrainedModelsStatsActionTests extends ESTestCase {
             null, Collections.singletonList(SKINNY_INGEST_PLUGIN), client);
     }
 
-
     public void testInferenceIngestStatsByModelId() {
         List<NodeStats> nodeStatsList = Arrays.asList(
             buildNodeStats(
@@ -198,6 +198,7 @@ public class TransportGetTrainedModelsStatsActionTests extends ESTestCase {
             put("trained_model_2", new HashSet<>(Arrays.asList("pipeline1", "pipeline2")));
         }};
         Map<String, IngestStats> ingestStatsMap = TransportGetTrainedModelsStatsAction.inferenceIngestStatsByModelId(response,
+            ModelAliasMetadata.EMPTY,
             pipelineIdsByModelIds);
 
         assertThat(ingestStatsMap.keySet(), equalTo(new HashSet<>(Arrays.asList("trained_model_1", "trained_model_2"))));
@@ -238,7 +239,7 @@ public class TransportGetTrainedModelsStatsActionTests extends ESTestCase {
         ClusterState clusterState = buildClusterStateWithModelReferences(modelId1, modelId2, modelId3);
 
         Map<String, Set<String>> pipelineIdsByModelIds =
-            TransportGetTrainedModelsStatsAction.pipelineIdsByModelIds(clusterState, ingestService, modelIds);
+            TransportGetTrainedModelsStatsAction.pipelineIdsByModelIdsOrAliases(clusterState, ingestService, modelIds);
 
         assertThat(pipelineIdsByModelIds.keySet(), equalTo(modelIds));
         assertThat(pipelineIdsByModelIds,

+ 44 - 6
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java

@@ -74,6 +74,7 @@ public class InferenceProcessorTests extends ESTestCase {
                 ClassificationConfig.EMPTY_PARAMS,
                 1.0,
                 1.0)),
+            null,
             true);
         inferenceProcessor.mutateDocument(response, document);
 
@@ -110,6 +111,7 @@ public class InferenceProcessorTests extends ESTestCase {
                 classificationConfig,
                 0.6,
                 0.6)),
+            null,
             true);
         inferenceProcessor.mutateDocument(response, document);
 
@@ -152,6 +154,7 @@ public class InferenceProcessorTests extends ESTestCase {
                 classificationConfig,
                 0.6,
                 0.6)),
+            null,
             true);
         inferenceProcessor.mutateDocument(response, document);
 
@@ -193,6 +196,7 @@ public class InferenceProcessorTests extends ESTestCase {
                 classificationConfig,
                 0.6,
                 0.6)),
+            null,
             true);
         inferenceProcessor.mutateDocument(response, document);
 
@@ -218,14 +222,16 @@ public class InferenceProcessorTests extends ESTestCase {
         IngestDocument document = new IngestDocument(source, ingestMetadata);
 
         InternalInferModelAction.Response response = new InternalInferModelAction.Response(
-            Collections.singletonList(new RegressionInferenceResults(0.7, regressionConfig)), true);
+            Collections.singletonList(new RegressionInferenceResults(0.7, regressionConfig)),
+            null,
+            true);
         inferenceProcessor.mutateDocument(response, document);
 
         assertThat(document.getFieldValue("ml.my_processor.foo", Double.class), equalTo(0.7));
         assertThat(document.getFieldValue("ml.my_processor.model_id", String.class), equalTo("regression_model"));
     }
 
-    public void testMutateDocumentRegressionWithTopFetures() {
+    public void testMutateDocumentRegressionWithTopFeatures() {
         RegressionConfig regressionConfig = new RegressionConfig("foo", 2);
         RegressionConfigUpdate regressionConfigUpdate = new RegressionConfigUpdate("foo", 2);
         InferenceProcessor inferenceProcessor = new InferenceProcessor(client,
@@ -245,7 +251,9 @@ public class InferenceProcessorTests extends ESTestCase {
         featureInfluence.add(new RegressionFeatureImportance("feature_2", -42.0));
 
         InternalInferModelAction.Response response = new InternalInferModelAction.Response(
-            Collections.singletonList(new RegressionInferenceResults(0.7, regressionConfig, featureInfluence)), true);
+            Collections.singletonList(new RegressionInferenceResults(0.7, regressionConfig, featureInfluence)),
+            null,
+            true);
         inferenceProcessor.mutateDocument(response, document);
 
         assertThat(document.getFieldValue("ml.my_processor.foo", Double.class), equalTo(0.7));
@@ -383,7 +391,9 @@ public class InferenceProcessorTests extends ESTestCase {
         assertThat(inferenceProcessor.buildRequest(document).isPreviouslyLicensed(), is(false));
 
         InternalInferModelAction.Response response = new InternalInferModelAction.Response(
-            Collections.singletonList(new RegressionInferenceResults(0.7, RegressionConfig.EMPTY_PARAMS)), true);
+            Collections.singletonList(new RegressionInferenceResults(0.7, RegressionConfig.EMPTY_PARAMS)),
+            null,
+            true);
         inferenceProcessor.handleResponse(response, document, (doc, ex) -> {
             assertThat(doc, is(not(nullValue())));
             assertThat(ex, is(nullValue()));
@@ -392,7 +402,9 @@ public class InferenceProcessorTests extends ESTestCase {
         assertThat(inferenceProcessor.buildRequest(document).isPreviouslyLicensed(), is(true));
 
         response = new InternalInferModelAction.Response(
-            Collections.singletonList(new RegressionInferenceResults(0.7, RegressionConfig.EMPTY_PARAMS)), false);
+            Collections.singletonList(new RegressionInferenceResults(0.7, RegressionConfig.EMPTY_PARAMS)),
+            null,
+            false);
 
         inferenceProcessor.handleResponse(response, document, (doc, ex) -> {
             assertThat(doc, is(not(nullValue())));
@@ -424,11 +436,37 @@ public class InferenceProcessorTests extends ESTestCase {
         IngestDocument document = new IngestDocument(source, ingestMetadata);
 
         InternalInferModelAction.Response response = new InternalInferModelAction.Response(
-            Collections.singletonList(new WarningInferenceResults("something broke")), true);
+            Collections.singletonList(new WarningInferenceResults("something broke")), null, true);
         inferenceProcessor.mutateDocument(response, document);
 
         assertThat(document.hasField(targetField), is(false));
         assertThat(document.hasField("ml.warning"), is(true));
         assertThat(document.hasField("ml.my_processor"), is(false));
     }
+
+    public void testMutateDocumentWithModelIdResult() {
+        String modelAlias = "special_model";
+        String modelId = "regression-123";
+        InferenceProcessor inferenceProcessor = new InferenceProcessor(client,
+            auditor,
+            "my_processor",
+            null,
+            "ml.my_processor",
+            modelAlias,
+            new RegressionConfigUpdate("foo", null),
+            Collections.emptyMap());
+
+        Map<String, Object> source = new HashMap<>();
+        Map<String, Object> ingestMetadata = new HashMap<>();
+        IngestDocument document = new IngestDocument(source, ingestMetadata);
+
+        InternalInferModelAction.Response response = new InternalInferModelAction.Response(
+            Collections.singletonList(new RegressionInferenceResults(0.7, new RegressionConfig("foo"))),
+            modelId,
+            true);
+        inferenceProcessor.mutateDocument(response, document);
+
+        assertThat(document.getFieldValue("ml.my_processor.foo", Double.class), equalTo(0.7));
+        assertThat(document.getFieldValue("ml.my_processor.model_id", String.class), equalTo(modelId));
+    }
 }

+ 140 - 9
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java

@@ -22,6 +22,7 @@ import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.breaker.CircuitBreaker;
 import org.elasticsearch.common.breaker.CircuitBreakingException;
 import org.elasticsearch.common.bytes.BytesReference;
+import org.elasticsearch.common.collect.Tuple;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.transport.TransportAddress;
 import org.elasticsearch.common.unit.ByteSizeValue;
@@ -44,6 +45,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceDefinition;
 import org.elasticsearch.xpack.core.ml.job.messages.Messages;
 import org.elasticsearch.xpack.ml.MachineLearning;
+import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata;
 import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService;
 import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
 import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
@@ -59,10 +61,12 @@ import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
+import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.TimeUnit;
+import java.util.stream.Collectors;
 
 import static org.elasticsearch.xpack.ml.MachineLearning.UTILITY_THREAD_POOL_NAME;
 import static org.hamcrest.Matchers.equalTo;
@@ -282,7 +286,6 @@ public class ModelLoadingServiceTests extends ESTestCase {
         verify(trainedModelProvider, times(5)).getTrainedModelForInference(eq(model3), any());
     }
 
-
     public void testWhenCacheEnabledButNotIngestNode() throws Exception {
         String model1 = "test-uncached-not-ingest-model-1";
         withTrainedModel(model1, 1L);
@@ -538,6 +541,101 @@ public class ModelLoadingServiceTests extends ESTestCase {
         assertEquals(1, model.getReferenceCount());
     }
 
+    public void testGetCachedModelViaModelAliases() throws Exception {
+        String model1 = "test-load-model-1";
+        String model2 = "test-load-model-2";
+        withTrainedModel(model1, 1L);
+        withTrainedModel(model2, 1L);
+
+        ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider,
+            auditor,
+            threadPool,
+            clusterService,
+            trainedModelStatsService,
+            Settings.EMPTY,
+            "test-node",
+            circuitBreaker);
+
+        modelLoadingService.clusterChanged(aliasChangeEvent(
+            true,
+            new String[]{"loaded_model"},
+            true,
+            Arrays.asList(Tuple.tuple(model1, "loaded_model"))
+            ));
+
+        String[] modelIds = new String[]{model1, "loaded_model"};
+        for(int i = 0; i < 10; i++) {
+            String model = modelIds[i%2];
+            PlainActionFuture<LocalModel> future = new PlainActionFuture<>();
+            modelLoadingService.getModelForPipeline(model, future);
+            assertThat(future.get(), is(not(nullValue())));
+        }
+
+        verify(trainedModelProvider, times(1)).getTrainedModelForInference(eq(model1), any());
+
+        assertTrue(modelLoadingService.isModelCached(model1));
+        assertTrue(modelLoadingService.isModelCached("loaded_model"));
+
+        // alias change only
+        modelLoadingService.clusterChanged(aliasChangeEvent(
+            true,
+            new String[]{"loaded_model"},
+            false,
+            Arrays.asList(Tuple.tuple(model2, "loaded_model"))
+        ));
+
+        modelIds = new String[]{model2, "loaded_model"};
+        for(int i = 0; i < 10; i++) {
+            String model = modelIds[i%2];
+            PlainActionFuture<LocalModel> future = new PlainActionFuture<>();
+            modelLoadingService.getModelForPipeline(model, future);
+            assertThat(future.get(), is(not(nullValue())));
+        }
+
+        verify(trainedModelProvider, times(1)).getTrainedModelForInference(eq(model2), any());
+        assertTrue(modelLoadingService.isModelCached(model2));
+        assertTrue(modelLoadingService.isModelCached("loaded_model"));
+    }
+
+    public void testAliasesGetUpdatedEvenWhenNotIngestNode() throws IOException {
+        String model1 = "test-load-model-1";
+        withTrainedModel(model1, 1L);
+        String model2 = "test-load-model-2";
+        withTrainedModel(model2, 1L);
+
+        ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider,
+            auditor,
+            threadPool,
+            clusterService,
+            trainedModelStatsService,
+            Settings.EMPTY,
+            "test-node",
+            circuitBreaker);
+
+        modelLoadingService.clusterChanged(aliasChangeEvent(
+            false,
+            new String[0],
+            false,
+            Arrays.asList(Tuple.tuple(model1, "loaded_model"))
+        ));
+
+        assertThat(modelLoadingService.getModelId("loaded_model"), equalTo(model1));
+
+        modelLoadingService.clusterChanged(aliasChangeEvent(
+            false,
+            new String[0],
+            false,
+            Arrays.asList(
+                Tuple.tuple(model1, "loaded_model_again"),
+                Tuple.tuple(model1, "loaded_model_foo"),
+                Tuple.tuple(model2, "loaded_model")
+            )
+        ));
+        assertThat(modelLoadingService.getModelId("loaded_model"), equalTo(model2));
+        assertThat(modelLoadingService.getModelId("loaded_model_foo"), equalTo(model1));
+        assertThat(modelLoadingService.getModelId("loaded_model_again"), equalTo(model1));
+    }
+
     @SuppressWarnings("unchecked")
     private void withTrainedModel(String modelId, long size) {
         InferenceDefinition definition = mock(InferenceDefinition.class);
@@ -601,6 +699,21 @@ public class ModelLoadingServiceTests extends ESTestCase {
         return ingestChangedEvent(true, modelId);
     }
 
+    private static ClusterChangedEvent aliasChangeEvent(boolean isIngestNode,
+                                                        String[] modelId,
+                                                        boolean ingestToo,
+                                                        List<Tuple<String, String>> modelIdAndAliases) throws IOException {
+        ClusterChangedEvent event = mock(ClusterChangedEvent.class);
+        Set<String> set = new HashSet<>();
+        set.add(ModelAliasMetadata.NAME);
+        if (ingestToo) {
+            set.add(IngestMetadata.TYPE);
+        }
+        when(event.changedCustomMetadataSet()).thenReturn(set);
+        when(event.state()).thenReturn(withModelReferencesAndAliasChange(isIngestNode, modelId, modelIdAndAliases));
+        return event;
+    }
+
     private static ClusterChangedEvent ingestChangedEvent(boolean isIngestNode, String... modelId) throws IOException {
         ClusterChangedEvent event = mock(ClusterChangedEvent.class);
         when(event.changedCustomMetadataSet()).thenReturn(Collections.singleton(IngestMetadata.TYPE));
@@ -609,14 +722,17 @@ public class ModelLoadingServiceTests extends ESTestCase {
     }
 
     private static ClusterState buildClusterStateWithModelReferences(boolean isIngestNode, String... modelId) throws IOException {
-        Map<String, PipelineConfiguration> configurations = new HashMap<>(modelId.length);
-        for (String id : modelId) {
-            configurations.put("pipeline_with_model_" + id, newConfigurationWithInferenceProcessor(id));
-        }
-        IngestMetadata ingestMetadata = new IngestMetadata(configurations);
+        return builder(isIngestNode).metadata(addIngest(Metadata.builder(), modelId)).build();
+    }
+
+    private static ClusterState withModelReferencesAndAliasChange(boolean isIngestNode,
+                                                                  String[] modelId,
+                                                                  List<Tuple<String, String>> modelIdAndAliases) throws IOException {
+        return builder(isIngestNode).metadata(addAliases(addIngest(Metadata.builder(), modelId), modelIdAndAliases)).build();
+    }
 
+    private static ClusterState.Builder builder(boolean isIngestNode) {
         return ClusterState.builder(new ClusterName("_name"))
-            .metadata(Metadata.builder().putCustom(IngestMetadata.TYPE, ingestMetadata))
             .nodes(DiscoveryNodes.builder().add(
                 new DiscoveryNode("node_name",
                     "node_id",
@@ -625,8 +741,23 @@ public class ModelLoadingServiceTests extends ESTestCase {
                     isIngestNode ? Collections.singleton(DiscoveryNodeRole.INGEST_ROLE) : Collections.emptySet(),
                     Version.CURRENT))
                 .localNodeId("node_id")
-                .build())
-            .build();
+                .build()
+            );
+    }
+
+    private static Metadata.Builder addIngest(Metadata.Builder builder, String... modelId) throws IOException {
+        Map<String, PipelineConfiguration> configurations = new HashMap<>(modelId.length);
+        for (String id : modelId) {
+            configurations.put("pipeline_with_model_" + id, newConfigurationWithInferenceProcessor(id));
+        }
+        IngestMetadata ingestMetadata = new IngestMetadata(configurations);
+        return builder.putCustom(IngestMetadata.TYPE, ingestMetadata);
+    }
+
+    private static Metadata.Builder addAliases(Metadata.Builder builder, List<Tuple<String, String>> modelIdAndAliases) {
+        ModelAliasMetadata modelAliasMetadata = new ModelAliasMetadata(modelIdAndAliases.stream()
+            .collect(Collectors.toMap(Tuple::v2, t -> new ModelAliasMetadata.ModelAliasEntry(t.v1()))));
+        return builder.putCustom(ModelAliasMetadata.NAME, modelAliasMetadata);
     }
 
     private static PipelineConfiguration newConfigurationWithInferenceProcessor(String modelId) throws IOException {

+ 1 - 0
x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java

@@ -135,6 +135,7 @@ public class Constants {
         "cluster:admin/xpack/ml/filters/update",
         "cluster:admin/xpack/ml/inference/delete",
         "cluster:admin/xpack/ml/inference/put",
+        "cluster:admin/xpack/ml/inference/model_aliases/put",
         "cluster:admin/xpack/ml/job/close",
         "cluster:admin/xpack/ml/job/data/post",
         "cluster:admin/xpack/ml/job/delete",

+ 1 - 1
x-pack/plugin/src/test/java/org/elasticsearch/xpack/test/rest/AbstractXPackRestTest.java

@@ -18,12 +18,12 @@ import org.elasticsearch.common.util.concurrent.ThreadContext;
 import org.elasticsearch.common.xcontent.support.XContentMapValues;
 import org.elasticsearch.plugins.MetadataUpgrader;
 import org.elasticsearch.test.SecuritySettingsSourceField;
-import org.elasticsearch.test.rest.ESRestTestCase;
 import org.elasticsearch.test.rest.yaml.ClientYamlTestCandidate;
 import org.elasticsearch.test.rest.yaml.ClientYamlTestResponse;
 import org.elasticsearch.test.rest.yaml.ESClientYamlSuiteTestCase;
 import org.elasticsearch.xpack.core.ml.MlConfigIndex;
 import org.elasticsearch.xpack.core.ml.MlMetaIndex;
+import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
 import org.elasticsearch.xpack.core.ml.integration.MlRestTestStateCleaner;
 import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
 import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndexFields;

+ 40 - 0
x-pack/plugin/src/test/resources/rest-api-spec/api/ml.put_trained_model_alias.json

@@ -0,0 +1,40 @@
+{
+  "ml.put_trained_model_alias":{
+    "documentation":{
+      "url":"https://www.elastic.co/guide/en/elasticsearch/reference/current/put-trained-models-aliases.html",
+      "description":"Creates a new model alias (or reassigns an existing one) to refer to the trained model"
+    },
+    "stability":"beta",
+    "visibility":"public",
+    "headers":{
+      "accept": [ "application/json"],
+      "content_type": ["application/json"]
+    },
+    "url":{
+      "paths":[
+        {
+          "path":"/_ml/trained_models/{model_id}/model_aliases/{model_alias}",
+          "methods":[
+            "PUT"
+          ],
+          "parts":{
+            "model_alias":{
+              "type":"string",
+              "description":"The trained model alias to update"
+            },
+            "model_id": {
+              "type": "string",
+              "description": "The trained model where the model alias should be assigned"
+            }
+          }
+        }
+      ]
+    },
+    "params":{
+      "reassign":{
+        "type":"boolean",
+        "description":"If the model_alias already exists and points to a separate model_id, this parameter must be true. Defaults to false."
+      }
+    }
+  }
+}

+ 23 - 0
x-pack/plugin/src/test/resources/rest-api-spec/test/ml/get_trained_model_stats.yml

@@ -79,6 +79,12 @@ setup:
                }
             }
           }
+  - do:
+      headers:
+        Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
+      ml.put_trained_model_alias:
+        model_alias: "my-regression"
+        model_id: "a-unused-regression-model1"
 ---
 "Test get stats given missing trained model":
 
@@ -175,3 +181,20 @@ setup:
   - match: { count: 1 }
   - match: { trained_model_stats.0.model_id: another-regression-model }
   - match: { trained_model_stats.0.pipeline_count: 0 }
+
+
+# test with model alias
+  - do:
+      ml.get_trained_models_stats:
+        model_id: "my-regression"
+
+  - match: { count: 1 }
+  - match: { trained_model_stats.0.model_id: a-unused-regression-model1 }
+
+  - do:
+      ml.get_trained_models_stats:
+        model_id: "my-regression,another-regression-model"
+
+  - match: { count: 2 }
+  - match: { trained_model_stats.0.model_id: a-unused-regression-model1 }
+  - match: { trained_model_stats.1.model_id: another-regression-model }

+ 86 - 24
x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml

@@ -561,30 +561,6 @@ setup:
       ml.delete_trained_model:
         model_id: "missing-trained-model"
 ---
-"Test delete given used trained model":
-  - do:
-      ingest.put_pipeline:
-        id: "regression-model-pipeline"
-        body:  >
-          {
-            "processors": [
-              {
-                "inference" : {
-                  "model_id" : "a-regression-model-0",
-                  "inference_config": {"regression": {}},
-                  "target_field": "regression_field",
-                  "field_map": {}
-                }
-              }
-            ]
-          }
-  - match: { acknowledged: true }
-
-  - do:
-      catch: conflict
-      ml.delete_trained_model:
-        model_id: "a-regression-model-0"
----
 "Test get pre-packaged trained models":
   - do:
       ml.get_trained_models:
@@ -851,3 +827,89 @@ setup:
         model_id: "a-regression-model-1"
         include_model_definition: true
         decompress_definition: false
+---
+"Test put model model aliases":
+
+  - do:
+      ml.put_trained_model_alias:
+        model_alias: "regression-model"
+        model_id: "a-regression-model-1"
+  - do:
+      ml.get_trained_models:
+        model_id: "regression-model,a-classification-model"
+
+  - match: { count: 2 }
+  - length: { trained_model_configs: 2 }
+  - match: { trained_model_configs.0.model_id: "a-classification-model" }
+  - match: { trained_model_configs.1.model_id: "a-regression-model-1" }
+
+  - do:
+      ml.put_trained_model_alias:
+        model_alias: "regression-model"
+        model_id: "a-regression-model-0"
+        reassign: true
+  - do:
+      ml.get_trained_models:
+        model_id: "regression-model,a-classification-model"
+
+  - match: { count: 2 }
+  - length: { trained_model_configs: 2 }
+  - match: { trained_model_configs.0.model_id: "a-classification-model" }
+  - match: { trained_model_configs.1.model_id: "a-regression-model-0" }
+
+  - do:
+      ml.put_trained_model_alias:
+        model_alias: "regression-model-again"
+        model_id: "a-regression-model-0"
+  - do:
+      ml.get_trained_models:
+        model_id: "a-regression-model-*"
+        size: 1
+
+  - length: { trained_model_configs: 1 }
+  - match: { trained_model_configs.0.model_id: "a-regression-model-0" }
+  - match: { trained_model_configs.0.metadata.model_aliases.0: "regression-model" }
+  - match: { trained_model_configs.0.metadata.model_aliases.1: "regression-model-again" }
+---
+"Test update model alias with model id referring to missing model":
+  - do:
+      catch: missing
+      ml.put_trained_model_alias:
+        model_alias: "regression-model"
+        model_id: "missing-model"
+---
+"Test update model alias with bad alias":
+  - do:
+      catch: /must start with alphanumeric and cannot end with numbers/
+      ml.put_trained_model_alias:
+        model_alias: "regression-model-123123"
+        model_id: "regression-model-123123"
+  - do:
+      catch: bad_request
+      ml.put_trained_model_alias:
+        model_alias: "z-classification-model"
+        model_id: "z-classification-model"
+---
+"Test update model alias where alias exists but old model id is different inference type":
+  - do:
+      ml.put_trained_model_alias:
+        model_alias: "regression-model"
+        model_id: "a-regression-model-0"
+  - do:
+      catch: bad_request
+      ml.put_trained_model_alias:
+        model_alias: "regression-model"
+        model_id: "a-classification-model"
+        reassign: true
+---
+"Test update model alias where alias exists but reassign is false":
+  - do:
+      ml.put_trained_model_alias:
+        model_alias: "regression-model"
+        model_id: "a-regression-model-0"
+  - do:
+      catch: bad_request
+      ml.put_trained_model_alias:
+        model_alias: "regression-model"
+        model_id: "a-regression-model-1"
+        reassign: false