Explorar o código

[ML] adds multi-class feature importance support (#53803)

Adds multi-class feature importance calculation. 

Feature importance objects are now mapped as follows
(logistic) Regression:
```
{
   "feature_name": "feature_0",
   "importance": -1.3
}
```
Multi-class [class names are `foo`, `bar`, `baz`]
```
{ 
   “feature_name”: “feature_0”, 
   “importance”: 2.0, // sum(abs()) of class importances
   “foo”: 1.0, 
   “bar”: 0.5, 
   “baz”: -0.5 
},
```

For users to get the full benefit of aggregating and searching for feature importance, they should update their index mapping as follows (before turning this option on in their pipelines)
```
 "ml.inference.feature_importance": {
          "type": "nested",
          "dynamic": true,
          "properties": {
            "feature_name": {
              "type": "keyword"
            },
            "importance": {
              "type": "double"
            }
          }
        }
```
The mapping field name is as follows
`ml.<inference.target_field>.<inference.tag>.feature_importance`
if `inference.tag` is not provided in the processor definition, it is not part of the field path.
`inference.target_field` is defaulted to `ml.inference`.
//cc @lcawl ^ Where should we document this?

If this makes it in for 7.7, there shouldn't be any feature_importance at inference BWC worries as 7.7 is the first version to have it.
Benjamin Trent %!s(int64=5) %!d(string=hai) anos
pai
achega
756a297ea6
Modificáronse 18 ficheiros con 410 adicións e 138 borrados
  1. 7 4
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java
  2. 97 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/FeatureImportance.java
  3. 3 3
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResults.java
  4. 9 5
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResults.java
  5. 16 15
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/SingleValueInferenceResults.java
  6. 35 5
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java
  7. 2 2
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java
  8. 22 15
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java
  9. 2 2
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/langident/LangIdentNeuralNetwork.java
  10. 34 39
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java
  11. 44 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResultsTests.java
  12. 44 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/FeatureImportanceTests.java
  13. 5 2
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResultsTests.java
  14. 37 2
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResultsTests.java
  15. 21 21
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java
  16. 9 9
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java
  17. 8 4
      x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java
  18. 15 10
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java

+ 7 - 4
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java

@@ -35,13 +35,13 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
                                           String classificationLabel,
                                           List<TopClassEntry> topClasses,
                                           InferenceConfig config) {
-        this(value, classificationLabel, topClasses, Collections.emptyMap(), (ClassificationConfig)config);
+        this(value, classificationLabel, topClasses, Collections.emptyList(), (ClassificationConfig)config);
     }
 
     public ClassificationInferenceResults(double value,
                                           String classificationLabel,
                                           List<TopClassEntry> topClasses,
-                                          Map<String, Double> featureImportance,
+                                          List<FeatureImportance> featureImportance,
                                           InferenceConfig config) {
         this(value, classificationLabel, topClasses, featureImportance, (ClassificationConfig)config);
     }
@@ -49,7 +49,7 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
     private ClassificationInferenceResults(double value,
                                            String classificationLabel,
                                            List<TopClassEntry> topClasses,
-                                           Map<String, Double> featureImportance,
+                                           List<FeatureImportance> featureImportance,
                                            ClassificationConfig classificationConfig) {
         super(value,
             SingleValueInferenceResults.takeTopFeatureImportances(featureImportance,
@@ -118,7 +118,10 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
                 topClasses.stream().map(TopClassEntry::asValueMap).collect(Collectors.toList()));
         }
         if (getFeatureImportance().size() > 0) {
-            document.setFieldValue(parentResultField + ".feature_importance", getFeatureImportance());
+            document.setFieldValue(parentResultField + ".feature_importance", getFeatureImportance()
+                .stream()
+                .map(FeatureImportance::toMap)
+                .collect(Collectors.toList()));
         }
     }
 

+ 97 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/FeatureImportance.java

@@ -0,0 +1,97 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.inference.results;
+
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.io.stream.Writeable;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.LinkedHashMap;
+import java.util.Map;
+import java.util.Objects;
+
+public class FeatureImportance implements Writeable {
+
+    private final Map<String, Double> classImportance;
+    private final double importance;
+    private final String featureName;
+    private static final String IMPORTANCE = "importance";
+    private static final String FEATURE_NAME = "feature_name";
+
+    public static FeatureImportance forRegression(String featureName, double importance) {
+        return new FeatureImportance(featureName, importance, null);
+    }
+
+    public static FeatureImportance forClassification(String featureName, Map<String, Double> classImportance) {
+        return new FeatureImportance(featureName, classImportance.values().stream().mapToDouble(Math::abs).sum(), classImportance);
+    }
+
+    private FeatureImportance(String featureName, double importance, Map<String, Double> classImportance) {
+        this.featureName = Objects.requireNonNull(featureName);
+        this.importance = importance;
+        this.classImportance = classImportance == null ? null : Collections.unmodifiableMap(classImportance);
+    }
+
+    public FeatureImportance(StreamInput in) throws IOException {
+        this.featureName = in.readString();
+        this.importance = in.readDouble();
+        if (in.readBoolean()) {
+            this.classImportance = in.readMap(StreamInput::readString, StreamInput::readDouble);
+        } else {
+            this.classImportance = null;
+        }
+    }
+
+    public Map<String, Double> getClassImportance() {
+        return classImportance;
+    }
+
+    public double getImportance() {
+        return importance;
+    }
+
+    public String getFeatureName() {
+        return featureName;
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeString(this.featureName);
+        out.writeDouble(this.importance);
+        out.writeBoolean(this.classImportance != null);
+        if (this.classImportance != null) {
+            out.writeMap(this.classImportance, StreamOutput::writeString, StreamOutput::writeDouble);
+        }
+    }
+
+    public Map<String, Object> toMap() {
+        Map<String, Object> map = new LinkedHashMap<>();
+        map.put(FEATURE_NAME, featureName);
+        map.put(IMPORTANCE, importance);
+        if (classImportance != null) {
+            classImportance.forEach(map::put);
+        }
+        return map;
+    }
+
+    @Override
+    public boolean equals(Object object) {
+        if (object == this) { return true; }
+        if (object == null || getClass() != object.getClass()) { return false; }
+        FeatureImportance that = (FeatureImportance) object;
+        return Objects.equals(featureName, that.featureName)
+            && Objects.equals(importance, that.importance)
+            && Objects.equals(classImportance, that.classImportance);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(featureName, importance, classImportance);
+    }
+
+}

+ 3 - 3
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResults.java

@@ -18,9 +18,9 @@ public class RawInferenceResults implements InferenceResults {
     public static final String NAME = "raw";
 
     private final double[] value;
-    private final Map<String, Double> featureImportance;
+    private final Map<String, double[]> featureImportance;
 
-    public RawInferenceResults(double[] value, Map<String, Double> featureImportance) {
+    public RawInferenceResults(double[] value, Map<String, double[]> featureImportance) {
         this.value = value;
         this.featureImportance = featureImportance;
     }
@@ -29,7 +29,7 @@ public class RawInferenceResults implements InferenceResults {
         return value;
     }
 
-    public Map<String, Double> getFeatureImportance() {
+    public Map<String, double[]> getFeatureImportance() {
         return featureImportance;
     }
 

+ 9 - 5
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResults.java

@@ -14,8 +14,9 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 
 import java.io.IOException;
 import java.util.Collections;
-import java.util.Map;
+import java.util.List;
 import java.util.Objects;
+import java.util.stream.Collectors;
 
 public class RegressionInferenceResults extends SingleValueInferenceResults {
 
@@ -24,14 +25,14 @@ public class RegressionInferenceResults extends SingleValueInferenceResults {
     private final String resultsField;
 
     public RegressionInferenceResults(double value, InferenceConfig config) {
-        this(value, (RegressionConfig) config, Collections.emptyMap());
+        this(value, (RegressionConfig) config, Collections.emptyList());
     }
 
-    public RegressionInferenceResults(double value, InferenceConfig config, Map<String, Double> featureImportance) {
+    public RegressionInferenceResults(double value, InferenceConfig config, List<FeatureImportance> featureImportance) {
         this(value, (RegressionConfig)config, featureImportance);
     }
 
-    private RegressionInferenceResults(double value, RegressionConfig regressionConfig, Map<String, Double> featureImportance) {
+    private RegressionInferenceResults(double value, RegressionConfig regressionConfig, List<FeatureImportance> featureImportance) {
         super(value,
             SingleValueInferenceResults.takeTopFeatureImportances(featureImportance,
                 regressionConfig.getNumTopFeatureImportanceValues()));
@@ -70,7 +71,10 @@ public class RegressionInferenceResults extends SingleValueInferenceResults {
         ExceptionsHelper.requireNonNull(parentResultField, "parentResultField");
         document.setFieldValue(parentResultField + "." + this.resultsField, value());
         if (getFeatureImportance().size() > 0) {
-            document.setFieldValue(parentResultField + ".feature_importance", getFeatureImportance());
+            document.setFieldValue(parentResultField + ".feature_importance", getFeatureImportance()
+                .stream()
+                .map(FeatureImportance::toMap)
+                .collect(Collectors.toList()));
         }
     }
 

+ 16 - 15
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/SingleValueInferenceResults.java

@@ -8,45 +8,46 @@ package org.elasticsearch.xpack.core.ml.inference.results;
 import org.elasticsearch.Version;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
-import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 
 import java.io.IOException;
 import java.util.Collections;
-import java.util.LinkedHashMap;
-import java.util.Map;
+import java.util.List;
+import java.util.stream.Collectors;
 
 public abstract class SingleValueInferenceResults implements InferenceResults {
 
     private final double value;
-    private final Map<String, Double> featureImportance;
+    private final List<FeatureImportance> featureImportance;
 
-    static Map<String, Double> takeTopFeatureImportances(Map<String, Double> unsortedFeatureImportances, int numTopFeatures) {
-        return unsortedFeatureImportances.entrySet()
-            .stream()
-            .sorted((l, r)-> Double.compare(Math.abs(r.getValue()), Math.abs(l.getValue())))
+    static List<FeatureImportance> takeTopFeatureImportances(List<FeatureImportance> unsortedFeatureImportances, int numTopFeatures) {
+        if (unsortedFeatureImportances == null || unsortedFeatureImportances.isEmpty()) {
+            return unsortedFeatureImportances;
+        }
+        return unsortedFeatureImportances.stream()
+            .sorted((l, r)-> Double.compare(Math.abs(r.getImportance()), Math.abs(l.getImportance())))
             .limit(numTopFeatures)
-            .collect(LinkedHashMap::new, (h, e) -> h.put(e.getKey(), e.getValue()) , LinkedHashMap::putAll);
+            .collect(Collectors.toList());
     }
 
     SingleValueInferenceResults(StreamInput in) throws IOException {
         value = in.readDouble();
         if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
-            this.featureImportance = in.readMap(StreamInput::readString, StreamInput::readDouble);
+            this.featureImportance = in.readList(FeatureImportance::new);
         } else {
-            this.featureImportance = Collections.emptyMap();
+            this.featureImportance = Collections.emptyList();
         }
     }
 
-    SingleValueInferenceResults(double value, Map<String, Double> featureImportance) {
+    SingleValueInferenceResults(double value, List<FeatureImportance> featureImportance) {
         this.value = value;
-        this.featureImportance = ExceptionsHelper.requireNonNull(featureImportance, "featureImportance");
+        this.featureImportance = featureImportance == null ? Collections.emptyList() : featureImportance;
     }
 
     public Double value() {
         return value;
     }
 
-    public Map<String, Double> getFeatureImportance() {
+    public List<FeatureImportance> getFeatureImportance() {
         return featureImportance;
     }
 
@@ -58,7 +59,7 @@ public abstract class SingleValueInferenceResults implements InferenceResults {
     public void writeTo(StreamOutput out) throws IOException {
         out.writeDouble(value);
         if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
-            out.writeMap(this.featureImportance, StreamOutput::writeString, StreamOutput::writeDouble);
+            out.writeList(this.featureImportance);
         }
     }
 

+ 35 - 5
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java

@@ -8,12 +8,14 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
 import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.collect.Tuple;
 import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
+import org.elasticsearch.xpack.core.ml.inference.results.FeatureImportance;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.Comparator;
 import java.util.HashMap;
+import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.stream.Collectors;
@@ -100,18 +102,46 @@ public final class InferenceHelpers {
         return null;
     }
 
-    public static Map<String, Double> decodeFeatureImportances(Map<String, String> processedFeatureToOriginalFeatureMap,
-                                                               Map<String, Double> featureImportances) {
+    public static Map<String, double[]> decodeFeatureImportances(Map<String, String> processedFeatureToOriginalFeatureMap,
+                                                                 Map<String, double[]> featureImportances) {
         if (processedFeatureToOriginalFeatureMap == null || processedFeatureToOriginalFeatureMap.isEmpty()) {
             return featureImportances;
         }
 
-        Map<String, Double> originalFeatureImportance = new HashMap<>();
+        Map<String, double[]> originalFeatureImportance = new HashMap<>();
         featureImportances.forEach((feature, importance) -> {
             String featureName = processedFeatureToOriginalFeatureMap.getOrDefault(feature, feature);
-            originalFeatureImportance.compute(featureName, (f, v1) -> v1 == null ? importance : v1 + importance);
+            originalFeatureImportance.compute(featureName, (f, v1) -> v1 == null ? importance : sumDoubleArrays(importance, v1));
         });
-
         return originalFeatureImportance;
     }
+
+    public static List<FeatureImportance> transformFeatureImportance(Map<String, double[]> featureImportance,
+                                                                     @Nullable List<String> classificationLabels) {
+        List<FeatureImportance> importances = new ArrayList<>(featureImportance.size());
+        featureImportance.forEach((k, v) -> {
+            // This indicates regression, or logistic regression
+            // If the length > 1, we assume multi-class classification.
+            if (v.length == 1) {
+                importances.add(FeatureImportance.forRegression(k, v[0]));
+            } else {
+                Map<String, Double> classImportance = new LinkedHashMap<>(v.length, 1.0f);
+                // If the classificationLabels exist, their length must match leaf_value length
+                assert classificationLabels == null || classificationLabels.size() == v.length;
+                for (int i = 0; i < v.length; i++) {
+                    classImportance.put(classificationLabels == null ? String.valueOf(i) : classificationLabels.get(i), v[i]);
+                }
+                importances.add(FeatureImportance.forClassification(k, classImportance));
+            }
+        });
+        return importances;
+    }
+
+    public static double[] sumDoubleArrays(double[] sumTo, double[] inc) {
+        assert sumTo != null && inc != null && sumTo.length == inc.length;
+        for (int i = 0; i < inc.length; i++) {
+            sumTo[i] += inc[i];
+        }
+        return sumTo;
+    }
 }

+ 2 - 2
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java

@@ -60,9 +60,9 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable, Accou
      * NOTE: Must be thread safe
      * @param fields The fields inferring against
      * @param featureDecoder A Map translating processed feature names to their original feature names
-     * @return A {@code Map<String, Double>} mapping each featureName to its importance
+     * @return A {@code Map<String, double[]>} mapping each featureName to its importance
      */
-    Map<String, Double> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder);
+    Map<String, double[]> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder);
 
     default Version getMinimalCompatibilityVersion() {
         return Version.V_7_6_0;

+ 22 - 15
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java

@@ -45,6 +45,8 @@ import java.util.OptionalDouble;
 import java.util.stream.Collectors;
 
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.classificationLabel;
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.decodeFeatureImportances;
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.transformFeatureImportance;
 
 public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrainedModel {
 
@@ -139,7 +141,7 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
                 "Cannot infer using configuration for [{}] when model target_type is [{}]", config.getName(), targetType.toString());
         }
         double[][] inferenceResults = new double[this.models.size()][];
-        List<Map<String, Double>> featureInfluence = new ArrayList<>();
+        List<Map<String, double[]>> featureInfluence = new ArrayList<>();
         int i = 0;
         NullInferenceConfig subModelInferenceConfig = new NullInferenceConfig(config.requestingImportance());
         for (TrainedModel model : models) {
@@ -152,7 +154,9 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
             }
         }
         double[] processed = outputAggregator.processValues(inferenceResults);
-        return buildResults(processed, featureInfluence, config, featureDecoderMap);
+        return buildResults(processed,
+            decodeFeatureImportances(featureDecoderMap, mergeFeatureImportances(featureInfluence)),
+            config);
     }
 
     @Override
@@ -161,19 +165,19 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
     }
 
     private InferenceResults buildResults(double[] processedInferences,
-                                          List<Map<String, Double>> featureInfluence,
-                                          InferenceConfig config,
-                                          Map<String, String> featureDecoderMap) {
+                                          Map<String, double[]> featureInfluence,
+                                          InferenceConfig config) {
         // Indicates that the config is useless and the caller just wants the raw value
         if (config instanceof NullInferenceConfig) {
-            return new RawInferenceResults(new double[] {outputAggregator.aggregate(processedInferences)},
-                InferenceHelpers.decodeFeatureImportances(featureDecoderMap, mergeFeatureImportances(featureInfluence)));
+            return new RawInferenceResults(
+                new double[] {outputAggregator.aggregate(processedInferences)},
+                featureInfluence);
         }
         switch(targetType) {
             case REGRESSION:
                 return new RegressionInferenceResults(outputAggregator.aggregate(processedInferences),
                     config,
-                    InferenceHelpers.decodeFeatureImportances(featureDecoderMap, mergeFeatureImportances(featureInfluence)));
+                    transformFeatureImportance(featureInfluence, null));
             case CLASSIFICATION:
                 ClassificationConfig classificationConfig = (ClassificationConfig) config;
                 assert classificationWeights == null || processedInferences.length == classificationWeights.length;
@@ -186,7 +190,7 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
                 return new ClassificationInferenceResults((double)topClasses.v1(),
                     classificationLabel(topClasses.v1(), classificationLabels),
                     topClasses.v2(),
-                    InferenceHelpers.decodeFeatureImportances(featureDecoderMap, mergeFeatureImportances(featureInfluence)),
+                    transformFeatureImportance(featureInfluence, classificationLabels),
                     config);
             default:
                 throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on ensemble model");
@@ -313,20 +317,23 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
         return models.stream().allMatch(TrainedModel::supportsFeatureImportance);
     }
 
-    Map<String, Double> featureImportance(Map<String, Object> fields) {
+    Map<String, double[]> featureImportance(Map<String, Object> fields) {
         return featureImportance(fields, Collections.emptyMap());
     }
 
     @Override
-    public Map<String, Double> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder) {
-        Map<String, Double> collapsed = mergeFeatureImportances(models.stream()
+    public Map<String, double[]> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder) {
+        Map<String, double[]> collapsed = mergeFeatureImportances(models.stream()
             .map(trainedModel -> trainedModel.featureImportance(fields, Collections.emptyMap()))
             .collect(Collectors.toList()));
-        return InferenceHelpers.decodeFeatureImportances(featureDecoder, collapsed);
+        return decodeFeatureImportances(featureDecoder, collapsed);
     }
 
-    private static Map<String, Double> mergeFeatureImportances(List<Map<String, Double>> featureImportances) {
-        return featureImportances.stream().collect(HashMap::new, (a, b) -> b.forEach((k, v) -> a.merge(k, v, Double::sum)), Map::putAll);
+    private static Map<String, double[]> mergeFeatureImportances(List<Map<String, double[]>> featureImportances) {
+        return featureImportances.stream()
+            .collect(HashMap::new,
+                (a, b) -> b.forEach((k, v) -> a.merge(k, v, InferenceHelpers::sumDoubleArrays)),
+                Map::putAll);
     }
 
     public static Builder builder() {

+ 2 - 2
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/langident/LangIdentNeuralNetwork.java

@@ -142,7 +142,7 @@ public class LangIdentNeuralNetwork implements StrictlyParsedTrainedModel, Lenie
         return new ClassificationInferenceResults(topClasses.v1(),
             LANGUAGE_NAMES.get(topClasses.v1()),
             topClasses.v2(),
-            Collections.emptyMap(),
+            Collections.emptyList(),
             classificationConfig);
     }
 
@@ -170,7 +170,7 @@ public class LangIdentNeuralNetwork implements StrictlyParsedTrainedModel, Lenie
     }
 
     @Override
-    public Map<String, Double> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder) {
+    public Map<String, double[]> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder) {
         throw new UnsupportedOperationException("[lang_ident] does not support feature importance");
     }
 

+ 34 - 39
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java

@@ -91,8 +91,8 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
     private final List<String> classificationLabels;
     private final CachedSupplier<Double> highestOrderCategory;
     // populated lazily when feature importance is calculated
-    private double[] nodeEstimates;
     private Integer maxDepth;
+    private Integer leafSize;
 
     Tree(List<String> featureNames, List<TreeNode> nodes, TargetType targetType, List<String> classificationLabels) {
         this.featureNames = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(featureNames, FEATURE_NAMES));
@@ -137,7 +137,7 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
             .map(f -> InferenceHelpers.toDouble(MapHelper.dig(f, fields)))
             .collect(Collectors.toList());
 
-        Map<String, Double> featureImportance = config.requestingImportance() ?
+        Map<String, double[]> featureImportance = config.requestingImportance() ?
             featureImportance(features, featureDecoderMap) :
             Collections.emptyMap();
 
@@ -149,7 +149,7 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
         return buildResult(node.getLeafValue(), featureImportance, config);
     }
 
-    private InferenceResults buildResult(double[] value, Map<String, Double> featureImportance, InferenceConfig config) {
+    private InferenceResults buildResult(double[] value, Map<String, double[]> featureImportance, InferenceConfig config) {
         assert value != null && value.length > 0;
         // Indicates that the config is useless and the caller just wants the raw value
         if (config instanceof NullInferenceConfig) {
@@ -166,10 +166,12 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
                 return new ClassificationInferenceResults(topClasses.v1(),
                     classificationLabel(topClasses.v1(), classificationLabels),
                     topClasses.v2(),
-                    featureImportance,
+                    InferenceHelpers.transformFeatureImportance(featureImportance, classificationLabels),
                     config);
             case REGRESSION:
-                return new RegressionInferenceResults(value[0], config, featureImportance);
+                return new RegressionInferenceResults(value[0],
+                    config,
+                    InferenceHelpers.transformFeatureImportance(featureImportance, null));
             default:
                 throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on tree model");
         }
@@ -283,7 +285,7 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
     }
 
     @Override
-    public Map<String, Double> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder) {
+    public Map<String, double[]> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder) {
         if (nodes.stream().allMatch(n -> n.getNumberSamples() == 0)) {
             throw ExceptionsHelper.badRequestException("[tree_structure.number_samples] must be greater than zero for feature importance");
         }
@@ -293,9 +295,12 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
         return featureImportance(features, featureDecoder);
     }
 
-    private Map<String, Double> featureImportance(List<Double> fieldValues, Map<String, String> featureDecoder) {
-        calculateNodeEstimatesIfNeeded();
-        double[] featureImportance = new double[fieldValues.size()];
+    private Map<String, double[]> featureImportance(List<Double> fieldValues, Map<String, String> featureDecoder) {
+        calculateDepthAndLeafValueSize();
+        double[][] featureImportance = new double[fieldValues.size()][leafSize];
+        for (int i = 0; i < fieldValues.size(); i++) {
+            featureImportance[i] = new double[leafSize];
+        }
         int arrSize = ((this.maxDepth + 1) * (this.maxDepth + 2))/2;
         ShapPath.PathElement[] elements = new ShapPath.PathElement[arrSize];
         for (int i = 0; i < arrSize; i++) {
@@ -303,24 +308,22 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
         }
         double[] scale = new double[arrSize];
         ShapPath initialPath = new ShapPath(elements, scale);
-        shapRecursive(fieldValues, this.nodeEstimates, initialPath, 0, 1.0, 1.0, -1, featureImportance, 0);
+        shapRecursive(fieldValues, initialPath, 0, 1.0, 1.0, -1, featureImportance, 0);
         return InferenceHelpers.decodeFeatureImportances(featureDecoder,
             IntStream.range(0, featureImportance.length)
                 .boxed()
                 .collect(Collectors.toMap(featureNames::get, i -> featureImportance[i])));
     }
 
-    private void calculateNodeEstimatesIfNeeded() {
-        if (this.nodeEstimates != null && this.maxDepth != null) {
+    private void calculateDepthAndLeafValueSize() {
+        if (this.maxDepth != null && this.leafSize != null) {
             return;
         }
         synchronized (this) {
-            if (this.nodeEstimates != null && this.maxDepth != null) {
+            if (this.maxDepth != null && this.leafSize != null) {
                 return;
             }
-            double[] estimates = new double[nodes.size()];
-            this.maxDepth = fillNodeEstimates(estimates, 0, 0);
-            this.nodeEstimates = estimates;
+            this.maxDepth = getDepth(0, 0);
         }
     }
 
@@ -331,23 +334,24 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
      * side first and then ported to the Java side.
      */
     private void shapRecursive(List<Double> processedFeatures,
-                               double[] nodeValues,
                                ShapPath parentSplitPath,
                                int nodeIndex,
                                double parentFractionZero,
                                double parentFractionOne,
                                int parentFeatureIndex,
-                               double[] featureImportance,
+                               double[][] featureImportance,
                                int nextIndex) {
         ShapPath splitPath = new ShapPath(parentSplitPath, nextIndex);
         TreeNode currNode = nodes.get(nodeIndex);
         nextIndex = splitPath.extend(parentFractionZero, parentFractionOne, parentFeatureIndex, nextIndex);
         if (currNode.isLeaf()) {
-            double leafValue = nodeValues[nodeIndex];
+            double[] leafValue = currNode.getLeafValue();
             for (int i = 1; i < nextIndex; ++i) {
-                double scale = splitPath.sumUnwoundPath(i, nextIndex);
                 int inputColumnIndex = splitPath.featureIndex(i);
-                featureImportance[inputColumnIndex] += scale * (splitPath.fractionOnes(i) - splitPath.fractionZeros(i)) * leafValue;
+                double scaled = splitPath.sumUnwoundPath(i, nextIndex) * (splitPath.fractionOnes(i) - splitPath.fractionZeros(i));
+                for (int j = 0; j < leafValue.length; j++) {
+                    featureImportance[inputColumnIndex][j] += scaled * leafValue[j];
+                }
             }
         } else {
             int hotIndex = currNode.compare(processedFeatures);
@@ -365,41 +369,32 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
 
             double hotFractionZero = nodes.get(hotIndex).getNumberSamples() / (double)currNode.getNumberSamples();
             double coldFractionZero = nodes.get(coldIndex).getNumberSamples() / (double)currNode.getNumberSamples();
-            shapRecursive(processedFeatures, nodeValues, splitPath,
+            shapRecursive(processedFeatures, splitPath,
                 hotIndex, incomingFractionZero * hotFractionZero,
                 incomingFractionOne, splitFeature, featureImportance, nextIndex);
-            shapRecursive(processedFeatures, nodeValues, splitPath,
+            shapRecursive(processedFeatures, splitPath,
                 coldIndex, incomingFractionZero * coldFractionZero,
                 0.0, splitFeature, featureImportance, nextIndex);
         }
     }
 
     /**
-     * This recursively populates the provided {@code double[]} with the node estimated values
+     * Get the depth of the tree and sets leafSize if it is null
      *
-     * Used when calculating feature importance.
-     * @param nodeEstimates Array to update in place with the node estimated values
      * @param nodeIndex Current node index
      * @param depth Current depth
      * @return The current max depth
      */
-    private int fillNodeEstimates(double[] nodeEstimates, int nodeIndex, int depth) {
+    private int getDepth(int nodeIndex, int depth) {
         TreeNode node = nodes.get(nodeIndex);
         if (node.isLeaf()) {
-            // TODO multi-value????
-            nodeEstimates[nodeIndex] = node.getLeafValue()[0];
+            if (leafSize == null) {
+                this.leafSize = node.getLeafValue().length;
+            }
             return 0;
         }
-
-        int depthLeft = fillNodeEstimates(nodeEstimates, node.getLeftChild(), depth + 1);
-        int depthRight = fillNodeEstimates(nodeEstimates, node.getRightChild(), depth + 1);
-        long leftWeight = nodes.get(node.getLeftChild()).getNumberSamples();
-        long rightWeight = nodes.get(node.getRightChild()).getNumberSamples();
-        long divisor = leftWeight + rightWeight;
-        double averageValue = divisor == 0 ?
-            0.0 :
-            (leftWeight * nodeEstimates[node.getLeftChild()] + rightWeight * nodeEstimates[node.getRightChild()]) / divisor;
-        nodeEstimates[nodeIndex] = averageValue;
+        int depthLeft = getDepth(node.getLeftChild(), depth + 1);
+        int depthRight = getDepth(node.getRightChild(), depth + 1);
         return Math.max(depthLeft, depthRight) + 1;
     }
 

+ 44 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResultsTests.java

@@ -16,20 +16,30 @@ import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.function.Supplier;
 import java.util.stream.Collectors;
 import java.util.stream.Stream;
 
 import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.hasSize;
 
 public class ClassificationInferenceResultsTests extends AbstractWireSerializingTestCase<ClassificationInferenceResults> {
 
     public static ClassificationInferenceResults createRandomResults() {
+        Supplier<FeatureImportance> featureImportanceCtor = randomBoolean() ?
+            FeatureImportanceTests::randomClassification :
+            FeatureImportanceTests::randomRegression;
+
         return new ClassificationInferenceResults(randomDouble(),
             randomBoolean() ? null : randomAlphaOfLength(10),
             randomBoolean() ? null :
                 Stream.generate(ClassificationInferenceResultsTests::createRandomClassEntry)
                     .limit(randomIntBetween(0, 10))
                     .collect(Collectors.toList()),
+            randomBoolean() ? null :
+                Stream.generate(featureImportanceCtor)
+                    .limit(randomIntBetween(1, 10))
+                    .collect(Collectors.toList()),
             ClassificationConfigTests.randomClassificationConfig());
     }
 
@@ -81,6 +91,40 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
         assertThat(document.getFieldValue("result_field.my_results", String.class), equalTo("foo"));
     }
 
+    public void testWriteResultsWithImportance() {
+        Supplier<FeatureImportance> featureImportanceCtor = randomBoolean() ?
+            FeatureImportanceTests::randomClassification :
+            FeatureImportanceTests::randomRegression;
+
+        List<FeatureImportance> importanceList = Stream.generate(featureImportanceCtor)
+            .limit(5)
+            .collect(Collectors.toList());
+        ClassificationInferenceResults result = new ClassificationInferenceResults(0.0,
+            "foo",
+            Collections.emptyList(),
+            importanceList,
+            new ClassificationConfig(0, "predicted_value", "top_classes", 3));
+        IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
+        result.writeResult(document, "result_field");
+
+        assertThat(document.getFieldValue("result_field.predicted_value", String.class), equalTo("foo"));
+        @SuppressWarnings("unchecked")
+        List<Map<String, Object>> writtenImportance = (List<Map<String, Object>>)document.getFieldValue(
+            "result_field.feature_importance",
+            List.class);
+        assertThat(writtenImportance, hasSize(3));
+        importanceList.sort((l, r)-> Double.compare(Math.abs(r.getImportance()), Math.abs(l.getImportance())));
+        for (int i = 0; i < 3; i++) {
+            Map<String, Object> objectMap = writtenImportance.get(i);
+            FeatureImportance importance = importanceList.get(i);
+            assertThat(objectMap.get("feature_name"), equalTo(importance.getFeatureName()));
+            assertThat(objectMap.get("importance"), equalTo(importance.getImportance()));
+            if (importance.getClassImportance() != null) {
+                importance.getClassImportance().forEach((k, v) -> assertThat(objectMap.get(k), equalTo(v)));
+            }
+        }
+    }
+
     @Override
     protected ClassificationInferenceResults createTestInstance() {
         return createRandomResults();

+ 44 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/FeatureImportanceTests.java

@@ -0,0 +1,44 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.inference.results;
+
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+
+import java.util.function.Function;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+
+public class FeatureImportanceTests extends AbstractWireSerializingTestCase<FeatureImportance> {
+
+    public static FeatureImportance createRandomInstance() {
+        return randomBoolean() ? randomClassification() : randomRegression();
+    }
+
+    static FeatureImportance randomRegression() {
+        return FeatureImportance.forRegression(randomAlphaOfLength(10), randomDoubleBetween(-10.0, 10.0, false));
+    }
+
+    static FeatureImportance randomClassification() {
+        return FeatureImportance.forClassification(
+            randomAlphaOfLength(10),
+            Stream.generate(() -> randomAlphaOfLength(10))
+                .limit(randomLongBetween(2, 10))
+                .collect(Collectors.toMap(Function.identity(), (k) -> randomDoubleBetween(-10, 10, false))));
+
+    }
+
+    @Override
+    protected FeatureImportance createTestInstance() {
+        return createRandomInstance();
+    }
+
+    @Override
+    protected Writeable.Reader<FeatureImportance> instanceReader() {
+        return FeatureImportance::new;
+    }
+}

+ 5 - 2
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResultsTests.java

@@ -22,7 +22,8 @@ public class RawInferenceResultsTests extends ESTestCase {
         for (int i = 0; i < n; i++) {
             results[i] = randomDouble();
         }
-        return new RawInferenceResults(results, randomBoolean() ? Collections.emptyMap() : Collections.singletonMap("foo", 1.08));
+        return new RawInferenceResults(results,
+            randomBoolean() ? Collections.emptyMap() : Collections.singletonMap("foo", new double[]{1.08}));
     }
 
     public void testEqualityAndHashcode() {
@@ -31,7 +32,9 @@ public class RawInferenceResultsTests extends ESTestCase {
         for (int i = 0; i < n; i++) {
             results[i] = randomDouble();
         }
-        Map<String, Double> importance = randomBoolean() ? Collections.emptyMap() : Collections.singletonMap("foo", 1.08);
+        Map<String, double[]> importance = randomBoolean() ?
+            Collections.emptyMap() :
+            Collections.singletonMap("foo", new double[]{1.08, 42.0});
         RawInferenceResults lft = new RawInferenceResults(results, new HashMap<>(importance));
         RawInferenceResults rgt = new RawInferenceResults(Arrays.copyOf(results, n), new HashMap<>(importance));
         assertThat(lft, equalTo(rgt));

+ 37 - 2
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResultsTests.java

@@ -8,19 +8,28 @@ package org.elasticsearch.xpack.core.ml.inference.results;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.ingest.IngestDocument;
 import org.elasticsearch.test.AbstractWireSerializingTestCase;
-import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigTests;
 
 import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
 
 import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.hasSize;
 
 
 public class RegressionInferenceResultsTests extends AbstractWireSerializingTestCase<RegressionInferenceResults> {
 
     public static RegressionInferenceResults createRandomResults() {
-        return new RegressionInferenceResults(randomDouble(), RegressionConfigTests.randomRegressionConfig());
+        return new RegressionInferenceResults(randomDouble(),
+            RegressionConfigTests.randomRegressionConfig(),
+            randomBoolean() ? null :
+                Stream.generate(FeatureImportanceTests::randomRegression)
+                    .limit(randomIntBetween(1, 10))
+                    .collect(Collectors.toList()));
     }
 
     public void testWriteResults() {
@@ -31,6 +40,32 @@ public class RegressionInferenceResultsTests extends AbstractWireSerializingTest
         assertThat(document.getFieldValue("result_field.predicted_value", Double.class), equalTo(0.3));
     }
 
+    public void testWriteResultsWithImportance() {
+        List<FeatureImportance> importanceList = Stream.generate(FeatureImportanceTests::randomRegression)
+            .limit(5)
+            .collect(Collectors.toList());
+        RegressionInferenceResults result = new RegressionInferenceResults(0.3,
+            new RegressionConfig("predicted_value", 3),
+            importanceList);
+        IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
+        result.writeResult(document, "result_field");
+
+        assertThat(document.getFieldValue("result_field.predicted_value", Double.class), equalTo(0.3));
+        @SuppressWarnings("unchecked")
+        List<Map<String, Object>> writtenImportance = (List<Map<String, Object>>)document.getFieldValue(
+            "result_field.feature_importance",
+            List.class);
+        assertThat(writtenImportance, hasSize(3));
+        importanceList.sort((l, r)-> Double.compare(Math.abs(r.getImportance()), Math.abs(l.getImportance())));
+        for (int i = 0; i < 3; i++) {
+            Map<String, Object> objectMap = writtenImportance.get(i);
+            FeatureImportance importance = importanceList.get(i);
+            assertThat(objectMap.get("feature_name"), equalTo(importance.getFeatureName()));
+            assertThat(objectMap.get("importance"), equalTo(importance.getImportance()));
+            assertThat(objectMap.size(), equalTo(2));
+        }
+    }
+
     @Override
     protected RegressionInferenceResults createTestInstance() {
         return createRandomResults();

+ 21 - 21
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java

@@ -684,45 +684,45 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
             .build();
 
 
-        Map<String, Double> featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.0, 0.9)));
-        assertThat(featureImportance.get("foo"), closeTo(-1.653200025, eps));
-        assertThat(featureImportance.get("bar"), closeTo( -0.12444978, eps));
+        Map<String, double[]> featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.0, 0.9)));
+        assertThat(featureImportance.get("foo")[0], closeTo(-1.653200025, eps));
+        assertThat(featureImportance.get("bar")[0], closeTo( -0.12444978, eps));
 
         featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.1, 0.8)));
-        assertThat(featureImportance.get("foo"), closeTo(-1.653200025, eps));
-        assertThat(featureImportance.get("bar"), closeTo( -0.12444978, eps));
+        assertThat(featureImportance.get("foo")[0], closeTo(-1.653200025, eps));
+        assertThat(featureImportance.get("bar")[0], closeTo( -0.12444978, eps));
 
         featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.2, 0.7)));
-        assertThat(featureImportance.get("foo"), closeTo(-1.653200025, eps));
-        assertThat(featureImportance.get("bar"), closeTo( -0.12444978, eps));
+        assertThat(featureImportance.get("foo")[0], closeTo(-1.653200025, eps));
+        assertThat(featureImportance.get("bar")[0], closeTo( -0.12444978, eps));
 
         featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.3, 0.6)));
-        assertThat(featureImportance.get("foo"), closeTo(-1.16997162, eps));
-        assertThat(featureImportance.get("bar"), closeTo( -0.12444978, eps));
+        assertThat(featureImportance.get("foo")[0], closeTo(-1.16997162, eps));
+        assertThat(featureImportance.get("bar")[0], closeTo( -0.12444978, eps));
 
         featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.4, 0.5)));
-        assertThat(featureImportance.get("foo"), closeTo(-1.16997162, eps));
-        assertThat(featureImportance.get("bar"), closeTo( -0.12444978, eps));
+        assertThat(featureImportance.get("foo")[0], closeTo(-1.16997162, eps));
+        assertThat(featureImportance.get("bar")[0], closeTo( -0.12444978, eps));
 
         featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.5, 0.4)));
-        assertThat(featureImportance.get("foo"), closeTo(0.0798679, eps));
-        assertThat(featureImportance.get("bar"), closeTo( -0.12444978, eps));
+        assertThat(featureImportance.get("foo")[0], closeTo(0.0798679, eps));
+        assertThat(featureImportance.get("bar")[0], closeTo( -0.12444978, eps));
 
         featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.6, 0.3)));
-        assertThat(featureImportance.get("foo"), closeTo(1.80491886, eps));
-        assertThat(featureImportance.get("bar"), closeTo(-0.4355742, eps));
+        assertThat(featureImportance.get("foo")[0], closeTo(1.80491886, eps));
+        assertThat(featureImportance.get("bar")[0], closeTo(-0.4355742, eps));
 
         featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.7, 0.2)));
-        assertThat(featureImportance.get("foo"), closeTo(2.0538184, eps));
-        assertThat(featureImportance.get("bar"), closeTo(0.1451914, eps));
+        assertThat(featureImportance.get("foo")[0], closeTo(2.0538184, eps));
+        assertThat(featureImportance.get("bar")[0], closeTo(0.1451914, eps));
 
         featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.8, 0.1)));
-        assertThat(featureImportance.get("foo"), closeTo(2.0538184, eps));
-        assertThat(featureImportance.get("bar"), closeTo(0.1451914, eps));
+        assertThat(featureImportance.get("foo")[0], closeTo(2.0538184, eps));
+        assertThat(featureImportance.get("bar")[0], closeTo(0.1451914, eps));
 
         featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.9, 0.0)));
-        assertThat(featureImportance.get("foo"), closeTo(2.0538184, eps));
-        assertThat(featureImportance.get("bar"), closeTo(0.1451914, eps));
+        assertThat(featureImportance.get("foo")[0], closeTo(2.0538184, eps));
+        assertThat(featureImportance.get("bar")[0], closeTo(0.1451914, eps));
     }
 
 

+ 9 - 9
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java

@@ -390,22 +390,22 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
                 TreeNode.builder(5).setLeafValue(13.0).setNumberSamples(1L),
                 TreeNode.builder(6).setLeafValue(18.0).setNumberSamples(1L)).build();
 
-        Map<String, Double> featureImportance = tree.featureImportance(zipObjMap(featureNames, Arrays.asList(0.25, 0.25)),
+        Map<String, double[]> featureImportance = tree.featureImportance(zipObjMap(featureNames, Arrays.asList(0.25, 0.25)),
             Collections.emptyMap());
-        assertThat(featureImportance.get("foo"), closeTo(-5.0, eps));
-        assertThat(featureImportance.get("bar"), closeTo(-2.5, eps));
+        assertThat(featureImportance.get("foo")[0], closeTo(-5.0, eps));
+        assertThat(featureImportance.get("bar")[0], closeTo(-2.5, eps));
 
         featureImportance = tree.featureImportance(zipObjMap(featureNames, Arrays.asList(0.25, 0.75)), Collections.emptyMap());
-        assertThat(featureImportance.get("foo"), closeTo(-5.0, eps));
-        assertThat(featureImportance.get("bar"), closeTo(2.5, eps));
+        assertThat(featureImportance.get("foo")[0], closeTo(-5.0, eps));
+        assertThat(featureImportance.get("bar")[0], closeTo(2.5, eps));
 
         featureImportance = tree.featureImportance(zipObjMap(featureNames, Arrays.asList(0.75, 0.25)), Collections.emptyMap());
-        assertThat(featureImportance.get("foo"), closeTo(5.0, eps));
-        assertThat(featureImportance.get("bar"), closeTo(-2.5, eps));
+        assertThat(featureImportance.get("foo")[0], closeTo(5.0, eps));
+        assertThat(featureImportance.get("bar")[0], closeTo(-2.5, eps));
 
         featureImportance = tree.featureImportance(zipObjMap(featureNames, Arrays.asList(0.75, 0.75)), Collections.emptyMap());
-        assertThat(featureImportance.get("foo"), closeTo(5.0, eps));
-        assertThat(featureImportance.get("bar"), closeTo(2.5, eps));
+        assertThat(featureImportance.get("foo")[0], closeTo(5.0, eps));
+        assertThat(featureImportance.get("bar")[0], closeTo(2.5, eps));
     }
 
     public void testMaxFeatureIndex() {

+ 8 - 4
x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java

@@ -156,8 +156,10 @@ public class InferenceIngestIT extends ESRestTestCase {
         String responseString = EntityUtils.toString(response.getEntity());
         assertThat(responseString, containsString("\"predicted_value\":\"second\""));
         assertThat(responseString, containsString("\"predicted_value\":1.0"));
-        assertThat(responseString, containsString("\"col2\":0.944"));
-        assertThat(responseString, containsString("\"col1\":0.19999"));
+        assertThat(responseString, containsString("\"feature_name\":\"col1\""));
+        assertThat(responseString, containsString("\"feature_name\":\"col2\""));
+        assertThat(responseString, containsString("\"importance\":0.944"));
+        assertThat(responseString, containsString("\"importance\":0.19999"));
 
         String sourceWithMissingModel = "{\n" +
             "  \"pipeline\": {\n" +
@@ -221,8 +223,10 @@ public class InferenceIngestIT extends ESRestTestCase {
         Response response = client().performRequest(simulateRequest(source));
         String responseString = EntityUtils.toString(response.getEntity());
         assertThat(responseString, containsString("\"predicted_value\":\"second\""));
-        assertThat(responseString, containsString("\"col2\":0.944"));
-        assertThat(responseString, containsString("\"col1\":0.19999"));
+        assertThat(responseString, containsString("\"feature_name\":\"col1\""));
+        assertThat(responseString, containsString("\"feature_name\":\"col2\""));
+        assertThat(responseString, containsString("\"importance\":0.944"));
+        assertThat(responseString, containsString("\"importance\":0.19999"));
     }
 
     public void testSimulateLangIdent() throws IOException {

+ 15 - 10
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java

@@ -10,6 +10,7 @@ import org.elasticsearch.ingest.IngestDocument;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction;
 import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
+import org.elasticsearch.xpack.core.ml.inference.results.FeatureImportance;
 import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
@@ -120,9 +121,9 @@ public class InferenceProcessorTests extends ESTestCase {
         classes.add(new ClassificationInferenceResults.TopClassEntry("foo", 0.6));
         classes.add(new ClassificationInferenceResults.TopClassEntry("bar", 0.4));
 
-        Map<String, Double> featureInfluence = new HashMap<>();
-        featureInfluence.put("feature_1", 1.13);
-        featureInfluence.put("feature_2", -42.0);
+        List<FeatureImportance> featureInfluence = new ArrayList<>();
+        featureInfluence.add(FeatureImportance.forRegression("feature_1", 1.13));
+        featureInfluence.add(FeatureImportance.forRegression("feature_2", -42.0));
 
         InternalInferModelAction.Response response = new InternalInferModelAction.Response(
             Collections.singletonList(new ClassificationInferenceResults(1.0,
@@ -135,8 +136,10 @@ public class InferenceProcessorTests extends ESTestCase {
 
         assertThat(document.getFieldValue("ml.my_processor.model_id", String.class), equalTo("classification_model"));
         assertThat(document.getFieldValue("ml.my_processor.predicted_value", String.class), equalTo("foo"));
-        assertThat(document.getFieldValue("ml.my_processor.feature_importance.feature_1", Double.class), equalTo(1.13));
-        assertThat(document.getFieldValue("ml.my_processor.feature_importance.feature_2", Double.class), equalTo(-42.0));
+        assertThat(document.getFieldValue("ml.my_processor.feature_importance.0.importance", Double.class), equalTo(-42.0));
+        assertThat(document.getFieldValue("ml.my_processor.feature_importance.0.feature_name", String.class), equalTo("feature_2"));
+        assertThat(document.getFieldValue("ml.my_processor.feature_importance.1.importance", Double.class), equalTo(1.13));
+        assertThat(document.getFieldValue("ml.my_processor.feature_importance.1.feature_name", String.class), equalTo("feature_1"));
     }
 
     @SuppressWarnings("unchecked")
@@ -205,9 +208,9 @@ public class InferenceProcessorTests extends ESTestCase {
         Map<String, Object> ingestMetadata = new HashMap<>();
         IngestDocument document = new IngestDocument(source, ingestMetadata);
 
-        Map<String, Double> featureInfluence = new HashMap<>();
-        featureInfluence.put("feature_1", 1.13);
-        featureInfluence.put("feature_2", -42.0);
+        List<FeatureImportance> featureInfluence = new ArrayList<>();
+        featureInfluence.add(FeatureImportance.forRegression("feature_1", 1.13));
+        featureInfluence.add(FeatureImportance.forRegression("feature_2", -42.0));
 
         InternalInferModelAction.Response response = new InternalInferModelAction.Response(
             Collections.singletonList(new RegressionInferenceResults(0.7, regressionConfig, featureInfluence)), true);
@@ -215,8 +218,10 @@ public class InferenceProcessorTests extends ESTestCase {
 
         assertThat(document.getFieldValue("ml.my_processor.foo", Double.class), equalTo(0.7));
         assertThat(document.getFieldValue("ml.my_processor.model_id", String.class), equalTo("regression_model"));
-        assertThat(document.getFieldValue("ml.my_processor.feature_importance.feature_1", Double.class), equalTo(1.13));
-        assertThat(document.getFieldValue("ml.my_processor.feature_importance.feature_2", Double.class), equalTo(-42.0));
+        assertThat(document.getFieldValue("ml.my_processor.feature_importance.0.importance", Double.class), equalTo(-42.0));
+        assertThat(document.getFieldValue("ml.my_processor.feature_importance.0.feature_name", String.class), equalTo("feature_2"));
+        assertThat(document.getFieldValue("ml.my_processor.feature_importance.1.importance", Double.class), equalTo(1.13));
+        assertThat(document.getFieldValue("ml.my_processor.feature_importance.1.feature_name", String.class), equalTo("feature_1"));
     }
 
     public void testGenerateRequestWithEmptyMapping() {