Browse Source

[ML] Adds feature importance to option to inference processor (#52218)

This adds machine learning model feature importance calculations to the inference processor. 

The new flag in the configuration matches the analytics parameter name: `num_top_feature_importance_values`
Example:
```
"inference": {
   "field_mappings": {},
   "model_id": "my_model",
   "inference_config": {
      "regression": {
         "num_top_feature_importance_values": 3
      }
   }
}
```

This will write to the document as follows:
```
"inference" : {
   "feature_importance" : { 
      "FlightTimeMin" : -76.90955548511226,
      "FlightDelayType" : 114.13514762158526,
      "DistanceMiles" : 13.731580450792187
   },
   "predicted_value" : 108.33165831875137,
   "model_id" : "my_model"
}
```

This is done through calculating the [SHAP values](https://arxiv.org/abs/1802.03888). 

It requires that models have populated `number_samples` for each tree node. This is not available to models that were created before 7.7. 

Additionally, if the inference config is requesting feature_importance, and not all nodes have been upgraded yet, it will not allow the pipeline to be created. This is to safe-guard in a mixed-version environment where only some ingest nodes have been upgraded.

NOTE: the algorithm is a Java port of the one laid out in ml-cpp: https://github.com/elastic/ml-cpp/blob/master/lib/maths/CTreeShapFeatureImportance.cc

usability blocked by: https://github.com/elastic/ml-cpp/pull/991
Benjamin Trent 5 years ago
parent
commit
20f54272f0
29 changed files with 980 additions and 104 deletions
  1. 12 0
      docs/reference/ingest/processors/inference.asciidoc
  2. 26 2
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java
  3. 6 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/CustomWordEmbedding.java
  4. 5 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncoding.java
  5. 7 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncoding.java
  6. 5 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/PreProcessor.java
  7. 5 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncoding.java
  8. 29 9
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java
  9. 7 5
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResults.java
  10. 21 6
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResults.java
  11. 28 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/SingleValueInferenceResults.java
  12. 45 9
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfig.java
  13. 4 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfig.java
  14. 17 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java
  15. 12 3
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NullInferenceConfig.java
  16. 40 6
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfig.java
  17. 162 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ShapPath.java
  18. 21 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java
  19. 46 10
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java
  20. 17 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/langident/LangIdentNeuralNetwork.java
  21. 142 9
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java
  22. 3 7
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinitionTests.java
  23. 3 1
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResultsTests.java
  24. 2 1
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigTests.java
  25. 2 1
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfigTests.java
  26. 155 19
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java
  27. 75 12
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java
  28. 18 1
      x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java
  29. 65 0
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java

+ 12 - 0
docs/reference/ingest/processors/inference.asciidoc

@@ -44,6 +44,12 @@ include::common-options.asciidoc[]
 Specifies the field to which the inference prediction is written. Defaults to 
 `predicted_value`.
 
+`num_top_feature_importance_values`::::
+(Optional, integer)
+Specifies the maximum number of
+{ml-docs}/dfa-regression.html#dfa-regression-feature-importance[feature
+importance] values per document. By default, it is zero and no feature importance
+calculation occurs.
 
 [discrete]
 [[inference-processor-classification-opt]]
@@ -63,6 +69,12 @@ Specifies the number of top class predictions to return. Defaults to 0.
 Specifies the field to which the top classes are written. Defaults to 
 `top_classes`.
 
+`num_top_feature_importance_values`::::
+(Optional, integer)
+Specifies the maximum number of
+{ml-docs}/dfa-classification.html#dfa-classification-feature-importance[feature
+importance] values per document. By default, it is zero and no feature importance
+calculation occurs.
 
 [discrete]
 [[inference-processor-config-example]]

+ 26 - 2
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java

@@ -32,6 +32,7 @@ import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
@@ -73,6 +74,7 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable, Acco
 
     private final TrainedModel trainedModel;
     private final List<PreProcessor> preProcessors;
+    private Map<String, String> decoderMap;
 
     private TrainedModelDefinition(TrainedModel trainedModel, List<PreProcessor> preProcessors) {
         this.trainedModel = ExceptionsHelper.requireNonNull(trainedModel, TRAINED_MODEL);
@@ -115,13 +117,35 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable, Acco
         return preProcessors;
     }
 
-    private void preProcess(Map<String, Object> fields) {
+    void preProcess(Map<String, Object> fields) {
         preProcessors.forEach(preProcessor -> preProcessor.process(fields));
     }
 
     public InferenceResults infer(Map<String, Object> fields, InferenceConfig config) {
         preProcess(fields);
-        return trainedModel.infer(fields, config);
+        if (config.requestingImportance() && trainedModel.supportsFeatureImportance() == false) {
+            throw ExceptionsHelper.badRequestException(
+                "Feature importance is not supported for the configured model of type [{}]",
+                trainedModel.getName());
+        }
+        return trainedModel.infer(fields,
+            config,
+            config.requestingImportance() ? getDecoderMap() : Collections.emptyMap());
+    }
+
+    private Map<String, String> getDecoderMap() {
+        if (decoderMap != null) {
+            return decoderMap;
+        }
+        synchronized (this) {
+            if (decoderMap != null) {
+                return decoderMap;
+            }
+            this.decoderMap = preProcessors.stream()
+                .map(PreProcessor::reverseLookup)
+                .collect(HashMap::new, Map::putAll, Map::putAll);
+            return decoderMap;
+        }
     }
 
     @Override

+ 6 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/CustomWordEmbedding.java

@@ -25,6 +25,7 @@ import org.elasticsearch.xpack.core.ml.inference.preprocessing.customwordembeddi
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
@@ -235,6 +236,11 @@ public class CustomWordEmbedding implements LenientlyParsedPreProcessor, Strictl
         fields.put(destField, concatEmbeddings(processedFeatures));
     }
 
+    @Override
+    public Map<String, String> reverseLookup() {
+        return Collections.singletonMap(destField, fieldName);
+    }
+
     @Override
     public long ramBytesUsed() {
         long size = SHALLOW_SIZE;

+ 5 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncoding.java

@@ -97,6 +97,11 @@ public class FrequencyEncoding implements LenientlyParsedPreProcessor, StrictlyP
         return featureName;
     }
 
+    @Override
+    public Map<String, String> reverseLookup() {
+        return Collections.singletonMap(featureName, field);
+    }
+
     @Override
     public String getName() {
         return NAME.getPreferredName();

+ 7 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncoding.java

@@ -18,8 +18,10 @@ import org.elasticsearch.xpack.core.ml.utils.MapHelper;
 
 import java.io.IOException;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.Map;
 import java.util.Objects;
+import java.util.stream.Collectors;
 
 /**
  * PreProcessor for one hot encoding a set of categorical values for a given field.
@@ -80,6 +82,11 @@ public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyPars
         return hotMap;
     }
 
+    @Override
+    public Map<String, String> reverseLookup() {
+        return hotMap.entrySet().stream().collect(Collectors.toMap(HashMap.Entry::getValue, (entry) -> field));
+    }
+
     @Override
     public String getName() {
         return NAME.getPreferredName();

+ 5 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/PreProcessor.java

@@ -24,4 +24,9 @@ public interface PreProcessor extends NamedXContentObject, NamedWriteable, Accou
      * @param fields The fields and their values to process
      */
     void process(Map<String, Object> fields);
+
+    /**
+     * @return Reverse lookup map to match resulting features to their original feature name
+     */
+    Map<String, String> reverseLookup();
 }

+ 5 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncoding.java

@@ -108,6 +108,11 @@ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, Strictly
         return featureName;
     }
 
+    @Override
+    public Map<String, String> reverseLookup() {
+        return Collections.singletonMap(featureName, field);
+    }
+
     @Override
     public String getName() {
         return NAME.getPreferredName();

+ 29 - 9
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java

@@ -35,9 +35,25 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
                                           String classificationLabel,
                                           List<TopClassEntry> topClasses,
                                           InferenceConfig config) {
-        super(value);
-        assert config instanceof ClassificationConfig;
-        ClassificationConfig classificationConfig = (ClassificationConfig)config;
+        this(value, classificationLabel, topClasses, Collections.emptyMap(), (ClassificationConfig)config);
+    }
+
+    public ClassificationInferenceResults(double value,
+                                          String classificationLabel,
+                                          List<TopClassEntry> topClasses,
+                                          Map<String, Double> featureImportance,
+                                          InferenceConfig config) {
+        this(value, classificationLabel, topClasses, featureImportance, (ClassificationConfig)config);
+    }
+
+    private ClassificationInferenceResults(double value,
+                                           String classificationLabel,
+                                           List<TopClassEntry> topClasses,
+                                           Map<String, Double> featureImportance,
+                                           ClassificationConfig classificationConfig) {
+        super(value,
+            SingleValueInferenceResults.takeTopFeatureImportances(featureImportance,
+                classificationConfig.getNumTopFeatureImportanceValues()));
         this.classificationLabel = classificationLabel;
         this.topClasses = topClasses == null ? Collections.emptyList() : Collections.unmodifiableList(topClasses);
         this.topNumClassesField = classificationConfig.getTopClassesResultsField();
@@ -74,16 +90,17 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
         if (object == this) { return true; }
         if (object == null || getClass() != object.getClass()) { return false; }
         ClassificationInferenceResults that = (ClassificationInferenceResults) object;
-        return Objects.equals(value(), that.value()) &&
-            Objects.equals(classificationLabel, that.classificationLabel) &&
-            Objects.equals(resultsField, that.resultsField) &&
-            Objects.equals(topNumClassesField, that.topNumClassesField) &&
-            Objects.equals(topClasses, that.topClasses);
+        return Objects.equals(value(), that.value())
+            && Objects.equals(classificationLabel, that.classificationLabel)
+            && Objects.equals(resultsField, that.resultsField)
+            && Objects.equals(topNumClassesField, that.topNumClassesField)
+            && Objects.equals(topClasses, that.topClasses)
+            && Objects.equals(getFeatureImportance(), that.getFeatureImportance());
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(value(), classificationLabel, topClasses, resultsField, topNumClassesField);
+        return Objects.hash(value(), classificationLabel, topClasses, resultsField, topNumClassesField, getFeatureImportance());
     }
 
     @Override
@@ -100,6 +117,9 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
             document.setFieldValue(parentResultField + "." + topNumClassesField,
                 topClasses.stream().map(TopClassEntry::asValueMap).collect(Collectors.toList()));
         }
+        if (getFeatureImportance().size() > 0) {
+            document.setFieldValue(parentResultField + ".feature_importance", getFeatureImportance());
+        }
     }
 
     @Override

+ 7 - 5
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResults.java

@@ -10,18 +10,19 @@ import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.ingest.IngestDocument;
 
 import java.io.IOException;
+import java.util.Map;
 import java.util.Objects;
 
 public class RawInferenceResults extends SingleValueInferenceResults {
 
     public static final String NAME = "raw";
 
-    public RawInferenceResults(double value) {
-        super(value);
+    public RawInferenceResults(double value, Map<String, Double> featureImportance) {
+        super(value, featureImportance);
     }
 
     public RawInferenceResults(StreamInput in) throws IOException {
-        super(in.readDouble());
+        super(in);
     }
 
     @Override
@@ -34,12 +35,13 @@ public class RawInferenceResults extends SingleValueInferenceResults {
         if (object == this) { return true; }
         if (object == null || getClass() != object.getClass()) { return false; }
         RawInferenceResults that = (RawInferenceResults) object;
-        return Objects.equals(value(), that.value());
+        return Objects.equals(value(), that.value())
+            && Objects.equals(getFeatureImportance(), that.getFeatureImportance());
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(value());
+        return Objects.hash(value(), getFeatureImportance());
     }
 
     @Override

+ 21 - 6
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResults.java

@@ -13,6 +13,8 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 
 import java.io.IOException;
+import java.util.Collections;
+import java.util.Map;
 import java.util.Objects;
 
 public class RegressionInferenceResults extends SingleValueInferenceResults {
@@ -22,14 +24,22 @@ public class RegressionInferenceResults extends SingleValueInferenceResults {
     private final String resultsField;
 
     public RegressionInferenceResults(double value, InferenceConfig config) {
-        super(value);
-        assert config instanceof RegressionConfig;
-        RegressionConfig regressionConfig = (RegressionConfig)config;
+        this(value, (RegressionConfig) config, Collections.emptyMap());
+    }
+
+    public RegressionInferenceResults(double value, InferenceConfig config, Map<String, Double> featureImportance) {
+        this(value, (RegressionConfig)config, featureImportance);
+    }
+
+    private RegressionInferenceResults(double value, RegressionConfig regressionConfig, Map<String, Double> featureImportance) {
+        super(value,
+            SingleValueInferenceResults.takeTopFeatureImportances(featureImportance,
+                regressionConfig.getNumTopFeatureImportanceValues()));
         this.resultsField = regressionConfig.getResultsField();
     }
 
     public RegressionInferenceResults(StreamInput in) throws IOException {
-        super(in.readDouble());
+        super(in);
         this.resultsField = in.readString();
     }
 
@@ -44,12 +54,14 @@ public class RegressionInferenceResults extends SingleValueInferenceResults {
         if (object == this) { return true; }
         if (object == null || getClass() != object.getClass()) { return false; }
         RegressionInferenceResults that = (RegressionInferenceResults) object;
-        return Objects.equals(value(), that.value()) && Objects.equals(this.resultsField, that.resultsField);
+        return Objects.equals(value(), that.value())
+            && Objects.equals(this.resultsField, that.resultsField)
+            && Objects.equals(this.getFeatureImportance(), that.getFeatureImportance());
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(value(), resultsField);
+        return Objects.hash(value(), resultsField, getFeatureImportance());
     }
 
     @Override
@@ -57,6 +69,9 @@ public class RegressionInferenceResults extends SingleValueInferenceResults {
         ExceptionsHelper.requireNonNull(document, "document");
         ExceptionsHelper.requireNonNull(parentResultField, "parentResultField");
         document.setFieldValue(parentResultField + "." + this.resultsField, value());
+        if (getFeatureImportance().size() > 0) {
+            document.setFieldValue(parentResultField + ".feature_importance", getFeatureImportance());
+        }
     }
 
     @Override

+ 28 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/SingleValueInferenceResults.java

@@ -5,27 +5,51 @@
  */
 package org.elasticsearch.xpack.core.ml.inference.results;
 
+import org.elasticsearch.Version;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 
 import java.io.IOException;
+import java.util.Collections;
+import java.util.LinkedHashMap;
+import java.util.Map;
 
 public abstract class SingleValueInferenceResults implements InferenceResults {
 
     private final double value;
+    private final Map<String, Double> featureImportance;
+
+    static Map<String, Double> takeTopFeatureImportances(Map<String, Double> unsortedFeatureImportances, int numTopFeatures) {
+        return unsortedFeatureImportances.entrySet()
+            .stream()
+            .sorted((l, r)-> Double.compare(Math.abs(r.getValue()), Math.abs(l.getValue())))
+            .limit(numTopFeatures)
+            .collect(LinkedHashMap::new, (h, e) -> h.put(e.getKey(), e.getValue()) , LinkedHashMap::putAll);
+    }
 
     SingleValueInferenceResults(StreamInput in) throws IOException {
         value = in.readDouble();
+        if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
+            this.featureImportance = in.readMap(StreamInput::readString, StreamInput::readDouble);
+        } else {
+            this.featureImportance = Collections.emptyMap();
+        }
     }
 
-    SingleValueInferenceResults(double value) {
+    SingleValueInferenceResults(double value, Map<String, Double> featureImportance) {
         this.value = value;
+        this.featureImportance = ExceptionsHelper.requireNonNull(featureImportance, "featureImportance");
     }
 
     public Double value() {
         return value;
     }
 
+    public Map<String, Double> getFeatureImportance() {
+        return featureImportance;
+    }
+
     public String valueAsString() {
         return String.valueOf(value);
     }
@@ -33,6 +57,9 @@ public abstract class SingleValueInferenceResults implements InferenceResults {
     @Override
     public void writeTo(StreamOutput out) throws IOException {
         out.writeDouble(value);
+        if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
+            out.writeMap(this.featureImportance, StreamOutput::writeString, StreamOutput::writeDouble);
+        }
     }
 
 }

+ 45 - 9
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfig.java

@@ -31,33 +31,39 @@ public class ClassificationConfig implements InferenceConfig {
     public static final ParseField RESULTS_FIELD = new ParseField("results_field");
     public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
     public static final ParseField TOP_CLASSES_RESULTS_FIELD = new ParseField("top_classes_results_field");
+    public static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values");
     private static final Version MIN_SUPPORTED_VERSION = Version.V_7_6_0;
 
-    public static ClassificationConfig EMPTY_PARAMS = new ClassificationConfig(0, DEFAULT_RESULTS_FIELD, DEFAULT_TOP_CLASSES_RESULTS_FIELD);
+    public static ClassificationConfig EMPTY_PARAMS =
+        new ClassificationConfig(0, DEFAULT_RESULTS_FIELD, DEFAULT_TOP_CLASSES_RESULTS_FIELD, null);
 
     private final int numTopClasses;
     private final String topClassesResultsField;
     private final String resultsField;
+    private final int numTopFeatureImportanceValues;
 
     public static ClassificationConfig fromMap(Map<String, Object> map) {
         Map<String, Object> options = new HashMap<>(map);
         Integer numTopClasses = (Integer)options.remove(NUM_TOP_CLASSES.getPreferredName());
         String topClassesResultsField = (String)options.remove(TOP_CLASSES_RESULTS_FIELD.getPreferredName());
         String resultsField = (String)options.remove(RESULTS_FIELD.getPreferredName());
+        Integer featureImportance = (Integer)options.remove(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName());
+
         if (options.isEmpty() == false) {
             throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", options.keySet());
         }
-        return new ClassificationConfig(numTopClasses, resultsField, topClassesResultsField);
+        return new ClassificationConfig(numTopClasses, resultsField, topClassesResultsField, featureImportance);
     }
 
     private static final ConstructingObjectParser<ClassificationConfig, Void> PARSER =
             new ConstructingObjectParser<>(NAME.getPreferredName(), args -> new ClassificationConfig(
-                    (Integer) args[0], (String) args[1], (String) args[2]));
+                    (Integer) args[0], (String) args[1], (String) args[2], (Integer) args[3]));
 
     static {
         PARSER.declareInt(optionalConstructorArg(), NUM_TOP_CLASSES);
         PARSER.declareString(optionalConstructorArg(), RESULTS_FIELD);
         PARSER.declareString(optionalConstructorArg(), TOP_CLASSES_RESULTS_FIELD);
+        PARSER.declareInt(optionalConstructorArg(), NUM_TOP_FEATURE_IMPORTANCE_VALUES);
     }
 
     public static ClassificationConfig fromXContent(XContentParser parser) {
@@ -65,19 +71,33 @@ public class ClassificationConfig implements InferenceConfig {
     }
 
     public ClassificationConfig(Integer numTopClasses) {
-        this(numTopClasses, null, null);
+        this(numTopClasses, null, null, null);
     }
 
     public ClassificationConfig(Integer numTopClasses, String resultsField, String topClassesResultsField) {
+        this(numTopClasses, resultsField, topClassesResultsField, 0);
+    }
+
+    public ClassificationConfig(Integer numTopClasses, String resultsField, String topClassesResultsField, Integer featureImportance) {
         this.numTopClasses = numTopClasses == null ? 0 : numTopClasses;
         this.topClassesResultsField = topClassesResultsField == null ? DEFAULT_TOP_CLASSES_RESULTS_FIELD : topClassesResultsField;
         this.resultsField = resultsField == null ? DEFAULT_RESULTS_FIELD : resultsField;
+        if (featureImportance != null && featureImportance < 0) {
+            throw new IllegalArgumentException("[" + NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName() +
+                "] must be greater than or equal to 0");
+        }
+        this.numTopFeatureImportanceValues = featureImportance == null ? 0 : featureImportance;
     }
 
     public ClassificationConfig(StreamInput in) throws IOException {
         this.numTopClasses = in.readInt();
         this.topClassesResultsField = in.readString();
         this.resultsField = in.readString();
+        if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
+            this.numTopFeatureImportanceValues = in.readVInt();
+        } else {
+            this.numTopFeatureImportanceValues = 0;
+        }
     }
 
     public int getNumTopClasses() {
@@ -92,11 +112,23 @@ public class ClassificationConfig implements InferenceConfig {
         return resultsField;
     }
 
+    public int getNumTopFeatureImportanceValues() {
+        return numTopFeatureImportanceValues;
+    }
+
+    @Override
+    public boolean requestingImportance() {
+        return numTopFeatureImportanceValues > 0;
+    }
+
     @Override
     public void writeTo(StreamOutput out) throws IOException {
         out.writeInt(numTopClasses);
         out.writeString(topClassesResultsField);
         out.writeString(resultsField);
+        if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
+            out.writeVInt(numTopFeatureImportanceValues);
+        }
     }
 
     @Override
@@ -104,14 +136,15 @@ public class ClassificationConfig implements InferenceConfig {
         if (this == o) return true;
         if (o == null || getClass() != o.getClass()) return false;
         ClassificationConfig that = (ClassificationConfig) o;
-        return Objects.equals(numTopClasses, that.numTopClasses) &&
-            Objects.equals(topClassesResultsField, that.topClassesResultsField) &&
-            Objects.equals(resultsField, that.resultsField);
+        return Objects.equals(numTopClasses, that.numTopClasses)
+            && Objects.equals(topClassesResultsField, that.topClassesResultsField)
+            && Objects.equals(resultsField, that.resultsField)
+            && Objects.equals(numTopFeatureImportanceValues, that.numTopFeatureImportanceValues);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(numTopClasses, topClassesResultsField, resultsField);
+        return Objects.hash(numTopClasses, topClassesResultsField, resultsField, numTopFeatureImportanceValues);
     }
 
     @Override
@@ -122,6 +155,9 @@ public class ClassificationConfig implements InferenceConfig {
         }
         builder.field(TOP_CLASSES_RESULTS_FIELD.getPreferredName(), topClassesResultsField);
         builder.field(RESULTS_FIELD.getPreferredName(), resultsField);
+        if (numTopFeatureImportanceValues > 0) {
+            builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues);
+        }
         builder.endObject();
         return builder;
     }
@@ -143,7 +179,7 @@ public class ClassificationConfig implements InferenceConfig {
 
     @Override
     public Version getMinimalSupportedVersion() {
-        return MIN_SUPPORTED_VERSION;
+        return numTopFeatureImportanceValues > 0 ? Version.V_7_7_0 : MIN_SUPPORTED_VERSION;
     }
 
 }

+ 4 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfig.java

@@ -18,4 +18,8 @@ public interface InferenceConfig extends NamedXContentObject, NamedWriteable {
      * All nodes in the cluster must be at least this version
      */
     Version getMinimalSupportedVersion();
+
+    default boolean requestingImportance() {
+        return false;
+    }
 }

+ 17 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java

@@ -13,7 +13,9 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.Comparator;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
 import java.util.stream.Collectors;
 import java.util.stream.IntStream;
 
@@ -98,4 +100,19 @@ public final class InferenceHelpers {
         }
         return null;
     }
+
+    public static Map<String, Double> decodeFeatureImportances(Map<String, String> processedFeatureToOriginalFeatureMap,
+                                                               Map<String, Double> featureImportances) {
+        if (processedFeatureToOriginalFeatureMap == null || processedFeatureToOriginalFeatureMap.isEmpty()) {
+            return featureImportances;
+        }
+
+        Map<String, Double> originalFeatureImportance = new HashMap<>();
+        featureImportances.forEach((feature, importance) -> {
+            String featureName = processedFeatureToOriginalFeatureMap.getOrDefault(feature, feature);
+            originalFeatureImportance.compute(featureName, (f, v1) -> v1 == null ? importance : v1 + importance);
+        });
+
+        return originalFeatureImportance;
+    }
 }

+ 12 - 3
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NullInferenceConfig.java

@@ -16,9 +16,12 @@ import java.io.IOException;
  */
 public class NullInferenceConfig implements InferenceConfig {
 
-    public static final NullInferenceConfig INSTANCE = new NullInferenceConfig();
+    private final boolean requestingFeatureImportance;
 
-    private NullInferenceConfig() { }
+
+    public NullInferenceConfig(boolean requestingFeatureImportance) {
+        this.requestingFeatureImportance = requestingFeatureImportance;
+    }
 
     @Override
     public boolean isTargetTypeSupported(TargetType targetType) {
@@ -37,6 +40,7 @@ public class NullInferenceConfig implements InferenceConfig {
 
     @Override
     public void writeTo(StreamOutput out) throws IOException {
+        throw new UnsupportedOperationException("Unable to serialize NullInferenceConfig objects");
     }
 
     @Override
@@ -46,6 +50,11 @@ public class NullInferenceConfig implements InferenceConfig {
 
     @Override
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
-        return builder;
+        throw new UnsupportedOperationException("Unable to write xcontent from NullInferenceConfig objects");
+    }
+
+    @Override
+    public boolean requestingImportance() {
+        return requestingFeatureImportance;
     }
 }

+ 40 - 6
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfig.java

@@ -26,24 +26,27 @@ public class RegressionConfig implements InferenceConfig {
     public static final ParseField NAME = new ParseField("regression");
     private static final Version MIN_SUPPORTED_VERSION = Version.V_7_6_0;
     public static final ParseField RESULTS_FIELD = new ParseField("results_field");
+    public static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values");
     private static final String DEFAULT_RESULTS_FIELD = "predicted_value";
 
-    public static RegressionConfig EMPTY_PARAMS = new RegressionConfig(DEFAULT_RESULTS_FIELD);
+    public static RegressionConfig EMPTY_PARAMS = new RegressionConfig(DEFAULT_RESULTS_FIELD, null);
 
     public static RegressionConfig fromMap(Map<String, Object> map) {
         Map<String, Object> options = new HashMap<>(map);
         String resultsField = (String)options.remove(RESULTS_FIELD.getPreferredName());
+        Integer featureImportance = (Integer)options.remove(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName());
         if (options.isEmpty() == false) {
             throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", map.keySet());
         }
-        return new RegressionConfig(resultsField);
+        return new RegressionConfig(resultsField, featureImportance);
     }
 
     private static final ConstructingObjectParser<RegressionConfig, Void> PARSER =
-            new ConstructingObjectParser<>(NAME.getPreferredName(), args -> new RegressionConfig((String) args[0]));
+            new ConstructingObjectParser<>(NAME.getPreferredName(), args -> new RegressionConfig((String) args[0], (Integer)args[1]));
 
     static {
         PARSER.declareString(optionalConstructorArg(), RESULTS_FIELD);
+        PARSER.declareInt(optionalConstructorArg(), NUM_TOP_FEATURE_IMPORTANCE_VALUES);
     }
 
     public static RegressionConfig fromXContent(XContentParser parser) {
@@ -51,19 +54,43 @@ public class RegressionConfig implements InferenceConfig {
     }
 
     private final String resultsField;
+    private final int numTopFeatureImportanceValues;
 
     public RegressionConfig(String resultsField) {
+        this(resultsField, 0);
+    }
+
+    public RegressionConfig(String resultsField, Integer numTopFeatureImportanceValues) {
         this.resultsField = resultsField == null ? DEFAULT_RESULTS_FIELD : resultsField;
+        if (numTopFeatureImportanceValues != null && numTopFeatureImportanceValues < 0) {
+            throw new IllegalArgumentException("[" + NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName() +
+                "] must be greater than or equal to 0");
+        }
+        this.numTopFeatureImportanceValues = numTopFeatureImportanceValues == null ? 0 : numTopFeatureImportanceValues;
     }
 
     public RegressionConfig(StreamInput in) throws IOException {
         this.resultsField = in.readString();
+        if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
+            this.numTopFeatureImportanceValues = in.readVInt();
+        } else {
+            this.numTopFeatureImportanceValues = 0;
+        }
+    }
+
+    public int getNumTopFeatureImportanceValues() {
+        return numTopFeatureImportanceValues;
     }
 
     public String getResultsField() {
         return resultsField;
     }
 
+    @Override
+    public boolean requestingImportance() {
+        return numTopFeatureImportanceValues > 0;
+    }
+
     @Override
     public String getWriteableName() {
         return NAME.getPreferredName();
@@ -72,6 +99,9 @@ public class RegressionConfig implements InferenceConfig {
     @Override
     public void writeTo(StreamOutput out) throws IOException {
         out.writeString(resultsField);
+        if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
+            out.writeVInt(numTopFeatureImportanceValues);
+        }
     }
 
     @Override
@@ -83,6 +113,9 @@ public class RegressionConfig implements InferenceConfig {
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
         builder.startObject();
         builder.field(RESULTS_FIELD.getPreferredName(), resultsField);
+        if (numTopFeatureImportanceValues > 0) {
+            builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues);
+        }
         builder.endObject();
         return builder;
     }
@@ -92,12 +125,13 @@ public class RegressionConfig implements InferenceConfig {
         if (this == o) return true;
         if (o == null || getClass() != o.getClass()) return false;
         RegressionConfig that = (RegressionConfig)o;
-        return Objects.equals(this.resultsField, that.resultsField);
+        return Objects.equals(this.resultsField, that.resultsField)
+            && Objects.equals(this.numTopFeatureImportanceValues, that.numTopFeatureImportanceValues);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(resultsField);
+        return Objects.hash(resultsField, numTopFeatureImportanceValues);
     }
 
     @Override
@@ -107,7 +141,7 @@ public class RegressionConfig implements InferenceConfig {
 
     @Override
     public Version getMinimalSupportedVersion() {
-        return MIN_SUPPORTED_VERSION;
+        return numTopFeatureImportanceValues > 0 ? Version.V_7_7_0 : MIN_SUPPORTED_VERSION;
     }
 
 }

+ 162 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ShapPath.java

@@ -0,0 +1,162 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
+
+
+/**
+ * Ported from https://github.com/elastic/ml-cpp/blob/master/include/maths/CTreeShapFeatureImportance.h Path struct
+ */
+public class ShapPath  {
+    private static final double DBL_EPSILON = Double.MIN_VALUE;
+
+    private final PathElement[] pathElements;
+    private final double[] scale;
+    private final int elementAndScaleOffset;
+
+    public ShapPath(ShapPath parentPath, int nextIndex) {
+        this.elementAndScaleOffset = parentPath.elementAndScaleOffset + nextIndex;
+        this.pathElements = parentPath.pathElements;
+        this.scale = parentPath.scale;
+        for (int i = 0; i < nextIndex; i++) {
+            pathElements[elementAndScaleOffset + i].featureIndex = parentPath.getElement(i).featureIndex;
+            pathElements[elementAndScaleOffset + i].fractionZeros = parentPath.getElement(i).fractionZeros;
+            pathElements[elementAndScaleOffset + i].fractionOnes = parentPath.getElement(i).fractionOnes;
+            scale[elementAndScaleOffset + i] = parentPath.getScale(i);
+        }
+    }
+
+    public ShapPath(PathElement[] elements, double[] scale) {
+        this.pathElements = elements;
+        this.scale = scale;
+        this.elementAndScaleOffset = 0;
+    }
+
+    // Update binomial coefficients to be able to compute Equation (2) from the paper.  In particular,
+    // we have in the line path.scale[i + 1] += fractionOne * path.scale[i] * (i + 1.0) / (pathDepth +
+    // 1.0) that if we're on the "one" path, i.e. if the last feature selects this path if we include that
+    // feature in S (then fractionOne is 1), and we need to consider all the additional ways we now have of
+    // constructing each S of each given cardinality i + 1. Each of these come by adding the last feature
+    // to sets of size i and we **also** need to scale by the difference in binomial coefficients as both M
+    // increases by one and i increases by one. So we get additive term 1{last feature selects path if in S}
+    // * scale(i) * (i+1)! (M+1-(i+1)-1)!/(M+1)! / (i! (M-i-1)!/ M!), whence += scale(i) * (i+1) / (M+1).
+    public int extend(double fractionZero, double fractionOne, int featureIndex, int nextIndex) {
+        setValues(nextIndex, fractionOne, fractionZero, featureIndex);
+        setScale(nextIndex, nextIndex == 0 ? 1.0 : 0.0);
+        double stepDown = fractionOne / (double)(nextIndex + 1);
+        double stepUp = fractionZero / (double)(nextIndex + 1);
+        double countDown = nextIndex * stepDown;
+        double countUp = stepUp;
+        for (int i = (nextIndex - 1); i >= 0; --i, countDown -= stepDown, countUp += stepUp) {
+            setScale(i + 1, getScale(i + 1) + getScale(i) * countDown);
+            setScale(i, getScale(i) * countUp);
+        }
+        return nextIndex + 1;
+    }
+
+    public double sumUnwoundPath(int pathIndex, int nextIndex) {
+        double total = 0.0;
+        int pathDepth = nextIndex - 1;
+        double nextFractionOne = getScale(pathDepth);
+        double fractionOne = fractionOnes(pathIndex);
+        double fractionZero = fractionZeros(pathIndex);
+        if (fractionOne != 0) {
+            double pD = pathDepth + 1;
+            double stepUp = fractionZero / pD;
+            double stepDown = fractionOne / pD;
+            double countUp = stepUp;
+            double countDown = (pD - 1.0) * stepDown;
+            for (int i = pathDepth - 1; i >= 0; --i, countUp += stepUp, countDown -= stepDown) {
+                double tmp = nextFractionOne / countDown;
+                nextFractionOne = getScale(i) - tmp * countUp;
+                total += tmp;
+            }
+        } else {
+            double pD = pathDepth;
+
+            for(int i = 0; i < pathDepth; i++) {
+                total += getScale(i) / pD--;
+            }
+            total *= (pathDepth + 1) / (fractionZero + DBL_EPSILON);
+        }
+
+        return total;
+    }
+
+    public int unwind(int pathIndex, int nextIndex) {
+        int pathDepth = nextIndex - 1;
+        double nextFractionOne = getScale(pathDepth);
+        double fractionOne = fractionOnes(pathIndex);
+        double fractionZero = fractionZeros(pathIndex);
+
+        if (fractionOne != 0) {
+            double stepUp = fractionZero / (double)(pathDepth + 1);
+            double stepDown = fractionOne / (double)nextIndex;
+            double countUp = 0.0;
+            double countDown = nextIndex * stepDown;
+            for (int i = pathDepth; i >= 0; --i, countUp += stepUp, countDown -= stepDown) {
+                double tmp = nextFractionOne / countDown;
+                nextFractionOne = getScale(i) - tmp * countUp;
+                setScale(i, tmp);
+            }
+        } else {
+            double stepDown = (fractionZero + DBL_EPSILON) / (double)(pathDepth + 1);
+            double countDown = pathDepth * stepDown;
+            for (int i = 0; i <= pathDepth; ++i, countDown -= stepDown) {
+                setScale(i, getScale(i) / countDown);
+            }
+        }
+        for (int i = pathIndex; i < pathDepth; ++i) {
+            PathElement element = getElement(i + 1);
+            setValues(i, element.fractionOnes, element.fractionZeros, element.featureIndex);
+        }
+        return nextIndex - 1;
+    }
+
+    private void setValues(int index, double fractionOnes, double fractionZeros, int featureIndex) {
+        pathElements[index + elementAndScaleOffset].fractionOnes = fractionOnes;
+        pathElements[index + elementAndScaleOffset].fractionZeros = fractionZeros;
+        pathElements[index + elementAndScaleOffset].featureIndex = featureIndex;
+    }
+
+    private double getScale(int offset) {
+        return scale[offset + elementAndScaleOffset];
+    }
+
+    private void setScale(int offset, double value) {
+        scale[offset + elementAndScaleOffset] = value;
+    }
+
+    public double fractionOnes(int pathIndex) {
+        return pathElements[pathIndex + elementAndScaleOffset].fractionOnes;
+    }
+
+    public double fractionZeros(int pathIndex) {
+        return pathElements[pathIndex + elementAndScaleOffset].fractionZeros;
+    }
+
+    public int findFeatureIndex(int splitFeature, int nextIndex) {
+        for (int i = elementAndScaleOffset; i < elementAndScaleOffset + nextIndex; i++) {
+            if (pathElements[i].featureIndex == splitFeature) {
+                return i - elementAndScaleOffset;
+            }
+        }
+        return -1;
+    }
+
+    public int featureIndex(int pathIndex) {
+        return pathElements[pathIndex + elementAndScaleOffset].featureIndex;
+    }
+
+    private PathElement getElement(int offset) {
+        return pathElements[offset + elementAndScaleOffset];
+    }
+
+    public static final class PathElement {
+        private double fractionOnes = 1.0;
+        private double fractionZeros = 1.0;
+        private int featureIndex = -1;
+    }
+}

+ 21 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java

@@ -6,6 +6,7 @@
 package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
 
 import org.apache.lucene.util.Accountable;
+import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.io.stream.NamedWriteable;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
@@ -17,12 +18,16 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable, Accou
     /**
      * Infer against the provided fields
      *
+     * NOTE: Must be thread safe
+     *
      * @param fields The fields and their values to infer against
      * @param config The configuration options for inference
+     * @param featureDecoderMap A map for decoding feature value names to their originating feature.
+     *                          Necessary for feature influence.
      * @return The predicted value. For classification this will be discrete values (e.g. 0.0, or 1.0).
      *                              For regression this is continuous.
      */
-    InferenceResults infer(Map<String, Object> fields, InferenceConfig config);
+    InferenceResults infer(Map<String, Object> fields, InferenceConfig config, @Nullable Map<String, String> featureDecoderMap);
 
     /**
      * @return {@link TargetType} for the model.
@@ -42,4 +47,19 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable, Accou
      * @return The estimated number of operations required at inference time
      */
     long estimatedNumOperations();
+
+    /**
+     * @return Does the model support feature importance
+     */
+    boolean supportsFeatureImportance();
+
+    /**
+     * Calculates the importance of each feature reference by the model for the passed in field values
+     *
+     * NOTE: Must be thread safe
+     * @param fields The fields inferring against
+     * @param featureDecoder A Map translating processed feature names to their original feature names
+     * @return A {@code Map<String, Double>} mapping each featureName to its importance
+     */
+    Map<String, Double> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder);
 }

+ 46 - 10
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java

@@ -37,6 +37,7 @@ import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
@@ -133,18 +134,25 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
     }
 
     @Override
-    public InferenceResults infer(Map<String, Object> fields, InferenceConfig config) {
+    public InferenceResults infer(Map<String, Object> fields, InferenceConfig config, Map<String, String> featureDecoderMap) {
         if (config.isTargetTypeSupported(targetType) == false) {
             throw ExceptionsHelper.badRequestException(
                 "Cannot infer using configuration for [{}] when model target_type is [{}]", config.getName(), targetType.toString());
         }
-        List<Double> inferenceResults = this.models.stream().map(model -> {
-            InferenceResults results = model.infer(fields, NullInferenceConfig.INSTANCE);
-            assert results instanceof SingleValueInferenceResults;
-            return ((SingleValueInferenceResults)results).value();
-        }).collect(Collectors.toList());
+        List<Double> inferenceResults = new ArrayList<>(this.models.size());
+        List<Map<String, Double>> featureInfluence = new ArrayList<>();
+        NullInferenceConfig subModelInferenceConfig = new NullInferenceConfig(config.requestingImportance());
+        this.models.forEach(model -> {
+            InferenceResults result = model.infer(fields, subModelInferenceConfig, Collections.emptyMap());
+            assert result instanceof SingleValueInferenceResults;
+            SingleValueInferenceResults inferenceResult = (SingleValueInferenceResults) result;
+            inferenceResults.add(inferenceResult.value());
+            if (config.requestingImportance()) {
+                featureInfluence.add(inferenceResult.getFeatureImportance());
+            }
+        });
         List<Double> processed = outputAggregator.processValues(inferenceResults);
-        return buildResults(processed, config);
+        return buildResults(processed, featureInfluence, config, featureDecoderMap);
     }
 
     @Override
@@ -152,14 +160,20 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
         return targetType;
     }
 
-    private InferenceResults buildResults(List<Double> processedInferences, InferenceConfig config) {
+    private InferenceResults buildResults(List<Double> processedInferences,
+                                          List<Map<String, Double>> featureInfluence,
+                                          InferenceConfig config,
+                                          Map<String, String> featureDecoderMap) {
         // Indicates that the config is useless and the caller just wants the raw value
         if (config instanceof NullInferenceConfig) {
-            return new RawInferenceResults(outputAggregator.aggregate(processedInferences));
+            return new RawInferenceResults(outputAggregator.aggregate(processedInferences),
+                InferenceHelpers.decodeFeatureImportances(featureDecoderMap, mergeFeatureImportances(featureInfluence)));
         }
         switch(targetType) {
             case REGRESSION:
-                return new RegressionInferenceResults(outputAggregator.aggregate(processedInferences), config);
+                return new RegressionInferenceResults(outputAggregator.aggregate(processedInferences),
+                    config,
+                    InferenceHelpers.decodeFeatureImportances(featureDecoderMap, mergeFeatureImportances(featureInfluence)));
             case CLASSIFICATION:
                 ClassificationConfig classificationConfig = (ClassificationConfig) config;
                 assert classificationWeights == null || processedInferences.size() == classificationWeights.length;
@@ -172,6 +186,7 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
                 return new ClassificationInferenceResults((double)topClasses.v1(),
                     classificationLabel(topClasses.v1(), classificationLabels),
                     topClasses.v2(),
+                    InferenceHelpers.decodeFeatureImportances(featureDecoderMap, mergeFeatureImportances(featureInfluence)),
                     config);
             default:
                 throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on ensemble model");
@@ -293,6 +308,27 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
         return (long)Math.ceil(avg.getAsDouble()) + 2 * (models.size() - 1);
     }
 
+    @Override
+    public boolean supportsFeatureImportance() {
+        return models.stream().allMatch(TrainedModel::supportsFeatureImportance);
+    }
+
+    Map<String, Double> featureImportance(Map<String, Object> fields) {
+        return featureImportance(fields, Collections.emptyMap());
+    }
+
+    @Override
+    public Map<String, Double> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder) {
+        Map<String, Double> collapsed = mergeFeatureImportances(models.stream()
+            .map(trainedModel -> trainedModel.featureImportance(fields, Collections.emptyMap()))
+            .collect(Collectors.toList()));
+        return InferenceHelpers.decodeFeatureImportances(featureDecoder, collapsed);
+    }
+
+    private static Map<String, Double> mergeFeatureImportances(List<Map<String, Double>> featureImportances) {
+        return featureImportances.stream().collect(HashMap::new, (a, b) -> b.forEach((k, v) -> a.merge(k, v, Double::sum)), Map::putAll);
+    }
+
     public static Builder builder() {
         return new Builder();
     }

+ 17 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/langident/LangIdentNeuralNetwork.java

@@ -26,6 +26,7 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 
 import java.io.IOException;
 import java.util.Arrays;
+import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
@@ -104,7 +105,11 @@ public class LangIdentNeuralNetwork implements StrictlyParsedTrainedModel, Lenie
     }
 
     @Override
-    public InferenceResults infer(Map<String, Object> fields, InferenceConfig config) {
+    public InferenceResults infer(Map<String, Object> fields, InferenceConfig config, Map<String, String> featureDecoderMap) {
+        if (config.requestingImportance()) {
+            throw ExceptionsHelper.badRequestException("[{}] model does not supports feature importance",
+                NAME.getPreferredName());
+        }
         if (config instanceof ClassificationConfig == false) {
             throw ExceptionsHelper.badRequestException("[{}] model only supports classification",
                 NAME.getPreferredName());
@@ -138,6 +143,7 @@ public class LangIdentNeuralNetwork implements StrictlyParsedTrainedModel, Lenie
         return new ClassificationInferenceResults(topClasses.v1(),
             LANGUAGE_NAMES.get(topClasses.v1()),
             topClasses.v2(),
+            Collections.emptyMap(),
             classificationConfig);
     }
 
@@ -159,6 +165,16 @@ public class LangIdentNeuralNetwork implements StrictlyParsedTrainedModel, Lenie
         return numOps;
     }
 
+    @Override
+    public boolean supportsFeatureImportance() {
+        return false;
+    }
+
+    @Override
+    public Map<String, Double> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder) {
+        throw new UnsupportedOperationException("[lang_ident] does not support feature importance");
+    }
+
     @Override
     public long ramBytesUsed() {
         long size = SHALLOW_SIZE;

+ 142 - 9
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java

@@ -26,6 +26,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NullInferenceConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ShapPath;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
@@ -44,6 +45,7 @@ import java.util.Objects;
 import java.util.Queue;
 import java.util.Set;
 import java.util.stream.Collectors;
+import java.util.stream.IntStream;
 
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.classificationLabel;
 
@@ -86,6 +88,9 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
     private final TargetType targetType;
     private final List<String> classificationLabels;
     private final CachedSupplier<Double> highestOrderCategory;
+    // populated lazily when feature importance is calculated
+    private double[] nodeEstimates;
+    private Integer maxDepth;
 
     Tree(List<String> featureNames, List<TreeNode> nodes, TargetType targetType, List<String> classificationLabels) {
         this.featureNames = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(featureNames, FEATURE_NAMES));
@@ -120,7 +125,7 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
     }
 
     @Override
-    public InferenceResults infer(Map<String, Object> fields, InferenceConfig config) {
+    public InferenceResults infer(Map<String, Object> fields, InferenceConfig config, Map<String, String> featureDecoderMap) {
         if (config.isTargetTypeSupported(targetType) == false) {
             throw ExceptionsHelper.badRequestException(
                 "Cannot infer using configuration for [{}] when model target_type is [{}]", config.getName(), targetType.toString());
@@ -129,21 +134,23 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
         List<Double> features = featureNames.stream()
             .map(f -> InferenceHelpers.toDouble(MapHelper.dig(f, fields)))
             .collect(Collectors.toList());
-        return infer(features, config);
-    }
 
-    private InferenceResults infer(List<Double> features, InferenceConfig config) {
+        Map<String, Double> featureImportance = config.requestingImportance() ?
+            featureImportance(features, featureDecoderMap) :
+            Collections.emptyMap();
+
         TreeNode node = nodes.get(0);
         while(node.isLeaf() == false) {
             node = nodes.get(node.compare(features));
         }
-        return buildResult(node.getLeafValue(), config);
+
+        return buildResult(node.getLeafValue(), featureImportance, config);
     }
 
-    private InferenceResults buildResult(Double value, InferenceConfig config) {
+    private InferenceResults buildResult(Double value, Map<String, Double> featureImportance, InferenceConfig config) {
         // Indicates that the config is useless and the caller just wants the raw value
         if (config instanceof NullInferenceConfig) {
-            return new RawInferenceResults(value);
+            return new RawInferenceResults(value, featureImportance);
         }
         switch (targetType) {
             case CLASSIFICATION:
@@ -156,9 +163,10 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
                 return new ClassificationInferenceResults(value,
                     classificationLabel(topClasses.v1(), classificationLabels),
                     topClasses.v2(),
+                    featureImportance,
                     config);
             case REGRESSION:
-                return new RegressionInferenceResults(value, config);
+                return new RegressionInferenceResults(value, config, featureImportance);
             default:
                 throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on tree model");
         }
@@ -192,7 +200,6 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
         // If we are classification, we should assume that the largest leaf value is whole.
         assert maxCategory == Math.rint(maxCategory);
         List<Double> list = new ArrayList<>(Collections.nCopies(Double.valueOf(maxCategory + 1).intValue(), 0.0));
-        // TODO, eventually have TreeNodes contain confidence levels
         list.set(Double.valueOf(inferenceValue).intValue(), 1.0);
         return list;
     }
@@ -263,12 +270,138 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
         detectCycle();
     }
 
+    @Override
+    public Map<String, Double> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder) {
+        if (nodes.stream().allMatch(n -> n.getNumberSamples() == 0)) {
+            throw ExceptionsHelper.badRequestException("[tree_structure.number_samples] must be greater than zero for feature importance");
+        }
+        List<Double> features = featureNames.stream()
+            .map(f -> InferenceHelpers.toDouble(MapHelper.dig(f, fields)))
+            .collect(Collectors.toList());
+        return featureImportance(features, featureDecoder);
+    }
+
+    private Map<String, Double> featureImportance(List<Double> fieldValues, Map<String, String> featureDecoder) {
+        calculateNodeEstimatesIfNeeded();
+        double[] featureImportance = new double[fieldValues.size()];
+        int arrSize = ((this.maxDepth + 1) * (this.maxDepth + 2))/2;
+        ShapPath.PathElement[] elements = new ShapPath.PathElement[arrSize];
+        for (int i = 0; i < arrSize; i++) {
+            elements[i] = new ShapPath.PathElement();
+        }
+        double[] scale = new double[arrSize];
+        ShapPath initialPath = new ShapPath(elements, scale);
+        shapRecursive(fieldValues, this.nodeEstimates, initialPath, 0, 1.0, 1.0, -1, featureImportance, 0);
+        return InferenceHelpers.decodeFeatureImportances(featureDecoder,
+            IntStream.range(0, featureImportance.length)
+                .boxed()
+                .collect(Collectors.toMap(featureNames::get, i -> featureImportance[i])));
+    }
+
+    private void calculateNodeEstimatesIfNeeded() {
+        if (this.nodeEstimates != null && this.maxDepth != null) {
+            return;
+        }
+        synchronized (this) {
+            if (this.nodeEstimates != null && this.maxDepth != null) {
+                return;
+            }
+            double[] estimates = new double[nodes.size()];
+            this.maxDepth = fillNodeEstimates(estimates, 0, 0);
+            this.nodeEstimates = estimates;
+        }
+    }
+
+    /**
+     * Note, this is a port from https://github.com/elastic/ml-cpp/blob/master/lib/maths/CTreeShapFeatureImportance.cc
+     *
+     * If improvements in performance or accuracy have been found, it is probably best that the changes are implemented on the native
+     * side first and then ported to the Java side.
+     */
+    private void shapRecursive(List<Double> processedFeatures,
+                               double[] nodeValues,
+                               ShapPath parentSplitPath,
+                               int nodeIndex,
+                               double parentFractionZero,
+                               double parentFractionOne,
+                               int parentFeatureIndex,
+                               double[] featureImportance,
+                               int nextIndex) {
+        ShapPath splitPath = new ShapPath(parentSplitPath, nextIndex);
+        TreeNode currNode = nodes.get(nodeIndex);
+        nextIndex = splitPath.extend(parentFractionZero, parentFractionOne, parentFeatureIndex, nextIndex);
+        if (currNode.isLeaf()) {
+            // TODO multi-value????
+            double leafValue = nodeValues[nodeIndex];
+            for (int i = 1; i < nextIndex; ++i) {
+                double scale = splitPath.sumUnwoundPath(i, nextIndex);
+                int inputColumnIndex = splitPath.featureIndex(i);
+                featureImportance[inputColumnIndex] += scale * (splitPath.fractionOnes(i) - splitPath.fractionZeros(i)) * leafValue;
+            }
+        } else {
+            int hotIndex = currNode.compare(processedFeatures);
+            int coldIndex = hotIndex == currNode.getLeftChild() ? currNode.getRightChild() : currNode.getLeftChild();
+
+            double incomingFractionZero = 1.0;
+            double incomingFractionOne = 1.0;
+            int splitFeature = currNode.getSplitFeature();
+            int pathIndex = splitPath.findFeatureIndex(splitFeature, nextIndex);
+            if (pathIndex > -1) {
+                incomingFractionZero = splitPath.fractionZeros(pathIndex);
+                incomingFractionOne = splitPath.fractionOnes(pathIndex);
+                nextIndex = splitPath.unwind(pathIndex, nextIndex);
+            }
+
+            double hotFractionZero = nodes.get(hotIndex).getNumberSamples() / (double)currNode.getNumberSamples();
+            double coldFractionZero = nodes.get(coldIndex).getNumberSamples() / (double)currNode.getNumberSamples();
+            shapRecursive(processedFeatures, nodeValues, splitPath,
+                hotIndex, incomingFractionZero * hotFractionZero,
+                incomingFractionOne, splitFeature, featureImportance, nextIndex);
+            shapRecursive(processedFeatures, nodeValues, splitPath,
+                coldIndex, incomingFractionZero * coldFractionZero,
+                0.0, splitFeature, featureImportance, nextIndex);
+        }
+    }
+
+    /**
+     * This recursively populates the provided {@code double[]} with the node estimated values
+     *
+     * Used when calculating feature importance.
+     * @param nodeEstimates Array to update in place with the node estimated values
+     * @param nodeIndex Current node index
+     * @param depth Current depth
+     * @return The current max depth
+     */
+    private int fillNodeEstimates(double[] nodeEstimates, int nodeIndex, int depth) {
+        TreeNode node = nodes.get(nodeIndex);
+        if (node.isLeaf()) {
+            nodeEstimates[nodeIndex] = node.getLeafValue();
+            return 0;
+        }
+
+        int depthLeft = fillNodeEstimates(nodeEstimates, node.getLeftChild(), depth + 1);
+        int depthRight = fillNodeEstimates(nodeEstimates, node.getRightChild(), depth + 1);
+        long leftWeight = nodes.get(node.getLeftChild()).getNumberSamples();
+        long rightWeight = nodes.get(node.getRightChild()).getNumberSamples();
+        long divisor = leftWeight + rightWeight;
+        double averageValue = divisor == 0 ?
+            0.0 :
+            (leftWeight * nodeEstimates[node.getLeftChild()] + rightWeight * nodeEstimates[node.getRightChild()]) / divisor;
+        nodeEstimates[nodeIndex] = averageValue;
+        return Math.max(depthLeft, depthRight) + 1;
+    }
+
     @Override
     public long estimatedNumOperations() {
         // Grabbing the features from the doc + the depth of the tree
         return (long)Math.ceil(Math.log(nodes.size())) + featureNames.size();
     }
 
+    @Override
+    public boolean supportsFeatureImportance() {
+        return true;
+    }
+
     /**
      * The highest index of a feature used any of the nodes.
      * If no nodes use a feature return -1. This can only happen

+ 3 - 7
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinitionTests.java

@@ -342,8 +342,7 @@ public class TrainedModelDefinitionTests extends AbstractSerializingTestCase<Tra
         }};
 
         assertThat(
-            ((ClassificationInferenceResults)definition.getTrainedModel()
-                .infer(fields, ClassificationConfig.EMPTY_PARAMS))
+            ((ClassificationInferenceResults)definition.infer(fields, ClassificationConfig.EMPTY_PARAMS))
                 .getClassificationLabel(),
             equalTo("Iris-setosa"));
 
@@ -354,8 +353,7 @@ public class TrainedModelDefinitionTests extends AbstractSerializingTestCase<Tra
             put("petal_width", 1.4);
         }};
         assertThat(
-            ((ClassificationInferenceResults)definition.getTrainedModel()
-                .infer(fields, ClassificationConfig.EMPTY_PARAMS))
+            ((ClassificationInferenceResults)definition.infer(fields, ClassificationConfig.EMPTY_PARAMS))
                 .getClassificationLabel(),
             equalTo("Iris-versicolor"));
 
@@ -366,10 +364,8 @@ public class TrainedModelDefinitionTests extends AbstractSerializingTestCase<Tra
             put("petal_width", 2.0);
         }};
         assertThat(
-            ((ClassificationInferenceResults)definition.getTrainedModel()
-                .infer(fields, ClassificationConfig.EMPTY_PARAMS))
+            ((ClassificationInferenceResults)definition.infer(fields, ClassificationConfig.EMPTY_PARAMS))
                 .getClassificationLabel(),
             equalTo("Iris-virginica"));
     }
-
 }

+ 3 - 1
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResultsTests.java

@@ -8,10 +8,12 @@ package org.elasticsearch.xpack.core.ml.inference.results;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.test.AbstractWireSerializingTestCase;
 
+import java.util.Collections;
+
 public class RawInferenceResultsTests extends AbstractWireSerializingTestCase<RawInferenceResults> {
 
     public static RawInferenceResults createRandomResults() {
-        return new RawInferenceResults(randomDouble());
+        return new RawInferenceResults(randomDouble(), randomBoolean() ? Collections.emptyMap() : Collections.singletonMap("foo", 1.08));
     }
 
     @Override

+ 2 - 1
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigTests.java

@@ -30,11 +30,12 @@ public class ClassificationConfigTests extends AbstractSerializingTestCase<Class
         ClassificationConfig expected = ClassificationConfig.EMPTY_PARAMS;
         assertThat(ClassificationConfig.fromMap(Collections.emptyMap()), equalTo(expected));
 
-        expected = new ClassificationConfig(3, "foo", "bar");
+        expected = new ClassificationConfig(3, "foo", "bar", 2);
         Map<String, Object> configMap = new HashMap<>();
         configMap.put(ClassificationConfig.NUM_TOP_CLASSES.getPreferredName(), 3);
         configMap.put(ClassificationConfig.RESULTS_FIELD.getPreferredName(), "foo");
         configMap.put(ClassificationConfig.TOP_CLASSES_RESULTS_FIELD.getPreferredName(), "bar");
+        configMap.put(ClassificationConfig.NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), 2);
         assertThat(ClassificationConfig.fromMap(configMap), equalTo(expected));
     }
 

+ 2 - 1
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfigTests.java

@@ -24,9 +24,10 @@ public class RegressionConfigTests extends AbstractSerializingTestCase<Regressio
     }
 
     public void testFromMap() {
-        RegressionConfig expected = new RegressionConfig("foo");
+        RegressionConfig expected = new RegressionConfig("foo", 3);
         Map<String, Object> config = new HashMap<>(){{
             put(RegressionConfig.RESULTS_FIELD.getPreferredName(), "foo");
+            put(RegressionConfig.NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), 3);
         }};
         assertThat(RegressionConfig.fromMap(config), equalTo(expected));
     }

+ 155 - 19
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java

@@ -22,6 +22,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests;
+import org.elasticsearch.xpack.core.ml.job.config.Operator;
 import org.junit.Before;
 import java.io.IOException;
 import java.util.ArrayList;
@@ -39,6 +40,7 @@ import static org.hamcrest.Matchers.closeTo;
 import static org.hamcrest.Matchers.equalTo;
 
 public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
+    private final double eps = 1.0E-8;
 
     private boolean lenient;
 
@@ -267,7 +269,8 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
         List<Double> scores   = Arrays.asList(0.230557435, 0.162032651);
         double eps = 0.000001;
         List<ClassificationInferenceResults.TopClassEntry> probabilities =
-            ((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses();
+            ((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap()))
+                .getTopClasses();
         for(int i = 0; i < expected.size(); i++) {
             assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps));
             assertThat(probabilities.get(i).getScore(), closeTo(scores.get(i), eps));
@@ -278,7 +281,8 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
         expected = Arrays.asList(0.310025518, 0.6899744811);
         scores   = Arrays.asList(0.217017863, 0.2069923443);
         probabilities =
-            ((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses();
+            ((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap()))
+                .getTopClasses();
         for(int i = 0; i < expected.size(); i++) {
             assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps));
             assertThat(probabilities.get(i).getScore(), closeTo(scores.get(i), eps));
@@ -289,7 +293,8 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
         expected = Arrays.asList(0.768524783, 0.231475216);
         scores   = Arrays.asList(0.230557435, 0.162032651);
         probabilities =
-            ((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses();
+            ((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap()))
+                .getTopClasses();
         for(int i = 0; i < expected.size(); i++) {
             assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps));
             assertThat(probabilities.get(i).getScore(), closeTo(scores.get(i), eps));
@@ -303,7 +308,8 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
         expected = Arrays.asList(0.6899744811, 0.3100255188);
         scores   = Arrays.asList(0.482982136, 0.0930076556);
         probabilities =
-            ((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses();
+            ((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap()))
+                .getTopClasses();
         for(int i = 0; i < expected.size(); i++) {
             assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps));
             assertThat(probabilities.get(i).getScore(), closeTo(scores.get(i), eps));
@@ -361,24 +367,28 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
         List<Double> featureVector = Arrays.asList(0.4, 0.0);
         Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
         assertThat(1.0,
-            closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001));
+            closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
+                0.00001));
 
         featureVector = Arrays.asList(2.0, 0.7);
         featureMap = zipObjMap(featureNames, featureVector);
         assertThat(1.0,
-            closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001));
+            closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
+                0.00001));
 
         featureVector = Arrays.asList(0.0, 1.0);
         featureMap = zipObjMap(featureNames, featureVector);
         assertThat(1.0,
-            closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001));
+            closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
+                0.00001));
 
         featureMap = new HashMap<>(2) {{
             put("foo", 0.3);
             put("bar", null);
         }};
         assertThat(0.0,
-            closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001));
+            closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
+                0.00001));
     }
 
     public void testMultiClassClassificationInference() {
@@ -432,24 +442,28 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
         List<Double> featureVector = Arrays.asList(0.4, 0.0);
         Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
         assertThat(2.0,
-            closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001));
+            closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
+                0.00001));
 
         featureVector = Arrays.asList(2.0, 0.7);
         featureMap = zipObjMap(featureNames, featureVector);
         assertThat(1.0,
-            closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001));
+            closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
+                0.00001));
 
         featureVector = Arrays.asList(0.0, 1.0);
         featureMap = zipObjMap(featureNames, featureVector);
         assertThat(1.0,
-            closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001));
+            closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
+                0.00001));
 
         featureMap = new HashMap<>(2) {{
             put("foo", 0.6);
             put("bar", null);
         }};
         assertThat(1.0,
-            closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001));
+            closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
+                0.00001));
     }
 
     public void testRegressionInference() {
@@ -489,12 +503,16 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
         List<Double> featureVector = Arrays.asList(0.4, 0.0);
         Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
         assertThat(0.9,
-            closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
+            closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap()))
+                .value(),
+                0.00001));
 
         featureVector = Arrays.asList(2.0, 0.7);
         featureMap = zipObjMap(featureNames, featureVector);
         assertThat(0.5,
-            closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
+            closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap()))
+                .value(),
+                0.00001));
 
         // Test with NO aggregator supplied, verifies default behavior of non-weighted sum
         ensemble = Ensemble.builder()
@@ -506,19 +524,25 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
         featureVector = Arrays.asList(0.4, 0.0);
         featureMap = zipObjMap(featureNames, featureVector);
         assertThat(1.8,
-            closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
+            closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap()))
+                .value(),
+                0.00001));
 
         featureVector = Arrays.asList(2.0, 0.7);
         featureMap = zipObjMap(featureNames, featureVector);
         assertThat(1.0,
-            closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
+            closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap()))
+                .value(),
+                0.00001));
 
         featureMap = new HashMap<>(2) {{
             put("foo", 0.3);
             put("bar", null);
         }};
         assertThat(1.8,
-            closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
+            closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap()))
+                .value(),
+                0.00001));
     }
 
     public void testInferNestedFields() {
@@ -564,7 +588,9 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
             }});
         }};
         assertThat(0.9,
-            closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
+            closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap()))
+                .value(),
+                0.00001));
 
         featureMap = new HashMap<>() {{
             put("foo", new HashMap<>(){{
@@ -575,7 +601,9 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
             }});
         }};
         assertThat(0.5,
-            closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
+            closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap()))
+                .value(),
+                0.00001));
     }
 
     public void testOperationsEstimations() {
@@ -590,6 +618,114 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
         assertThat(ensemble.estimatedNumOperations(), equalTo(9L));
     }
 
+    public void testFeatureImportance() {
+        List<String> featureNames = Arrays.asList("foo", "bar");
+        Tree tree1 = Tree.builder()
+            .setFeatureNames(featureNames)
+            .setNodes(
+                TreeNode.builder(0)
+                    .setSplitFeature(0)
+                    .setOperator(Operator.LT)
+                    .setLeftChild(1)
+                    .setRightChild(2)
+                    .setThreshold(0.55)
+                    .setNumberSamples(10L),
+                TreeNode.builder(1)
+                    .setSplitFeature(0)
+                    .setLeftChild(3)
+                    .setRightChild(4)
+                    .setOperator(Operator.LT)
+                    .setThreshold(0.41)
+                    .setNumberSamples(6L),
+                TreeNode.builder(2)
+                    .setSplitFeature(1)
+                    .setLeftChild(5)
+                    .setRightChild(6)
+                    .setOperator(Operator.LT)
+                    .setThreshold(0.25)
+                    .setNumberSamples(4L),
+                TreeNode.builder(3).setLeafValue(1.18230136).setNumberSamples(5L),
+                TreeNode.builder(4).setLeafValue(1.98006658).setNumberSamples(1L),
+                TreeNode.builder(5).setLeafValue(3.25350885).setNumberSamples(3L),
+                TreeNode.builder(6).setLeafValue(2.42384369).setNumberSamples(1L)).build();
+
+        Tree tree2 = Tree.builder()
+            .setFeatureNames(featureNames)
+            .setNodes(
+                TreeNode.builder(0)
+                    .setSplitFeature(0)
+                    .setOperator(Operator.LT)
+                    .setLeftChild(1)
+                    .setRightChild(2)
+                    .setThreshold(0.45)
+                    .setNumberSamples(10L),
+                TreeNode.builder(1)
+                    .setSplitFeature(0)
+                    .setLeftChild(3)
+                    .setRightChild(4)
+                    .setOperator(Operator.LT)
+                    .setThreshold(0.25)
+                    .setNumberSamples(5L),
+                TreeNode.builder(2)
+                    .setSplitFeature(0)
+                    .setLeftChild(5)
+                    .setRightChild(6)
+                    .setOperator(Operator.LT)
+                    .setThreshold(0.59)
+                    .setNumberSamples(5L),
+                TreeNode.builder(3).setLeafValue(1.04476388).setNumberSamples(3L),
+                TreeNode.builder(4).setLeafValue(1.52799228).setNumberSamples(2L),
+                TreeNode.builder(5).setLeafValue(1.98006658).setNumberSamples(1L),
+                TreeNode.builder(6).setLeafValue(2.950216).setNumberSamples(4L)).build();
+
+        Ensemble ensemble = Ensemble.builder().setOutputAggregator(new WeightedSum())
+            .setTrainedModels(Arrays.asList(tree1, tree2))
+            .setFeatureNames(featureNames)
+            .build();
+
+
+        Map<String, Double> featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.0, 0.9)));
+        assertThat(featureImportance.get("foo"), closeTo(-1.653200025, eps));
+        assertThat(featureImportance.get("bar"), closeTo( -0.12444978, eps));
+
+        featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.1, 0.8)));
+        assertThat(featureImportance.get("foo"), closeTo(-1.653200025, eps));
+        assertThat(featureImportance.get("bar"), closeTo( -0.12444978, eps));
+
+        featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.2, 0.7)));
+        assertThat(featureImportance.get("foo"), closeTo(-1.653200025, eps));
+        assertThat(featureImportance.get("bar"), closeTo( -0.12444978, eps));
+
+        featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.3, 0.6)));
+        assertThat(featureImportance.get("foo"), closeTo(-1.16997162, eps));
+        assertThat(featureImportance.get("bar"), closeTo( -0.12444978, eps));
+
+        featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.4, 0.5)));
+        assertThat(featureImportance.get("foo"), closeTo(-1.16997162, eps));
+        assertThat(featureImportance.get("bar"), closeTo( -0.12444978, eps));
+
+        featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.5, 0.4)));
+        assertThat(featureImportance.get("foo"), closeTo(0.0798679, eps));
+        assertThat(featureImportance.get("bar"), closeTo( -0.12444978, eps));
+
+        featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.6, 0.3)));
+        assertThat(featureImportance.get("foo"), closeTo(1.80491886, eps));
+        assertThat(featureImportance.get("bar"), closeTo(-0.4355742, eps));
+
+        featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.7, 0.2)));
+        assertThat(featureImportance.get("foo"), closeTo(2.0538184, eps));
+        assertThat(featureImportance.get("bar"), closeTo(0.1451914, eps));
+
+        featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.8, 0.1)));
+        assertThat(featureImportance.get("foo"), closeTo(2.0538184, eps));
+        assertThat(featureImportance.get("bar"), closeTo(0.1451914, eps));
+
+        featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.9, 0.0)));
+        assertThat(featureImportance.get("foo"), closeTo(2.0538184, eps));
+        assertThat(featureImportance.get("bar"), closeTo(0.1451914, eps));
+    }
+
+
     private static Map<String, Object> zipObjMap(List<String> keys, List<Double> values) {
         return IntStream.range(0, keys.size()).boxed().collect(Collectors.toMap(keys::get, values::get));
     }

+ 75 - 12
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java

@@ -15,6 +15,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceRes
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
+import org.elasticsearch.xpack.core.ml.job.config.Operator;
 import org.junit.Before;
 
 import java.io.IOException;
@@ -35,6 +36,7 @@ import static org.hamcrest.Matchers.equalTo;
 
 public class TreeTests extends AbstractSerializingTestCase<Tree> {
 
+    private final double eps = 1.0E-8;
     private boolean lenient;
 
     @Before
@@ -118,7 +120,8 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
         List<Double> featureVector = Arrays.asList(0.6, 0.0);
         Map<String, Object> featureMap = zipObjMap(featureNames, featureVector); // does not really matter as this is a stump
         assertThat(42.0,
-            closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
+            closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(),
+                0.00001));
     }
 
     public void testInfer() {
@@ -138,27 +141,31 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
         List<Double> featureVector = Arrays.asList(0.6, 0.0);
         Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
         assertThat(0.3,
-            closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
+            closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(),
+                0.00001));
 
         // This should hit the left child of the left child of the root node
         // i.e. it takes the path left, left
         featureVector = Arrays.asList(0.3, 0.7);
         featureMap = zipObjMap(featureNames, featureVector);
         assertThat(0.1,
-            closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
+            closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(),
+                0.00001));
 
         // This should hit the right child of the left child of the root node
         // i.e. it takes the path left, right
         featureVector = Arrays.asList(0.3, 0.9);
         featureMap = zipObjMap(featureNames, featureVector);
         assertThat(0.2,
-            closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
+            closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(),
+                0.00001));
 
         // This should still work if the internal values are strings
         List<String> featureVectorStrings = Arrays.asList("0.3", "0.9");
         featureMap = zipObjMap(featureNames, featureVectorStrings);
         assertThat(0.2,
-            closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
+            closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(),
+                0.00001));
 
         // This should handle missing values and take the default_left path
         featureMap = new HashMap<>(2) {{
@@ -166,7 +173,8 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
             put("bar", null);
         }};
         assertThat(0.1,
-            closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
+            closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(),
+                0.00001));
     }
 
     public void testInferNestedFields() {
@@ -192,7 +200,8 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
             }});
         }};
         assertThat(0.3,
-            closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
+            closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(),
+                0.00001));
 
         // This should hit the left child of the left child of the root node
         // i.e. it takes the path left, left
@@ -205,7 +214,8 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
             }});
         }};
         assertThat(0.1,
-            closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
+            closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(),
+                0.00001));
 
         // This should hit the right child of the left child of the root node
         // i.e. it takes the path left, right
@@ -218,7 +228,8 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
             }});
         }};
         assertThat(0.2,
-            closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
+            closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap())).value(),
+                0.00001));
     }
 
     public void testTreeClassificationProbability() {
@@ -241,7 +252,8 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
         List<String> expectedFields = Arrays.asList("dog", "cat");
         Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
         List<ClassificationInferenceResults.TopClassEntry> probabilities =
-            ((ClassificationInferenceResults)tree.infer(featureMap, new ClassificationConfig(2))).getTopClasses();
+            ((ClassificationInferenceResults)tree.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap()))
+                .getTopClasses();
         for(int i = 0; i < expectedProbs.size(); i++) {
             assertThat(probabilities.get(i).getProbability(), closeTo(expectedProbs.get(i), eps));
             assertThat(probabilities.get(i).getClassification(), equalTo(expectedFields.get(i)));
@@ -252,7 +264,8 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
         featureVector = Arrays.asList(0.3, 0.7);
         featureMap = zipObjMap(featureNames, featureVector);
         probabilities =
-            ((ClassificationInferenceResults)tree.infer(featureMap, new ClassificationConfig(2))).getTopClasses();
+            ((ClassificationInferenceResults)tree.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap()))
+                .getTopClasses();
         for(int i = 0; i < expectedProbs.size(); i++) {
             assertThat(probabilities.get(i).getProbability(), closeTo(expectedProbs.get(i), eps));
             assertThat(probabilities.get(i).getClassification(), equalTo(expectedFields.get(i)));
@@ -264,7 +277,8 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
             put("bar", null);
         }};
         probabilities =
-            ((ClassificationInferenceResults)tree.infer(featureMap, new ClassificationConfig(2))).getTopClasses();
+            ((ClassificationInferenceResults)tree.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap()))
+                .getTopClasses();
         for(int i = 0; i < expectedProbs.size(); i++) {
             assertThat(probabilities.get(i).getProbability(), closeTo(expectedProbs.get(i), eps));
             assertThat(probabilities.get(i).getClassification(), equalTo(expectedFields.get(i)));
@@ -345,6 +359,55 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
         assertThat(tree.estimatedNumOperations(), equalTo(7L));
     }
 
+    public void testFeatureImportance() {
+        List<String> featureNames = Arrays.asList("foo", "bar");
+        Tree tree = Tree.builder()
+            .setFeatureNames(featureNames)
+            .setNodes(
+                TreeNode.builder(0)
+                    .setSplitFeature(0)
+                    .setOperator(Operator.LT)
+                    .setLeftChild(1)
+                    .setRightChild(2)
+                    .setThreshold(0.5)
+                    .setNumberSamples(4L),
+                TreeNode.builder(1)
+                    .setSplitFeature(1)
+                    .setLeftChild(3)
+                    .setRightChild(4)
+                    .setOperator(Operator.LT)
+                    .setThreshold(0.5)
+                    .setNumberSamples(2L),
+                TreeNode.builder(2)
+                    .setSplitFeature(1)
+                    .setLeftChild(5)
+                    .setRightChild(6)
+                    .setOperator(Operator.LT)
+                    .setThreshold(0.5)
+                    .setNumberSamples(2L),
+                TreeNode.builder(3).setLeafValue(3.0).setNumberSamples(1L),
+                TreeNode.builder(4).setLeafValue(8.0).setNumberSamples(1L),
+                TreeNode.builder(5).setLeafValue(13.0).setNumberSamples(1L),
+                TreeNode.builder(6).setLeafValue(18.0).setNumberSamples(1L)).build();
+
+        Map<String, Double> featureImportance = tree.featureImportance(zipObjMap(featureNames, Arrays.asList(0.25, 0.25)),
+            Collections.emptyMap());
+        assertThat(featureImportance.get("foo"), closeTo(-5.0, eps));
+        assertThat(featureImportance.get("bar"), closeTo(-2.5, eps));
+
+        featureImportance = tree.featureImportance(zipObjMap(featureNames, Arrays.asList(0.25, 0.75)), Collections.emptyMap());
+        assertThat(featureImportance.get("foo"), closeTo(-5.0, eps));
+        assertThat(featureImportance.get("bar"), closeTo(2.5, eps));
+
+        featureImportance = tree.featureImportance(zipObjMap(featureNames, Arrays.asList(0.75, 0.25)), Collections.emptyMap());
+        assertThat(featureImportance.get("foo"), closeTo(5.0, eps));
+        assertThat(featureImportance.get("bar"), closeTo(-2.5, eps));
+
+        featureImportance = tree.featureImportance(zipObjMap(featureNames, Arrays.asList(0.75, 0.75)), Collections.emptyMap());
+        assertThat(featureImportance.get("foo"), closeTo(5.0, eps));
+        assertThat(featureImportance.get("bar"), closeTo(2.5, eps));
+    }
+
     public void testMaxFeatureIndex() {
 
         int numFeatures = randomIntBetween(1, 15);

+ 18 - 1
x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java

@@ -115,7 +115,10 @@ public class InferenceIngestIT extends ESRestTestCase {
             "        \"inference\": {\n" +
             "          \"target_field\": \"ml.classification\",\n" +
             "          \"inference_config\": {\"classification\": " +
-            "                {\"num_top_classes\":2, \"top_classes_results_field\": \"result_class_prob\"}},\n" +
+            "                {\"num_top_classes\":2, " +
+            "                \"top_classes_results_field\": \"result_class_prob\"," +
+            "                \"num_top_feature_importance_values\": 2" +
+            "          }},\n" +
             "          \"model_id\": \"test_classification\",\n" +
             "          \"field_mappings\": {\n" +
             "            \"col1\": \"col1\",\n" +
@@ -153,6 +156,8 @@ public class InferenceIngestIT extends ESRestTestCase {
         String responseString = EntityUtils.toString(response.getEntity());
         assertThat(responseString, containsString("\"predicted_value\":\"second\""));
         assertThat(responseString, containsString("\"predicted_value\":1.0"));
+        assertThat(responseString, containsString("\"col2\":0.944"));
+        assertThat(responseString, containsString("\"col1\":0.19999"));
 
         String sourceWithMissingModel = "{\n" +
             "  \"pipeline\": {\n" +
@@ -321,16 +326,19 @@ public class InferenceIngestIT extends ESRestTestCase {
         "                \"split_gain\": 12.0,\n" +
         "                \"threshold\": 10.0,\n" +
         "                \"decision_type\": \"lte\",\n" +
+        "                \"number_samples\": 300,\n" +
         "                \"default_left\": true,\n" +
         "                \"left_child\": 1,\n" +
         "                \"right_child\": 2\n" +
         "              },\n" +
         "              {\n" +
         "                \"node_index\": 1,\n" +
+        "                \"number_samples\": 100,\n" +
         "                \"leaf_value\": 1\n" +
         "              },\n" +
         "              {\n" +
         "                \"node_index\": 2,\n" +
+        "                \"number_samples\": 200,\n" +
         "                \"leaf_value\": 2\n" +
         "              }\n" +
         "            ],\n" +
@@ -352,15 +360,18 @@ public class InferenceIngestIT extends ESRestTestCase {
         "                \"threshold\": 10.0,\n" +
         "                \"decision_type\": \"lte\",\n" +
         "                \"default_left\": true,\n" +
+        "                \"number_samples\": 150,\n" +
         "                \"left_child\": 1,\n" +
         "                \"right_child\": 2\n" +
         "              },\n" +
         "              {\n" +
         "                \"node_index\": 1,\n" +
+        "                \"number_samples\": 50,\n" +
         "                \"leaf_value\": 1\n" +
         "              },\n" +
         "              {\n" +
         "                \"node_index\": 2,\n" +
+        "                \"number_samples\": 100,\n" +
         "                \"leaf_value\": 2\n" +
         "              }\n" +
         "            ],\n" +
@@ -445,6 +456,7 @@ public class InferenceIngestIT extends ESRestTestCase {
         "              {\n" +
         "                \"node_index\": 0,\n" +
         "                \"split_feature\": 0,\n" +
+        "                \"number_samples\": 100,\n" +
         "                \"split_gain\": 12.0,\n" +
         "                \"threshold\": 10.0,\n" +
         "                \"decision_type\": \"lte\",\n" +
@@ -454,10 +466,12 @@ public class InferenceIngestIT extends ESRestTestCase {
         "              },\n" +
         "              {\n" +
         "                \"node_index\": 1,\n" +
+        "                \"number_samples\": 80,\n" +
         "                \"leaf_value\": 1\n" +
         "              },\n" +
         "              {\n" +
         "                \"node_index\": 2,\n" +
+        "                \"number_samples\": 20,\n" +
         "                \"leaf_value\": 0\n" +
         "              }\n" +
         "            ],\n" +
@@ -476,6 +490,7 @@ public class InferenceIngestIT extends ESRestTestCase {
         "                \"node_index\": 0,\n" +
         "                \"split_feature\": 0,\n" +
         "                \"split_gain\": 12.0,\n" +
+        "                \"number_samples\": 180,\n" +
         "                \"threshold\": 10.0,\n" +
         "                \"decision_type\": \"lte\",\n" +
         "                \"default_left\": true,\n" +
@@ -484,10 +499,12 @@ public class InferenceIngestIT extends ESRestTestCase {
         "              },\n" +
         "              {\n" +
         "                \"node_index\": 1,\n" +
+        "                \"number_samples\": 10,\n" +
         "                \"leaf_value\": 1\n" +
         "              },\n" +
         "              {\n" +
         "                \"node_index\": 2,\n" +
+        "                \"number_samples\": 170,\n" +
         "                \"leaf_value\": 0\n" +
         "              }\n" +
         "            ],\n" +

+ 65 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java

@@ -102,6 +102,43 @@ public class InferenceProcessorTests extends ESTestCase {
         assertThat(document.getFieldValue("ml.my_processor.predicted_value", String.class), equalTo("foo"));
     }
 
+    public void testMutateDocumentClassificationFeatureInfluence() {
+        ClassificationConfig classificationConfig = new ClassificationConfig(2, null, null, 2);
+        InferenceProcessor inferenceProcessor = new InferenceProcessor(client,
+            auditor,
+            "my_processor",
+            "ml.my_processor",
+            "classification_model",
+            classificationConfig,
+            Collections.emptyMap());
+
+        Map<String, Object> source = new HashMap<>();
+        Map<String, Object> ingestMetadata = new HashMap<>();
+        IngestDocument document = new IngestDocument(source, ingestMetadata);
+
+        List<ClassificationInferenceResults.TopClassEntry> classes = new ArrayList<>(2);
+        classes.add(new ClassificationInferenceResults.TopClassEntry("foo", 0.6));
+        classes.add(new ClassificationInferenceResults.TopClassEntry("bar", 0.4));
+
+        Map<String, Double> featureInfluence = new HashMap<>();
+        featureInfluence.put("feature_1", 1.13);
+        featureInfluence.put("feature_2", -42.0);
+
+        InternalInferModelAction.Response response = new InternalInferModelAction.Response(
+            Collections.singletonList(new ClassificationInferenceResults(1.0,
+                "foo",
+                classes,
+                featureInfluence,
+                classificationConfig)),
+            true);
+        inferenceProcessor.mutateDocument(response, document);
+
+        assertThat(document.getFieldValue("ml.my_processor.model_id", String.class), equalTo("classification_model"));
+        assertThat(document.getFieldValue("ml.my_processor.predicted_value", String.class), equalTo("foo"));
+        assertThat(document.getFieldValue("ml.my_processor.feature_importance.feature_1", Double.class), equalTo(1.13));
+        assertThat(document.getFieldValue("ml.my_processor.feature_importance.feature_2", Double.class), equalTo(-42.0));
+    }
+
     @SuppressWarnings("unchecked")
     public void testMutateDocumentClassificationTopNClassesWithSpecificField() {
         ClassificationConfig classificationConfig = new ClassificationConfig(2, "result", "tops");
@@ -154,6 +191,34 @@ public class InferenceProcessorTests extends ESTestCase {
         assertThat(document.getFieldValue("ml.my_processor.model_id", String.class), equalTo("regression_model"));
     }
 
+    public void testMutateDocumentRegressionWithTopFetures() {
+        RegressionConfig regressionConfig = new RegressionConfig("foo", 2);
+        InferenceProcessor inferenceProcessor = new InferenceProcessor(client,
+            auditor,
+            "my_processor",
+            "ml.my_processor",
+            "regression_model",
+            regressionConfig,
+            Collections.emptyMap());
+
+        Map<String, Object> source = new HashMap<>();
+        Map<String, Object> ingestMetadata = new HashMap<>();
+        IngestDocument document = new IngestDocument(source, ingestMetadata);
+
+        Map<String, Double> featureInfluence = new HashMap<>();
+        featureInfluence.put("feature_1", 1.13);
+        featureInfluence.put("feature_2", -42.0);
+
+        InternalInferModelAction.Response response = new InternalInferModelAction.Response(
+            Collections.singletonList(new RegressionInferenceResults(0.7, regressionConfig, featureInfluence)), true);
+        inferenceProcessor.mutateDocument(response, document);
+
+        assertThat(document.getFieldValue("ml.my_processor.foo", Double.class), equalTo(0.7));
+        assertThat(document.getFieldValue("ml.my_processor.model_id", String.class), equalTo("regression_model"));
+        assertThat(document.getFieldValue("ml.my_processor.feature_importance.feature_1", Double.class), equalTo(1.13));
+        assertThat(document.getFieldValue("ml.my_processor.feature_importance.feature_2", Double.class), equalTo(-42.0));
+    }
+
     public void testGenerateRequestWithEmptyMapping() {
         String modelId = "model";
         Integer topNClasses = randomBoolean() ? null : randomIntBetween(1, 10);