Selaa lähdekoodia

[ML] Remove top level importance from classification inference results (#62486)

As we have decided top level importance for classification is not useful,
it has been removed from the results from the training job. This commit
also removes them from inference.
Dimitris Athanasiou 5 vuotta sitten
vanhempi
commit
bba49aa64f
20 muutettua tiedostoa jossa 647 lisäystä ja 258 poistoa
  1. 7 5
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/results/FeatureImportance.java
  2. 1 1
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/results/FeatureImportanceTests.java
  3. 26 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/AbstractFeatureImportance.java
  4. 32 93
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationFeatureImportance.java
  5. 42 12
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java
  6. 160 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/LegacyFeatureImportance.java
  7. 88 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionFeatureImportance.java
  8. 47 13
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResults.java
  9. 2 21
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/SingleValueInferenceResults.java
  10. 17 16
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java
  11. 70 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationFeatureImportanceTests.java
  12. 6 16
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResultsTests.java
  13. 0 49
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/FeatureImportanceTests.java
  14. 77 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/LegacyFeatureImportanceTests.java
  15. 34 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionFeatureImportanceTests.java
  16. 5 5
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResultsTests.java
  17. 9 9
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinitionTests.java
  18. 4 3
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/aggs/InternalInferenceAggregationTests.java
  19. 6 6
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/aggs/ParsedInference.java
  20. 14 9
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java

+ 7 - 5
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/results/FeatureImportance.java

@@ -47,7 +47,7 @@ public class FeatureImportance implements ToXContentObject {
 
     static {
         PARSER.declareString(constructorArg(), new ParseField(FeatureImportance.FEATURE_NAME));
-        PARSER.declareDouble(constructorArg(), new ParseField(FeatureImportance.IMPORTANCE));
+        PARSER.declareDouble(optionalConstructorArg(), new ParseField(FeatureImportance.IMPORTANCE));
         PARSER.declareObjectArray(optionalConstructorArg(),
             (p, c) -> ClassImportance.fromXContent(p),
             new ParseField(FeatureImportance.CLASSES));
@@ -58,10 +58,10 @@ public class FeatureImportance implements ToXContentObject {
     }
 
     private final List<ClassImportance> classImportance;
-    private final double importance;
+    private final Double importance;
     private final String featureName;
 
-    public FeatureImportance(String featureName, double importance, List<ClassImportance> classImportance) {
+    public FeatureImportance(String featureName, Double importance, List<ClassImportance> classImportance) {
         this.featureName = Objects.requireNonNull(featureName);
         this.importance = importance;
         this.classImportance = classImportance == null ? null : Collections.unmodifiableList(classImportance);
@@ -71,7 +71,7 @@ public class FeatureImportance implements ToXContentObject {
         return classImportance;
     }
 
-    public double getImportance() {
+    public Double getImportance() {
         return importance;
     }
 
@@ -83,7 +83,9 @@ public class FeatureImportance implements ToXContentObject {
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
         builder.startObject();
         builder.field(FEATURE_NAME, featureName);
-        builder.field(IMPORTANCE, importance);
+        if (importance != null) {
+            builder.field(IMPORTANCE, importance);
+        }
         if (classImportance != null && classImportance.isEmpty() == false) {
             builder.field(CLASSES, classImportance);
         }

+ 1 - 1
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/results/FeatureImportanceTests.java

@@ -32,7 +32,7 @@ public class FeatureImportanceTests extends AbstractXContentTestCase<FeatureImpo
     protected FeatureImportance createTestInstance() {
         return new FeatureImportance(
             randomAlphaOfLength(10),
-            randomDoubleBetween(-10.0, 10.0, false),
+            randomBoolean() ? null : randomDoubleBetween(-10.0, 10.0, false),
             randomBoolean() ? null :
                 Stream.generate(() -> randomAlphaOfLength(10))
                     .limit(randomLongBetween(2, 10))

+ 26 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/AbstractFeatureImportance.java

@@ -0,0 +1,26 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.results;
+
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.ToXContentObject;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+
+import java.io.IOException;
+import java.util.Map;
+
+abstract class AbstractFeatureImportance implements Writeable, ToXContentObject {
+
+    public abstract String getFeatureName();
+
+    public abstract Map<String, Object> toMap();
+
+    @Override
+    public final XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        return builder.map(toMap());
+    }
+}

+ 32 - 93
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/FeatureImportance.java → x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationFeatureImportance.java

@@ -5,7 +5,6 @@
  */
 package org.elasticsearch.xpack.core.ml.inference.results;
 
-import org.elasticsearch.Version;
 import org.elasticsearch.common.ParseField;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
@@ -26,157 +25,101 @@ import java.util.stream.Collectors;
 import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
 import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
 
-public class FeatureImportance implements Writeable, ToXContentObject {
+public class ClassificationFeatureImportance extends AbstractFeatureImportance {
 
     private final List<ClassImportance> classImportance;
-    private final double importance;
     private final String featureName;
-    static final String IMPORTANCE = "importance";
+
     static final String FEATURE_NAME = "feature_name";
     static final String CLASSES = "classes";
 
-    public static FeatureImportance forRegression(String featureName, double importance) {
-        return new FeatureImportance(featureName, importance, null);
-    }
-
-    public static FeatureImportance forBinaryClassification(String featureName, double importance, List<ClassImportance> classImportance) {
-        return new FeatureImportance(featureName,
-            importance,
-            classImportance);
-    }
-
-    public static FeatureImportance forClassification(String featureName, List<ClassImportance> classImportance) {
-        return new FeatureImportance(featureName,
-            classImportance.stream().mapToDouble(ClassImportance::getImportance).map(Math::abs).sum(),
-            classImportance);
-    }
-
     @SuppressWarnings("unchecked")
-    private static final ConstructingObjectParser<FeatureImportance, Void> PARSER =
-        new ConstructingObjectParser<>("feature_importance",
-            a -> new FeatureImportance((String) a[0], (Double) a[1], (List<ClassImportance>) a[2])
+    private static final ConstructingObjectParser<ClassificationFeatureImportance, Void> PARSER =
+        new ConstructingObjectParser<>("classification_feature_importance",
+            a -> new ClassificationFeatureImportance((String) a[0], (List<ClassImportance>) a[1])
         );
 
     static {
-        PARSER.declareString(constructorArg(), new ParseField(FeatureImportance.FEATURE_NAME));
-        PARSER.declareDouble(constructorArg(), new ParseField(FeatureImportance.IMPORTANCE));
+        PARSER.declareString(constructorArg(), new ParseField(ClassificationFeatureImportance.FEATURE_NAME));
         PARSER.declareObjectArray(optionalConstructorArg(),
             (p, c) -> ClassImportance.fromXContent(p),
-            new ParseField(FeatureImportance.CLASSES));
+            new ParseField(ClassificationFeatureImportance.CLASSES));
     }
 
-    public static FeatureImportance fromXContent(XContentParser parser) {
+    public static ClassificationFeatureImportance fromXContent(XContentParser parser) {
         return PARSER.apply(parser, null);
     }
 
-    FeatureImportance(String featureName, double importance, List<ClassImportance> classImportance) {
+    public ClassificationFeatureImportance(String featureName, List<ClassImportance> classImportance) {
         this.featureName = Objects.requireNonNull(featureName);
-        this.importance = importance;
-        this.classImportance = classImportance == null ? null : Collections.unmodifiableList(classImportance);
+        this.classImportance = classImportance == null ? Collections.emptyList() : Collections.unmodifiableList(classImportance);
     }
 
-    public FeatureImportance(StreamInput in) throws IOException {
+    public ClassificationFeatureImportance(StreamInput in) throws IOException {
         this.featureName = in.readString();
-        this.importance = in.readDouble();
-        if (in.readBoolean()) {
-            if (in.getVersion().before(Version.V_7_10_0)) {
-                Map<String, Double> classImportance = in.readMap(StreamInput::readString, StreamInput::readDouble);
-                this.classImportance = ClassImportance.fromMap(classImportance);
-            } else {
-                this.classImportance = in.readList(ClassImportance::new);
-            }
-        } else {
-            this.classImportance = null;
-        }
+        this.classImportance = in.readList(ClassImportance::new);
     }
 
     public List<ClassImportance> getClassImportance() {
         return classImportance;
     }
 
-    public double getImportance() {
-        return importance;
-    }
-
+    @Override
     public String getFeatureName() {
         return featureName;
     }
 
+    public double getTotalImportance() {
+        if (classImportance.size() == 2) {
+            // Binary classification. We can return the first class importance here
+            return Math.abs(classImportance.get(0).getImportance());
+        }
+        return classImportance.stream().mapToDouble(ClassImportance::getImportance).map(Math::abs).sum();
+    }
+
     @Override
     public void writeTo(StreamOutput out) throws IOException {
-        out.writeString(this.featureName);
-        out.writeDouble(this.importance);
-        out.writeBoolean(this.classImportance != null);
-        if (this.classImportance != null) {
-            if (out.getVersion().before(Version.V_7_10_0)) {
-                out.writeMap(ClassImportance.toMap(this.classImportance), StreamOutput::writeString, StreamOutput::writeDouble);
-            } else {
-                out.writeList(this.classImportance);
-            }
-        }
+        out.writeString(featureName);
+        out.writeList(classImportance);
     }
 
+    @Override
     public Map<String, Object> toMap() {
         Map<String, Object> map = new LinkedHashMap<>();
         map.put(FEATURE_NAME, featureName);
-        map.put(IMPORTANCE, importance);
-        if (classImportance != null) {
+        if (classImportance.isEmpty() == false) {
             map.put(CLASSES, classImportance.stream().map(ClassImportance::toMap).collect(Collectors.toList()));
         }
         return map;
     }
 
-    @Override
-    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
-        builder.startObject();
-        builder.field(FEATURE_NAME, featureName);
-        builder.field(IMPORTANCE, importance);
-        if (classImportance != null && classImportance.isEmpty() == false) {
-            builder.field(CLASSES, classImportance);
-        }
-        builder.endObject();
-        return builder;
-    }
-
     @Override
     public boolean equals(Object object) {
         if (object == this) { return true; }
         if (object == null || getClass() != object.getClass()) { return false; }
-        FeatureImportance that = (FeatureImportance) object;
+        ClassificationFeatureImportance that = (ClassificationFeatureImportance) object;
         return Objects.equals(featureName, that.featureName)
-            && Objects.equals(importance, that.importance)
             && Objects.equals(classImportance, that.classImportance);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(featureName, importance, classImportance);
+        return Objects.hash(featureName, classImportance);
     }
 
     public static class ClassImportance implements Writeable, ToXContentObject {
 
         static final String CLASS_NAME = "class_name";
+        static final String IMPORTANCE = "importance";
 
         private static final ConstructingObjectParser<ClassImportance, Void> PARSER =
-            new ConstructingObjectParser<>("feature_importance_class_importance",
-                a -> new ClassImportance((String) a[0], (Double) a[1])
+            new ConstructingObjectParser<>("classification_feature_importance_class_importance",
+                a -> new ClassImportance(a[0], (Double) a[1])
             );
 
         static {
             PARSER.declareString(constructorArg(), new ParseField(CLASS_NAME));
-            PARSER.declareDouble(constructorArg(), new ParseField(FeatureImportance.IMPORTANCE));
-        }
-
-        private static ClassImportance fromMapEntry(Map.Entry<String, Double> entry) {
-            return new ClassImportance(entry.getKey(), entry.getValue());
-        }
-
-        private static List<ClassImportance> fromMap(Map<String, Double> classImportanceMap) {
-            return classImportanceMap.entrySet().stream().map(ClassImportance::fromMapEntry).collect(Collectors.toList());
-        }
-
-        private static Map<String, Double> toMap(List<ClassImportance> importances) {
-            return importances.stream().collect(Collectors.toMap(i -> i.className.toString(), i -> i.importance));
+            PARSER.declareDouble(constructorArg(), new ParseField(IMPORTANCE));
         }
 
         public static ClassImportance fromXContent(XContentParser parser) {
@@ -219,11 +162,7 @@ public class FeatureImportance implements Writeable, ToXContentObject {
 
         @Override
         public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
-            builder.startObject();
-            builder.field(CLASS_NAME, className);
-            builder.field(IMPORTANCE, importance);
-            builder.endObject();
-            return builder;
+            return builder.map(toMap());
         }
 
         @Override

+ 42 - 12
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java

@@ -5,6 +5,7 @@
  */
 package org.elasticsearch.xpack.core.ml.inference.results;
 
+import org.elasticsearch.Version;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.xcontent.XContentBuilder;
@@ -14,9 +15,9 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldTyp
 
 import java.io.IOException;
 import java.util.Collections;
-import java.util.Map;
 import java.util.LinkedHashMap;
 import java.util.List;
+import java.util.Map;
 import java.util.Objects;
 import java.util.stream.Collectors;
 
@@ -33,12 +34,13 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
     private final Double predictionProbability;
     private final Double predictionScore;
     private final List<TopClassEntry> topClasses;
+    private final List<ClassificationFeatureImportance> featureImportance;
     private final PredictionFieldType predictionFieldType;
 
     public ClassificationInferenceResults(double value,
                                           String classificationLabel,
                                           List<TopClassEntry> topClasses,
-                                          List<FeatureImportance> featureImportance,
+                                          List<ClassificationFeatureImportance> featureImportance,
                                           InferenceConfig config,
                                           Double predictionProbability,
                                           Double predictionScore) {
@@ -54,13 +56,11 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
     private ClassificationInferenceResults(double value,
                                            String classificationLabel,
                                            List<TopClassEntry> topClasses,
-                                           List<FeatureImportance> featureImportance,
+                                           List<ClassificationFeatureImportance> featureImportance,
                                            ClassificationConfig classificationConfig,
                                            Double predictionProbability,
                                            Double predictionScore) {
-        super(value,
-            SingleValueInferenceResults.takeTopFeatureImportances(featureImportance,
-                classificationConfig.getNumTopFeatureImportanceValues()));
+        super(value);
         this.classificationLabel = classificationLabel;
         this.topClasses = topClasses == null ? Collections.emptyList() : Collections.unmodifiableList(topClasses);
         this.topNumClassesField = classificationConfig.getTopClassesResultsField();
@@ -68,10 +68,30 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
         this.predictionFieldType = classificationConfig.getPredictionFieldType();
         this.predictionProbability = predictionProbability;
         this.predictionScore = predictionScore;
+        this.featureImportance = takeTopFeatureImportances(featureImportance, classificationConfig.getNumTopFeatureImportanceValues());
+    }
+
+    static List<ClassificationFeatureImportance> takeTopFeatureImportances(List<ClassificationFeatureImportance> featureImportances,
+                                                                           int numTopFeatures) {
+        if (featureImportances == null || featureImportances.isEmpty()) {
+            return Collections.emptyList();
+        }
+        return featureImportances.stream()
+            .sorted((l, r)-> Double.compare(r.getTotalImportance(), l.getTotalImportance()))
+            .limit(numTopFeatures)
+            .collect(Collectors.toUnmodifiableList());
     }
 
     public ClassificationInferenceResults(StreamInput in) throws IOException {
         super(in);
+        if (in.getVersion().onOrAfter(Version.V_8_0_0)) {
+            this.featureImportance = in.readList(ClassificationFeatureImportance::new);
+        } else {
+            this.featureImportance = in.readList(LegacyFeatureImportance::new)
+                .stream()
+                .map(LegacyFeatureImportance::forClassification)
+                .collect(Collectors.toList());
+        }
         this.classificationLabel = in.readOptionalString();
         this.topClasses = Collections.unmodifiableList(in.readList(TopClassEntry::new));
         this.topNumClassesField = in.readString();
@@ -93,9 +113,18 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
         return predictionFieldType;
     }
 
+    public List<ClassificationFeatureImportance> getFeatureImportance() {
+        return featureImportance;
+    }
+
     @Override
     public void writeTo(StreamOutput out) throws IOException {
         super.writeTo(out);
+        if (out.getVersion().onOrAfter(Version.V_8_0_0)) {
+            out.writeList(featureImportance);
+        } else {
+            out.writeList(featureImportance.stream().map(LegacyFeatureImportance::fromClassification).collect(Collectors.toList()));
+        }
         out.writeOptionalString(classificationLabel);
         out.writeCollection(topClasses);
         out.writeString(topNumClassesField);
@@ -118,7 +147,7 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
             && Objects.equals(predictionFieldType, that.predictionFieldType)
             && Objects.equals(predictionProbability, that.predictionProbability)
             && Objects.equals(predictionScore, that.predictionScore)
-            && Objects.equals(getFeatureImportance(), that.getFeatureImportance());
+            && Objects.equals(featureImportance, that.featureImportance);
     }
 
     @Override
@@ -130,7 +159,7 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
             topNumClassesField,
             predictionProbability,
             predictionScore,
-            getFeatureImportance(),
+            featureImportance,
             predictionFieldType);
     }
 
@@ -165,8 +194,9 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
         if (predictionScore != null) {
             map.put(PREDICTION_SCORE, predictionScore);
         }
-        if (getFeatureImportance().isEmpty() == false) {
-            map.put(FEATURE_IMPORTANCE, getFeatureImportance().stream().map(FeatureImportance::toMap).collect(Collectors.toList()));
+        if (featureImportance.isEmpty() == false) {
+            map.put(FEATURE_IMPORTANCE, featureImportance.stream().map(ClassificationFeatureImportance::toMap)
+                .collect(Collectors.toList()));
         }
         return map;
     }
@@ -188,8 +218,8 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
         if (predictionScore != null) {
             builder.field(PREDICTION_SCORE, predictionScore);
         }
-        if (getFeatureImportance().size() > 0) {
-            builder.field(FEATURE_IMPORTANCE, getFeatureImportance());
+        if (featureImportance.isEmpty() == false) {
+            builder.field(FEATURE_IMPORTANCE, featureImportance);
         }
         return builder;
     }

+ 160 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/LegacyFeatureImportance.java

@@ -0,0 +1,160 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.inference.results;
+
+import org.elasticsearch.Version;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.io.stream.Writeable;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.stream.Collectors;
+
+/**
+ * This class captures serialization of feature importance for
+ * classification and regression prior to version 7.10.
+ */
+public class LegacyFeatureImportance implements Writeable {
+
+    public static LegacyFeatureImportance fromClassification(ClassificationFeatureImportance classificationFeatureImportance) {
+        return new LegacyFeatureImportance(
+            classificationFeatureImportance.getFeatureName(),
+            classificationFeatureImportance.getTotalImportance(),
+            classificationFeatureImportance.getClassImportance().stream().map(classImportance -> new ClassImportance(
+                classImportance.getClassName(), classImportance.getImportance())).collect(Collectors.toList())
+        );
+    }
+
+    public static LegacyFeatureImportance fromRegression(RegressionFeatureImportance regressionFeatureImportance) {
+        return new LegacyFeatureImportance(
+            regressionFeatureImportance.getFeatureName(),
+            regressionFeatureImportance.getImportance(),
+            null
+        );
+    }
+
+    private final List<ClassImportance> classImportance;
+    private final double importance;
+    private final String featureName;
+
+    LegacyFeatureImportance(String featureName, double importance, List<ClassImportance> classImportance) {
+        this.featureName = Objects.requireNonNull(featureName);
+        this.importance = importance;
+        this.classImportance = classImportance == null ? null : Collections.unmodifiableList(classImportance);
+    }
+
+    public LegacyFeatureImportance(StreamInput in) throws IOException {
+        this.featureName = in.readString();
+        this.importance = in.readDouble();
+        if (in.readBoolean()) {
+            if (in.getVersion().before(Version.V_7_10_0)) {
+                Map<String, Double> classImportance = in.readMap(StreamInput::readString, StreamInput::readDouble);
+                this.classImportance = ClassImportance.fromMap(classImportance);
+            } else {
+                this.classImportance = in.readList(ClassImportance::new);
+            }
+        } else {
+            this.classImportance = null;
+        }
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeString(featureName);
+        out.writeDouble(importance);
+        out.writeBoolean(classImportance != null);
+        if (classImportance != null) {
+            if (out.getVersion().before(Version.V_7_10_0)) {
+                out.writeMap(ClassImportance.toMap(classImportance), StreamOutput::writeString, StreamOutput::writeDouble);
+            } else {
+                out.writeList(classImportance);
+            }
+        }
+    }
+
+    @Override
+    public boolean equals(Object object) {
+        if (object == this) { return true; }
+        if (object == null || getClass() != object.getClass()) { return false; }
+        LegacyFeatureImportance that = (LegacyFeatureImportance) object;
+        return Objects.equals(featureName, that.featureName)
+            && Objects.equals(importance, that.importance)
+            && Objects.equals(classImportance, that.classImportance);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(featureName, importance, classImportance);
+    }
+
+    public RegressionFeatureImportance forRegression() {
+        assert classImportance == null;
+        return new RegressionFeatureImportance(featureName, importance);
+    }
+
+    public ClassificationFeatureImportance forClassification() {
+        assert classImportance != null;
+        return new ClassificationFeatureImportance(featureName, classImportance.stream().map(
+            aClassImportance -> new ClassificationFeatureImportance.ClassImportance(
+                aClassImportance.className, aClassImportance.importance)).collect(Collectors.toList()));
+    }
+
+    public static class ClassImportance implements Writeable {
+
+        private static ClassImportance fromMapEntry(Map.Entry<String, Double> entry) {
+            return new ClassImportance(entry.getKey(), entry.getValue());
+        }
+
+        private static List<ClassImportance> fromMap(Map<String, Double> classImportanceMap) {
+            return classImportanceMap.entrySet().stream().map(ClassImportance::fromMapEntry).collect(Collectors.toList());
+        }
+
+        private static Map<String, Double> toMap(List<ClassImportance> importances) {
+            return importances.stream().collect(Collectors.toMap(i -> i.className.toString(), i -> i.importance));
+        }
+
+        private final Object className;
+        private final double importance;
+
+        public ClassImportance(Object className, double importance) {
+            this.className = className;
+            this.importance = importance;
+        }
+
+        public ClassImportance(StreamInput in) throws IOException {
+            this.className = in.readGenericValue();
+            this.importance = in.readDouble();
+        }
+
+        double getImportance() {
+            return importance;
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            out.writeGenericValue(className);
+            out.writeDouble(importance);
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            ClassImportance that = (ClassImportance) o;
+            return Double.compare(that.importance, importance) == 0 &&
+                Objects.equals(className, that.className);
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(className, importance);
+        }
+    }
+}

+ 88 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionFeatureImportance.java

@@ -0,0 +1,88 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.inference.results;
+
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.XContentParser;
+
+import java.io.IOException;
+import java.util.LinkedHashMap;
+import java.util.Map;
+import java.util.Objects;
+
+import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
+
+public class RegressionFeatureImportance extends AbstractFeatureImportance {
+
+    private final double importance;
+    private final String featureName;
+    static final String IMPORTANCE = "importance";
+    static final String FEATURE_NAME = "feature_name";
+
+    private static final ConstructingObjectParser<RegressionFeatureImportance, Void> PARSER =
+        new ConstructingObjectParser<>("regression_feature_importance",
+            a -> new RegressionFeatureImportance((String) a[0], (Double) a[1])
+        );
+
+    static {
+        PARSER.declareString(constructorArg(), new ParseField(RegressionFeatureImportance.FEATURE_NAME));
+        PARSER.declareDouble(constructorArg(), new ParseField(RegressionFeatureImportance.IMPORTANCE));
+    }
+
+    public static RegressionFeatureImportance fromXContent(XContentParser parser) {
+        return PARSER.apply(parser, null);
+    }
+
+    public RegressionFeatureImportance(String featureName, double importance) {
+        this.featureName = Objects.requireNonNull(featureName);
+        this.importance = importance;
+    }
+
+    public RegressionFeatureImportance(StreamInput in) throws IOException {
+        this.featureName = in.readString();
+        this.importance = in.readDouble();
+    }
+
+    public double getImportance() {
+        return importance;
+    }
+
+    @Override
+    public String getFeatureName() {
+        return featureName;
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeString(featureName);
+        out.writeDouble(importance);
+    }
+
+    @Override
+    public Map<String, Object> toMap() {
+        Map<String, Object> map = new LinkedHashMap<>();
+        map.put(FEATURE_NAME, featureName);
+        map.put(IMPORTANCE, importance);
+        return map;
+    }
+
+    @Override
+    public boolean equals(Object object) {
+        if (object == this) { return true; }
+        if (object == null || getClass() != object.getClass()) { return false; }
+        RegressionFeatureImportance that = (RegressionFeatureImportance) object;
+        return Objects.equals(featureName, that.featureName)
+            && Objects.equals(importance, that.importance);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(featureName, importance);
+    }
+}

+ 47 - 13
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResults.java

@@ -5,6 +5,7 @@
  */
 package org.elasticsearch.xpack.core.ml.inference.results;
 
+import org.elasticsearch.Version;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.xcontent.XContentBuilder;
@@ -24,14 +25,19 @@ public class RegressionInferenceResults extends SingleValueInferenceResults {
     public static final String NAME = "regression";
 
     private final String resultsField;
+    private final List<RegressionFeatureImportance> featureImportance;
 
     public RegressionInferenceResults(double value, InferenceConfig config) {
         this(value, config, Collections.emptyList());
     }
 
-    public RegressionInferenceResults(double value, InferenceConfig config, List<FeatureImportance> featureImportance) {
-        this(value, ((RegressionConfig)config).getResultsField(),
-            ((RegressionConfig)config).getNumTopFeatureImportanceValues(), featureImportance);
+    public RegressionInferenceResults(double value, InferenceConfig config, List<RegressionFeatureImportance> featureImportance) {
+        this(
+            value,
+            ((RegressionConfig)config).getResultsField(),
+            ((RegressionConfig)config).getNumTopFeatureImportanceValues(),
+            featureImportance
+        );
     }
 
     public RegressionInferenceResults(double value, String resultsField) {
@@ -39,28 +45,56 @@ public class RegressionInferenceResults extends SingleValueInferenceResults {
     }
 
     public RegressionInferenceResults(double value, String resultsField,
-                                      List<FeatureImportance> featureImportance) {
+                                      List<RegressionFeatureImportance> featureImportance) {
         this(value, resultsField, featureImportance.size(), featureImportance);
     }
 
     public RegressionInferenceResults(double value, String resultsField, int topNFeatures,
-                                       List<FeatureImportance> featureImportance) {
-        super(value,
-            SingleValueInferenceResults.takeTopFeatureImportances(featureImportance, topNFeatures));
+                                       List<RegressionFeatureImportance> featureImportance) {
+        super(value);
         this.resultsField = resultsField;
+        this.featureImportance = takeTopFeatureImportances(featureImportance, topNFeatures);
+    }
+
+    static List<RegressionFeatureImportance> takeTopFeatureImportances(List<RegressionFeatureImportance> featureImportances,
+                                                                       int numTopFeatures) {
+        if (featureImportances == null || featureImportances.isEmpty()) {
+            return Collections.emptyList();
+        }
+        return featureImportances.stream()
+            .sorted((l, r)-> Double.compare(Math.abs(r.getImportance()), Math.abs(l.getImportance())))
+            .limit(numTopFeatures)
+            .collect(Collectors.toUnmodifiableList());
     }
 
     public RegressionInferenceResults(StreamInput in) throws IOException {
         super(in);
+        if (in.getVersion().onOrAfter(Version.V_8_0_0)) {
+            this.featureImportance = in.readList(RegressionFeatureImportance::new);
+        } else {
+            this.featureImportance = in.readList(LegacyFeatureImportance::new)
+                .stream()
+                .map(LegacyFeatureImportance::forRegression)
+                .collect(Collectors.toList());
+        }
         this.resultsField = in.readString();
     }
 
     @Override
     public void writeTo(StreamOutput out) throws IOException {
         super.writeTo(out);
+        if (out.getVersion().onOrAfter(Version.V_8_0_0)) {
+            out.writeList(featureImportance);
+        } else {
+            out.writeList(featureImportance.stream().map(LegacyFeatureImportance::fromRegression).collect(Collectors.toList()));
+        }
         out.writeString(resultsField);
     }
 
+    public List<RegressionFeatureImportance> getFeatureImportance() {
+        return featureImportance;
+    }
+
     @Override
     public boolean equals(Object object) {
         if (object == this) { return true; }
@@ -68,12 +102,12 @@ public class RegressionInferenceResults extends SingleValueInferenceResults {
         RegressionInferenceResults that = (RegressionInferenceResults) object;
         return Objects.equals(value(), that.value())
             && Objects.equals(this.resultsField, that.resultsField)
-            && Objects.equals(this.getFeatureImportance(), that.getFeatureImportance());
+            && Objects.equals(this.featureImportance, that.featureImportance);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(value(), resultsField, getFeatureImportance());
+        return Objects.hash(value(), resultsField, featureImportance);
     }
 
     @Override
@@ -85,8 +119,8 @@ public class RegressionInferenceResults extends SingleValueInferenceResults {
     public Map<String, Object> asMap() {
         Map<String, Object> map = new LinkedHashMap<>();
         map.put(resultsField, value());
-        if (getFeatureImportance().isEmpty() == false) {
-            map.put(FEATURE_IMPORTANCE, getFeatureImportance().stream().map(FeatureImportance::toMap).collect(Collectors.toList()));
+        if (featureImportance.isEmpty() == false) {
+            map.put(FEATURE_IMPORTANCE, featureImportance.stream().map(RegressionFeatureImportance::toMap).collect(Collectors.toList()));
         }
         return map;
     }
@@ -94,8 +128,8 @@ public class RegressionInferenceResults extends SingleValueInferenceResults {
     @Override
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
         builder.field(resultsField, value());
-        if (getFeatureImportance().size() > 0) {
-            builder.field(FEATURE_IMPORTANCE, getFeatureImportance());
+        if (featureImportance.isEmpty() == false) {
+            builder.field(FEATURE_IMPORTANCE, featureImportance);
         }
         return builder;
     }

+ 2 - 21
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/SingleValueInferenceResults.java

@@ -9,44 +9,26 @@ import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 
 import java.io.IOException;
-import java.util.Collections;
-import java.util.List;
-import java.util.stream.Collectors;
 
 public abstract class SingleValueInferenceResults implements InferenceResults {
 
     public static final String FEATURE_IMPORTANCE = "feature_importance";
 
     private final double value;
-    private final List<FeatureImportance> featureImportance;
-
-    static List<FeatureImportance> takeTopFeatureImportances(List<FeatureImportance> unsortedFeatureImportances, int numTopFeatures) {
-        if (unsortedFeatureImportances == null || unsortedFeatureImportances.isEmpty()) {
-            return unsortedFeatureImportances;
-        }
-        return unsortedFeatureImportances.stream()
-            .sorted((l, r)-> Double.compare(Math.abs(r.getImportance()), Math.abs(l.getImportance())))
-            .limit(numTopFeatures)
-            .collect(Collectors.toList());
-    }
+
 
     SingleValueInferenceResults(StreamInput in) throws IOException {
         value = in.readDouble();
-        this.featureImportance = in.readList(FeatureImportance::new);
     }
 
-    SingleValueInferenceResults(double value, List<FeatureImportance> featureImportance) {
+    SingleValueInferenceResults(double value) {
         this.value = value;
-        this.featureImportance = featureImportance == null ? Collections.emptyList() : featureImportance;
     }
 
     public Double value() {
         return value;
     }
 
-    public List<FeatureImportance> getFeatureImportance() {
-        return featureImportance;
-    }
 
     public String valueAsString() {
         return String.valueOf(value);
@@ -55,7 +37,6 @@ public abstract class SingleValueInferenceResults implements InferenceResults {
     @Override
     public void writeTo(StreamOutput out) throws IOException {
         out.writeDouble(value);
-        out.writeList(this.featureImportance);
     }
 
 }

+ 17 - 16
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java

@@ -7,7 +7,8 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
 
 import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.collect.Tuple;
-import org.elasticsearch.xpack.core.ml.inference.results.FeatureImportance;
+import org.elasticsearch.xpack.core.ml.inference.results.ClassificationFeatureImportance;
+import org.elasticsearch.xpack.core.ml.inference.results.RegressionFeatureImportance;
 import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 
@@ -130,17 +131,18 @@ public final class InferenceHelpers {
         return originalFeatureImportance;
     }
 
-    public static List<FeatureImportance> transformFeatureImportanceRegression(Map<String, double[]> featureImportance) {
-        List<FeatureImportance> importances = new ArrayList<>(featureImportance.size());
-        featureImportance.forEach((k, v) -> importances.add(FeatureImportance.forRegression(k, v[0])));
+    public static List<RegressionFeatureImportance> transformFeatureImportanceRegression(Map<String, double[]> featureImportance) {
+        List<RegressionFeatureImportance> importances = new ArrayList<>(featureImportance.size());
+        featureImportance.forEach((k, v) -> importances.add(new RegressionFeatureImportance(k, v[0])));
         return importances;
     }
 
-    public static List<FeatureImportance> transformFeatureImportanceClassification(Map<String, double[]> featureImportance,
-                                                                                   final int predictedValue,
-                                                                                   @Nullable List<String> classificationLabels,
-                                                                                   @Nullable PredictionFieldType predictionFieldType) {
-        List<FeatureImportance> importances = new ArrayList<>(featureImportance.size());
+    public static List<ClassificationFeatureImportance> transformFeatureImportanceClassification(
+            Map<String, double[]> featureImportance,
+            final int predictedValue,
+            @Nullable List<String> classificationLabels,
+            @Nullable PredictionFieldType predictionFieldType) {
+        List<ClassificationFeatureImportance> importances = new ArrayList<>(featureImportance.size());
         final PredictionFieldType fieldType = predictionFieldType == null ? PredictionFieldType.STRING : predictionFieldType;
         featureImportance.forEach((k, v) -> {
             // This indicates logistic regression (binary classification)
@@ -152,27 +154,26 @@ public final class InferenceHelpers {
                 final int otherClass = 1 - predictedValue;
                 String predictedLabel = classificationLabels == null ? null : classificationLabels.get(predictedValue);
                 String otherLabel = classificationLabels == null ? null : classificationLabels.get(otherClass);
-                importances.add(FeatureImportance.forBinaryClassification(k,
-                    v[0],
+                importances.add(new ClassificationFeatureImportance(k,
                     Arrays.asList(
-                        new FeatureImportance.ClassImportance(
+                        new ClassificationFeatureImportance.ClassImportance(
                             fieldType.transformPredictedValue((double)predictedValue, predictedLabel),
                             v[0]),
-                        new FeatureImportance.ClassImportance(
+                        new ClassificationFeatureImportance.ClassImportance(
                             fieldType.transformPredictedValue((double)otherClass, otherLabel),
                             -v[0])
                     )));
             } else {
-                List<FeatureImportance.ClassImportance> classImportance = new ArrayList<>(v.length);
+                List<ClassificationFeatureImportance.ClassImportance> classImportance = new ArrayList<>(v.length);
                 // If the classificationLabels exist, their length must match leaf_value length
                 assert classificationLabels == null || classificationLabels.size() == v.length;
                 for (int i = 0; i < v.length; i++) {
                     String label = classificationLabels == null ? null : classificationLabels.get(i);
-                    classImportance.add(new FeatureImportance.ClassImportance(
+                    classImportance.add(new ClassificationFeatureImportance.ClassImportance(
                         fieldType.transformPredictedValue((double)i, label),
                         v[i]));
                 }
-                importances.add(FeatureImportance.forClassification(k, classImportance));
+                importances.add(new ClassificationFeatureImportance(k, classImportance));
             }
         });
         return importances;

+ 70 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationFeatureImportanceTests.java

@@ -0,0 +1,70 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.inference.results;
+
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.test.AbstractSerializingTestCase;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+import static org.hamcrest.Matchers.closeTo;
+import static org.hamcrest.Matchers.equalTo;
+
+public class ClassificationFeatureImportanceTests extends AbstractSerializingTestCase<ClassificationFeatureImportance> {
+
+    @Override
+    protected ClassificationFeatureImportance doParseInstance(XContentParser parser) throws IOException {
+        return ClassificationFeatureImportance.fromXContent(parser);
+    }
+
+    @Override
+    protected Writeable.Reader<ClassificationFeatureImportance> instanceReader() {
+        return ClassificationFeatureImportance::new;
+    }
+
+    @Override
+    protected ClassificationFeatureImportance createTestInstance() {
+        return createRandomInstance();
+    }
+
+    public static ClassificationFeatureImportance createRandomInstance() {
+        return new ClassificationFeatureImportance(
+            randomAlphaOfLength(10),
+            Stream.generate(() -> randomAlphaOfLength(10))
+                .limit(randomLongBetween(2, 10))
+                .map(name -> new ClassificationFeatureImportance.ClassImportance(name, randomDoubleBetween(-10, 10, false)))
+                .collect(Collectors.toList()));
+    }
+
+    public void testGetTotalImportance_GivenBinary() {
+        ClassificationFeatureImportance featureImportance = new ClassificationFeatureImportance(
+            "binary",
+            Arrays.asList(
+                new ClassificationFeatureImportance.ClassImportance("a", 0.15),
+                new ClassificationFeatureImportance.ClassImportance("not-a", -0.15)
+            )
+        );
+
+        assertThat(featureImportance.getTotalImportance(), equalTo(0.15));
+    }
+
+    public void testGetTotalImportance_GivenMulticlass() {
+        ClassificationFeatureImportance featureImportance = new ClassificationFeatureImportance(
+            "multiclass",
+            Arrays.asList(
+                new ClassificationFeatureImportance.ClassImportance("a", 0.15),
+                new ClassificationFeatureImportance.ClassImportance("b", -0.05),
+                new ClassificationFeatureImportance.ClassImportance("c", 0.30)
+            )
+        );
+
+        assertThat(featureImportance.getTotalImportance(), closeTo(0.50, 0.00000001));
+    }
+}

+ 6 - 16
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResultsTests.java

@@ -18,7 +18,6 @@ import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
-import java.util.function.Supplier;
 import java.util.stream.Collectors;
 import java.util.stream.Stream;
 
@@ -29,10 +28,6 @@ import static org.hamcrest.Matchers.hasSize;
 public class ClassificationInferenceResultsTests extends AbstractWireSerializingTestCase<ClassificationInferenceResults> {
 
     public static ClassificationInferenceResults createRandomResults() {
-        Supplier<FeatureImportance> featureImportanceCtor = randomBoolean() ?
-            FeatureImportanceTests::randomClassification :
-            FeatureImportanceTests::randomRegression;
-
         ClassificationConfig config = ClassificationConfigTests.randomClassificationConfig();
         Double value = randomDouble();
         if (config.getPredictionFieldType() == PredictionFieldType.BOOLEAN) {
@@ -47,7 +42,7 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
                     .limit(randomIntBetween(0, 10))
                     .collect(Collectors.toList()),
             randomBoolean() ? null :
-                Stream.generate(featureImportanceCtor)
+                Stream.generate(ClassificationFeatureImportanceTests::createRandomInstance)
                     .limit(randomIntBetween(1, 10))
                     .collect(Collectors.toList()),
             config,
@@ -123,11 +118,7 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
     }
 
     public void testWriteResultsWithImportance() {
-        Supplier<FeatureImportance> featureImportanceCtor = randomBoolean() ?
-            FeatureImportanceTests::randomClassification :
-            FeatureImportanceTests::randomRegression;
-
-        List<FeatureImportance> importanceList = Stream.generate(featureImportanceCtor)
+        List<ClassificationFeatureImportance> importanceList = Stream.generate(ClassificationFeatureImportanceTests::createRandomInstance)
             .limit(5)
             .collect(Collectors.toList());
         ClassificationInferenceResults result = new ClassificationInferenceResults(0.0,
@@ -146,18 +137,17 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
             "result_field.feature_importance",
             List.class);
         assertThat(writtenImportance, hasSize(3));
-        importanceList.sort((l, r) -> Double.compare(Math.abs(r.getImportance()), Math.abs(l.getImportance())));
+        importanceList.sort((l, r) -> Double.compare(Math.abs(r.getTotalImportance()), Math.abs(l.getTotalImportance())));
         for (int i = 0; i < 3; i++) {
             Map<String, Object> objectMap = writtenImportance.get(i);
-            FeatureImportance importance = importanceList.get(i);
+            ClassificationFeatureImportance importance = importanceList.get(i);
             assertThat(objectMap.get("feature_name"), equalTo(importance.getFeatureName()));
-            assertThat(objectMap.get("importance"), equalTo(importance.getImportance()));
             @SuppressWarnings("unchecked")
             List<Map<String, Object>> classImportances = (List<Map<String, Object>>)objectMap.get("classes");
             if (importance.getClassImportance() != null) {
                 for (int j = 0; j < importance.getClassImportance().size(); j++) {
                     Map<String, Object> classMap = classImportances.get(j);
-                    FeatureImportance.ClassImportance classImportance = importance.getClassImportance().get(j);
+                    ClassificationFeatureImportance.ClassImportance classImportance = importance.getClassImportance().get(j);
                     assertThat(classMap.get("class_name"), equalTo(classImportance.getClassName()));
                     assertThat(classMap.get("importance"), equalTo(classImportance.getImportance()));
                 }
@@ -212,7 +202,7 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
         expected = "{\"predicted_value\":\"label1\",\"prediction_probability\":1.0,\"prediction_score\":1.0}";
         assertEquals(expected, stringRep);
 
-        FeatureImportance fi = new FeatureImportance("foo", 1.0, Collections.emptyList());
+        ClassificationFeatureImportance fi = new ClassificationFeatureImportance("foo", Collections.emptyList());
         TopClassEntry tp = new TopClassEntry("class", 1.0, 1.0);
         result = new ClassificationInferenceResults(1.0, "label1", Collections.singletonList(tp),
             Collections.singletonList(fi), config,

+ 0 - 49
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/FeatureImportanceTests.java

@@ -1,49 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License;
- * you may not use this file except in compliance with the Elastic License.
- */
-package org.elasticsearch.xpack.core.ml.inference.results;
-
-import org.elasticsearch.common.io.stream.Writeable;
-import org.elasticsearch.common.xcontent.XContentParser;
-import org.elasticsearch.test.AbstractSerializingTestCase;
-
-import java.io.IOException;
-import java.util.stream.Collectors;
-import java.util.stream.Stream;
-
-public class FeatureImportanceTests extends AbstractSerializingTestCase<FeatureImportance> {
-
-    public static FeatureImportance createRandomInstance() {
-        return randomBoolean() ? randomClassification() : randomRegression();
-    }
-
-    static FeatureImportance randomRegression() {
-        return FeatureImportance.forRegression(randomAlphaOfLength(10), randomDoubleBetween(-10.0, 10.0, false));
-    }
-
-    static FeatureImportance randomClassification() {
-        return FeatureImportance.forClassification(
-            randomAlphaOfLength(10),
-            Stream.generate(() -> randomAlphaOfLength(10))
-                .limit(randomLongBetween(2, 10))
-                .map(name -> new FeatureImportance.ClassImportance(name, randomDoubleBetween(-10, 10, false)))
-                .collect(Collectors.toList()));
-    }
-
-    @Override
-    protected FeatureImportance createTestInstance() {
-        return createRandomInstance();
-    }
-
-    @Override
-    protected Writeable.Reader<FeatureImportance> instanceReader() {
-        return FeatureImportance::new;
-    }
-
-    @Override
-    protected FeatureImportance doParseInstance(XContentParser parser) throws IOException {
-        return FeatureImportance.fromXContent(parser);
-    }
-}

+ 77 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/LegacyFeatureImportanceTests.java

@@ -0,0 +1,77 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.inference.results;
+
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+
+import java.util.List;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+import static org.hamcrest.Matchers.equalTo;
+
+public class LegacyFeatureImportanceTests extends AbstractWireSerializingTestCase<LegacyFeatureImportance> {
+
+    public static LegacyFeatureImportance createRandomInstance() {
+        return createRandomInstance(randomBoolean());
+    }
+
+    public static LegacyFeatureImportance createRandomInstance(boolean hasClasses) {
+        double importance = randomDouble();
+        List<LegacyFeatureImportance.ClassImportance> classImportances = null;
+        if (hasClasses) {
+            classImportances = Stream.generate(() -> randomAlphaOfLength(10))
+                .limit(randomLongBetween(2, 10))
+                .map(featureName -> new LegacyFeatureImportance.ClassImportance(featureName, randomDouble()))
+                .collect(Collectors.toList());
+
+            importance = classImportances.stream().mapToDouble(LegacyFeatureImportance.ClassImportance::getImportance).map(Math::abs).sum();
+        }
+        return new LegacyFeatureImportance(randomAlphaOfLength(10), importance, classImportances);
+    }
+
+    @Override
+    protected LegacyFeatureImportance createTestInstance() {
+        return createRandomInstance();
+    }
+
+    @Override
+    protected Writeable.Reader<LegacyFeatureImportance> instanceReader() {
+        return LegacyFeatureImportance::new;
+    }
+
+    public void testClassificationConversion() {
+        {
+            ClassificationFeatureImportance classificationFeatureImportance = ClassificationFeatureImportanceTests.createRandomInstance();
+            LegacyFeatureImportance legacyFeatureImportance = LegacyFeatureImportance.fromClassification(classificationFeatureImportance);
+            ClassificationFeatureImportance convertedFeatureImportance = legacyFeatureImportance.forClassification();
+            assertThat(convertedFeatureImportance, equalTo(classificationFeatureImportance));
+        }
+        {
+            LegacyFeatureImportance legacyFeatureImportance = createRandomInstance(true);
+            ClassificationFeatureImportance classificationFeatureImportance = legacyFeatureImportance.forClassification();
+            LegacyFeatureImportance convertedFeatureImportance = LegacyFeatureImportance.fromClassification(
+                classificationFeatureImportance);
+            assertThat(convertedFeatureImportance, equalTo(legacyFeatureImportance));
+        }
+    }
+
+    public void testRegressionConversion() {
+        {
+            RegressionFeatureImportance regressionFeatureImportance = RegressionFeatureImportanceTests.createRandomInstance();
+            LegacyFeatureImportance legacyFeatureImportance = LegacyFeatureImportance.fromRegression(regressionFeatureImportance);
+            RegressionFeatureImportance convertedFeatureImportance = legacyFeatureImportance.forRegression();
+            assertThat(convertedFeatureImportance, equalTo(regressionFeatureImportance));
+        }
+        {
+            LegacyFeatureImportance legacyFeatureImportance = createRandomInstance(false);
+            RegressionFeatureImportance regressionFeatureImportance = legacyFeatureImportance.forRegression();
+            LegacyFeatureImportance convertedFeatureImportance = LegacyFeatureImportance.fromRegression(regressionFeatureImportance);
+            assertThat(convertedFeatureImportance, equalTo(legacyFeatureImportance));
+        }
+    }
+}

+ 34 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionFeatureImportanceTests.java

@@ -0,0 +1,34 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.inference.results;
+
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.test.AbstractSerializingTestCase;
+
+import java.io.IOException;
+
+public class RegressionFeatureImportanceTests extends AbstractSerializingTestCase<RegressionFeatureImportance> {
+
+    @Override
+    protected RegressionFeatureImportance doParseInstance(XContentParser parser) throws IOException {
+        return RegressionFeatureImportance.fromXContent(parser);
+    }
+
+    @Override
+    protected Writeable.Reader<RegressionFeatureImportance> instanceReader() {
+        return RegressionFeatureImportance::new;
+    }
+
+    @Override
+    protected RegressionFeatureImportance createTestInstance() {
+        return createRandomInstance();
+    }
+
+    public static RegressionFeatureImportance createRandomInstance() {
+        return new RegressionFeatureImportance(randomAlphaOfLength(10), randomDoubleBetween(-10.0, 10.0, false));
+    }
+}

+ 5 - 5
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResultsTests.java

@@ -29,8 +29,8 @@ public class RegressionInferenceResultsTests extends AbstractWireSerializingTest
     public static RegressionInferenceResults createRandomResults() {
         return new RegressionInferenceResults(randomDouble(),
             RegressionConfigTests.randomRegressionConfig(),
-            randomBoolean() ? null :
-                Stream.generate(FeatureImportanceTests::randomRegression)
+            randomBoolean() ? Collections.emptyList() :
+                Stream.generate(RegressionFeatureImportanceTests::createRandomInstance)
                     .limit(randomIntBetween(1, 10))
                     .collect(Collectors.toList()));
     }
@@ -50,7 +50,7 @@ public class RegressionInferenceResultsTests extends AbstractWireSerializingTest
     }
 
     public void testWriteResultsWithImportance() {
-        List<FeatureImportance> importanceList = Stream.generate(FeatureImportanceTests::randomRegression)
+        List<RegressionFeatureImportance> importanceList = Stream.generate(RegressionFeatureImportanceTests::createRandomInstance)
             .limit(5)
             .collect(Collectors.toList());
         RegressionInferenceResults result = new RegressionInferenceResults(0.3,
@@ -68,7 +68,7 @@ public class RegressionInferenceResultsTests extends AbstractWireSerializingTest
         importanceList.sort((l, r)-> Double.compare(Math.abs(r.getImportance()), Math.abs(l.getImportance())));
         for (int i = 0; i < 3; i++) {
             Map<String, Object> objectMap = writtenImportance.get(i);
-            FeatureImportance importance = importanceList.get(i);
+            RegressionFeatureImportance importance = importanceList.get(i);
             assertThat(objectMap.get("feature_name"), equalTo(importance.getFeatureName()));
             assertThat(objectMap.get("importance"), equalTo(importance.getImportance()));
             assertThat(objectMap.size(), equalTo(2));
@@ -92,7 +92,7 @@ public class RegressionInferenceResultsTests extends AbstractWireSerializingTest
         String expected = "{\"" + resultsField + "\":1.0}";
         assertEquals(expected, stringRep);
 
-        FeatureImportance fi = new FeatureImportance("foo", 1.0, Collections.emptyList());
+        RegressionFeatureImportance fi = new RegressionFeatureImportance("foo", 1.0);
         result = new RegressionInferenceResults(1.0, resultsField, Collections.singletonList(fi));
         stringRep = Strings.toString(result);
         expected = "{\"" + resultsField + "\":1.0,\"feature_importance\":[{\"feature_name\":\"foo\",\"importance\":1.0}]}";

+ 9 - 9
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinitionTests.java

@@ -18,8 +18,8 @@ import org.elasticsearch.search.SearchModule;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xpack.core.ml.inference.InferenceToXContentCompressor;
 import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
+import org.elasticsearch.xpack.core.ml.inference.results.ClassificationFeatureImportance;
 import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
-import org.elasticsearch.xpack.core.ml.inference.results.FeatureImportance;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
 
 import java.io.IOException;
@@ -138,9 +138,9 @@ public class InferenceDefinitionTests extends ESTestCase {
         ClassificationInferenceResults results = (ClassificationInferenceResults) inferenceDefinition.infer(featureMap, config);
         assertThat(results.valueAsString(), equalTo("second"));
         assertThat(results.getFeatureImportance().get(0).getFeatureName(), equalTo("col2"));
-        assertThat(results.getFeatureImportance().get(0).getImportance(), closeTo(0.944, 0.001));
+        assertThat(results.getFeatureImportance().get(0).getTotalImportance(), closeTo(0.944, 0.001));
         assertThat(results.getFeatureImportance().get(1).getFeatureName(), equalTo("col1"));
-        assertThat(results.getFeatureImportance().get(1).getImportance(), closeTo(0.199, 0.001));
+        assertThat(results.getFeatureImportance().get(1).getTotalImportance(), closeTo(0.199, 0.001));
     }
 
     public void testComplexInferenceDefinitionInferWithCustomPreProcessor() throws IOException {
@@ -159,20 +159,20 @@ public class InferenceDefinitionTests extends ESTestCase {
 
         ClassificationInferenceResults results = (ClassificationInferenceResults) inferenceDefinition.infer(featureMap, config);
         assertThat(results.valueAsString(), equalTo("second"));
-        FeatureImportance featureImportance1 = results.getFeatureImportance().get(0);
+        ClassificationFeatureImportance featureImportance1 = results.getFeatureImportance().get(0);
         assertThat(featureImportance1.getFeatureName(), equalTo("col2"));
-        assertThat(featureImportance1.getImportance(), closeTo(0.944, 0.001));
-        for (FeatureImportance.ClassImportance classImportance : featureImportance1.getClassImportance()) {
+        assertThat(featureImportance1.getTotalImportance(), closeTo(0.944, 0.001));
+        for (ClassificationFeatureImportance.ClassImportance classImportance : featureImportance1.getClassImportance()) {
             if (classImportance.getClassName().equals("second")) {
                 assertThat(classImportance.getImportance(), closeTo(0.944, 0.001));
             } else {
                 assertThat(classImportance.getImportance(), closeTo(-0.944, 0.001));
             }
         }
-        FeatureImportance featureImportance2 = results.getFeatureImportance().get(1);
+        ClassificationFeatureImportance featureImportance2 = results.getFeatureImportance().get(1);
         assertThat(featureImportance2.getFeatureName(), equalTo("col1_male"));
-        assertThat(featureImportance2.getImportance(), closeTo(0.199, 0.001));
-        for (FeatureImportance.ClassImportance classImportance : featureImportance2.getClassImportance()) {
+        assertThat(featureImportance2.getTotalImportance(), closeTo(0.199, 0.001));
+        for (ClassificationFeatureImportance.ClassImportance classImportance : featureImportance2.getClassImportance()) {
             if (classImportance.getClassName().equals("second")) {
                 assertThat(classImportance.getImportance(), closeTo(0.199, 0.001));
             } else {

+ 4 - 3
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/aggs/InternalInferenceAggregationTests.java

@@ -14,10 +14,11 @@ import org.elasticsearch.search.aggregations.Aggregation;
 import org.elasticsearch.search.aggregations.InvalidAggregationPathException;
 import org.elasticsearch.search.aggregations.ParsedAggregation;
 import org.elasticsearch.test.InternalAggregationTestCase;
+import org.elasticsearch.xpack.core.ml.inference.results.ClassificationFeatureImportance;
 import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResultsTests;
-import org.elasticsearch.xpack.core.ml.inference.results.FeatureImportance;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
+import org.elasticsearch.xpack.core.ml.inference.results.RegressionFeatureImportance;
 import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResultsTests;
 import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
@@ -106,7 +107,7 @@ public class InternalInferenceAggregationTests extends InternalAggregationTestCa
         } else if (result instanceof RegressionInferenceResults) {
             RegressionInferenceResults regression = (RegressionInferenceResults) result;
             assertEquals(regression.value(), parsed.getValue());
-            List<FeatureImportance> featureImportance = regression.getFeatureImportance();
+            List<RegressionFeatureImportance> featureImportance = regression.getFeatureImportance();
             if (featureImportance.isEmpty()) {
                 featureImportance = null;
             }
@@ -115,7 +116,7 @@ public class InternalInferenceAggregationTests extends InternalAggregationTestCa
             ClassificationInferenceResults classification = (ClassificationInferenceResults) result;
             assertEquals(classification.predictedValue(), parsed.getValue());
 
-            List<FeatureImportance> featureImportance = classification.getFeatureImportance();
+            List<ClassificationFeatureImportance> featureImportance = classification.getFeatureImportance();
             if (featureImportance.isEmpty()) {
                 featureImportance = null;
             }

+ 6 - 6
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/aggs/ParsedInference.java

@@ -13,7 +13,6 @@ import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentParseException;
 import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.search.aggregations.ParsedAggregation;
-import org.elasticsearch.xpack.core.ml.inference.results.FeatureImportance;
 import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
 import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
@@ -21,6 +20,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConf
 
 import java.io.IOException;
 import java.util.List;
+import java.util.Map;
 
 import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
 import static org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults.PREDICTION_PROBABILITY;
@@ -45,7 +45,7 @@ public class ParsedInference extends ParsedAggregation {
     @SuppressWarnings("unchecked")
     private static final ConstructingObjectParser<ParsedInference, Void> PARSER =
         new ConstructingObjectParser<>(ParsedInference.class.getSimpleName(), true,
-            args -> new ParsedInference(args[0], (List<FeatureImportance>) args[1],
+            args -> new ParsedInference(args[0], (List<Map<String, Object>>) args[1],
                 (List<TopClassEntry>) args[2], (String) args[3], (Double) args[4], (Double) args[5]));
 
     static {
@@ -65,7 +65,7 @@ public class ParsedInference extends ParsedAggregation {
             }
             return o;
         }, CommonFields.VALUE, ObjectParser.ValueType.VALUE);
-        PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> FeatureImportance.fromXContent(p),
+        PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> p.map(),
             new ParseField(SingleValueInferenceResults.FEATURE_IMPORTANCE));
         PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> TopClassEntry.fromXContent(p),
             new ParseField(ClassificationConfig.DEFAULT_TOP_CLASSES_RESULTS_FIELD));
@@ -82,14 +82,14 @@ public class ParsedInference extends ParsedAggregation {
     }
 
     private final Object value;
-    private final List<FeatureImportance> featureImportance;
+    private final List<Map<String, Object>> featureImportance;
     private final List<TopClassEntry> topClasses;
     private final String warning;
     private final Double predictionProbability;
     private final Double predictionScore;
 
     ParsedInference(Object value,
-                    List<FeatureImportance> featureImportance,
+                    List<Map<String, Object>> featureImportance,
                     List<TopClassEntry> topClasses,
                     String warning,
                     Double predictionProbability,
@@ -106,7 +106,7 @@ public class ParsedInference extends ParsedAggregation {
         return value;
     }
 
-    public List<FeatureImportance> getFeatureImportance() {
+    public List<Map<String, Object>> getFeatureImportance() {
         return featureImportance;
     }
 

+ 14 - 9
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java

@@ -9,8 +9,9 @@ import org.elasticsearch.client.Client;
 import org.elasticsearch.ingest.IngestDocument;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction;
+import org.elasticsearch.xpack.core.ml.inference.results.ClassificationFeatureImportance;
 import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
-import org.elasticsearch.xpack.core.ml.inference.results.FeatureImportance;
+import org.elasticsearch.xpack.core.ml.inference.results.RegressionFeatureImportance;
 import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
 import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
@@ -136,9 +137,11 @@ public class InferenceProcessorTests extends ESTestCase {
         classes.add(new TopClassEntry("foo", 0.6, 0.6));
         classes.add(new TopClassEntry("bar", 0.4, 0.4));
 
-        List<FeatureImportance> featureInfluence = new ArrayList<>();
-        featureInfluence.add(FeatureImportance.forRegression("feature_1", 1.13));
-        featureInfluence.add(FeatureImportance.forRegression("feature_2", -42.0));
+        List<ClassificationFeatureImportance> featureInfluence = new ArrayList<>();
+        featureInfluence.add(new ClassificationFeatureImportance("feature_1",
+            Collections.singletonList(new ClassificationFeatureImportance.ClassImportance("class_a", 1.13))));
+        featureInfluence.add(new ClassificationFeatureImportance("feature_2",
+            Collections.singletonList(new ClassificationFeatureImportance.ClassImportance("class_b", -42.0))));
 
         InternalInferModelAction.Response response = new InternalInferModelAction.Response(
             Collections.singletonList(new ClassificationInferenceResults(1.0,
@@ -153,10 +156,12 @@ public class InferenceProcessorTests extends ESTestCase {
 
         assertThat(document.getFieldValue("ml.my_processor.model_id", String.class), equalTo("classification_model"));
         assertThat(document.getFieldValue("ml.my_processor.predicted_value", String.class), equalTo("foo"));
-        assertThat(document.getFieldValue("ml.my_processor.feature_importance.0.importance", Double.class), equalTo(-42.0));
         assertThat(document.getFieldValue("ml.my_processor.feature_importance.0.feature_name", String.class), equalTo("feature_2"));
-        assertThat(document.getFieldValue("ml.my_processor.feature_importance.1.importance", Double.class), equalTo(1.13));
+        assertThat(document.getFieldValue("ml.my_processor.feature_importance.0.classes.0.class_name", String.class), equalTo("class_b"));
+        assertThat(document.getFieldValue("ml.my_processor.feature_importance.0.classes.0.importance", Double.class), equalTo(-42.0));
         assertThat(document.getFieldValue("ml.my_processor.feature_importance.1.feature_name", String.class), equalTo("feature_1"));
+        assertThat(document.getFieldValue("ml.my_processor.feature_importance.1.classes.0.class_name", String.class), equalTo("class_a"));
+        assertThat(document.getFieldValue("ml.my_processor.feature_importance.1.classes.0.importance", Double.class), equalTo(1.13));
     }
 
     @SuppressWarnings("unchecked")
@@ -234,9 +239,9 @@ public class InferenceProcessorTests extends ESTestCase {
         Map<String, Object> ingestMetadata = new HashMap<>();
         IngestDocument document = new IngestDocument(source, ingestMetadata);
 
-        List<FeatureImportance> featureInfluence = new ArrayList<>();
-        featureInfluence.add(FeatureImportance.forRegression("feature_1", 1.13));
-        featureInfluence.add(FeatureImportance.forRegression("feature_2", -42.0));
+        List<RegressionFeatureImportance> featureInfluence = new ArrayList<>();
+        featureInfluence.add(new RegressionFeatureImportance("feature_1", 1.13));
+        featureInfluence.add(new RegressionFeatureImportance("feature_2", -42.0));
 
         InternalInferModelAction.Response response = new InternalInferModelAction.Response(
             Collections.singletonList(new RegressionInferenceResults(0.7, regressionConfig, featureInfluence)), true);