Browse Source

[ML] Add new include flag to GET inference/<model_id> API for model training metadata (#61922)

Adds new flag include to the get trained models API
The flag initially has two valid values: definition, total_feature_importance.
Consequently, the old include_model_definition flag is now deprecated.
When total_feature_importance is included, the total_feature_importance field is included in the model metadata object.
Including definition is the same as previously setting include_model_definition=true.
Benjamin Trent 5 years ago
parent
commit
fdb7b6d3b5
29 changed files with 820 additions and 163 deletions
  1. 3 3
      client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java
  2. 26 8
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsRequest.java
  3. 208 0
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/metadata/TotalFeatureImportance.java
  4. 2 2
      client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java
  5. 10 3
      client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java
  6. 6 5
      client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java
  7. 63 0
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/metadata/TotalFeatureImportanceTests.java
  8. 11 9
      docs/java-rest/high-level/ml/get-trained-models.asciidoc
  9. 100 18
      docs/reference/ml/df-analytics/apis/get-inference-trained-model.asciidoc
  10. 17 0
      docs/reference/ml/ml-shared.asciidoc
  11. 58 8
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsAction.java
  12. 23 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java
  13. 33 21
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TotalFeatureImportance.java
  14. 4 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TrainedModelMetadata.java
  15. 1 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java
  16. 28 4
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsRequestTests.java
  17. 8 4
      x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java
  18. 24 6
      x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java
  19. 1 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java
  20. 18 8
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java
  21. 1 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java
  22. 3 3
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java
  23. 112 34
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java
  24. 11 5
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java
  25. 9 9
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java
  26. 2 2
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java
  27. 1 1
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/trainedmodels/langident/LangIdentNeuralNetworkInferenceTests.java
  28. 3 4
      x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models.json
  29. 34 2
      x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml

+ 3 - 3
client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java

@@ -779,9 +779,9 @@ final class MLRequestConverters {
             params.putParam(GetTrainedModelsRequest.DECOMPRESS_DEFINITION,
                 Boolean.toString(getTrainedModelsRequest.getDecompressDefinition()));
         }
-        if (getTrainedModelsRequest.getIncludeDefinition() != null) {
-            params.putParam(GetTrainedModelsRequest.INCLUDE_MODEL_DEFINITION,
-                Boolean.toString(getTrainedModelsRequest.getIncludeDefinition()));
+        if (getTrainedModelsRequest.getIncludes().isEmpty() == false) {
+            params.putParam(GetTrainedModelsRequest.INCLUDE,
+                Strings.collectionToCommaDelimitedString(getTrainedModelsRequest.getIncludes()));
         }
         if (getTrainedModelsRequest.getTags() != null) {
             params.putParam(GetTrainedModelsRequest.TAGS, Strings.collectionToCommaDelimitedString(getTrainedModelsRequest.getTags()));

+ 26 - 8
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsRequest.java

@@ -26,21 +26,26 @@ import org.elasticsearch.client.ml.inference.TrainedModelConfig;
 import org.elasticsearch.common.Nullable;
 
 import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Objects;
 import java.util.Optional;
+import java.util.Set;
 
 public class GetTrainedModelsRequest implements Validatable {
 
+    private static final String DEFINITION = "definition";
+    private static final String TOTAL_FEATURE_IMPORTANCE = "total_feature_importance";
     public static final String ALLOW_NO_MATCH = "allow_no_match";
-    public static final String INCLUDE_MODEL_DEFINITION = "include_model_definition";
     public static final String FOR_EXPORT = "for_export";
     public static final String DECOMPRESS_DEFINITION = "decompress_definition";
     public static final String TAGS = "tags";
+    public static final String INCLUDE = "include";
 
     private final List<String> ids;
     private Boolean allowNoMatch;
-    private Boolean includeDefinition;
+    private Set<String> includes = new HashSet<>();
     private Boolean decompressDefinition;
     private Boolean forExport;
     private PageParams pageParams;
@@ -86,19 +91,32 @@ public class GetTrainedModelsRequest implements Validatable {
         return this;
     }
 
-    public Boolean getIncludeDefinition() {
-        return includeDefinition;
+    public Set<String> getIncludes() {
+        return Collections.unmodifiableSet(includes);
+    }
+
+    public GetTrainedModelsRequest includeDefinition() {
+        this.includes.add(DEFINITION);
+        return this;
+    }
+
+    public GetTrainedModelsRequest includeTotalFeatureImportance() {
+        this.includes.add(TOTAL_FEATURE_IMPORTANCE);
+        return this;
     }
 
     /**
      * Whether to include the full model definition.
      *
      * The full model definition can be very large.
-     *
+     * @deprecated Use {@link GetTrainedModelsRequest#includeDefinition()}
      * @param includeDefinition If {@code true}, the definition is included.
      */
+    @Deprecated
     public GetTrainedModelsRequest setIncludeDefinition(Boolean includeDefinition) {
-        this.includeDefinition = includeDefinition;
+        if (includeDefinition != null && includeDefinition) {
+            return this.includeDefinition();
+        }
         return this;
     }
 
@@ -173,13 +191,13 @@ public class GetTrainedModelsRequest implements Validatable {
         return Objects.equals(ids, other.ids)
             && Objects.equals(allowNoMatch, other.allowNoMatch)
             && Objects.equals(decompressDefinition, other.decompressDefinition)
-            && Objects.equals(includeDefinition, other.includeDefinition)
+            && Objects.equals(includes, other.includes)
             && Objects.equals(forExport, other.forExport)
             && Objects.equals(pageParams, other.pageParams);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(ids, allowNoMatch, pageParams, decompressDefinition, includeDefinition, forExport);
+        return Objects.hash(ids, allowNoMatch, pageParams, decompressDefinition, includes, forExport);
     }
 }

+ 208 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/metadata/TotalFeatureImportance.java

@@ -0,0 +1,208 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.client.ml.inference.trainedmodel.metadata;
+
+import org.elasticsearch.common.Nullable;
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ObjectParser;
+import org.elasticsearch.common.xcontent.ToXContentObject;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParseException;
+import org.elasticsearch.common.xcontent.XContentParser;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
+import java.util.Objects;
+
+public class TotalFeatureImportance implements ToXContentObject {
+
+    private static final String NAME = "total_feature_importance";
+    public static final ParseField FEATURE_NAME = new ParseField("feature_name");
+    public static final ParseField IMPORTANCE = new ParseField("importance");
+    public static final ParseField CLASSES = new ParseField("classes");
+    public static final ParseField MEAN_MAGNITUDE = new ParseField("mean_magnitude");
+    public static final ParseField MIN = new ParseField("min");
+    public static final ParseField MAX = new ParseField("max");
+
+    @SuppressWarnings("unchecked")
+    public static final ConstructingObjectParser<TotalFeatureImportance, Void> PARSER = new ConstructingObjectParser<>(NAME,
+        true,
+        a -> new TotalFeatureImportance((String)a[0], (Importance)a[1], (List<ClassImportance>)a[2]));
+
+    static {
+        PARSER.declareString(ConstructingObjectParser.constructorArg(), FEATURE_NAME);
+        PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(), Importance.PARSER, IMPORTANCE);
+        PARSER.declareObjectArray(ConstructingObjectParser.optionalConstructorArg(), ClassImportance.PARSER, CLASSES);
+    }
+
+    public static TotalFeatureImportance fromXContent(XContentParser parser) {
+        return PARSER.apply(parser, null);
+    }
+
+    public final String featureName;
+    public final Importance importance;
+    public final List<ClassImportance> classImportances;
+
+    TotalFeatureImportance(String featureName, @Nullable Importance importance, @Nullable List<ClassImportance> classImportances) {
+        this.featureName = featureName;
+        this.importance = importance;
+        this.classImportances = classImportances == null ? Collections.emptyList() : classImportances;
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        builder.field(FEATURE_NAME.getPreferredName(), featureName);
+        if (importance != null) {
+            builder.field(IMPORTANCE.getPreferredName(), importance);
+        }
+        if (classImportances.isEmpty() == false) {
+            builder.field(CLASSES.getPreferredName(), classImportances);
+        }
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        TotalFeatureImportance that = (TotalFeatureImportance) o;
+        return Objects.equals(that.importance, importance)
+            && Objects.equals(featureName, that.featureName)
+            && Objects.equals(classImportances, that.classImportances);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(featureName, importance, classImportances);
+    }
+
+    public static class Importance implements ToXContentObject {
+        private static final String NAME = "importance";
+
+        public static final ConstructingObjectParser<Importance, Void> PARSER = new ConstructingObjectParser<>(NAME,
+            true,
+            a -> new Importance((double)a[0], (double)a[1], (double)a[2]));
+
+        static {
+            PARSER.declareDouble(ConstructingObjectParser.constructorArg(), MEAN_MAGNITUDE);
+            PARSER.declareDouble(ConstructingObjectParser.constructorArg(), MIN);
+            PARSER.declareDouble(ConstructingObjectParser.constructorArg(), MAX);
+        }
+
+        private final double meanMagnitude;
+        private final double min;
+        private final double max;
+
+        public Importance(double meanMagnitude, double min, double max) {
+            this.meanMagnitude = meanMagnitude;
+            this.min = min;
+            this.max = max;
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            Importance that = (Importance) o;
+            return Double.compare(that.meanMagnitude, meanMagnitude) == 0 &&
+                Double.compare(that.min, min) == 0 &&
+                Double.compare(that.max, max) == 0;
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(meanMagnitude, min, max);
+        }
+
+        @Override
+        public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+            builder.startObject();
+            builder.field(MEAN_MAGNITUDE.getPreferredName(), meanMagnitude);
+            builder.field(MIN.getPreferredName(), min);
+            builder.field(MAX.getPreferredName(), max);
+            builder.endObject();
+            return builder;
+        }
+    }
+
+    public static class ClassImportance implements ToXContentObject {
+        private static final String NAME = "total_class_importance";
+
+        public static final ParseField CLASS_NAME = new ParseField("class_name");
+        public static final ParseField IMPORTANCE = new ParseField("importance");
+
+        public static final ConstructingObjectParser<ClassImportance, Void> PARSER = new ConstructingObjectParser<>(NAME,
+            true,
+            a -> new ClassImportance(a[0], (Importance)a[1]));
+
+        static {
+            PARSER.declareField(ConstructingObjectParser.constructorArg(), (p, c) -> {
+                if (p.currentToken() == XContentParser.Token.VALUE_STRING) {
+                    return p.text();
+                } else if (p.currentToken() == XContentParser.Token.VALUE_NUMBER) {
+                    return p.numberValue();
+                } else if (p.currentToken() == XContentParser.Token.VALUE_BOOLEAN) {
+                    return p.booleanValue();
+                }
+                throw new XContentParseException("Unsupported token [" + p.currentToken() + "]");
+            }, CLASS_NAME, ObjectParser.ValueType.VALUE);
+            PARSER.declareObject(ConstructingObjectParser.constructorArg(), Importance.PARSER, IMPORTANCE);
+        }
+
+        public static ClassImportance fromXContent(XContentParser parser) {
+            return PARSER.apply(parser, null);
+        }
+
+        public final Object className;
+        public final Importance importance;
+
+        ClassImportance(Object className, Importance importance) {
+            this.className = className;
+            this.importance = importance;
+        }
+
+        @Override
+        public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+            builder.startObject();
+            builder.field(CLASS_NAME.getPreferredName(), className);
+            builder.field(IMPORTANCE.getPreferredName(), importance);
+            builder.endObject();
+            return builder;
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            ClassImportance that = (ClassImportance) o;
+            return Objects.equals(that.importance, importance) && Objects.equals(className, that.className);
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(className, importance);
+        }
+
+    }
+}

+ 2 - 2
client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java

@@ -894,7 +894,7 @@ public class MLRequestConvertersTests extends ESTestCase {
         GetTrainedModelsRequest getRequest = new GetTrainedModelsRequest(modelId1, modelId2, modelId3)
             .setAllowNoMatch(false)
             .setDecompressDefinition(true)
-            .setIncludeDefinition(false)
+            .includeDefinition()
             .setTags("tag1", "tag2")
             .setPageParams(new PageParams(100, 300));
 
@@ -908,7 +908,7 @@ public class MLRequestConvertersTests extends ESTestCase {
                 hasEntry("allow_no_match", "false"),
                 hasEntry("decompress_definition", "true"),
                 hasEntry("tags", "tag1,tag2"),
-                hasEntry("include_model_definition", "false")
+                hasEntry("include", "definition")
             ));
         assertNull(request.getEntity());
     }

+ 10 - 3
client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java

@@ -2227,7 +2227,10 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
 
         {
             GetTrainedModelsResponse getTrainedModelsResponse = execute(
-                new GetTrainedModelsRequest(modelIdPrefix + 0).setDecompressDefinition(true).setIncludeDefinition(true),
+                new GetTrainedModelsRequest(modelIdPrefix + 0)
+                    .setDecompressDefinition(true)
+                    .includeDefinition()
+                    .includeTotalFeatureImportance(),
                 machineLearningClient::getTrainedModels,
                 machineLearningClient::getTrainedModelsAsync);
 
@@ -2238,7 +2241,10 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
             assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getModelId(), equalTo(modelIdPrefix + 0));
 
             getTrainedModelsResponse = execute(
-                new GetTrainedModelsRequest(modelIdPrefix + 0).setDecompressDefinition(false).setIncludeDefinition(true),
+                new GetTrainedModelsRequest(modelIdPrefix + 0)
+                    .setDecompressDefinition(false)
+                    .includeTotalFeatureImportance()
+                    .includeDefinition(),
                 machineLearningClient::getTrainedModels,
                 machineLearningClient::getTrainedModelsAsync);
 
@@ -2249,7 +2255,8 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
             assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getModelId(), equalTo(modelIdPrefix + 0));
 
             getTrainedModelsResponse = execute(
-                new GetTrainedModelsRequest(modelIdPrefix + 0).setDecompressDefinition(false).setIncludeDefinition(false),
+                new GetTrainedModelsRequest(modelIdPrefix + 0)
+                    .setDecompressDefinition(false),
                 machineLearningClient::getTrainedModels,
                 machineLearningClient::getTrainedModelsAsync);
             assertThat(getTrainedModelsResponse.getCount(), equalTo(1L));

+ 6 - 5
client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java

@@ -3694,11 +3694,12 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
             // tag::get-trained-models-request
             GetTrainedModelsRequest request = new GetTrainedModelsRequest("my-trained-model") // <1>
                 .setPageParams(new PageParams(0, 1)) // <2>
-                .setIncludeDefinition(false) // <3>
-                .setDecompressDefinition(false) // <4>
-                .setAllowNoMatch(true) // <5>
-                .setTags("regression") // <6>
-                .setForExport(false); // <7>
+                .includeDefinition() // <3>
+                .includeTotalFeatureImportance() // <4>
+                .setDecompressDefinition(false) // <5>
+                .setAllowNoMatch(true) // <6>
+                .setTags("regression") // <7>
+                .setForExport(false); // <8>
             // end::get-trained-models-request
             request.setTags((List<String>)null);
 

+ 63 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/metadata/TotalFeatureImportanceTests.java

@@ -0,0 +1,63 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.inference.trainedmodel.metadata;
+
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.test.AbstractXContentTestCase;
+
+import java.io.IOException;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+
+public class TotalFeatureImportanceTests extends AbstractXContentTestCase<TotalFeatureImportance> {
+
+
+    public static TotalFeatureImportance randomInstance() {
+        return new TotalFeatureImportance(
+            randomAlphaOfLength(10),
+            randomBoolean() ? null : randomImportance(),
+            randomBoolean() ?
+                null :
+                Stream.generate(() -> new TotalFeatureImportance.ClassImportance(randomAlphaOfLength(10), randomImportance()))
+                    .limit(randomIntBetween(1, 10))
+                    .collect(Collectors.toList())
+            );
+    }
+
+    private static TotalFeatureImportance.Importance randomImportance() {
+        return new TotalFeatureImportance.Importance(randomDouble(), randomDouble(), randomDouble());
+    }
+
+    @Override
+    protected TotalFeatureImportance createTestInstance() {
+        return randomInstance();
+    }
+
+    @Override
+    protected TotalFeatureImportance doParseInstance(XContentParser parser) throws IOException {
+        return TotalFeatureImportance.fromXContent(parser);
+    }
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return true;
+    }
+
+}

+ 11 - 9
docs/java-rest/high-level/ml/get-trained-models.asciidoc

@@ -22,26 +22,28 @@ IDs, or the special wildcard `_all` to get all trained models.
 --------------------------------------------------
 include-tagged::{doc-tests-file}[{api}-request]
 --------------------------------------------------
-<1> Constructing a new GET request referencing an existing Trained Model
+<1> Constructing a new GET request referencing an existing trained model
 <2> Set the paging parameters
 <3> Indicate if the complete model definition should be included
-<4> Should the definition be fully decompressed on GET
-<5> Allow empty response if no Trained Models match the provided ID patterns.
-    If false, an error will be thrown if no Trained Models match the
+<4> Indicate if the total feature importance for the features used in training
+    should be included in the model `metadata` field.
+<5> Should the definition be fully decompressed on GET
+<6> Allow empty response if no trained models match the provided ID patterns.
+    If false, an error will be thrown if no trained models match the
     ID patterns.
-<6> An optional list of tags used to narrow the model search. A Trained Model
+<7> An optional list of tags used to narrow the model search. A trained model
     can have many tags or none. The trained models in the response will
     contain all the provided tags.
-<7> Optional boolean value indicating if certain fields should be removed on
-    retrieval. This is useful for getting the trained model in a format that
-    can then be put into another cluster.
+<8> Optional boolean value for requesting the trained model in a format that can
+    then be put into another cluster. Certain fields that can only be set when
+    the model is imported are removed.
 
 include::../execution.asciidoc[]
 
 [id="{upid}-{api}-response"]
 ==== Response
 
-The returned +{response}+ contains the requested Trained Model.
+The returned +{response}+ contains the requested trained model.
 
 ["source","java",subs="attributes,callouts,macros"]
 --------------------------------------------------

+ 100 - 18
docs/reference/ml/df-analytics/apis/get-inference-trained-model.asciidoc

@@ -29,19 +29,19 @@ experimental[]
 [[ml-get-inference-prereq]]
 == {api-prereq-title}
 
-If the {es} {security-features} are enabled, you must have the following 
+If the {es} {security-features} are enabled, you must have the following
 privileges:
 
 * cluster: `monitor_ml`
-  
-For more information, see <<security-privileges>> and 
+
+For more information, see <<security-privileges>> and
 {ml-docs-setup-privileges}.
 
 
 [[ml-get-inference-desc]]
 == {api-description-title}
 
-You can get information for multiple trained models in a single API request by 
+You can get information for multiple trained models in a single API request by
 using a comma-separated list of model IDs or a wildcard expression.
 
 
@@ -49,7 +49,7 @@ using a comma-separated list of model IDs or a wildcard expression.
 == {api-path-parms-title}
 
 `<model_id>`::
-(Optional, string) 
+(Optional, string)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]
 
 
@@ -57,12 +57,12 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]
 == {api-query-parms-title}
 
 `allow_no_match`::
-(Optional, boolean) 
+(Optional, boolean)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=allow-no-match-models]
 
 `decompress_definition`::
 (Optional, boolean)
-Specifies whether the included model definition should be returned as a JSON map 
+Specifies whether the included model definition should be returned as a JSON map
 (`true`) or in a custom compressed format (`false`). Defaults to `true`.
 
 `for_export`::
@@ -72,17 +72,21 @@ retrieval. This allows the model to be in an acceptable format to be retrieved
 and then added to another cluster. Default is false.
 
 `from`::
-(Optional, integer) 
+(Optional, integer)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=from-models]
 
-`include_model_definition`::
-(Optional, boolean)
-Specifies whether the model definition is returned in the response. Defaults to 
-`false`. When `true`, only a single model must match the ID patterns provided. 
-Otherwise, a bad request is returned.
+`include`::
+(Optional, string)
+A comma delimited string of optional fields to include in the response body.
+Valid options are:
+ - `definition`: Includes the model definition
+ - `total_feature_importance`: Includes the total feature importance for the
+   training data set. This field is available in the `metadata` field in the
+   response body.
+Default is empty, indicating including no optional fields.
 
 `size`::
-(Optional, integer) 
+(Optional, integer)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=size-models]
 
 `tags`::
@@ -95,7 +99,7 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=tags]
 
 `trained_model_configs`::
 (array)
-An array of trained model resources, which are sorted by the `model_id` value in 
+An array of trained model resources, which are sorted by the `model_id` value in
 ascending order.
 +
 .Properties of trained model resources
@@ -133,8 +137,86 @@ The license level of the trained model.
 
 `metadata`:::
 (object)
-An object containing metadata about the trained model. For example, models 
+An object containing metadata about the trained model. For example, models
 created by {dfanalytics} contain `analysis_config` and `input` objects.
+.Properties of metadata
+[%collapsible%open]
+=====
+`total_feature_importance`:::
+(array)
+An array of the total feature importance for each feature used from
+the training data set. This array of objects is returned if {dfanalytics} trained
+the model and the request includes `total_feature_importance` in the `include`
+request parameter.
++
+.Properties of total feature importance
+[%collapsible%open]
+======
+
+`feature_name`:::
+(string)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-feature-name]
+
+`importance`:::
+(object)
+A collection of feature importance statistics related to the training data set for this particular feature.
++
+.Properties of feature importance
+[%collapsible%open]
+=======
+`mean_magnitude`:::
+(double)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-magnitude]
+
+`max`:::
+(int)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-max]
+
+`min`:::
+(int)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-min]
+
+=======
+
+`classes`:::
+(array)
+If the trained model is a classification model, feature importance statistics are gathered
+per target class value.
++
+.Properties of class feature importance
+[%collapsible%open]
+
+=======
+
+`class_name`:::
+(string)
+The target class value. Could be a string, boolean, or number.
+
+`importance`:::
+(object)
+A collection of feature importance statistics related to the training data set for this particular feature.
++
+.Properties of feature importance
+[%collapsible%open]
+========
+`mean_magnitude`:::
+(double)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-magnitude]
+
+`max`:::
+(int)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-max]
+
+`min`:::
+(int)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-metadata-feature-importance-min]
+
+========
+
+=======
+
+======
+=====
 
 `model_id`:::
 (string)
@@ -154,13 +236,13 @@ The {es} version number in which the trained model was created.
 == {api-response-codes-title}
 
 `400`::
-  If `include_model_definition` is `true`, this code indicates that more than 
+  If `include_model_definition` is `true`, this code indicates that more than
   one models match the ID pattern.
 
 `404` (Missing resources)::
   If `allow_no_match` is `false`, this code indicates that there are no
   resources that match the request or only partial matches for the request.
-  
+
 
 [[ml-get-inference-example]]
 == {api-examples-title}

+ 17 - 0
docs/reference/ml/ml-shared.asciidoc

@@ -785,6 +785,23 @@ prediction. Defaults to the `results_field` value of the {dfanalytics-job} that
 used to train the model, which defaults to `<dependent_variable>_prediction`.
 end::inference-config-results-field-processor[]
 
+tag::inference-metadata-feature-importance-feature-name[]
+The training feature name for which this importance was calculated.
+end::inference-metadata-feature-importance-feature-name[]
+tag::inference-metadata-feature-importance-magnitude[]
+The average magnitude of this feature across all the training data.
+This value is the average of the absolute values of the importance
+for this feature.
+end::inference-metadata-feature-importance-magnitude[]
+tag::inference-metadata-feature-importance-max[]
+The maximum importance value across all the training data for this
+feature.
+end::inference-metadata-feature-importance-max[]
+tag::inference-metadata-feature-importance-min[]
+The minimum importance value across all the training data for this
+feature.
+end::inference-metadata-feature-importance-min[]
+
 tag::influencers[]
 A comma separated list of influencer field names. Typically these can be the by,
 over, or partition fields that are used in the detector configuration. You might

+ 58 - 8
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsAction.java

@@ -5,19 +5,24 @@
  */
 package org.elasticsearch.xpack.core.ml.action;
 
+import org.elasticsearch.Version;
 import org.elasticsearch.action.ActionType;
 import org.elasticsearch.common.ParseField;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.util.set.Sets;
 import org.elasticsearch.xpack.core.action.AbstractGetResourcesRequest;
 import org.elasticsearch.xpack.core.action.AbstractGetResourcesResponse;
 import org.elasticsearch.xpack.core.action.util.QueryPage;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 
 import java.io.IOException;
 import java.util.Collections;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Objects;
+import java.util.Set;
 
 
 public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Response> {
@@ -31,23 +36,60 @@ public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Re
 
     public static class Request extends AbstractGetResourcesRequest {
 
-        public static final ParseField INCLUDE_MODEL_DEFINITION = new ParseField("include_model_definition");
+        static final String DEFINITION = "definition";
+        static final String TOTAL_FEATURE_IMPORTANCE = "total_feature_importance";
+        private static final Set<String> KNOWN_INCLUDES;
+        static {
+            HashSet<String> includes = new HashSet<>(2, 1.0f);
+            includes.add(DEFINITION);
+            includes.add(TOTAL_FEATURE_IMPORTANCE);
+            KNOWN_INCLUDES = Collections.unmodifiableSet(includes);
+        }
+        public static final ParseField INCLUDE = new ParseField("include");
+        public static final String INCLUDE_MODEL_DEFINITION = "include_model_definition";
         public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match");
         public static final ParseField TAGS = new ParseField("tags");
 
-        private final boolean includeModelDefinition;
+        private final Set<String> includes;
         private final List<String> tags;
 
+        @Deprecated
         public Request(String id, boolean includeModelDefinition, List<String> tags) {
             setResourceId(id);
             setAllowNoResources(true);
-            this.includeModelDefinition = includeModelDefinition;
             this.tags = tags == null ? Collections.emptyList() : tags;
+            if (includeModelDefinition) {
+                this.includes = new HashSet<>(Collections.singletonList(DEFINITION));
+            } else {
+                this.includes = Collections.emptySet();
+            }
+        }
+
+        public Request(String id, List<String> tags, Set<String> includes) {
+            setResourceId(id);
+            setAllowNoResources(true);
+            this.tags = tags == null ? Collections.emptyList() : tags;
+            this.includes = includes == null ? Collections.emptySet() : includes;
+            Set<String> unknownIncludes = Sets.difference(this.includes, KNOWN_INCLUDES);
+            if (unknownIncludes.isEmpty() == false) {
+                throw ExceptionsHelper.badRequestException(
+                    "unknown [include] parameters {}. Valid options are {}",
+                    unknownIncludes,
+                    KNOWN_INCLUDES);
+            }
         }
 
         public Request(StreamInput in) throws IOException {
             super(in);
-            this.includeModelDefinition = in.readBoolean();
+            if (in.getVersion().onOrAfter(Version.V_7_10_0)) {
+                this.includes = in.readSet(StreamInput::readString);
+            } else {
+                Set<String> includes = new HashSet<>();
+                if (in.readBoolean()) {
+                    includes.add(DEFINITION);
+                }
+                this.includes = includes;
+            }
             this.tags = in.readStringList();
         }
 
@@ -57,7 +99,11 @@ public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Re
         }
 
         public boolean isIncludeModelDefinition() {
-            return includeModelDefinition;
+            return this.includes.contains(DEFINITION);
+        }
+
+        public boolean isIncludeTotalFeatureImportance() {
+            return this.includes.contains(TOTAL_FEATURE_IMPORTANCE);
         }
 
         public List<String> getTags() {
@@ -67,13 +113,17 @@ public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Re
         @Override
         public void writeTo(StreamOutput out) throws IOException {
             super.writeTo(out);
-            out.writeBoolean(includeModelDefinition);
+            if (out.getVersion().onOrAfter(Version.V_7_10_0)) {
+                out.writeCollection(this.includes, StreamOutput::writeString);
+            } else {
+                out.writeBoolean(this.includes.contains(DEFINITION));
+            }
             out.writeStringCollection(tags);
         }
 
         @Override
         public int hashCode() {
-            return Objects.hash(super.hashCode(), includeModelDefinition, tags);
+            return Objects.hash(super.hashCode(), includes, tags);
         }
 
         @Override
@@ -85,7 +135,7 @@ public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Re
                 return false;
             }
             Request other = (Request) obj;
-            return super.equals(obj) && this.includeModelDefinition == other.includeModelDefinition && Objects.equals(tags, other.tags);
+            return super.equals(obj) && this.includes.equals(other.includes) && Objects.equals(tags, other.tags);
         }
     }
 

+ 23 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java

@@ -26,6 +26,7 @@ import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConst
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedInferenceConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedInferenceConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TotalFeatureImportance;
 import org.elasticsearch.xpack.core.ml.job.messages.Messages;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.core.ml.utils.MlStrings;
@@ -39,6 +40,7 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
+import java.util.Set;
 import java.util.stream.Collectors;
 
 import static org.elasticsearch.action.ValidateActions.addValidationError;
@@ -51,6 +53,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
     public static final int CURRENT_DEFINITION_COMPRESSION_VERSION = 1;
     public static final String DECOMPRESS_DEFINITION = "decompress_definition";
     public static final String FOR_EXPORT = "for_export";
+    public static final String TOTAL_FEATURE_IMPORTANCE = "total_feature_importance";
+    private static final Set<String> RESERVED_METADATA_FIELDS = Collections.singleton(TOTAL_FEATURE_IMPORTANCE);
 
     private static final String ESTIMATED_HEAP_MEMORY_USAGE_HUMAN = "estimated_heap_memory_usage";
 
@@ -408,7 +412,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
             this.definition = config.definition == null ? null : new LazyModelDefinition(config.definition);
             this.description = config.getDescription();
             this.tags = config.getTags();
-            this.metadata = config.getMetadata();
+            this.metadata = config.getMetadata() == null ? null : new HashMap<>(config.getMetadata());
             this.input = config.getInput();
             this.estimatedOperations = config.estimatedOperations;
             this.estimatedHeapMemory = config.estimatedHeapMemory;
@@ -460,6 +464,18 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
             return this;
         }
 
+        public Builder setFeatureImportance(List<TotalFeatureImportance> totalFeatureImportance) {
+            if (totalFeatureImportance == null) {
+                return this;
+            }
+            if (this.metadata == null) {
+                this.metadata = new HashMap<>();
+            }
+            this.metadata.put(TOTAL_FEATURE_IMPORTANCE,
+                totalFeatureImportance.stream().map(TotalFeatureImportance::asMap).collect(Collectors.toList()));
+            return this;
+        }
+
         public Builder setParsedDefinition(TrainedModelDefinition.Builder definition) {
             if (definition == null) {
                 return this;
@@ -616,6 +632,12 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
                     ESTIMATED_OPERATIONS.getPreferredName(),
                     validationException);
                 validationException = checkIllegalSetting(licenseLevel, LICENSE_LEVEL.getPreferredName(), validationException);
+                if (metadata != null) {
+                    validationException = checkIllegalSetting(
+                        metadata.get(TOTAL_FEATURE_IMPORTANCE),
+                        METADATA.getPreferredName() + "." + TOTAL_FEATURE_IMPORTANCE,
+                        validationException);
+                }
             }
 
             if (validationException != null) {

+ 33 - 21
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TotalFeatureImportance.java

@@ -20,8 +20,11 @@ import org.elasticsearch.common.xcontent.XContentParser;
 
 import java.io.IOException;
 import java.util.Collections;
+import java.util.LinkedHashMap;
 import java.util.List;
+import java.util.Map;
 import java.util.Objects;
+import java.util.stream.Collectors;
 
 public class TotalFeatureImportance implements ToXContentObject, Writeable {
 
@@ -81,16 +84,7 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable {
 
     @Override
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
-        builder.startObject();
-        builder.field(FEATURE_NAME.getPreferredName(), featureName);
-        if (importance != null) {
-            builder.field(IMPORTANCE.getPreferredName(), importance);
-        }
-        if (classImportances.isEmpty() == false) {
-            builder.field(CLASSES.getPreferredName(), classImportances);
-        }
-        builder.endObject();
-        return builder;
+        return builder.map(asMap());
     }
 
     @Override
@@ -103,6 +97,18 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable {
             && Objects.equals(classImportances, that.classImportances);
     }
 
+    public Map<String, Object> asMap() {
+        Map<String, Object> map = new LinkedHashMap<>();
+        map.put(FEATURE_NAME.getPreferredName(), featureName);
+        if (importance != null) {
+            map.put(IMPORTANCE.getPreferredName(), importance.asMap());
+        }
+        if (classImportances.isEmpty() == false) {
+            map.put(CLASSES.getPreferredName(), classImportances.stream().map(ClassImportance::asMap).collect(Collectors.toList()));
+        }
+        return map;
+    }
+
     @Override
     public int hashCode() {
         return Objects.hash(featureName, importance, classImportances);
@@ -165,12 +171,15 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable {
 
         @Override
         public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
-            builder.startObject();
-            builder.field(MEAN_MAGNITUDE.getPreferredName(), meanMagnitude);
-            builder.field(MIN.getPreferredName(), min);
-            builder.field(MAX.getPreferredName(), max);
-            builder.endObject();
-            return builder;
+            return builder.map(asMap());
+        }
+
+        private Map<String, Object> asMap() {
+            Map<String, Object> map = new LinkedHashMap<>();
+            map.put(MEAN_MAGNITUDE.getPreferredName(), meanMagnitude);
+            map.put(MIN.getPreferredName(), min);
+            map.put(MAX.getPreferredName(), max);
+            return map;
         }
     }
 
@@ -229,11 +238,14 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable {
 
         @Override
         public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
-            builder.startObject();
-            builder.field(CLASS_NAME.getPreferredName(), className);
-            builder.field(IMPORTANCE.getPreferredName(), importance);
-            builder.endObject();
-            return builder;
+            return builder.map(asMap());
+        }
+
+        private Map<String, Object> asMap() {
+            Map<String, Object> map = new LinkedHashMap<>();
+            map.put(CLASS_NAME.getPreferredName(), className);
+            map.put(IMPORTANCE.getPreferredName(), importance.asMap());
+            return map;
         }
 
         @Override

+ 4 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TrainedModelMetadata.java

@@ -53,6 +53,10 @@ public class TrainedModelMetadata implements ToXContentObject, Writeable {
         return NAME + "-" + modelId;
     }
 
+    public static String modelId(String docId) {
+        return docId.substring(NAME.length() + 1);
+    }
+
     private final List<TotalFeatureImportance> totalFeatureImportances;
     private final String modelId;
 

+ 1 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java

@@ -103,7 +103,7 @@ public final class Messages {
     public static final String INFERENCE_CONFIG_NOT_SUPPORTED_ON_VERSION =
         "Configuration [{0}] requires minimum node version [{1}] (current minimum node version [{2}]";
     public static final String MODEL_DEFINITION_NOT_FOUND = "Could not find trained model definition [{0}]";
-    public static final String MODEL_METADATA_NOT_FOUND = "Could not find trained model metadata [{0}]";
+    public static final String MODEL_METADATA_NOT_FOUND = "Could not find trained model metadata {0}";
     public static final String INFERENCE_CANNOT_DELETE_MODEL =
         "Unable to delete model [{0}]";
     public static final String MODEL_DEFINITION_TRUNCATED =

+ 28 - 4
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsRequestTests.java

@@ -5,19 +5,28 @@
  */
 package org.elasticsearch.xpack.core.ml.action;
 
+import org.elasticsearch.Version;
 import org.elasticsearch.common.io.stream.Writeable;
-import org.elasticsearch.test.AbstractWireSerializingTestCase;
 import org.elasticsearch.xpack.core.action.util.PageParams;
+import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
 import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction.Request;
 
-public class GetTrainedModelsRequestTests extends AbstractWireSerializingTestCase<Request> {
+import java.util.HashSet;
+import java.util.Set;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+public class GetTrainedModelsRequestTests extends AbstractBWCWireSerializationTestCase<Request> {
 
     @Override
     protected Request createTestInstance() {
         Request request = new Request(randomAlphaOfLength(20),
-            randomBoolean(),
             randomBoolean() ? null :
-            randomList(10, () -> randomAlphaOfLength(10)));
+            randomList(10, () -> randomAlphaOfLength(10)),
+            randomBoolean() ? null :
+                Stream.generate(() -> randomFrom(Request.DEFINITION, Request.TOTAL_FEATURE_IMPORTANCE))
+                    .limit(4)
+                    .collect(Collectors.toSet()));
         request.setPageParams(new PageParams(randomIntBetween(0, 100), randomIntBetween(0, 100)));
         return request;
     }
@@ -26,4 +35,19 @@ public class GetTrainedModelsRequestTests extends AbstractWireSerializingTestCas
     protected Writeable.Reader<Request> instanceReader() {
         return Request::new;
     }
+
+    @Override
+    protected Request mutateInstanceForVersion(Request instance, Version version) {
+        if (version.before(Version.V_7_10_0)) {
+            Set<String> includes = new HashSet<>();
+            if (instance.isIncludeModelDefinition()) {
+                includes.add(Request.DEFINITION);
+            }
+            return new Request(
+                instance.getResourceId(),
+                instance.getTags(),
+                includes);
+        }
+        return instance;
+    }
 }

+ 8 - 4
x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java

@@ -42,11 +42,13 @@ import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
+import java.util.Map;
 import java.util.Set;
 import java.util.stream.Collectors;
 import java.util.stream.Stream;
 
 import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.hasKey;
 import static org.hamcrest.Matchers.startsWith;
 
 public class ChunkedTrainedModelPersisterIT extends MlSingleNodeTestCase {
@@ -95,19 +97,21 @@ public class ChunkedTrainedModelPersisterIT extends MlSingleNodeTestCase {
         trainedModelProvider.expandIds(modelId + "*", false, PageParams.defaultParams(), Collections.emptySet(), getIdsFuture);
         Tuple<Long, Set<String>> ids = getIdsFuture.actionGet();
         assertThat(ids.v1(), equalTo(1L));
+        String inferenceModelId = ids.v2().iterator().next();
 
         PlainActionFuture<TrainedModelConfig> getTrainedModelFuture = new PlainActionFuture<>();
-        trainedModelProvider.getTrainedModel(ids.v2().iterator().next(), true, getTrainedModelFuture);
+        trainedModelProvider.getTrainedModel(inferenceModelId, true, true, getTrainedModelFuture);
 
         TrainedModelConfig storedConfig = getTrainedModelFuture.actionGet();
         assertThat(storedConfig.getCompressedDefinition(), equalTo(compressedDefinition));
         assertThat(storedConfig.getEstimatedOperations(), equalTo((long)modelSizeInfo.numOperations()));
         assertThat(storedConfig.getEstimatedHeapMemory(), equalTo(modelSizeInfo.ramBytesUsed()));
+        assertThat(storedConfig.getMetadata(), hasKey("total_feature_importance"));
 
-        PlainActionFuture<TrainedModelMetadata> getTrainedMetadataFuture = new PlainActionFuture<>();
-        trainedModelProvider.getTrainedModelMetadata(ids.v2().iterator().next(), getTrainedMetadataFuture);
+        PlainActionFuture<Map<String, TrainedModelMetadata>> getTrainedMetadataFuture = new PlainActionFuture<>();
+        trainedModelProvider.getTrainedModelMetadata(Collections.singletonList(inferenceModelId), getTrainedMetadataFuture);
 
-        TrainedModelMetadata storedMetadata = getTrainedMetadataFuture.actionGet();
+        TrainedModelMetadata storedMetadata = getTrainedMetadataFuture.actionGet().get(inferenceModelId);
         assertThat(storedMetadata.getModelId(), startsWith(modelId));
         assertThat(storedMetadata.getTotalFeatureImportances(), equalTo(modelMetadata.getFeatureImportances()));
     }

+ 24 - 6
x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java

@@ -89,7 +89,10 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
         assertThat(exceptionHolder.get(), is(nullValue()));
 
         AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
-        blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, true, listener), getConfigHolder, exceptionHolder);
+        blockingCall(
+            listener -> trainedModelProvider.getTrainedModel(modelId, true, false, listener),
+            getConfigHolder,
+            exceptionHolder);
         getConfigHolder.get().ensureParsedDefinition(xContentRegistry());
         assertThat(getConfigHolder.get(), is(not(nullValue())));
         assertThat(getConfigHolder.get(), equalTo(config));
@@ -120,7 +123,10 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
         assertThat(exceptionHolder.get(), is(nullValue()));
 
         AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
-        blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, false, listener), getConfigHolder, exceptionHolder);
+        blockingCall(listener ->
+            trainedModelProvider.getTrainedModel(modelId, false, false, listener),
+            getConfigHolder,
+            exceptionHolder);
         getConfigHolder.get().ensureParsedDefinition(xContentRegistry());
         assertThat(getConfigHolder.get(), is(not(nullValue())));
         assertThat(getConfigHolder.get(), equalTo(copyWithoutDefinition));
@@ -131,7 +137,10 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
         String modelId = "test-get-missing-trained-model-config";
         AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
         AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
-        blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, true, listener), getConfigHolder, exceptionHolder);
+        blockingCall(
+            listener -> trainedModelProvider.getTrainedModel(modelId, true, false, listener),
+            getConfigHolder,
+            exceptionHolder);
         assertThat(exceptionHolder.get(), is(not(nullValue())));
         assertThat(exceptionHolder.get().getMessage(),
             equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId)));
@@ -153,7 +162,10 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
             .actionGet();
 
         AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
-        blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, true, listener), getConfigHolder, exceptionHolder);
+        blockingCall(
+            listener -> trainedModelProvider.getTrainedModel(modelId, true, false, listener),
+            getConfigHolder,
+            exceptionHolder);
         assertThat(exceptionHolder.get(), is(not(nullValue())));
         assertThat(exceptionHolder.get().getMessage(),
             equalTo(Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId)));
@@ -192,7 +204,10 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
         }
 
         AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
-        blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, true, listener), getConfigHolder, exceptionHolder);
+        blockingCall(
+            listener -> trainedModelProvider.getTrainedModel(modelId, true, false, listener),
+            getConfigHolder,
+            exceptionHolder);
         assertThat(getConfigHolder.get(), is(nullValue()));
         assertThat(exceptionHolder.get(), is(not(nullValue())));
         assertThat(exceptionHolder.get().getMessage(), equalTo(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId)));
@@ -237,7 +252,10 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
             }
         }
         AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
-        blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, true, listener), getConfigHolder, exceptionHolder);
+        blockingCall(
+            listener -> trainedModelProvider.getTrainedModel(modelId, true, false, listener),
+            getConfigHolder,
+            exceptionHolder);
         assertThat(getConfigHolder.get(), is(nullValue()));
         assertThat(exceptionHolder.get(), is(not(nullValue())));
         assertThat(exceptionHolder.get().getMessage(), equalTo(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId)));

+ 1 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java

@@ -946,7 +946,7 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
                 new ActionHandler<>(DeleteTrainedModelAction.INSTANCE, TransportDeleteTrainedModelAction.class),
                 new ActionHandler<>(GetTrainedModelsStatsAction.INSTANCE, TransportGetTrainedModelsStatsAction.class),
                 new ActionHandler<>(PutTrainedModelAction.INSTANCE, TransportPutTrainedModelAction.class),
-                usageAction,
+            usageAction,
                 infoAction);
     }
 

+ 18 - 8
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java

@@ -57,15 +57,25 @@ public class TransportGetTrainedModelsAction extends HandledTransportAction<Requ
                 }
 
                 if (request.isIncludeModelDefinition()) {
-                    provider.getTrainedModel(totalAndIds.v2().iterator().next(), true, ActionListener.wrap(
-                        config -> listener.onResponse(responseBuilder.setModels(Collections.singletonList(config)).build()),
-                        listener::onFailure
-                    ));
+                    provider.getTrainedModel(
+                        totalAndIds.v2().iterator().next(),
+                        true,
+                        request.isIncludeTotalFeatureImportance(),
+                        ActionListener.wrap(
+                            config -> listener.onResponse(responseBuilder.setModels(Collections.singletonList(config)).build()),
+                            listener::onFailure
+                        )
+                    );
                 } else {
-                    provider.getTrainedModels(totalAndIds.v2(), request.isAllowNoResources(), ActionListener.wrap(
-                        configs -> listener.onResponse(responseBuilder.setModels(configs).build()),
-                        listener::onFailure
-                    ));
+                    provider.getTrainedModels(
+                        totalAndIds.v2(),
+                        request.isAllowNoResources(),
+                        request.isIncludeTotalFeatureImportance(),
+                        ActionListener.wrap(
+                            configs -> listener.onResponse(responseBuilder.setModels(configs).build()),
+                            listener::onFailure
+                        )
+                    );
                 }
             },
             listener::onFailure

+ 1 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java

@@ -82,7 +82,7 @@ public class TransportInternalInferModelAction extends HandledTransportAction<Re
             responseBuilder.setLicensed(true);
             this.modelLoadingService.getModelForPipeline(request.getModelId(), getModelListener);
         } else {
-            trainedModelProvider.getTrainedModel(request.getModelId(), false, ActionListener.wrap(
+            trainedModelProvider.getTrainedModel(request.getModelId(), false, false, ActionListener.wrap(
                 trainedModelConfig -> {
                     responseBuilder.setLicensed(licenseState.isAllowedByLicense(trainedModelConfig.getLicenseLevel()));
                     if (licenseState.isAllowedByLicense(trainedModelConfig.getLicenseLevel()) || request.isPreviouslyLicensed()) {

+ 3 - 3
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java

@@ -270,7 +270,7 @@ public class ModelLoadingService implements ClusterStateListener {
     }
 
     private void loadModel(String modelId, Consumer consumer) {
-        provider.getTrainedModel(modelId, false, ActionListener.wrap(
+        provider.getTrainedModel(modelId, false, false, ActionListener.wrap(
             trainedModelConfig -> {
                 trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getEstimatedHeapMemory(), modelId);
                 provider.getTrainedModelForInference(modelId, ActionListener.wrap(
@@ -306,7 +306,7 @@ public class ModelLoadingService implements ClusterStateListener {
         // If we the model is not loaded and we did not kick off a new loading attempt, this means that we may be getting called
         // by a simulated pipeline
         logger.trace(() -> new ParameterizedMessage("[{}] not actively loading, eager loading without cache", modelId));
-        provider.getTrainedModel(modelId, false, ActionListener.wrap(
+        provider.getTrainedModel(modelId, false, false, ActionListener.wrap(
             trainedModelConfig -> {
                 // Verify we can pull the model into memory without causing OOM
                 trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getEstimatedHeapMemory(), modelId);
@@ -434,7 +434,7 @@ public class ModelLoadingService implements ClusterStateListener {
 
             logger.trace(() -> new ParameterizedMessage("Persisting stats for evicted model [{}]",
                 notification.getValue().model.getModelId()));
-            
+
             // If the model is no longer referenced, flush the stats to persist as soon as possible
             notification.getValue().model.persistStats(referencedModels.contains(notification.getKey()) == false);
         } finally {

+ 112 - 34
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java

@@ -88,9 +88,11 @@ import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.Comparator;
+import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
+import java.util.Objects;
 import java.util.Set;
 import java.util.TreeSet;
 import java.util.stream.Collectors;
@@ -234,14 +236,14 @@ public class TrainedModelProvider {
             ));
     }
 
-    public void getTrainedModelMetadata(String modelId, ActionListener<TrainedModelMetadata> listener) {
+    public void getTrainedModelMetadata(Collection<String> modelIds, ActionListener<Map<String, TrainedModelMetadata>> listener) {
         SearchRequest searchRequest = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN)
             .setQuery(QueryBuilders.constantScoreQuery(QueryBuilders
                 .boolQuery()
-                .filter(QueryBuilders.termQuery(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId))
+                .filter(QueryBuilders.termsQuery(TrainedModelConfig.MODEL_ID.getPreferredName(), modelIds))
                 .filter(QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(),
                     TrainedModelMetadata.NAME))))
-            .setSize(1)
+            .setSize(10_000)
             // First find the latest index
             .addSort("_index", SortOrder.DESC)
             .request();
@@ -249,18 +251,20 @@ public class TrainedModelProvider {
             searchResponse -> {
                 if (searchResponse.getHits().getHits().length == 0) {
                     listener.onFailure(new ResourceNotFoundException(
-                        Messages.getMessage(Messages.MODEL_METADATA_NOT_FOUND, modelId)));
+                        Messages.getMessage(Messages.MODEL_METADATA_NOT_FOUND, modelIds)));
                     return;
                 }
-                List<TrainedModelMetadata> metadataList = handleHits(searchResponse.getHits().getHits(),
-                    modelId,
-                    this::parseMetadataLenientlyFromSource);
-                listener.onResponse(metadataList.get(0));
+                HashMap<String, TrainedModelMetadata> map = new HashMap<>();
+                for (SearchHit hit : searchResponse.getHits().getHits()) {
+                    String modelId = TrainedModelMetadata.modelId(Objects.requireNonNull(hit.getId()));
+                    map.putIfAbsent(modelId, parseMetadataLenientlyFromSource(hit.getSourceRef(), modelId));
+                }
+                listener.onResponse(map);
             },
             e -> {
                 if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) {
                     listener.onFailure(new ResourceNotFoundException(
-                        Messages.getMessage(Messages.MODEL_METADATA_NOT_FOUND, modelId)));
+                        Messages.getMessage(Messages.MODEL_METADATA_NOT_FOUND, modelIds)));
                     return;
                 }
                 listener.onFailure(e);
@@ -370,7 +374,7 @@ public class TrainedModelProvider {
         // TODO Change this when we get more than just langIdent stored
         if (MODELS_STORED_AS_RESOURCE.contains(modelId)) {
             try {
-                TrainedModelConfig config = loadModelFromResource(modelId, false).ensureParsedDefinition(xContentRegistry);
+                TrainedModelConfig config = loadModelFromResource(modelId, false).build().ensureParsedDefinition(xContentRegistry);
                 assert config.getModelDefinition().getTrainedModel() instanceof LangIdentNeuralNetwork;
                 listener.onResponse(
                     InferenceDefinition.builder()
@@ -433,18 +437,50 @@ public class TrainedModelProvider {
         ));
     }
 
-    public void getTrainedModel(final String modelId, final boolean includeDefinition, final ActionListener<TrainedModelConfig> listener) {
+    public void getTrainedModel(final String modelId,
+                                final boolean includeDefinition,
+                                final boolean includeTotalFeatureImportance,
+                                final ActionListener<TrainedModelConfig> finalListener) {
 
         if (MODELS_STORED_AS_RESOURCE.contains(modelId)) {
             try {
-                listener.onResponse(loadModelFromResource(modelId, includeDefinition == false));
+                finalListener.onResponse(loadModelFromResource(modelId, includeDefinition == false).build());
                 return;
             } catch (ElasticsearchException ex) {
-                listener.onFailure(ex);
+                finalListener.onFailure(ex);
                 return;
             }
         }
 
+        ActionListener<TrainedModelConfig.Builder> getTrainedModelListener = ActionListener.wrap(
+            modelBuilder -> {
+                if (includeTotalFeatureImportance == false) {
+                    finalListener.onResponse(modelBuilder.build());
+                    return;
+                }
+                this.getTrainedModelMetadata(Collections.singletonList(modelId), ActionListener.wrap(
+                    metadata -> {
+                        TrainedModelMetadata modelMetadata = metadata.get(modelId);
+                        if (modelMetadata != null) {
+                            modelBuilder.setFeatureImportance(modelMetadata.getTotalFeatureImportances());
+                        }
+                        finalListener.onResponse(modelBuilder.build());
+                    },
+                    failure -> {
+                        // total feature importance is not necessary for a model to be valid
+                        // we shouldn't fail if it is not found
+                        if (ExceptionsHelper.unwrapCause(failure) instanceof ResourceNotFoundException) {
+                            finalListener.onResponse(modelBuilder.build());
+                            return;
+                        }
+                        finalListener.onFailure(failure);
+                    }
+                ));
+
+            },
+            finalListener::onFailure
+        );
+
         QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders
             .idsQuery()
             .addIds(modelId));
@@ -482,11 +518,11 @@ public class TrainedModelProvider {
                 try {
                     builder = handleSearchItem(multiSearchResponse.getResponses()[0], modelId, this::parseInferenceDocLenientlyFromSource);
                 } catch (ResourceNotFoundException ex) {
-                    listener.onFailure(new ResourceNotFoundException(
+                    getTrainedModelListener.onFailure(new ResourceNotFoundException(
                         Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId)));
                     return;
                 } catch (Exception ex) {
-                    listener.onFailure(ex);
+                    getTrainedModelListener.onFailure(ex);
                     return;
                 }
 
@@ -499,22 +535,22 @@ public class TrainedModelProvider {
                             String compressedString = getDefinitionFromDocs(docs, modelId);
                             builder.setDefinitionFromString(compressedString);
                         } catch (ElasticsearchException elasticsearchException) {
-                            listener.onFailure(elasticsearchException);
+                            getTrainedModelListener.onFailure(elasticsearchException);
                             return;
                         }
 
                     } catch (ResourceNotFoundException ex) {
-                        listener.onFailure(new ResourceNotFoundException(
+                        getTrainedModelListener.onFailure(new ResourceNotFoundException(
                             Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId)));
                         return;
                     } catch (Exception ex) {
-                        listener.onFailure(ex);
+                        getTrainedModelListener.onFailure(ex);
                         return;
                     }
                 }
-                listener.onResponse(builder.build());
+                getTrainedModelListener.onResponse(builder);
             },
-            listener::onFailure
+            getTrainedModelListener::onFailure
         );
 
         executeAsyncWithOrigin(client,
@@ -531,7 +567,10 @@ public class TrainedModelProvider {
      * This does no expansion on the ids.
      * It assumes that there are fewer than 10k.
      */
-    public void getTrainedModels(Set<String> modelIds, boolean allowNoResources, final ActionListener<List<TrainedModelConfig>> listener) {
+    public void getTrainedModels(Set<String> modelIds,
+                                 boolean allowNoResources,
+                                 boolean includeTotalFeatureImportance,
+                                 final ActionListener<List<TrainedModelConfig>> finalListener) {
         QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds(modelIds.toArray(new String[0])));
 
         SearchRequest searchRequest = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN)
@@ -540,23 +579,63 @@ public class TrainedModelProvider {
             .setQuery(queryBuilder)
             .setSize(modelIds.size())
             .request();
-        List<TrainedModelConfig> configs = new ArrayList<>(modelIds.size());
+        List<TrainedModelConfig.Builder> configs = new ArrayList<>(modelIds.size());
         Set<String> modelsInIndex = Sets.difference(modelIds, MODELS_STORED_AS_RESOURCE);
         Set<String> modelsAsResource = Sets.intersection(MODELS_STORED_AS_RESOURCE, modelIds);
         for(String modelId : modelsAsResource) {
             try {
                 configs.add(loadModelFromResource(modelId, true));
             } catch (ElasticsearchException ex) {
-                listener.onFailure(ex);
+                finalListener.onFailure(ex);
                 return;
             }
         }
         if (modelsInIndex.isEmpty()) {
-            configs.sort(Comparator.comparing(TrainedModelConfig::getModelId));
-            listener.onResponse(configs);
+            finalListener.onResponse(configs.stream()
+                .map(TrainedModelConfig.Builder::build)
+                .sorted(Comparator.comparing(TrainedModelConfig::getModelId))
+                .collect(Collectors.toList()));
             return;
         }
 
+        ActionListener<List<TrainedModelConfig.Builder>> getTrainedModelListener = ActionListener.wrap(
+            modelBuilders -> {
+                if (includeTotalFeatureImportance == false) {
+                    finalListener.onResponse(modelBuilders.stream()
+                        .map(TrainedModelConfig.Builder::build)
+                        .sorted(Comparator.comparing(TrainedModelConfig::getModelId))
+                        .collect(Collectors.toList()));
+                    return;
+                }
+                this.getTrainedModelMetadata(modelIds, ActionListener.wrap(
+                    metadata ->
+                        finalListener.onResponse(modelBuilders.stream()
+                            .map(builder -> {
+                                TrainedModelMetadata modelMetadata = metadata.get(builder.getModelId());
+                                if (modelMetadata != null) {
+                                    builder.setFeatureImportance(modelMetadata.getTotalFeatureImportances());
+                                }
+                                return builder.build();
+                            })
+                            .sorted(Comparator.comparing(TrainedModelConfig::getModelId))
+                            .collect(Collectors.toList())),
+                    failure -> {
+                        // total feature importance is not necessary for a model to be valid
+                        // we shouldn't fail if it is not found
+                        if (ExceptionsHelper.unwrapCause(failure) instanceof ResourceNotFoundException) {
+                            finalListener.onResponse(modelBuilders.stream()
+                                .map(TrainedModelConfig.Builder::build)
+                                .sorted(Comparator.comparing(TrainedModelConfig::getModelId))
+                                .collect(Collectors.toList()));
+                            return;
+                        }
+                        finalListener.onFailure(failure);
+                    }
+                ));
+            },
+            finalListener::onFailure
+        );
+
         ActionListener<SearchResponse> configSearchHandler = ActionListener.wrap(
             searchResponse -> {
                 Set<String> observedIds = new HashSet<>(
@@ -567,12 +646,12 @@ public class TrainedModelProvider {
                     try {
                         if (observedIds.contains(searchHit.getId()) == false) {
                             configs.add(
-                                parseInferenceDocLenientlyFromSource(searchHit.getSourceRef(), searchHit.getId()).build()
+                                parseInferenceDocLenientlyFromSource(searchHit.getSourceRef(), searchHit.getId())
                             );
                             observedIds.add(searchHit.getId());
                         }
                     } catch (IOException ex) {
-                        listener.onFailure(
+                        getTrainedModelListener.onFailure(
                             ExceptionsHelper.serverError(INFERENCE_FAILED_TO_DESERIALIZE, ex, searchHit.getId()));
                         return;
                     }
@@ -582,14 +661,13 @@ public class TrainedModelProvider {
                 // Otherwise, treat it as if it was never expanded to begin with.
                 Set<String> missingConfigs = Sets.difference(modelIds, observedIds);
                 if (missingConfigs.isEmpty() == false && allowNoResources == false) {
-                    listener.onFailure(new ResourceNotFoundException(Messages.INFERENCE_NOT_FOUND_MULTIPLE, missingConfigs));
+                    getTrainedModelListener.onFailure(new ResourceNotFoundException(Messages.INFERENCE_NOT_FOUND_MULTIPLE, missingConfigs));
                     return;
                 }
                 // Ensure sorted even with the injection of locally resourced models
-                configs.sort(Comparator.comparing(TrainedModelConfig::getModelId));
-                listener.onResponse(configs);
+                getTrainedModelListener.onResponse(configs);
             },
-            listener::onFailure
+            getTrainedModelListener::onFailure
         );
 
         executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, configSearchHandler);
@@ -638,7 +716,7 @@ public class TrainedModelProvider {
             foundResourceIds = new HashSet<>();
             for(String resourceId : matchedResourceIds) {
                 // Does the model as a resource have all the tags?
-                if (Sets.newHashSet(loadModelFromResource(resourceId, true).getTags()).containsAll(tags)) {
+                if (Sets.newHashSet(loadModelFromResource(resourceId, true).build().getTags()).containsAll(tags)) {
                     foundResourceIds.add(resourceId);
                 }
             }
@@ -832,7 +910,7 @@ public class TrainedModelProvider {
         return QueryBuilders.constantScoreQuery(boolQueryBuilder);
     }
 
-    TrainedModelConfig loadModelFromResource(String modelId, boolean nullOutDefinition) {
+    TrainedModelConfig.Builder loadModelFromResource(String modelId, boolean nullOutDefinition) {
         URL resource = getClass().getResource(MODEL_RESOURCE_PATH + modelId + MODEL_RESOURCE_FILE_EXT);
         if (resource == null) {
             logger.error("[{}] presumed stored as a resource but not found", modelId);
@@ -847,7 +925,7 @@ public class TrainedModelProvider {
             if (nullOutDefinition) {
                 builder.clearDefinition();
             }
-            return builder.build();
+            return builder;
         } catch (IOException ioEx) {
             logger.error(new ParameterizedMessage("[{}] failed to parse model definition", modelId), ioEx);
             throw ExceptionsHelper.serverError(INFERENCE_FAILED_TO_DESERIALIZE, ioEx, modelId);

+ 11 - 5
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java

@@ -25,6 +25,7 @@ import org.elasticsearch.xpack.ml.MachineLearning;
 import java.io.IOException;
 import java.util.Collections;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
@@ -55,12 +56,17 @@ public class RestGetTrainedModelsAction extends BaseRestHandler {
         if (Strings.isNullOrEmpty(modelId)) {
             modelId = Metadata.ALL;
         }
-        boolean includeModelDefinition = restRequest.paramAsBoolean(
-            GetTrainedModelsAction.Request.INCLUDE_MODEL_DEFINITION.getPreferredName(),
-            false
-        );
         List<String> tags = asList(restRequest.paramAsStringArray(TrainedModelConfig.TAGS.getPreferredName(), Strings.EMPTY_ARRAY));
-        GetTrainedModelsAction.Request request = new GetTrainedModelsAction.Request(modelId, includeModelDefinition, tags);
+        Set<String> includes = new HashSet<>(
+            asList(
+                restRequest.paramAsStringArray(
+                    GetTrainedModelsAction.Request.INCLUDE.getPreferredName(),
+                    Strings.EMPTY_ARRAY)));
+        final GetTrainedModelsAction.Request request = restRequest.hasParam(GetTrainedModelsAction.Request.INCLUDE_MODEL_DEFINITION) ?
+            new GetTrainedModelsAction.Request(modelId,
+                restRequest.paramAsBoolean(GetTrainedModelsAction.Request.INCLUDE_MODEL_DEFINITION, false),
+                tags) :
+            new GetTrainedModelsAction.Request(modelId, tags, includes);
         if (restRequest.hasParam(PageParams.FROM.getPreferredName()) || restRequest.hasParam(PageParams.SIZE.getPreferredName())) {
             request.setPageParams(new PageParams(restRequest.paramAsInt(PageParams.FROM.getPreferredName(), PageParams.DEFAULT_FROM),
                 restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE)));

+ 9 - 9
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java

@@ -437,9 +437,9 @@ public class ModelLoadingServiceTests extends ESTestCase {
         // the loading occurred or which models are currently in the cache due to evictions.
         // Verify that we have at least loaded all three
         assertBusy(() -> {
-            verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), eq(false), any());
-            verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), eq(false), any());
-            verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(false), any());
+            verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), eq(false), eq(false), any());
+            verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), eq(false), eq(false), any());
+            verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(false), eq(false), any());
         });
         assertBusy(() -> {
             assertThat(circuitBreaker.getUsed(), equalTo(10L));
@@ -553,10 +553,10 @@ public class ModelLoadingServiceTests extends ESTestCase {
         }).when(trainedModelProvider).getTrainedModelForInference(eq(modelId), any());
         doAnswer(invocationOnMock -> {
             @SuppressWarnings("rawtypes")
-            ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2];
+            ActionListener listener = (ActionListener) invocationOnMock.getArguments()[3];
             listener.onResponse(trainedModelConfig);
             return null;
-        }).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), any());
+        }).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), eq(false), any());
     }
 
     @SuppressWarnings("unchecked")
@@ -564,20 +564,20 @@ public class ModelLoadingServiceTests extends ESTestCase {
         if (randomBoolean()) {
             doAnswer(invocationOnMock -> {
                 @SuppressWarnings("rawtypes")
-                ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2];
+                ActionListener listener = (ActionListener) invocationOnMock.getArguments()[3];
                 listener.onFailure(new ResourceNotFoundException(
                     Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId)));
                 return null;
-            }).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), any());
+            }).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), eq(false), any());
         } else {
             TrainedModelConfig trainedModelConfig = mock(TrainedModelConfig.class);
             when(trainedModelConfig.getEstimatedHeapMemory()).thenReturn(0L);
             doAnswer(invocationOnMock -> {
                 @SuppressWarnings("rawtypes")
-                ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2];
+                ActionListener listener = (ActionListener) invocationOnMock.getArguments()[3];
                 listener.onResponse(trainedModelConfig);
                 return null;
-            }).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), any());
+            }).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(false), eq(false), any());
             doAnswer(invocationOnMock -> {
                 @SuppressWarnings("rawtypes")
                 ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1];

+ 2 - 2
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java

@@ -57,14 +57,14 @@ public class TrainedModelProviderTests extends ESTestCase {
         TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry());
         for(String modelId : TrainedModelProvider.MODELS_STORED_AS_RESOURCE) {
             PlainActionFuture<TrainedModelConfig> future = new PlainActionFuture<>();
-            trainedModelProvider.getTrainedModel(modelId, true, future);
+            trainedModelProvider.getTrainedModel(modelId, true, false, future);
             TrainedModelConfig configWithDefinition = future.actionGet();
 
             assertThat(configWithDefinition.getModelId(), equalTo(modelId));
             assertThat(configWithDefinition.ensureParsedDefinition(xContentRegistry()).getModelDefinition(), is(not(nullValue())));
 
             PlainActionFuture<TrainedModelConfig> futureNoDefinition = new PlainActionFuture<>();
-            trainedModelProvider.getTrainedModel(modelId, false, futureNoDefinition);
+            trainedModelProvider.getTrainedModel(modelId, false, false, futureNoDefinition);
             TrainedModelConfig configWithoutDefinition = futureNoDefinition.actionGet();
 
             assertThat(configWithoutDefinition.getModelId(), equalTo(modelId));

+ 1 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/trainedmodels/langident/LangIdentNeuralNetworkInferenceTests.java

@@ -33,7 +33,7 @@ public class LangIdentNeuralNetworkInferenceTests extends ESTestCase {
         TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry());
         PlainActionFuture<TrainedModelConfig> future = new PlainActionFuture<>();
         // Should be OK as we don't make any client calls
-        trainedModelProvider.getTrainedModel("lang_ident_model_1", true, future);
+        trainedModelProvider.getTrainedModel("lang_ident_model_1", true, false, future);
         TrainedModelConfig config = future.actionGet();
 
         config.ensureParsedDefinition(xContentRegistry());

+ 3 - 4
x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models.json

@@ -34,11 +34,10 @@
         "description":"Whether to ignore if a wildcard expression matches no trained models. (This includes `_all` string or when no trained models have been specified)",
         "default":true
       },
-      "include_model_definition":{
-        "type":"boolean",
+      "include":{
+        "type":"string",
         "required":false,
-        "description":"Should the full model definition be included in the results. These definitions can be large. So be cautious when including them. Defaults to false.",
-        "default":false
+        "description":"A comma-separate list of fields to optionally include. Valid options are 'definition' and 'total_feature_importance'. Default is none."
       },
       "decompress_definition":{
         "type":"boolean",

+ 34 - 2
x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml

@@ -1,6 +1,24 @@
 setup:
   - skip:
-      features: headers
+      features:
+        - headers
+        - allowed_warnings
+  - do:
+      allowed_warnings:
+        - "index [.ml-inference-000003] matches multiple legacy templates [.ml-inference-000003, global], composable templates will only match a single template"
+      headers:
+        Content-Type: application/json
+        Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
+      index:
+        id: trained_model_metadata-a-regression-model-0
+        index: .ml-inference-000003
+        body:
+          model_id: "a-regression-model-0"
+          doc_type: "trained_model_metadata"
+          total_feature_importance:
+            - { feature_name: "foo", importance: { mean_magnitude: 6.0,min: -3.0,max: 3.0 }}
+            - { feature_name: "bar", importance: { mean_magnitude: 6.0,min: -3.0,max: 3.0 }}
+
   - do:
       headers:
         Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
@@ -548,6 +566,20 @@ setup:
   - match: { count: 12 }
   - match: { trained_model_configs.0.model_id: "a-regression-model-1" }
 ---
+"Test get models with include total feature importance":
+  - do:
+      ml.get_trained_models:
+        model_id: "a-regression-model-*"
+        include: "total_feature_importance"
+  - match: { count: 2 }
+  - length: { trained_model_configs: 2 }
+  - match: { trained_model_configs.0.model_id: "a-regression-model-0" }
+  - is_true: trained_model_configs.0.metadata.total_feature_importance
+  - length: { trained_model_configs.0.metadata.total_feature_importance: 2 }
+  - match: { trained_model_configs.1.model_id: "a-regression-model-1" }
+  - is_false: trained_model_configs.1.metadata.total_feature_importance
+
+---
 "Test delete given unused trained model":
   - do:
       ml.delete_trained_model:
@@ -824,7 +856,7 @@ setup:
       ml.get_trained_models:
         model_id: "a-regression-model-1"
         for_export: true
-        include_model_definition: true
+        include: "definition"
         decompress_definition: false
 
   - match: { trained_model_configs.0.description: "empty model for tests" }