Browse Source

[ML] binary classification per-class feature importance for model inference (#61597)

This commit addresses two issues:

- per class feature importance is now written out for binary classification (logistic regression)
- The `class_name` in per class feature importance now matches what is written in the `top_classes` array.
Benjamin Trent 5 years ago
parent
commit
e6c3481b76

+ 12 - 6
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/FeatureImportance.java

@@ -39,6 +39,12 @@ public class FeatureImportance implements Writeable, ToXContentObject {
         return new FeatureImportance(featureName, importance, null);
     }
 
+    public static FeatureImportance forBinaryClassification(String featureName, double importance, List<ClassImportance> classImportance) {
+        return new FeatureImportance(featureName,
+            importance,
+            classImportance);
+    }
+
     public static FeatureImportance forClassification(String featureName, List<ClassImportance> classImportance) {
         return new FeatureImportance(featureName,
             classImportance.stream().mapToDouble(ClassImportance::getImportance).map(Math::abs).sum(),
@@ -170,27 +176,27 @@ public class FeatureImportance implements Writeable, ToXContentObject {
         }
 
         private static Map<String, Double> toMap(List<ClassImportance> importances) {
-            return importances.stream().collect(Collectors.toMap(i -> i.className, i -> i.importance));
+            return importances.stream().collect(Collectors.toMap(i -> i.className.toString(), i -> i.importance));
         }
 
         public static ClassImportance fromXContent(XContentParser parser) {
             return PARSER.apply(parser, null);
         }
 
-        private final String className;
+        private final Object className;
         private final double importance;
 
-        public ClassImportance(String className, double importance) {
+        public ClassImportance(Object className, double importance) {
             this.className = className;
             this.importance = importance;
         }
 
         public ClassImportance(StreamInput in) throws IOException {
-            this.className = in.readString();
+            this.className = in.readGenericValue();
             this.importance = in.readDouble();
         }
 
-        public String getClassName() {
+        public Object getClassName() {
             return className;
         }
 
@@ -207,7 +213,7 @@ public class FeatureImportance implements Writeable, ToXContentObject {
 
         @Override
         public void writeTo(StreamOutput out) throws IOException {
-            out.writeString(className);
+            out.writeGenericValue(className);
             out.writeDouble(importance);
         }
 

+ 31 - 5
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java

@@ -12,6 +12,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collections;
 import java.util.Comparator;
 import java.util.HashMap;
@@ -129,21 +130,46 @@ public final class InferenceHelpers {
         return originalFeatureImportance;
     }
 
-    public static List<FeatureImportance> transformFeatureImportance(Map<String, double[]> featureImportance,
-                                                                     @Nullable List<String> classificationLabels) {
+    public static List<FeatureImportance> transformFeatureImportanceRegression(Map<String, double[]> featureImportance) {
         List<FeatureImportance> importances = new ArrayList<>(featureImportance.size());
+        featureImportance.forEach((k, v) -> importances.add(FeatureImportance.forRegression(k, v[0])));
+        return importances;
+    }
+
+    public static List<FeatureImportance> transformFeatureImportanceClassification(Map<String, double[]> featureImportance,
+                                                                                   final int predictedValue,
+                                                                                   @Nullable List<String> classificationLabels,
+                                                                                   @Nullable PredictionFieldType predictionFieldType) {
+        List<FeatureImportance> importances = new ArrayList<>(featureImportance.size());
+        final PredictionFieldType fieldType = predictionFieldType == null ? PredictionFieldType.STRING : predictionFieldType;
         featureImportance.forEach((k, v) -> {
-            // This indicates regression, or logistic regression
+            // This indicates logistic regression (binary classification)
             // If the length > 1, we assume multi-class classification.
             if (v.length == 1) {
-                importances.add(FeatureImportance.forRegression(k, v[0]));
+                assert predictedValue == 1 || predictedValue == 0;
+                // If predicted value is `1`, then the other class is `0`
+                // If predicted value is `0`, then the other class is `1`
+                final int otherClass = 1 - predictedValue;
+                String predictedLabel = classificationLabels == null ? null : classificationLabels.get(predictedValue);
+                String otherLabel = classificationLabels == null ? null : classificationLabels.get(otherClass);
+                importances.add(FeatureImportance.forBinaryClassification(k,
+                    v[0],
+                    Arrays.asList(
+                        new FeatureImportance.ClassImportance(
+                            fieldType.transformPredictedValue((double)predictedValue, predictedLabel),
+                            v[0]),
+                        new FeatureImportance.ClassImportance(
+                            fieldType.transformPredictedValue((double)otherClass, otherLabel),
+                            -v[0])
+                    )));
             } else {
                 List<FeatureImportance.ClassImportance> classImportance = new ArrayList<>(v.length);
                 // If the classificationLabels exist, their length must match leaf_value length
                 assert classificationLabels == null || classificationLabels.size() == v.length;
                 for (int i = 0; i < v.length; i++) {
+                    String label = classificationLabels == null ? null : classificationLabels.get(i);
                     classImportance.add(new FeatureImportance.ClassImportance(
-                        classificationLabels == null ? String.valueOf(i) : classificationLabels.get(i),
+                        fieldType.transformPredictedValue((double)i, label),
                         v[i]));
                 }
                 importances.add(FeatureImportance.forClassification(k, classImportance));

+ 21 - 20
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModel.java

@@ -43,7 +43,8 @@ import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optiona
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.classificationLabel;
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.decodeFeatureImportances;
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.sumDoubleArrays;
-import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.transformFeatureImportance;
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.transformFeatureImportanceClassification;
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.transformFeatureImportanceRegression;
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.AGGREGATE_OUTPUT;
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.CLASSIFICATION_LABELS;
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.CLASSIFICATION_WEIGHTS;
@@ -154,14 +155,7 @@ public class EnsembleInferenceModel implements InferenceModel {
             RawInferenceResults inferenceResult = (RawInferenceResults) result;
             inferenceResults[i++] = inferenceResult.getValue();
             if (config.requestingImportance()) {
-                double[][] modelFeatureImportance = inferenceResult.getFeatureImportance();
-                assert modelFeatureImportance.length == featureInfluence.length;
-                for (int j = 0; j < modelFeatureImportance.length; j++) {
-                    if (featureInfluence[j] == null) {
-                        featureInfluence[j] = new double[modelFeatureImportance[j].length];
-                    }
-                    featureInfluence[j] = sumDoubleArrays(featureInfluence[j], modelFeatureImportance[j]);
-                }
+                addFeatureImportance(featureInfluence, inferenceResult);
             }
         }
         double[] processed = outputAggregator.processValues(inferenceResults);
@@ -176,18 +170,22 @@ public class EnsembleInferenceModel implements InferenceModel {
             InferenceResults result = model.infer(features, subModelInferenceConfig);
             assert result instanceof RawInferenceResults;
             RawInferenceResults inferenceResult = (RawInferenceResults) result;
-            double[][] modelFeatureImportance = inferenceResult.getFeatureImportance();
-            assert modelFeatureImportance.length == featureInfluence.length;
-            for (int j = 0; j < modelFeatureImportance.length; j++) {
-                if (featureInfluence[j] == null) {
-                    featureInfluence[j] = new double[modelFeatureImportance[j].length];
-                }
-                featureInfluence[j] = sumDoubleArrays(featureInfluence[j], modelFeatureImportance[j]);
-            }
+            addFeatureImportance(featureInfluence, inferenceResult);
         }
         return featureInfluence;
     }
 
+    private void addFeatureImportance(double[][] featureInfluence, RawInferenceResults inferenceResult) {
+        double[][] modelFeatureImportance = inferenceResult.getFeatureImportance();
+        assert modelFeatureImportance.length == featureInfluence.length;
+        for (int j = 0; j < modelFeatureImportance.length; j++) {
+            if (featureInfluence[j] == null) {
+                featureInfluence[j] = new double[modelFeatureImportance[j].length];
+            }
+            featureInfluence[j] = sumDoubleArrays(featureInfluence[j], modelFeatureImportance[j]);
+        }
+    }
+
     private InferenceResults buildResults(double[] processedInferences,
                                           double[][] featureImportance,
                                           Map<String, String> featureDecoderMap,
@@ -208,7 +206,7 @@ public class EnsembleInferenceModel implements InferenceModel {
             case REGRESSION:
                 return new RegressionInferenceResults(outputAggregator.aggregate(processedInferences),
                     config,
-                    transformFeatureImportance(decodedFeatureImportance, null));
+                    transformFeatureImportanceRegression(decodedFeatureImportance));
             case CLASSIFICATION:
                 ClassificationConfig classificationConfig = (ClassificationConfig) config;
                 assert classificationWeights == null || processedInferences.length == classificationWeights.length;
@@ -220,10 +218,13 @@ public class EnsembleInferenceModel implements InferenceModel {
                     classificationConfig.getNumTopClasses(),
                     classificationConfig.getPredictionFieldType());
                 final InferenceHelpers.TopClassificationValue value = topClasses.v1();
-                return new ClassificationInferenceResults((double)value.getValue(),
+                return new ClassificationInferenceResults(value.getValue(),
                     classificationLabel(topClasses.v1().getValue(), classificationLabels),
                     topClasses.v2(),
-                    transformFeatureImportance(decodedFeatureImportance, classificationLabels),
+                    transformFeatureImportanceClassification(decodedFeatureImportance,
+                        value.getValue(),
+                        classificationLabels,
+                        classificationConfig.getPredictionFieldType()),
                     config,
                     value.getProbability(),
                     value.getScore());

+ 5 - 2
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModel.java

@@ -188,14 +188,17 @@ public class TreeInferenceModel implements InferenceModel {
                 return new ClassificationInferenceResults(classificationValue.getValue(),
                     classificationLabel(classificationValue.getValue(), classificationLabels),
                     topClasses.v2(),
-                    InferenceHelpers.transformFeatureImportance(decodedFeatureImportance, classificationLabels),
+                    InferenceHelpers.transformFeatureImportanceClassification(decodedFeatureImportance,
+                        classificationValue.getValue(),
+                        classificationLabels,
+                        classificationConfig.getPredictionFieldType()),
                     config,
                     classificationValue.getProbability(),
                     classificationValue.getScore());
             case REGRESSION:
                 return new RegressionInferenceResults(value[0],
                     config,
-                    InferenceHelpers.transformFeatureImportance(decodedFeatureImportance, null));
+                    InferenceHelpers.transformFeatureImportanceRegression(decodedFeatureImportance));
             default:
                 throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on tree model");
         }

+ 17 - 6
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TotalFeatureImportance.java

@@ -12,8 +12,10 @@ import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ObjectParser;
 import org.elasticsearch.common.xcontent.ToXContentObject;
 import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParseException;
 import org.elasticsearch.common.xcontent.XContentParser;
 
 import java.io.IOException;
@@ -185,8 +187,17 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable {
         private static ConstructingObjectParser<ClassImportance, Void> createParser(boolean ignoreUnknownFields) {
             ConstructingObjectParser<ClassImportance, Void> parser = new ConstructingObjectParser<>(NAME,
                 ignoreUnknownFields,
-                a -> new ClassImportance((String)a[0], (Importance)a[1]));
-            parser.declareString(ConstructingObjectParser.constructorArg(), CLASS_NAME);
+                a -> new ClassImportance(a[0], (Importance)a[1]));
+            parser.declareField(ConstructingObjectParser.constructorArg(), (p, c) -> {
+                if (p.currentToken() == XContentParser.Token.VALUE_STRING) {
+                    return p.text();
+                } else if (p.currentToken() == XContentParser.Token.VALUE_NUMBER) {
+                    return p.numberValue();
+                } else if (p.currentToken() == XContentParser.Token.VALUE_BOOLEAN) {
+                    return p.booleanValue();
+                }
+                throw new XContentParseException("Unsupported token [" + p.currentToken() + "]");
+            }, CLASS_NAME, ObjectParser.ValueType.VALUE);
             parser.declareObject(ConstructingObjectParser.constructorArg(),
                 ignoreUnknownFields ? Importance.LENIENT_PARSER : Importance.STRICT_PARSER,
                 IMPORTANCE);
@@ -197,22 +208,22 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable {
             return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null);
         }
 
-        public final String className;
+        public final Object className;
         public final Importance importance;
 
         public ClassImportance(StreamInput in) throws IOException {
-            this.className = in.readString();
+            this.className = in.readGenericValue();
             this.importance = new Importance(in);
         }
 
-        ClassImportance(String className, Importance importance) {
+        ClassImportance(Object className, Importance importance) {
             this.className = className;
             this.importance = importance;
         }
 
         @Override
         public void writeTo(StreamOutput out) throws IOException {
-            out.writeString(className);
+            out.writeGenericValue(className);
             importance.writeTo(out);
         }
 

+ 21 - 4
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinitionTests.java

@@ -19,6 +19,7 @@ import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xpack.core.ml.inference.InferenceToXContentCompressor;
 import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
 import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
+import org.elasticsearch.xpack.core.ml.inference.results.FeatureImportance;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
 
 import java.io.IOException;
@@ -158,10 +159,26 @@ public class InferenceDefinitionTests extends ESTestCase {
 
         ClassificationInferenceResults results = (ClassificationInferenceResults) inferenceDefinition.infer(featureMap, config);
         assertThat(results.valueAsString(), equalTo("second"));
-        assertThat(results.getFeatureImportance().get(0).getFeatureName(), equalTo("col2"));
-        assertThat(results.getFeatureImportance().get(0).getImportance(), closeTo(0.944, 0.001));
-        assertThat(results.getFeatureImportance().get(1).getFeatureName(), equalTo("col1_male"));
-        assertThat(results.getFeatureImportance().get(1).getImportance(), closeTo(0.199, 0.001));
+        FeatureImportance featureImportance1 = results.getFeatureImportance().get(0);
+        assertThat(featureImportance1.getFeatureName(), equalTo("col2"));
+        assertThat(featureImportance1.getImportance(), closeTo(0.944, 0.001));
+        for (FeatureImportance.ClassImportance classImportance : featureImportance1.getClassImportance()) {
+            if (classImportance.getClassName().equals("second")) {
+                assertThat(classImportance.getImportance(), closeTo(0.944, 0.001));
+            } else {
+                assertThat(classImportance.getImportance(), closeTo(-0.944, 0.001));
+            }
+        }
+        FeatureImportance featureImportance2 = results.getFeatureImportance().get(1);
+        assertThat(featureImportance2.getFeatureName(), equalTo("col1_male"));
+        assertThat(featureImportance2.getImportance(), closeTo(0.199, 0.001));
+        for (FeatureImportance.ClassImportance classImportance : featureImportance2.getClassImportance()) {
+            if (classImportance.getClassName().equals("second")) {
+                assertThat(classImportance.getImportance(), closeTo(0.199, 0.001));
+            } else {
+                assertThat(classImportance.getImportance(), closeTo(-0.199, 0.001));
+            }
+        }
     }
 
     public static String getClassificationDefinition(boolean customPreprocessor) {