Browse Source

[ML] handle new model metadata stream from native process (#59725)

This adds the serialization handling for the new model_metadata object from the native process.
Benjamin Trent 5 năm trước cách đây
mục cha
commit
b99234beec
18 tập tin đã thay đổi với 834 bổ sung34 xóa
  1. 3 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/persistence/InferenceIndexConstants.java
  2. 242 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TotalFeatureImportance.java
  3. 112 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TrainedModelMetadata.java
  4. 3 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java
  5. 45 1
      x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/inference_index_template.json
  6. 68 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TotalFeatureImportanceTests.java
  7. 58 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TrainedModelMetadataTests.java
  8. 0 2
      x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java
  9. 0 2
      x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java
  10. 18 1
      x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java
  11. 6 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java
  12. 77 12
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersister.java
  13. 25 4
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java
  14. 65 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/ModelMetadata.java
  15. 75 2
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java
  16. 25 3
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersisterTests.java
  17. 8 1
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResultTests.java
  18. 4 4
      x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_stats_crud.yml

+ 3 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/persistence/InferenceIndexConstants.java

@@ -16,8 +16,10 @@ public final class InferenceIndexConstants {
      * version: 7.8.0:
      *  - adds inference_config definition to trained model config
      *
+     * version: 7.10.0: 000003
+     *  - adds trained_model_metadata object
      */
-    public static final String INDEX_VERSION = "000002";
+    public static final String INDEX_VERSION = "000003";
     public static final String INDEX_NAME_PREFIX = ".ml-inference-";
     public static final String INDEX_PATTERN = INDEX_NAME_PREFIX + "*";
     public static final String LATEST_INDEX_NAME = INDEX_NAME_PREFIX + INDEX_VERSION;

+ 242 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TotalFeatureImportance.java

@@ -0,0 +1,242 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata;
+
+import org.elasticsearch.common.Nullable;
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ToXContentObject;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
+import java.util.Objects;
+
+public class TotalFeatureImportance implements ToXContentObject, Writeable {
+
+    private static final String NAME = "total_feature_importance";
+    public static final ParseField FEATURE_NAME = new ParseField("feature_name");
+    public static final ParseField IMPORTANCE = new ParseField("importance");
+    public static final ParseField CLASSES = new ParseField("classes");
+    public static final ParseField MEAN_MAGNITUDE = new ParseField("mean_magnitude");
+    public static final ParseField MIN = new ParseField("min");
+    public static final ParseField MAX = new ParseField("max");
+
+    // These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly
+    public static final ConstructingObjectParser<TotalFeatureImportance, Void> LENIENT_PARSER = createParser(true);
+    public static final ConstructingObjectParser<TotalFeatureImportance, Void> STRICT_PARSER = createParser(false);
+
+    @SuppressWarnings("unchecked")
+    private static ConstructingObjectParser<TotalFeatureImportance, Void> createParser(boolean ignoreUnknownFields) {
+        ConstructingObjectParser<TotalFeatureImportance, Void> parser = new ConstructingObjectParser<>(NAME,
+            ignoreUnknownFields,
+            a -> new TotalFeatureImportance((String)a[0], (Importance)a[1], (List<ClassImportance>)a[2]));
+        parser.declareString(ConstructingObjectParser.constructorArg(), FEATURE_NAME);
+        parser.declareObject(ConstructingObjectParser.optionalConstructorArg(),
+            ignoreUnknownFields ? Importance.LENIENT_PARSER : Importance.STRICT_PARSER,
+            IMPORTANCE);
+        parser.declareObjectArray(ConstructingObjectParser.optionalConstructorArg(),
+            ignoreUnknownFields ? ClassImportance.LENIENT_PARSER : ClassImportance.STRICT_PARSER,
+            CLASSES);
+        return parser;
+    }
+
+    public static TotalFeatureImportance fromXContent(XContentParser parser, boolean lenient) throws IOException {
+        return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null);
+    }
+
+    public final String featureName;
+    public final Importance importance;
+    public final List<ClassImportance> classImportances;
+
+    public TotalFeatureImportance(StreamInput in) throws IOException {
+        this.featureName = in.readString();
+        this.importance = in.readOptionalWriteable(Importance::new);
+        this.classImportances = in.readList(ClassImportance::new);
+    }
+
+    TotalFeatureImportance(String featureName, @Nullable Importance importance, @Nullable List<ClassImportance> classImportances) {
+        this.featureName = featureName;
+        this.importance = importance;
+        this.classImportances = classImportances == null ? Collections.emptyList() : classImportances;
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeString(featureName);
+        out.writeOptionalWriteable(importance);
+        out.writeList(classImportances);
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        builder.field(FEATURE_NAME.getPreferredName(), featureName);
+        if (importance != null) {
+            builder.field(IMPORTANCE.getPreferredName(), importance);
+        }
+        if (classImportances.isEmpty() == false) {
+            builder.field(CLASSES.getPreferredName(), classImportances);
+        }
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        TotalFeatureImportance that = (TotalFeatureImportance) o;
+        return Objects.equals(that.importance, importance)
+            && Objects.equals(featureName, that.featureName)
+            && Objects.equals(classImportances, that.classImportances);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(featureName, importance, classImportances);
+    }
+
+    public static class Importance implements ToXContentObject, Writeable {
+        private static final String NAME = "importance";
+
+        // These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly
+        public static final ConstructingObjectParser<Importance, Void> LENIENT_PARSER = createParser(true);
+        public static final ConstructingObjectParser<Importance, Void> STRICT_PARSER = createParser(false);
+
+        private static ConstructingObjectParser<Importance, Void> createParser(boolean ignoreUnknownFields) {
+            ConstructingObjectParser<Importance, Void> parser = new ConstructingObjectParser<>(NAME,
+                ignoreUnknownFields,
+                a -> new Importance((double)a[0], (double)a[1], (double)a[2]));
+            parser.declareDouble(ConstructingObjectParser.constructorArg(), MEAN_MAGNITUDE);
+            parser.declareDouble(ConstructingObjectParser.constructorArg(), MIN);
+            parser.declareDouble(ConstructingObjectParser.constructorArg(), MAX);
+            return parser;
+        }
+
+        private final double meanMagnitude;
+        private final double min;
+        private final double max;
+
+        public Importance(double meanMagnitude, double min, double max) {
+            this.meanMagnitude = meanMagnitude;
+            this.min = min;
+            this.max = max;
+        }
+
+        public Importance(StreamInput in) throws IOException {
+            this.meanMagnitude = in.readDouble();
+            this.min = in.readDouble();
+            this.max = in.readDouble();
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            Importance that = (Importance) o;
+            return Double.compare(that.meanMagnitude, meanMagnitude) == 0 &&
+                Double.compare(that.min, min) == 0 &&
+                Double.compare(that.max, max) == 0;
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(meanMagnitude, min, max);
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            out.writeDouble(meanMagnitude);
+            out.writeDouble(min);
+            out.writeDouble(max);
+        }
+
+        @Override
+        public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+            builder.startObject();
+            builder.field(MEAN_MAGNITUDE.getPreferredName(), meanMagnitude);
+            builder.field(MIN.getPreferredName(), min);
+            builder.field(MAX.getPreferredName(), max);
+            builder.endObject();
+            return builder;
+        }
+    }
+
+    public static class ClassImportance implements ToXContentObject, Writeable {
+        private static final String NAME = "total_class_importance";
+
+        public static final ParseField CLASS_NAME = new ParseField("class_name");
+        public static final ParseField IMPORTANCE = new ParseField("importance");
+
+        // These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly
+        public static final ConstructingObjectParser<ClassImportance, Void> LENIENT_PARSER = createParser(true);
+        public static final ConstructingObjectParser<ClassImportance, Void> STRICT_PARSER = createParser(false);
+
+        private static ConstructingObjectParser<ClassImportance, Void> createParser(boolean ignoreUnknownFields) {
+            ConstructingObjectParser<ClassImportance, Void> parser = new ConstructingObjectParser<>(NAME,
+                ignoreUnknownFields,
+                a -> new ClassImportance((String)a[0], (Importance)a[1]));
+            parser.declareString(ConstructingObjectParser.constructorArg(), CLASS_NAME);
+            parser.declareObject(ConstructingObjectParser.constructorArg(),
+                ignoreUnknownFields ? Importance.LENIENT_PARSER : Importance.STRICT_PARSER,
+                IMPORTANCE);
+            return parser;
+        }
+
+        public static ClassImportance fromXContent(XContentParser parser, boolean lenient) throws IOException {
+            return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null);
+        }
+
+        public final String className;
+        public final Importance importance;
+
+        public ClassImportance(StreamInput in) throws IOException {
+            this.className = in.readString();
+            this.importance = new Importance(in);
+        }
+
+        ClassImportance(String className, Importance importance) {
+            this.className = className;
+            this.importance = importance;
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            out.writeString(className);
+            importance.writeTo(out);
+        }
+
+        @Override
+        public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+            builder.startObject();
+            builder.field(CLASS_NAME.getPreferredName(), className);
+            builder.field(IMPORTANCE.getPreferredName(), importance);
+            builder.endObject();
+            return builder;
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            ClassImportance that = (ClassImportance) o;
+            return Objects.equals(that.importance, importance) && Objects.equals(className, that.className);
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(className, importance);
+        }
+
+    }
+}

+ 112 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TrainedModelMetadata.java

@@ -0,0 +1,112 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata;
+
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ToXContentObject;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
+import java.util.Objects;
+
+public class TrainedModelMetadata implements ToXContentObject, Writeable {
+
+    public static final String NAME = "trained_model_metadata";
+    public static final ParseField TOTAL_FEATURE_IMPORTANCE = new ParseField("total_feature_importance");
+    public static final ParseField MODEL_ID = new ParseField("model_id");
+
+    // These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly
+    public static final ConstructingObjectParser<TrainedModelMetadata, Void> LENIENT_PARSER = createParser(true);
+    public static final ConstructingObjectParser<TrainedModelMetadata, Void> STRICT_PARSER = createParser(false);
+
+    @SuppressWarnings("unchecked")
+    private static ConstructingObjectParser<TrainedModelMetadata, Void> createParser(boolean ignoreUnknownFields) {
+        ConstructingObjectParser<TrainedModelMetadata, Void> parser = new ConstructingObjectParser<>(NAME,
+            ignoreUnknownFields,
+            a -> new TrainedModelMetadata((String)a[0], (List<TotalFeatureImportance>)a[1]));
+        parser.declareString(ConstructingObjectParser.constructorArg(), MODEL_ID);
+        parser.declareObjectArray(ConstructingObjectParser.constructorArg(),
+            ignoreUnknownFields ? TotalFeatureImportance.LENIENT_PARSER : TotalFeatureImportance.STRICT_PARSER,
+            TOTAL_FEATURE_IMPORTANCE);
+        return parser;
+    }
+
+    public static TrainedModelMetadata fromXContent(XContentParser parser, boolean lenient) throws IOException {
+        return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null);
+    }
+
+    public static String docId(String modelId) {
+        return NAME + "-" + modelId;
+    }
+
+    private final List<TotalFeatureImportance> totalFeatureImportances;
+    private final String modelId;
+
+    public TrainedModelMetadata(StreamInput in) throws IOException {
+        this.modelId = in.readString();
+        this.totalFeatureImportances = in.readList(TotalFeatureImportance::new);
+    }
+
+    public TrainedModelMetadata(String modelId, List<TotalFeatureImportance> totalFeatureImportances) {
+        this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
+        this.totalFeatureImportances = Collections.unmodifiableList(totalFeatureImportances);
+    }
+
+    public String getModelId() {
+        return modelId;
+    }
+
+    public String getDocId() {
+        return docId(modelId);
+    }
+
+    public List<TotalFeatureImportance> getTotalFeatureImportances() {
+        return totalFeatureImportances;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        TrainedModelMetadata that = (TrainedModelMetadata) o;
+        return Objects.equals(totalFeatureImportances, that.totalFeatureImportances) &&
+            Objects.equals(modelId, that.modelId);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(totalFeatureImportances, modelId);
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeString(modelId);
+        out.writeList(totalFeatureImportances);
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) {
+            builder.field(InferenceIndexConstants.DOC_TYPE.getPreferredName(), NAME);
+        }
+        builder.field(MODEL_ID.getPreferredName(), modelId);
+        builder.field(TOTAL_FEATURE_IMPORTANCE.getPreferredName(), totalFeatureImportances);
+        builder.endObject();
+        return builder;
+    }
+}

+ 3 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java

@@ -92,12 +92,15 @@ public final class Messages {
 
     public static final String INFERENCE_TRAINED_MODEL_EXISTS = "Trained machine learning model [{0}] already exists";
     public static final String INFERENCE_TRAINED_MODEL_DOC_EXISTS = "Trained machine learning model chunked doc [{0}][{1}] already exists";
+    public static final String INFERENCE_TRAINED_MODEL_METADATA_EXISTS = "Trained machine learning model metadata [{0}] already exists";
     public static final String INFERENCE_FAILED_TO_STORE_MODEL = "Failed to store trained machine learning model [{0}]";
+    public static final String INFERENCE_FAILED_TO_STORE_MODEL_METADATA = "Failed to store trained machine learning model metadata [{0}]";
     public static final String INFERENCE_NOT_FOUND = "Could not find trained model [{0}]";
     public static final String INFERENCE_NOT_FOUND_MULTIPLE = "Could not find trained models {0}";
     public static final String INFERENCE_CONFIG_NOT_SUPPORTED_ON_VERSION =
         "Configuration [{0}] requires minimum node version [{1}] (current minimum node version [{2}]";
     public static final String MODEL_DEFINITION_NOT_FOUND = "Could not find trained model definition [{0}]";
+    public static final String MODEL_METADATA_NOT_FOUND = "Could not find trained model metadata [{0}]";
     public static final String INFERENCE_CANNOT_DELETE_MODEL =
         "Unable to delete model [{0}]";
     public static final String MODEL_DEFINITION_TRUNCATED =

+ 45 - 1
x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/inference_index_template.json

@@ -2,7 +2,7 @@
   "order" : 0,
   "version" : ${xpack.ml.version.id},
   "index_patterns" : [
-    ".ml-inference-000002"
+    ".ml-inference-000003"
   ],
   "settings" : {
     "index" : {
@@ -70,6 +70,50 @@
         },
         "inference_config": {
           "enabled": false
+        },
+        "total_feature_importance": {
+          "type": "nested",
+          "dynamic": "false",
+          "properties": {
+            "importance": {
+              "properties": {
+                "min": {
+                  "type": "double"
+                },
+                "max": {
+                  "type": "double"
+                },
+                "mean_magnitude": {
+                  "type": "double"
+                }
+              }
+            },
+            "feature_name": {
+              "type": "keyword"
+            },
+            "classes": {
+              "type": "nested",
+              "dynamic": "false",
+              "properties": {
+                "importance": {
+                  "properties": {
+                    "min": {
+                      "type": "double"
+                    },
+                    "max": {
+                      "type": "double"
+                    },
+                    "mean_magnitude": {
+                      "type": "double"
+                    }
+                  }
+                },
+                "class_name": {
+                  "type": "keyword"
+                }
+              }
+            }
+          }
         }
       }
     }

+ 68 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TotalFeatureImportanceTests.java

@@ -0,0 +1,68 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata;
+
+import org.elasticsearch.Version;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
+import org.junit.Before;
+
+import java.io.IOException;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+
+public class TotalFeatureImportanceTests extends AbstractBWCSerializationTestCase<TotalFeatureImportance> {
+
+    private boolean lenient;
+
+    public static TotalFeatureImportance randomInstance() {
+        return new TotalFeatureImportance(
+            randomAlphaOfLength(10),
+            randomBoolean() ? null : randomImportance(),
+            randomBoolean() ?
+                null :
+                Stream.generate(() -> new TotalFeatureImportance.ClassImportance(randomAlphaOfLength(10), randomImportance()))
+                    .limit(randomIntBetween(1, 10))
+                    .collect(Collectors.toList())
+            );
+    }
+
+    private static TotalFeatureImportance.Importance randomImportance() {
+        return new TotalFeatureImportance.Importance(randomDouble(), randomDouble(), randomDouble());
+    }
+
+    @Before
+    public void chooseStrictOrLenient() {
+        lenient = randomBoolean();
+    }
+
+    @Override
+    protected TotalFeatureImportance createTestInstance() {
+        return randomInstance();
+    }
+
+    @Override
+    protected Writeable.Reader<TotalFeatureImportance> instanceReader() {
+        return TotalFeatureImportance::new;
+    }
+
+    @Override
+    protected TotalFeatureImportance doParseInstance(XContentParser parser) throws IOException {
+        return TotalFeatureImportance.fromXContent(parser, lenient);
+    }
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return lenient;
+    }
+
+    @Override
+    protected TotalFeatureImportance mutateInstanceForVersion(TotalFeatureImportance instance, Version version) {
+        return instance;
+    }
+}

+ 58 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TrainedModelMetadataTests.java

@@ -0,0 +1,58 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata;
+
+import org.elasticsearch.Version;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
+import org.junit.Before;
+
+import java.io.IOException;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+
+public class TrainedModelMetadataTests extends AbstractBWCSerializationTestCase<TrainedModelMetadata> {
+
+    private boolean lenient;
+
+    public static TrainedModelMetadata randomInstance() {
+        return new TrainedModelMetadata(
+            randomAlphaOfLength(10),
+            Stream.generate(TotalFeatureImportanceTests::randomInstance).limit(randomIntBetween(1, 10)).collect(Collectors.toList()));
+    }
+
+    @Before
+    public void chooseStrictOrLenient() {
+        lenient = randomBoolean();
+    }
+
+    @Override
+    protected TrainedModelMetadata createTestInstance() {
+        return randomInstance();
+    }
+
+    @Override
+    protected Writeable.Reader<TrainedModelMetadata> instanceReader() {
+        return TrainedModelMetadata::new;
+    }
+
+    @Override
+    protected TrainedModelMetadata doParseInstance(XContentParser parser) throws IOException {
+        return TrainedModelMetadata.fromXContent(parser, lenient);
+    }
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return lenient;
+    }
+
+    @Override
+    protected TrainedModelMetadata mutateInstanceForVersion(TrainedModelMetadata instance, Version version) {
+        return instance;
+    }
+}

+ 0 - 2
x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java

@@ -6,7 +6,6 @@
 package org.elasticsearch.xpack.ml.integration;
 
 import org.apache.logging.log4j.message.ParameterizedMessage;
-import org.apache.lucene.util.LuceneTestCase;
 import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.action.DocWriteRequest;
@@ -75,7 +74,6 @@ import static org.hamcrest.Matchers.not;
 import static org.hamcrest.Matchers.nullValue;
 import static org.hamcrest.Matchers.startsWith;
 
-@LuceneTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/pull/1456")
 public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
 
     private static final String BOOLEAN_FIELD = "boolean-field";

+ 0 - 2
x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java

@@ -5,7 +5,6 @@
  */
 package org.elasticsearch.xpack.ml.integration;
 
-import org.apache.lucene.util.LuceneTestCase;
 import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.action.DocWriteRequest;
 import org.elasticsearch.action.bulk.BulkRequestBuilder;
@@ -58,7 +57,6 @@ import static org.hamcrest.Matchers.lessThan;
 import static org.hamcrest.Matchers.lessThanOrEqualTo;
 import static org.hamcrest.Matchers.not;
 
-@LuceneTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/pull/1456")
 public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
 
     private static final String NUMERICAL_FEATURE_FIELD = "feature";

+ 18 - 1
x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java

@@ -22,8 +22,11 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelInputTests;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TotalFeatureImportanceTests;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TrainedModelMetadata;
 import org.elasticsearch.xpack.ml.MlSingleNodeTestCase;
 import org.elasticsearch.xpack.ml.dataframe.process.ChunkedTrainedModelPersister;
+import org.elasticsearch.xpack.ml.dataframe.process.results.ModelMetadata;
 import org.elasticsearch.xpack.ml.dataframe.process.results.TrainedModelDefinitionChunk;
 import org.elasticsearch.xpack.ml.extractor.DocValueField;
 import org.elasticsearch.xpack.ml.extractor.ExtractedField;
@@ -40,8 +43,11 @@ import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
 import java.util.Set;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
 
 import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.startsWith;
 
 public class ChunkedTrainedModelPersisterIT extends MlSingleNodeTestCase {
 
@@ -76,10 +82,14 @@ public class ChunkedTrainedModelPersisterIT extends MlSingleNodeTestCase {
 
         //Accuracy for size is not tested here
         ModelSizeInfo modelSizeInfo = ModelSizeInfoTests.createRandom();
-        persister.createAndIndexInferenceModelMetadata(modelSizeInfo);
+        persister.createAndIndexInferenceModelConfig(modelSizeInfo);
         for (int i = 0; i < chunks.size(); i++) {
             persister.createAndIndexInferenceModelDoc(new TrainedModelDefinitionChunk(chunks.get(i), i, i == (chunks.size() - 1)));
         }
+        ModelMetadata modelMetadata = new ModelMetadata(Stream.generate(TotalFeatureImportanceTests::randomInstance)
+            .limit(randomIntBetween(1, 10))
+            .collect(Collectors.toList()));
+        persister.createAndIndexInferenceModelMetadata(modelMetadata);
 
         PlainActionFuture<Tuple<Long, Set<String>>> getIdsFuture = new PlainActionFuture<>();
         trainedModelProvider.expandIds(modelId + "*", false, PageParams.defaultParams(), Collections.emptySet(), getIdsFuture);
@@ -93,6 +103,13 @@ public class ChunkedTrainedModelPersisterIT extends MlSingleNodeTestCase {
         assertThat(storedConfig.getCompressedDefinition(), equalTo(compressedDefinition));
         assertThat(storedConfig.getEstimatedOperations(), equalTo((long)modelSizeInfo.numOperations()));
         assertThat(storedConfig.getEstimatedHeapMemory(), equalTo(modelSizeInfo.ramBytesUsed()));
+
+        PlainActionFuture<TrainedModelMetadata> getTrainedMetadataFuture = new PlainActionFuture<>();
+        trainedModelProvider.getTrainedModelMetadata(ids.v2().iterator().next(), getTrainedMetadataFuture);
+
+        TrainedModelMetadata storedMetadata = getTrainedMetadataFuture.actionGet();
+        assertThat(storedMetadata.getModelId(), startsWith(modelId));
+        assertThat(storedMetadata.getTotalFeatureImportances(), equalTo(modelMetadata.getFeatureImportances()));
     }
 
     private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String modelId) {

+ 6 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java

@@ -18,6 +18,7 @@ import org.elasticsearch.xpack.core.ml.job.messages.Messages;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
 import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
+import org.elasticsearch.xpack.ml.dataframe.process.results.ModelMetadata;
 import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
 import org.elasticsearch.xpack.ml.dataframe.process.results.TrainedModelDefinitionChunk;
 import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder;
@@ -141,12 +142,16 @@ public class AnalyticsResultProcessor {
         }
         ModelSizeInfo modelSize = result.getModelSizeInfo();
         if (modelSize != null) {
-            latestModelId = chunkedTrainedModelPersister.createAndIndexInferenceModelMetadata(modelSize);
+            latestModelId = chunkedTrainedModelPersister.createAndIndexInferenceModelConfig(modelSize);
         }
         TrainedModelDefinitionChunk trainedModelDefinitionChunk = result.getTrainedModelDefinitionChunk();
         if (trainedModelDefinitionChunk != null && isCancelled == false) {
             chunkedTrainedModelPersister.createAndIndexInferenceModelDoc(trainedModelDefinitionChunk);
         }
+        ModelMetadata modelMetadata = result.getModelMetadata();
+        if (modelMetadata != null) {
+            chunkedTrainedModelPersister.createAndIndexInferenceModelMetadata(modelMetadata);
+        }
         MemoryUsage memoryUsage = result.getMemoryUsage();
         if (memoryUsage != null) {
             processMemoryUsage(memoryUsage);

+ 77 - 12
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersister.java

@@ -23,9 +23,11 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
 import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TrainedModelMetadata;
 import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.core.security.user.XPackUser;
+import org.elasticsearch.xpack.ml.dataframe.process.results.ModelMetadata;
 import org.elasticsearch.xpack.ml.dataframe.process.results.TrainedModelDefinitionChunk;
 import org.elasticsearch.xpack.ml.extractor.ExtractedField;
 import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
@@ -75,7 +77,7 @@ public class ChunkedTrainedModelPersister {
     }
 
     public void createAndIndexInferenceModelDoc(TrainedModelDefinitionChunk trainedModelDefinitionChunk) {
-        if (Strings.isNullOrEmpty(this.currentModelId.get())) {
+        if (readyToStoreNewModel.get()) {
             failureHandler.accept(ExceptionsHelper.serverError(
                 "chunked inference model definition is attempting to be stored before trained model configuration"
             ));
@@ -98,7 +100,7 @@ public class ChunkedTrainedModelPersister {
         }
     }
 
-    public String createAndIndexInferenceModelMetadata(ModelSizeInfo inferenceModelSize) {
+    public String createAndIndexInferenceModelConfig(ModelSizeInfo inferenceModelSize) {
         if (readyToStoreNewModel.compareAndSet(true, false) == false) {
             failureHandler.accept(ExceptionsHelper.serverError(
                 "new inference model is attempting to be stored before completion previous model storage"
@@ -106,19 +108,41 @@ public class ChunkedTrainedModelPersister {
             return null;
         }
         TrainedModelConfig trainedModelConfig = createTrainedModelConfig(inferenceModelSize);
-        CountDownLatch latch = storeTrainedModelMetadata(trainedModelConfig);
+        CountDownLatch latch = storeTrainedModelConfig(trainedModelConfig);
         try {
             if (latch.await(STORE_TIMEOUT_SEC, TimeUnit.SECONDS) == false) {
-                LOGGER.error("[{}] Timed out (30s) waiting for inference model metadata to be stored", analytics.getId());
+                LOGGER.error("[{}] Timed out (30s) waiting for inference model config to be stored", analytics.getId());
             }
         } catch (InterruptedException e) {
             Thread.currentThread().interrupt();
             this.readyToStoreNewModel.set(true);
-            failureHandler.accept(ExceptionsHelper.serverError("interrupted waiting for inference model metadata to be stored"));
+            failureHandler.accept(ExceptionsHelper.serverError("interrupted waiting for inference model config to be stored"));
         }
         return trainedModelConfig.getModelId();
     }
 
+    public void createAndIndexInferenceModelMetadata(ModelMetadata modelMetadata) {
+        if (Strings.isNullOrEmpty(this.currentModelId.get())) {
+            failureHandler.accept(ExceptionsHelper.serverError(
+                "inference model metadata is attempting to be stored before trained model configuration"
+            ));
+            return;
+        }
+        TrainedModelMetadata trainedModelMetadata = new TrainedModelMetadata(this.currentModelId.get(),
+            modelMetadata.getFeatureImportances());
+
+
+        CountDownLatch latch = storeTrainedModelMetadata(trainedModelMetadata);
+        try {
+            if (latch.await(STORE_TIMEOUT_SEC, TimeUnit.SECONDS) == false) {
+                LOGGER.error("[{}] Timed out (30s) waiting for inference model metadata to be stored", analytics.getId());
+            }
+        } catch (InterruptedException e) {
+            Thread.currentThread().interrupt();
+            failureHandler.accept(ExceptionsHelper.serverError("interrupted waiting for inference model metadata to be stored"));
+        }
+    }
+
     private CountDownLatch storeTrainedModelDoc(TrainedModelDefinitionDoc trainedModelDefinitionDoc) {
         CountDownLatch latch = new CountDownLatch(1);
 
@@ -154,7 +178,6 @@ public class ChunkedTrainedModelPersister {
                     analytics.getId(),
                     this.currentModelId.get());
                 auditor.info(analytics.getId(), "Stored trained model with id [" + this.currentModelId.get() + "]");
-                this.currentModelId.set("");
                 readyToStoreNewModel.set(true);
                 provider.refreshInferenceIndex(refreshListener);
             },
@@ -171,26 +194,68 @@ public class ChunkedTrainedModelPersister {
         provider.storeTrainedModelDefinitionDoc(trainedModelDefinitionDoc, storeListener);
         return latch;
     }
-    private CountDownLatch storeTrainedModelMetadata(TrainedModelConfig trainedModelConfig) {
+
+    private CountDownLatch storeTrainedModelMetadata(TrainedModelMetadata trainedModelMetadata) {
+        CountDownLatch latch = new CountDownLatch(1);
+
+        // Latch is attached to this action as it is the last one to execute.
+        ActionListener<RefreshResponse> refreshListener = new LatchedActionListener<>(ActionListener.wrap(
+            refreshed -> {
+                if (refreshed != null) {
+                    LOGGER.debug(() -> new ParameterizedMessage(
+                        "[{}] refreshed inference index after model metadata store",
+                        analytics.getId()
+                    ));
+                }
+            },
+            e -> LOGGER.warn(
+                new ParameterizedMessage("[{}] failed to refresh inference index after model metadata store", analytics.getId()),
+                e)
+        ), latch);
+
+        // First, store the model and refresh is necessary
+        ActionListener<Void> storeListener = ActionListener.wrap(
+            r -> {
+                LOGGER.debug(
+                    "[{}] stored trained model metadata with id [{}]",
+                    analytics.getId(),
+                    this.currentModelId.get());
+                readyToStoreNewModel.set(true);
+                provider.refreshInferenceIndex(refreshListener);
+            },
+            e -> {
+                this.readyToStoreNewModel.set(true);
+                failureHandler.accept(ExceptionsHelper.serverError(
+                    "error storing trained model metadata with id [{}]",
+                    e,
+                    trainedModelMetadata.getModelId()));
+                refreshListener.onResponse(null);
+            }
+        );
+        provider.storeTrainedModelMetadata(trainedModelMetadata, storeListener);
+        return latch;
+    }
+
+    private CountDownLatch storeTrainedModelConfig(TrainedModelConfig trainedModelConfig) {
         CountDownLatch latch = new CountDownLatch(1);
         ActionListener<Boolean> storeListener = ActionListener.wrap(
             aBoolean -> {
                 if (aBoolean == false) {
-                    LOGGER.error("[{}] Storing trained model metadata responded false", analytics.getId());
+                    LOGGER.error("[{}] Storing trained model config responded false", analytics.getId());
                     readyToStoreNewModel.set(true);
-                    failureHandler.accept(ExceptionsHelper.serverError("storing trained model responded false"));
+                    failureHandler.accept(ExceptionsHelper.serverError("storing trained model config false"));
                 } else {
-                    LOGGER.debug("[{}] Stored trained model metadata with id [{}]", analytics.getId(), trainedModelConfig.getModelId());
+                    LOGGER.debug("[{}] Stored trained model config with id [{}]", analytics.getId(), trainedModelConfig.getModelId());
                 }
             },
             e -> {
                 readyToStoreNewModel.set(true);
-                failureHandler.accept(ExceptionsHelper.serverError("error storing trained model metadata with id [{}]",
+                failureHandler.accept(ExceptionsHelper.serverError("error storing trained model config with id [{}]",
                     e,
                     trainedModelConfig.getModelId()));
             }
         );
-        provider.storeTrainedModelMetadata(trainedModelConfig, new LatchedActionListener<>(storeListener, latch));
+        provider.storeTrainedModelConfig(trainedModelConfig, new LatchedActionListener<>(storeListener, latch));
         return latch;
     }
 

+ 25 - 4
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java

@@ -33,6 +33,7 @@ public class AnalyticsResult implements ToXContentObject {
     private static final ParseField OUTLIER_DETECTION_STATS = new ParseField("outlier_detection_stats");
     private static final ParseField CLASSIFICATION_STATS = new ParseField("classification_stats");
     private static final ParseField REGRESSION_STATS = new ParseField("regression_stats");
+    private static final ParseField MODEL_METADATA = new ParseField("model_metadata");
 
     public static final ConstructingObjectParser<AnalyticsResult, Void> PARSER = new ConstructingObjectParser<>(TYPE.getPreferredName(),
             a -> new AnalyticsResult(
@@ -43,7 +44,8 @@ public class AnalyticsResult implements ToXContentObject {
                 (ClassificationStats) a[4],
                 (RegressionStats) a[5],
                 (ModelSizeInfo) a[6],
-                (TrainedModelDefinitionChunk) a[7]
+                (TrainedModelDefinitionChunk) a[7],
+                (ModelMetadata) a[8]
             ));
 
     static {
@@ -55,6 +57,7 @@ public class AnalyticsResult implements ToXContentObject {
         PARSER.declareObject(optionalConstructorArg(), RegressionStats.STRICT_PARSER, REGRESSION_STATS);
         PARSER.declareObject(optionalConstructorArg(), ModelSizeInfo.PARSER, MODEL_SIZE_INFO);
         PARSER.declareObject(optionalConstructorArg(), TrainedModelDefinitionChunk.PARSER, COMPRESSED_INFERENCE_MODEL);
+        PARSER.declareObject(optionalConstructorArg(), ModelMetadata.PARSER, MODEL_METADATA);
     }
 
     private final RowResults rowResults;
@@ -65,6 +68,7 @@ public class AnalyticsResult implements ToXContentObject {
     private final RegressionStats regressionStats;
     private final ModelSizeInfo modelSizeInfo;
     private final TrainedModelDefinitionChunk trainedModelDefinitionChunk;
+    private final ModelMetadata modelMetadata;
 
     private AnalyticsResult(@Nullable RowResults rowResults,
                             @Nullable PhaseProgress phaseProgress,
@@ -73,7 +77,8 @@ public class AnalyticsResult implements ToXContentObject {
                             @Nullable ClassificationStats classificationStats,
                             @Nullable RegressionStats regressionStats,
                             @Nullable ModelSizeInfo modelSizeInfo,
-                            @Nullable TrainedModelDefinitionChunk trainedModelDefinitionChunk) {
+                            @Nullable TrainedModelDefinitionChunk trainedModelDefinitionChunk,
+                            @Nullable ModelMetadata modelMetadata) {
         this.rowResults = rowResults;
         this.phaseProgress = phaseProgress;
         this.memoryUsage = memoryUsage;
@@ -82,6 +87,7 @@ public class AnalyticsResult implements ToXContentObject {
         this.regressionStats = regressionStats;
         this.modelSizeInfo = modelSizeInfo;
         this.trainedModelDefinitionChunk = trainedModelDefinitionChunk;
+        this.modelMetadata = modelMetadata;
     }
 
     public RowResults getRowResults() {
@@ -116,6 +122,10 @@ public class AnalyticsResult implements ToXContentObject {
         return trainedModelDefinitionChunk;
     }
 
+    public ModelMetadata getModelMetadata() {
+        return modelMetadata;
+    }
+
     @Override
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
         builder.startObject();
@@ -143,6 +153,9 @@ public class AnalyticsResult implements ToXContentObject {
         if (trainedModelDefinitionChunk != null) {
             builder.field(COMPRESSED_INFERENCE_MODEL.getPreferredName(), trainedModelDefinitionChunk);
         }
+        if (modelMetadata != null) {
+            builder.field(MODEL_METADATA.getPreferredName(), modelMetadata);
+        }
         builder.endObject();
         return builder;
     }
@@ -164,13 +177,14 @@ public class AnalyticsResult implements ToXContentObject {
             && Objects.equals(classificationStats, that.classificationStats)
             && Objects.equals(modelSizeInfo, that.modelSizeInfo)
             && Objects.equals(trainedModelDefinitionChunk, that.trainedModelDefinitionChunk)
+            && Objects.equals(modelMetadata, that.modelMetadata)
             && Objects.equals(regressionStats, that.regressionStats);
     }
 
     @Override
     public int hashCode() {
         return Objects.hash(rowResults, phaseProgress, memoryUsage, outlierDetectionStats, classificationStats,
-            regressionStats, modelSizeInfo, trainedModelDefinitionChunk);
+            regressionStats, modelSizeInfo, trainedModelDefinitionChunk, modelMetadata);
     }
 
     public static Builder builder() {
@@ -187,6 +201,7 @@ public class AnalyticsResult implements ToXContentObject {
         private RegressionStats regressionStats;
         private ModelSizeInfo modelSizeInfo;
         private TrainedModelDefinitionChunk trainedModelDefinitionChunk;
+        private ModelMetadata modelMetadata;
 
         private Builder() {}
 
@@ -230,6 +245,11 @@ public class AnalyticsResult implements ToXContentObject {
             return this;
         }
 
+        public Builder setModelMetadata(ModelMetadata modelMetadata) {
+            this.modelMetadata = modelMetadata;
+            return this;
+        }
+
         public AnalyticsResult build() {
             return new AnalyticsResult(
                 rowResults,
@@ -239,7 +259,8 @@ public class AnalyticsResult implements ToXContentObject {
                 classificationStats,
                 regressionStats,
                 modelSizeInfo,
-                trainedModelDefinitionChunk
+                trainedModelDefinitionChunk,
+                modelMetadata
             );
         }
     }

+ 65 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/ModelMetadata.java

@@ -0,0 +1,65 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+
+package org.elasticsearch.xpack.ml.dataframe.process.results;
+
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ToXContentObject;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TotalFeatureImportance;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Objects;
+
+import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
+
+public class ModelMetadata implements ToXContentObject {
+
+    public static final ParseField TOTAL_FEATURE_IMPORTANCE = new ParseField("total_feature_importance");
+
+    @SuppressWarnings("unchecked")
+    public static final ConstructingObjectParser<ModelMetadata, Void> PARSER = new ConstructingObjectParser<>(
+        "trained_model_metadata",
+        a -> new ModelMetadata((List<TotalFeatureImportance>) a[0]));
+
+    static {
+        PARSER.declareObjectArray(constructorArg(), TotalFeatureImportance.STRICT_PARSER, TOTAL_FEATURE_IMPORTANCE);
+    }
+
+    private final List<TotalFeatureImportance> featureImportances;
+
+    public ModelMetadata(List<TotalFeatureImportance> featureImportances) {
+        this.featureImportances = featureImportances;
+    }
+
+    public List<TotalFeatureImportance> getFeatureImportances() {
+        return featureImportances;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        ModelMetadata that = (ModelMetadata) o;
+        return Objects.equals(featureImportances, that.featureImportances);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(featureImportances);
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        builder.field(TOTAL_FEATURE_IMPORTANCE.getPreferredName(), featureImportances);
+        builder.endObject();
+        return builder;
+    }
+
+}

+ 75 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java

@@ -74,6 +74,7 @@ import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConst
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceDefinition;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.langident.LangIdentNeuralNetwork;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TrainedModelMetadata;
 import org.elasticsearch.xpack.core.ml.job.messages.Messages;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
@@ -146,8 +147,7 @@ public class TrainedModelProvider {
         storeTrainedModelAndDefinition(trainedModelConfig, listener);
     }
 
-    public void storeTrainedModelMetadata(TrainedModelConfig trainedModelConfig,
-                                          ActionListener<Boolean> listener) {
+    public void storeTrainedModelConfig(TrainedModelConfig trainedModelConfig, ActionListener<Boolean> listener) {
         if (MODELS_STORED_AS_RESOURCE.contains(trainedModelConfig.getModelId())) {
             listener.onFailure(new ResourceAlreadyExistsException(
                 Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, trainedModelConfig.getModelId())));
@@ -206,6 +206,68 @@ public class TrainedModelProvider {
             ));
     }
 
+    public void storeTrainedModelMetadata(TrainedModelMetadata trainedModelMetadata, ActionListener<Void> listener) {
+        if (MODELS_STORED_AS_RESOURCE.contains(trainedModelMetadata.getModelId())) {
+            listener.onFailure(new ResourceAlreadyExistsException(
+                Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, trainedModelMetadata.getModelId())));
+            return;
+        }
+        executeAsyncWithOrigin(client,
+            ML_ORIGIN,
+            IndexAction.INSTANCE,
+            createRequest(trainedModelMetadata.getDocId(), InferenceIndexConstants.LATEST_INDEX_NAME, trainedModelMetadata),
+            ActionListener.wrap(
+                indexResponse -> listener.onResponse(null),
+                e -> {
+                    if (ExceptionsHelper.unwrapCause(e) instanceof VersionConflictEngineException) {
+                        listener.onFailure(new ResourceAlreadyExistsException(
+                            Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_METADATA_EXISTS,
+                                trainedModelMetadata.getModelId())));
+                    } else {
+                        listener.onFailure(
+                            new ElasticsearchStatusException(Messages.INFERENCE_FAILED_TO_STORE_MODEL_METADATA,
+                                RestStatus.INTERNAL_SERVER_ERROR,
+                                e,
+                                trainedModelMetadata.getModelId()));
+                    }
+                }
+            ));
+    }
+
+    public void getTrainedModelMetadata(String modelId, ActionListener<TrainedModelMetadata> listener) {
+        SearchRequest searchRequest = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN)
+            .setQuery(QueryBuilders.constantScoreQuery(QueryBuilders
+                .boolQuery()
+                .filter(QueryBuilders.termQuery(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId))
+                .filter(QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(),
+                    TrainedModelMetadata.NAME))))
+            .setSize(1)
+            // First find the latest index
+            .addSort("_index", SortOrder.DESC)
+            .request();
+        executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, ActionListener.wrap(
+            searchResponse -> {
+                if (searchResponse.getHits().getHits().length == 0) {
+                    listener.onFailure(new ResourceNotFoundException(
+                        Messages.getMessage(Messages.MODEL_METADATA_NOT_FOUND, modelId)));
+                    return;
+                }
+                List<TrainedModelMetadata> metadataList = handleHits(searchResponse.getHits().getHits(),
+                    modelId,
+                    this::parseMetadataLenientlyFromSource);
+                listener.onResponse(metadataList.get(0));
+            },
+            e -> {
+                if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) {
+                    listener.onFailure(new ResourceNotFoundException(
+                        Messages.getMessage(Messages.MODEL_METADATA_NOT_FOUND, modelId)));
+                    return;
+                }
+                listener.onFailure(e);
+            }
+        ));
+    }
+
     public void refreshInferenceIndex(ActionListener<RefreshResponse> listener) {
         executeAsyncWithOrigin(client,
             ML_ORIGIN,
@@ -927,6 +989,17 @@ public class TrainedModelProvider {
         }
     }
 
+    private TrainedModelMetadata parseMetadataLenientlyFromSource(BytesReference source, String modelId) throws IOException {
+        try (InputStream stream = source.streamInput();
+             XContentParser parser = XContentFactory.xContent(XContentType.JSON)
+                 .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, stream)) {
+            return TrainedModelMetadata.fromXContent(parser, true);
+        } catch (IOException e) {
+            logger.error(new ParameterizedMessage("[{}] failed to parse model metadata", modelId), e);
+            throw e;
+        }
+    }
+
     private IndexRequest createRequest(String docId, String index, ToXContentObject body) {
         return createRequest(new IndexRequest(index), docId, body);
     }

+ 25 - 3
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersisterTests.java

@@ -18,7 +18,10 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
 import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
 import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TotalFeatureImportanceTests;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TrainedModelMetadata;
 import org.elasticsearch.xpack.core.security.user.XPackUser;
+import org.elasticsearch.xpack.ml.dataframe.process.results.ModelMetadata;
 import org.elasticsearch.xpack.ml.dataframe.process.results.TrainedModelDefinitionChunk;
 import org.elasticsearch.xpack.ml.extractor.DocValueField;
 import org.elasticsearch.xpack.ml.extractor.ExtractedField;
@@ -35,6 +38,8 @@ import org.mockito.Mockito;
 import java.util.Collections;
 import java.util.List;
 import java.util.Map;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
 
 import static org.hamcrest.Matchers.contains;
 import static org.hamcrest.Matchers.containsString;
@@ -78,7 +83,7 @@ public class ChunkedTrainedModelPersisterTests extends ESTestCase {
             ActionListener<Boolean> storeListener = (ActionListener<Boolean>) invocationOnMock.getArguments()[1];
             storeListener.onResponse(true);
             return null;
-        }).when(trainedModelProvider).storeTrainedModelMetadata(any(TrainedModelConfig.class), any(ActionListener.class));
+        }).when(trainedModelProvider).storeTrainedModelConfig(any(TrainedModelConfig.class), any(ActionListener.class));
 
         doAnswer(invocationOnMock -> {
             ActionListener<Void> storeListener = (ActionListener<Void>) invocationOnMock.getArguments()[1];
@@ -86,22 +91,36 @@ public class ChunkedTrainedModelPersisterTests extends ESTestCase {
             return null;
         }).when(trainedModelProvider).storeTrainedModelDefinitionDoc(any(TrainedModelDefinitionDoc.class), any(ActionListener.class));
 
+        doAnswer(invocationOnMock -> {
+            ActionListener<Void> storeListener = (ActionListener<Void>) invocationOnMock.getArguments()[1];
+            storeListener.onResponse(null);
+            return null;
+        }).when(trainedModelProvider).storeTrainedModelMetadata(any(TrainedModelMetadata.class), any(ActionListener.class));
+
         ChunkedTrainedModelPersister resultProcessor = createChunkedTrainedModelPersister(extractedFieldList, analyticsConfig);
         ModelSizeInfo modelSizeInfo = ModelSizeInfoTests.createRandom();
         TrainedModelDefinitionChunk chunk1 = new TrainedModelDefinitionChunk(randomAlphaOfLength(10), 0, false);
         TrainedModelDefinitionChunk chunk2 = new TrainedModelDefinitionChunk(randomAlphaOfLength(10), 1, true);
+        ModelMetadata modelMetadata = new ModelMetadata(Stream.generate(TotalFeatureImportanceTests::randomInstance)
+            .limit(randomIntBetween(1, 10))
+            .collect(Collectors.toList()));
 
-        resultProcessor.createAndIndexInferenceModelMetadata(modelSizeInfo);
+        resultProcessor.createAndIndexInferenceModelConfig(modelSizeInfo);
         resultProcessor.createAndIndexInferenceModelDoc(chunk1);
         resultProcessor.createAndIndexInferenceModelDoc(chunk2);
+        resultProcessor.createAndIndexInferenceModelMetadata(modelMetadata);
 
         ArgumentCaptor<TrainedModelConfig> storedModelCaptor = ArgumentCaptor.forClass(TrainedModelConfig.class);
-        verify(trainedModelProvider).storeTrainedModelMetadata(storedModelCaptor.capture(), any(ActionListener.class));
+        verify(trainedModelProvider).storeTrainedModelConfig(storedModelCaptor.capture(), any(ActionListener.class));
 
         ArgumentCaptor<TrainedModelDefinitionDoc> storedDocCapture = ArgumentCaptor.forClass(TrainedModelDefinitionDoc.class);
         verify(trainedModelProvider, times(2))
             .storeTrainedModelDefinitionDoc(storedDocCapture.capture(), any(ActionListener.class));
 
+        ArgumentCaptor<TrainedModelMetadata> storedMetadataCaptor = ArgumentCaptor.forClass(TrainedModelMetadata.class);
+        verify(trainedModelProvider, times(1))
+            .storeTrainedModelMetadata(storedMetadataCaptor.capture(), any(ActionListener.class));
+
         TrainedModelConfig storedModel = storedModelCaptor.getValue();
         assertThat(storedModel.getLicenseLevel(), equalTo(License.OperationMode.PLATINUM));
         assertThat(storedModel.getModelId(), containsString(JOB_ID));
@@ -132,6 +151,9 @@ public class ChunkedTrainedModelPersisterTests extends ESTestCase {
         assertThat(storedModel.getModelId(), equalTo(storedDoc1.getModelId()));
         assertThat(storedModel.getModelId(), equalTo(storedDoc2.getModelId()));
 
+        TrainedModelMetadata storedMetadata = storedMetadataCaptor.getValue();
+        assertThat(storedMetadata.getModelId(), equalTo(storedModel.getModelId()));
+
         ArgumentCaptor<String> auditCaptor = ArgumentCaptor.forClass(String.class);
         verify(auditor).info(eq(JOB_ID), auditCaptor.capture());
         assertThat(auditCaptor.getValue(), containsString("Stored trained model with id [" + JOB_ID));

+ 8 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResultTests.java

@@ -16,6 +16,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.stats.common.MemoryUsageTests;
 import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStatsTests;
 import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStatsTests;
 import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TotalFeatureImportanceTests;
 import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
 import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
 import org.elasticsearch.xpack.ml.inference.modelsize.MlModelSizeNamedXContentProvider;
@@ -24,6 +25,8 @@ import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfoTests;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
 
 public class AnalyticsResultTests extends AbstractXContentTestCase<AnalyticsResult> {
 
@@ -38,7 +41,6 @@ public class AnalyticsResultTests extends AbstractXContentTestCase<AnalyticsResu
 
     protected AnalyticsResult createTestInstance() {
         AnalyticsResult.Builder builder = AnalyticsResult.builder();
-
         if (randomBoolean()) {
             builder.setRowResults(RowResultsTests.createRandom());
         }
@@ -64,6 +66,11 @@ public class AnalyticsResultTests extends AbstractXContentTestCase<AnalyticsResu
             String def = randomAlphaOfLengthBetween(100, 1000);
             builder.setTrainedModelDefinitionChunk(new TrainedModelDefinitionChunk(def, randomIntBetween(0, 10), randomBoolean()));
         }
+        if (randomBoolean()) {
+            builder.setModelMetadata(new ModelMetadata(Stream.generate(TotalFeatureImportanceTests::randomInstance)
+                .limit(randomIntBetween(1, 10))
+                .collect(Collectors.toList())));
+        }
         return builder.build();
     }
 

+ 4 - 4
x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_stats_crud.yml

@@ -5,12 +5,12 @@ setup:
         - allowed_warnings
   - do:
       allowed_warnings:
-        - "index [.ml-inference-000002] matches multiple legacy templates [.ml-inference-000002, global], composable templates will only match a single template"
+        - "index [.ml-inference-000003] matches multiple legacy templates [.ml-inference-000003, global], composable templates will only match a single template"
       headers:
         Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
       index:
         id: trained_model_config-a-unused-regression-model1-0
-        index: .ml-inference-000002
+        index: .ml-inference-000003
         body: >
           {
             "model_id": "a-unused-regression-model1",
@@ -27,7 +27,7 @@ setup:
         Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
       index:
         id: trained_model_config-a-unused-regression-model-0
-        index: .ml-inference-000002
+        index: .ml-inference-000003
         body: >
           {
             "model_id": "a-unused-regression-model",
@@ -43,7 +43,7 @@ setup:
         Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
       index:
         id: trained_model_config-a-used-regression-model-0
-        index: .ml-inference-000002
+        index: .ml-inference-000003
         body: >
           {
             "model_id": "a-used-regression-model",